FSSCoding 930f53a0fb Major code quality improvements and structural organization
- Applied Black formatter and isort across entire codebase for professional consistency
- Moved implementation scripts (rag-mini.py, rag-tui.py) to bin/ directory for cleaner root
- Updated shell scripts to reference new bin/ locations maintaining user compatibility
- Added comprehensive linting configuration (.flake8, pyproject.toml) with dedicated .venv-linting
- Removed development artifacts (commit_message.txt, GET_STARTED.md duplicate) from root
- Consolidated documentation and fixed script references across all guides
- Relocated test_fixes.py to proper tests/ directory
- Enhanced project structure following Python packaging standards

All user commands work identically while improving code organization and beginner accessibility.
2025-08-28 15:29:54 +10:00

399 lines
14 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Persistent RAG server that keeps models loaded in memory.
No more loading/unloading madness!
"""
import json
import logging
import os
import socket
import subprocess
import sys
import threading
import time
from pathlib import Path
from typing import Any, Dict, Optional
# Fix Windows console
if sys.platform == "win32":
os.environ["PYTHONUTF8"] = "1"
from .ollama_embeddings import OllamaEmbedder as CodeEmbedder
from .performance import PerformanceMonitor
from .search import CodeSearcher
logger = logging.getLogger(__name__)
class RAGServer:
"""Persistent server that keeps embeddings and DB loaded."""
def __init__(self, project_path: Path, port: int = 7777):
self.project_path = project_path
self.port = port
self.searcher = None
self.embedder = None
self.running = False
self.socket = None
self.start_time = None
self.query_count = 0
def _kill_existing_server(self):
"""Kill any existing process using our port."""
try:
# Check if port is in use
test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = test_sock.connect_ex(("localhost", self.port))
test_sock.close()
if result == 0: # Port is in use
print(f" Port {self.port} is already in use, attempting to free it...")
if sys.platform == "win32":
# Windows: Find and kill process using netstat
import subprocess
try:
# Get process ID using the port
result = subprocess.run(
["netstat", "-ano"], capture_output=True, text=True
)
for line in result.stdout.split("\n"):
if f":{self.port}" in line and "LISTENING" in line:
parts = line.split()
pid = parts[-1]
print(f" Found process {pid} using port {self.port}")
# Kill the process
subprocess.run(["taskkill", "//PID", pid, "//F"], check=False)
print(f" Killed process {pid}")
time.sleep(1) # Give it a moment to release the port
break
except Exception as e:
print(f" Could not auto-kill process: {e}")
else:
# Unix/Linux: Use lsof and kill
import subprocess
try:
result = subprocess.run(
["lso", "-ti", f":{self.port}"],
capture_output=True,
text=True,
)
if result.stdout.strip():
pid = result.stdout.strip()
subprocess.run(["kill", "-9", pid], check=False)
print(f" Killed process {pid}")
time.sleep(1)
except Exception as e:
print(f" Could not auto-kill process: {e}")
except Exception as e:
# Non-critical error, just log it
logger.debug(f"Error checking port: {e}")
def start(self):
"""Start the RAG server."""
# Kill any existing process on our port first
self._kill_existing_server()
print(f" Starting RAG server on port {self.port}...")
# Load everything once
perf = PerformanceMonitor()
with perf.measure("Load Embedder"):
self.embedder = CodeEmbedder()
with perf.measure("Connect Database"):
self.searcher = CodeSearcher(self.project_path, embedder=self.embedder)
perf.print_summary()
# Start server
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(("localhost", self.port))
self.socket.listen(5)
self.running = True
self.start_time = time.time()
print(f"\n RAG server ready on localhost:{self.port}")
print(" Model loaded, database connected")
print(" Waiting for queries...\n")
# Handle connections
while self.running:
try:
client, addr = self.socket.accept()
thread = threading.Thread(target=self._handle_client, args=(client,))
thread.daemon = True
thread.start()
except KeyboardInterrupt:
break
except Exception as e:
if self.running:
logger.error(f"Server error: {e}")
def _handle_client(self, client: socket.socket):
"""Handle a client connection."""
try:
# Receive query with proper message framing
data = self._receive_json(client)
request = json.loads(data)
# Check for shutdown command
if request.get("command") == "shutdown":
print("\n Shutdown requested")
response = {"success": True, "message": "Server shutting down"}
self._send_json(client, response)
self.stop()
return
query = request.get("query", "")
top_k = request.get("top_k", 10)
self.query_count += 1
print(f"[Query #{self.query_count}] {query}")
# Perform search
start = time.time()
results = self.searcher.search(query, top_k=top_k)
search_time = time.time() - start
# Prepare response
response = {
"success": True,
"query": query,
"count": len(results),
"search_time_ms": int(search_time * 1000),
"results": [r.to_dict() for r in results],
"server_uptime": int(time.time() - self.start_time),
"total_queries": self.query_count,
}
# Send response with proper framing
self._send_json(client, response)
print(f" Found {len(results)} results in {search_time*1000:.0f}ms")
except ConnectionError:
# Normal disconnection - client closed connection
# This is expected behavior, don't log as error
pass
except Exception as e:
# Only log actual errors, not normal disconnections
if "Connection closed" not in str(e):
logger.error(f"Client handler error: {e}")
error_response = {"success": False, "error": str(e)}
try:
self._send_json(client, error_response)
except (ConnectionError, OSError, TypeError, ValueError, socket.error):
pass
finally:
client.close()
def _receive_json(self, sock: socket.socket) -> str:
"""Receive a complete JSON message with length prefix."""
# First receive the length (4 bytes)
length_data = b""
while len(length_data) < 4:
chunk = sock.recv(4 - len(length_data))
if not chunk:
raise ConnectionError("Connection closed while receiving length")
length_data += chunk
length = int.from_bytes(length_data, "big")
# Now receive the actual data
data = b""
while len(data) < length:
chunk = sock.recv(min(65536, length - len(data)))
if not chunk:
raise ConnectionError("Connection closed while receiving data")
data += chunk
return data.decode("utf-8")
def _send_json(self, sock: socket.socket, data: dict):
"""Send a JSON message with length prefix."""
# Sanitize the data to ensure JSON compatibility
json_str = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
json_bytes = json_str.encode("utf-8")
# Send length prefix (4 bytes)
length = len(json_bytes)
sock.send(length.to_bytes(4, "big"))
# Send the data
sock.sendall(json_bytes)
def stop(self):
"""Stop the server."""
self.running = False
if self.socket:
self.socket.close()
print("\n RAG server stopped")
class RAGClient:
"""Client to communicate with RAG server."""
def __init__(self, port: int = 7777):
self.port = port
self.use_legacy = False
def search(self, query: str, top_k: int = 10) -> Dict[str, Any]:
"""Send search query to server."""
try:
# Connect to server
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("localhost", self.port))
# Send request with proper framing
request = {"query": query, "top_k": top_k}
self._send_json(sock, request)
# Receive response with proper framing
data = self._receive_json(sock)
response = json.loads(data)
sock.close()
return response
except ConnectionRefusedError:
return {
"success": False,
"error": "RAG server not running. Start with: rag-mini server",
}
except ConnectionError as e:
# Try legacy mode without message framing
if not self.use_legacy and "receiving length" in str(e):
self.use_legacy = True
return self._search_legacy(query, top_k)
return {"success": False, "error": str(e)}
except Exception as e:
return {"success": False, "error": str(e)}
def _receive_json(self, sock: socket.socket) -> str:
"""Receive a complete JSON message with length prefix."""
# First receive the length (4 bytes)
length_data = b""
while len(length_data) < 4:
chunk = sock.recv(4 - len(length_data))
if not chunk:
raise ConnectionError("Connection closed while receiving length")
length_data += chunk
length = int.from_bytes(length_data, "big")
# Now receive the actual data
data = b""
while len(data) < length:
chunk = sock.recv(min(65536, length - len(data)))
if not chunk:
raise ConnectionError("Connection closed while receiving data")
data += chunk
return data.decode("utf-8")
def _send_json(self, sock: socket.socket, data: dict):
"""Send a JSON message with length prefix."""
json_str = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
json_bytes = json_str.encode("utf-8")
# Send length prefix (4 bytes)
length = len(json_bytes)
sock.send(length.to_bytes(4, "big"))
# Send the data
sock.sendall(json_bytes)
def _search_legacy(self, query: str, top_k: int = 10) -> Dict[str, Any]:
"""Legacy search without message framing for old servers."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("localhost", self.port))
# Send request (old way)
request = {"query": query, "top_k": top_k}
sock.send(json.dumps(request).encode("utf-8"))
# Receive response (accumulate until we get valid JSON)
data = b""
while True:
chunk = sock.recv(65536)
if not chunk:
break
data += chunk
try:
# Try to decode as JSON
response = json.loads(data.decode("utf-8"))
sock.close()
return response
except json.JSONDecodeError:
# Keep receiving
continue
sock.close()
return {"success": False, "error": "Incomplete response from server"}
except Exception as e:
return {"success": False, "error": str(e)}
def is_running(self) -> bool:
"""Check if server is running."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(("localhost", self.port))
sock.close()
return result == 0
except (ConnectionError, OSError, TypeError, ValueError, socket.error):
return False
def start_server(project_path: Path, port: int = 7777):
"""Start the RAG server."""
server = RAGServer(project_path, port)
try:
server.start()
except KeyboardInterrupt:
server.stop()
def auto_start_if_needed(project_path: Path) -> Optional[subprocess.Popen]:
"""Auto-start server if not running."""
client = RAGClient()
if not client.is_running():
# Start server in background
import subprocess
cmd = [
sys.executable,
"-m",
"mini_rag.cli",
"server",
"--path",
str(project_path),
]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
creationflags=(subprocess.CREATE_NEW_CONSOLE if sys.platform == "win32" else 0),
)
# Wait for server to start
for _ in range(30): # 30 second timeout
time.sleep(1)
if client.is_running():
print(" RAG server started automatically")
return process
# Failed to start
process.terminate()
raise RuntimeError("Failed to start RAG server")
return None