Complete rebrand for v1.0-simple-search branch: Directory Changes: - claude_rag/ → mini_rag/ (preserving git history) Content Changes: - Updated all imports: from claude_rag → from mini_rag - Updated all file paths: .claude-rag → .mini-rag - Updated documentation and comments - Updated configuration files and examples - Updated all tests to use mini_rag imports This ensures complete independence from Claude/Anthropic branding while maintaining all functionality and git history. Simple branch contains the basic RAG system without LLM features.
411 lines
14 KiB
Python
411 lines
14 KiB
Python
"""
|
||
Persistent RAG server that keeps models loaded in memory.
|
||
No more loading/unloading madness!
|
||
"""
|
||
|
||
import json
|
||
import socket
|
||
import threading
|
||
import time
|
||
import subprocess
|
||
from pathlib import Path
|
||
from typing import Dict, Any, Optional
|
||
import logging
|
||
import sys
|
||
import os
|
||
|
||
# Fix Windows console
|
||
if sys.platform == 'win32':
|
||
os.environ['PYTHONUTF8'] = '1'
|
||
|
||
from .search import CodeSearcher
|
||
from .ollama_embeddings import OllamaEmbedder as CodeEmbedder
|
||
from .performance import PerformanceMonitor
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RAGServer:
|
||
"""Persistent server that keeps embeddings and DB loaded."""
|
||
|
||
def __init__(self, project_path: Path, port: int = 7777):
|
||
self.project_path = project_path
|
||
self.port = port
|
||
self.searcher = None
|
||
self.embedder = None
|
||
self.running = False
|
||
self.socket = None
|
||
self.start_time = None
|
||
self.query_count = 0
|
||
|
||
def _kill_existing_server(self):
|
||
"""Kill any existing process using our port."""
|
||
try:
|
||
# Check if port is in use
|
||
test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
result = test_sock.connect_ex(('localhost', self.port))
|
||
test_sock.close()
|
||
|
||
if result == 0: # Port is in use
|
||
print(f"️ Port {self.port} is already in use, attempting to free it...")
|
||
|
||
if sys.platform == 'win32':
|
||
# Windows: Find and kill process using netstat
|
||
import subprocess
|
||
try:
|
||
# Get process ID using the port
|
||
result = subprocess.run(
|
||
['netstat', '-ano'],
|
||
capture_output=True,
|
||
text=True
|
||
)
|
||
|
||
for line in result.stdout.split('\n'):
|
||
if f':{self.port}' in line and 'LISTENING' in line:
|
||
parts = line.split()
|
||
pid = parts[-1]
|
||
print(f" Found process {pid} using port {self.port}")
|
||
|
||
# Kill the process
|
||
subprocess.run(['taskkill', '//PID', pid, '//F'], check=False)
|
||
print(f" Killed process {pid}")
|
||
time.sleep(1) # Give it a moment to release the port
|
||
break
|
||
except Exception as e:
|
||
print(f" ️ Could not auto-kill process: {e}")
|
||
else:
|
||
# Unix/Linux: Use lsof and kill
|
||
import subprocess
|
||
try:
|
||
result = subprocess.run(
|
||
['lsof', '-ti', f':{self.port}'],
|
||
capture_output=True,
|
||
text=True
|
||
)
|
||
if result.stdout.strip():
|
||
pid = result.stdout.strip()
|
||
subprocess.run(['kill', '-9', pid], check=False)
|
||
print(f" Killed process {pid}")
|
||
time.sleep(1)
|
||
except Exception as e:
|
||
print(f" ️ Could not auto-kill process: {e}")
|
||
except Exception as e:
|
||
# Non-critical error, just log it
|
||
logger.debug(f"Error checking port: {e}")
|
||
|
||
def start(self):
|
||
"""Start the RAG server."""
|
||
# Kill any existing process on our port first
|
||
self._kill_existing_server()
|
||
|
||
print(f" Starting RAG server on port {self.port}...")
|
||
|
||
# Load everything once
|
||
perf = PerformanceMonitor()
|
||
|
||
with perf.measure("Load Embedder"):
|
||
self.embedder = CodeEmbedder()
|
||
|
||
with perf.measure("Connect Database"):
|
||
self.searcher = CodeSearcher(self.project_path, embedder=self.embedder)
|
||
|
||
perf.print_summary()
|
||
|
||
# Start server
|
||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
self.socket.bind(('localhost', self.port))
|
||
self.socket.listen(5)
|
||
|
||
self.running = True
|
||
self.start_time = time.time()
|
||
|
||
print(f"\n RAG server ready on localhost:{self.port}")
|
||
print(" Model loaded, database connected")
|
||
print(" Waiting for queries...\n")
|
||
|
||
# Handle connections
|
||
while self.running:
|
||
try:
|
||
client, addr = self.socket.accept()
|
||
thread = threading.Thread(target=self._handle_client, args=(client,))
|
||
thread.daemon = True
|
||
thread.start()
|
||
except KeyboardInterrupt:
|
||
break
|
||
except Exception as e:
|
||
if self.running:
|
||
logger.error(f"Server error: {e}")
|
||
|
||
def _handle_client(self, client: socket.socket):
|
||
"""Handle a client connection."""
|
||
try:
|
||
# Receive query with proper message framing
|
||
data = self._receive_json(client)
|
||
request = json.loads(data)
|
||
|
||
# Check for shutdown command
|
||
if request.get('command') == 'shutdown':
|
||
print("\n Shutdown requested")
|
||
response = {'success': True, 'message': 'Server shutting down'}
|
||
self._send_json(client, response)
|
||
self.stop()
|
||
return
|
||
|
||
query = request.get('query', '')
|
||
top_k = request.get('top_k', 10)
|
||
|
||
self.query_count += 1
|
||
print(f"[Query #{self.query_count}] {query}")
|
||
|
||
# Perform search
|
||
start = time.time()
|
||
results = self.searcher.search(query, top_k=top_k)
|
||
search_time = time.time() - start
|
||
|
||
# Prepare response
|
||
response = {
|
||
'success': True,
|
||
'query': query,
|
||
'count': len(results),
|
||
'search_time_ms': int(search_time * 1000),
|
||
'results': [r.to_dict() for r in results],
|
||
'server_uptime': int(time.time() - self.start_time),
|
||
'total_queries': self.query_count,
|
||
}
|
||
|
||
# Send response with proper framing
|
||
self._send_json(client, response)
|
||
|
||
print(f" Found {len(results)} results in {search_time*1000:.0f}ms")
|
||
|
||
except ConnectionError as e:
|
||
# Normal disconnection - client closed connection
|
||
# This is expected behavior, don't log as error
|
||
pass
|
||
except Exception as e:
|
||
# Only log actual errors, not normal disconnections
|
||
if "Connection closed" not in str(e):
|
||
logger.error(f"Client handler error: {e}")
|
||
error_response = {
|
||
'success': False,
|
||
'error': str(e)
|
||
}
|
||
try:
|
||
self._send_json(client, error_response)
|
||
except:
|
||
pass
|
||
finally:
|
||
client.close()
|
||
|
||
def _receive_json(self, sock: socket.socket) -> str:
|
||
"""Receive a complete JSON message with length prefix."""
|
||
# First receive the length (4 bytes)
|
||
length_data = b''
|
||
while len(length_data) < 4:
|
||
chunk = sock.recv(4 - len(length_data))
|
||
if not chunk:
|
||
raise ConnectionError("Connection closed while receiving length")
|
||
length_data += chunk
|
||
|
||
length = int.from_bytes(length_data, 'big')
|
||
|
||
# Now receive the actual data
|
||
data = b''
|
||
while len(data) < length:
|
||
chunk = sock.recv(min(65536, length - len(data)))
|
||
if not chunk:
|
||
raise ConnectionError("Connection closed while receiving data")
|
||
data += chunk
|
||
|
||
return data.decode('utf-8')
|
||
|
||
def _send_json(self, sock: socket.socket, data: dict):
|
||
"""Send a JSON message with length prefix."""
|
||
# Sanitize the data to ensure JSON compatibility
|
||
json_str = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
|
||
json_bytes = json_str.encode('utf-8')
|
||
|
||
# Send length prefix (4 bytes)
|
||
length = len(json_bytes)
|
||
sock.send(length.to_bytes(4, 'big'))
|
||
|
||
# Send the data
|
||
sock.sendall(json_bytes)
|
||
|
||
def stop(self):
|
||
"""Stop the server."""
|
||
self.running = False
|
||
if self.socket:
|
||
self.socket.close()
|
||
print("\n RAG server stopped")
|
||
|
||
|
||
class RAGClient:
|
||
"""Client to communicate with RAG server."""
|
||
|
||
def __init__(self, port: int = 7777):
|
||
self.port = port
|
||
self.use_legacy = False
|
||
|
||
def search(self, query: str, top_k: int = 10) -> Dict[str, Any]:
|
||
"""Send search query to server."""
|
||
try:
|
||
# Connect to server
|
||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
sock.connect(('localhost', self.port))
|
||
|
||
# Send request with proper framing
|
||
request = {
|
||
'query': query,
|
||
'top_k': top_k
|
||
}
|
||
self._send_json(sock, request)
|
||
|
||
# Receive response with proper framing
|
||
data = self._receive_json(sock)
|
||
response = json.loads(data)
|
||
|
||
sock.close()
|
||
return response
|
||
|
||
except ConnectionRefusedError:
|
||
return {
|
||
'success': False,
|
||
'error': 'RAG server not running. Start with: mini-rag server'
|
||
}
|
||
except ConnectionError as e:
|
||
# Try legacy mode without message framing
|
||
if not self.use_legacy and "receiving length" in str(e):
|
||
self.use_legacy = True
|
||
return self._search_legacy(query, top_k)
|
||
return {
|
||
'success': False,
|
||
'error': str(e)
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'error': str(e)
|
||
}
|
||
|
||
def _receive_json(self, sock: socket.socket) -> str:
|
||
"""Receive a complete JSON message with length prefix."""
|
||
# First receive the length (4 bytes)
|
||
length_data = b''
|
||
while len(length_data) < 4:
|
||
chunk = sock.recv(4 - len(length_data))
|
||
if not chunk:
|
||
raise ConnectionError("Connection closed while receiving length")
|
||
length_data += chunk
|
||
|
||
length = int.from_bytes(length_data, 'big')
|
||
|
||
# Now receive the actual data
|
||
data = b''
|
||
while len(data) < length:
|
||
chunk = sock.recv(min(65536, length - len(data)))
|
||
if not chunk:
|
||
raise ConnectionError("Connection closed while receiving data")
|
||
data += chunk
|
||
|
||
return data.decode('utf-8')
|
||
|
||
def _send_json(self, sock: socket.socket, data: dict):
|
||
"""Send a JSON message with length prefix."""
|
||
json_str = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
|
||
json_bytes = json_str.encode('utf-8')
|
||
|
||
# Send length prefix (4 bytes)
|
||
length = len(json_bytes)
|
||
sock.send(length.to_bytes(4, 'big'))
|
||
|
||
# Send the data
|
||
sock.sendall(json_bytes)
|
||
|
||
def _search_legacy(self, query: str, top_k: int = 10) -> Dict[str, Any]:
|
||
"""Legacy search without message framing for old servers."""
|
||
try:
|
||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
sock.connect(('localhost', self.port))
|
||
|
||
# Send request (old way)
|
||
request = {
|
||
'query': query,
|
||
'top_k': top_k
|
||
}
|
||
sock.send(json.dumps(request).encode('utf-8'))
|
||
|
||
# Receive response (accumulate until we get valid JSON)
|
||
data = b''
|
||
while True:
|
||
chunk = sock.recv(65536)
|
||
if not chunk:
|
||
break
|
||
data += chunk
|
||
try:
|
||
# Try to decode as JSON
|
||
response = json.loads(data.decode('utf-8'))
|
||
sock.close()
|
||
return response
|
||
except json.JSONDecodeError:
|
||
# Keep receiving
|
||
continue
|
||
|
||
sock.close()
|
||
return {
|
||
'success': False,
|
||
'error': 'Incomplete response from server'
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'error': str(e)
|
||
}
|
||
|
||
def is_running(self) -> bool:
|
||
"""Check if server is running."""
|
||
try:
|
||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
result = sock.connect_ex(('localhost', self.port))
|
||
sock.close()
|
||
return result == 0
|
||
except:
|
||
return False
|
||
|
||
|
||
def start_server(project_path: Path, port: int = 7777):
|
||
"""Start the RAG server."""
|
||
server = RAGServer(project_path, port)
|
||
try:
|
||
server.start()
|
||
except KeyboardInterrupt:
|
||
server.stop()
|
||
|
||
|
||
def auto_start_if_needed(project_path: Path) -> Optional[subprocess.Popen]:
|
||
"""Auto-start server if not running."""
|
||
client = RAGClient()
|
||
if not client.is_running():
|
||
# Start server in background
|
||
import subprocess
|
||
cmd = [sys.executable, "-m", "mini_rag.cli", "server", "--path", str(project_path)]
|
||
process = subprocess.Popen(
|
||
cmd,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
creationflags=subprocess.CREATE_NEW_CONSOLE if sys.platform == 'win32' else 0
|
||
)
|
||
|
||
# Wait for server to start
|
||
for _ in range(30): # 30 second timeout
|
||
time.sleep(1)
|
||
if client.is_running():
|
||
print(" RAG server started automatically")
|
||
return process
|
||
|
||
# Failed to start
|
||
process.terminate()
|
||
raise RuntimeError("Failed to start RAG server")
|
||
|
||
return None |