Fss-Rag-Mini/mini_rag/claude_rag/ollama_embeddings.py
BobAi dc866e6ce3 MAJOR: Remove all Claude references and rename to Mini-RAG
Complete rebrand for v1.0-simple-search branch:

Directory Changes:
- claude_rag/ → mini_rag/ (preserving git history)

Content Changes:
- Updated all imports: from claude_rag → from mini_rag
- Updated all file paths: .claude-rag → .mini-rag
- Updated documentation and comments
- Updated configuration files and examples
- Updated all tests to use mini_rag imports

This ensures complete independence from Claude/Anthropic
branding while maintaining all functionality and git history.

Simple branch contains the basic RAG system without LLM features.
2025-08-12 19:27:55 +10:00

444 lines
17 KiB
Python

"""
Hybrid code embedding module - Ollama primary with ML fallback.
Tries Ollama first, falls back to local ML stack if needed.
"""
import requests
import numpy as np
from typing import List, Union, Optional, Dict, Any
import logging
from functools import lru_cache
import time
import json
from concurrent.futures import ThreadPoolExecutor
import threading
logger = logging.getLogger(__name__)
# Try to import fallback ML dependencies
FALLBACK_AVAILABLE = False
try:
import torch
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
FALLBACK_AVAILABLE = True
logger.debug("ML fallback dependencies available")
except ImportError:
logger.debug("ML fallback not available - Ollama only mode")
class OllamaEmbedder:
"""Hybrid embeddings: Ollama primary with ML fallback."""
def __init__(self, model_name: str = "nomic-embed-text:latest", base_url: str = "http://localhost:11434",
enable_fallback: bool = True):
"""
Initialize the hybrid embedder.
Args:
model_name: Ollama model to use for embeddings
base_url: Base URL for Ollama API
enable_fallback: Whether to use ML fallback if Ollama fails
"""
self.model_name = model_name
self.base_url = base_url
self.embedding_dim = 768 # Standard for nomic-embed-text
self.enable_fallback = enable_fallback and FALLBACK_AVAILABLE
# State tracking
self.ollama_available = False
self.fallback_embedder = None
self.mode = "unknown" # "ollama", "fallback", or "hash"
# Try to initialize Ollama first
self._initialize_providers()
def _initialize_providers(self):
"""Initialize embedding providers in priority order."""
# Try Ollama first
try:
self._verify_ollama_connection()
self.ollama_available = True
self.mode = "ollama"
logger.info(f"✅ Ollama embeddings active: {self.model_name}")
except Exception as e:
logger.debug(f"Ollama not available: {e}")
self.ollama_available = False
# Try ML fallback
if self.enable_fallback:
try:
self._initialize_fallback_embedder()
self.mode = "fallback"
logger.info(f"✅ ML fallback active: {self.fallback_embedder.model_type if hasattr(self.fallback_embedder, 'model_type') else 'transformer'}")
except Exception as fallback_error:
logger.warning(f"ML fallback failed: {fallback_error}")
self.mode = "hash"
logger.info("⚠️ Using hash-based embeddings (deterministic fallback)")
else:
self.mode = "hash"
logger.info("⚠️ Using hash-based embeddings (no fallback enabled)")
def _verify_ollama_connection(self):
"""Verify Ollama server is running and model is available."""
# Check server status
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
response.raise_for_status()
# Check if our model is available
models = response.json().get('models', [])
model_names = [model['name'] for model in models]
if self.model_name not in model_names:
logger.warning(f"Model {self.model_name} not found. Available: {model_names}")
# Try to pull the model
self._pull_model()
def _initialize_fallback_embedder(self):
"""Initialize the ML fallback embedder."""
if not FALLBACK_AVAILABLE:
raise RuntimeError("ML dependencies not available for fallback")
# Try lightweight models first for better compatibility
fallback_models = [
("sentence-transformers/all-MiniLM-L6-v2", 384, self._init_sentence_transformer),
("microsoft/codebert-base", 768, self._init_transformer_model),
("microsoft/unixcoder-base", 768, self._init_transformer_model),
]
for model_name, dim, init_func in fallback_models:
try:
init_func(model_name)
self.embedding_dim = dim
logger.info(f"Loaded fallback model: {model_name}")
return
except Exception as e:
logger.debug(f"Failed to load {model_name}: {e}")
continue
raise RuntimeError("Could not initialize any fallback embedding model")
def _init_sentence_transformer(self, model_name: str):
"""Initialize sentence-transformers model."""
self.fallback_embedder = SentenceTransformer(model_name)
self.fallback_embedder.model_type = 'sentence_transformer'
def _init_transformer_model(self, model_name: str):
"""Initialize transformer model."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()
# Create a simple wrapper
class TransformerWrapper:
def __init__(self, model, tokenizer, device):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model_type = 'transformer'
self.fallback_embedder = TransformerWrapper(model, tokenizer, device)
def _pull_model(self):
"""Pull the embedding model if not available."""
logger.info(f"Pulling model {self.model_name}...")
try:
response = requests.post(
f"{self.base_url}/api/pull",
json={"name": self.model_name},
timeout=300 # 5 minutes for model download
)
response.raise_for_status()
logger.info(f"Successfully pulled {self.model_name}")
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Failed to pull model {self.model_name}: {e}")
def _get_embedding(self, text: str) -> np.ndarray:
"""Get embedding using the best available provider."""
if self.mode == "ollama" and self.ollama_available:
return self._get_ollama_embedding(text)
elif self.mode == "fallback" and self.fallback_embedder:
return self._get_fallback_embedding(text)
else:
# Hash fallback
return self._hash_embedding(text)
def _get_ollama_embedding(self, text: str) -> np.ndarray:
"""Get embedding from Ollama API."""
try:
response = requests.post(
f"{self.base_url}/api/embeddings",
json={
"model": self.model_name,
"prompt": text
},
timeout=30
)
response.raise_for_status()
result = response.json()
embedding = result.get('embedding', [])
if not embedding:
raise ValueError("No embedding returned from Ollama")
return np.array(embedding, dtype=np.float32)
except requests.exceptions.RequestException as e:
logger.error(f"Ollama API request failed: {e}")
# Degrade gracefully - try fallback if available
if self.mode == "ollama" and self.enable_fallback and self.fallback_embedder:
logger.info("Falling back to ML embeddings due to Ollama failure")
self.mode = "fallback" # Switch mode temporarily
return self._get_fallback_embedding(text)
return self._hash_embedding(text)
except (ValueError, KeyError) as e:
logger.error(f"Invalid response from Ollama: {e}")
return self._hash_embedding(text)
def _get_fallback_embedding(self, text: str) -> np.ndarray:
"""Get embedding from ML fallback."""
try:
if self.fallback_embedder.model_type == 'sentence_transformer':
embedding = self.fallback_embedder.encode([text], convert_to_numpy=True)[0]
return embedding.astype(np.float32)
elif self.fallback_embedder.model_type == 'transformer':
# Tokenize and generate embedding
inputs = self.fallback_embedder.tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.fallback_embedder.device)
with torch.no_grad():
outputs = self.fallback_embedder.model(**inputs)
# Use pooler output if available, otherwise mean pooling
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
embedding = outputs.pooler_output[0]
else:
# Mean pooling over sequence length
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state[0]
# Mask and average
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0)
sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9)
embedding = sum_embeddings / sum_mask
return embedding.cpu().numpy().astype(np.float32)
else:
raise ValueError(f"Unknown fallback model type: {self.fallback_embedder.model_type}")
except Exception as e:
logger.error(f"Fallback embedding failed: {e}")
return self._hash_embedding(text)
def _hash_embedding(self, text: str) -> np.ndarray:
"""Generate deterministic hash-based embedding as fallback."""
import hashlib
# Create deterministic hash
hash_obj = hashlib.sha256(text.encode('utf-8'))
hash_bytes = hash_obj.digest()
# Convert to numbers and normalize
hash_nums = np.frombuffer(hash_bytes, dtype=np.uint8)
# Expand to target dimension using repetition
while len(hash_nums) < self.embedding_dim:
hash_nums = np.concatenate([hash_nums, hash_nums])
# Take exactly the dimension we need
embedding = hash_nums[:self.embedding_dim].astype(np.float32)
# Normalize to [-1, 1] range
embedding = (embedding / 127.5) - 1.0
logger.debug(f"Using hash fallback embedding for text: {text[:50]}...")
return embedding
def embed_code(self, code: Union[str, List[str]], language: str = "python") -> np.ndarray:
"""
Generate embeddings for code snippet(s).
Args:
code: Single code string or list of code strings
language: Programming language (used for context)
Returns:
Embedding vector(s) as numpy array
"""
if isinstance(code, str):
code = [code]
single_input = True
else:
single_input = False
# Preprocess code for better embeddings
processed_code = [self._preprocess_code(c, language) for c in code]
# Generate embeddings
embeddings = []
for text in processed_code:
embedding = self._get_embedding(text)
embeddings.append(embedding)
embeddings = np.array(embeddings, dtype=np.float32)
if single_input:
return embeddings[0]
return embeddings
def _preprocess_code(self, code: str, language: str = "python") -> str:
"""
Preprocess code for better embedding quality.
Add language context and clean up formatting.
"""
# Remove leading/trailing whitespace
code = code.strip()
# Normalize whitespace but preserve structure
lines = code.split('\n')
processed_lines = []
for line in lines:
# Remove trailing whitespace
line = line.rstrip()
# Keep non-empty lines
if line:
processed_lines.append(line)
cleaned_code = '\n'.join(processed_lines)
# Add language context for better embeddings
if language and cleaned_code:
return f"```{language}\n{cleaned_code}\n```"
return cleaned_code
@lru_cache(maxsize=1000)
def embed_query(self, query: str) -> np.ndarray:
"""
Embed a search query with caching.
Queries are often repeated, so we cache them.
"""
# Enhance query for code search
enhanced_query = f"Search for code related to: {query}"
return self._get_embedding(enhanced_query)
def batch_embed_files(self, file_contents: List[dict], max_workers: int = 4) -> List[dict]:
"""
Embed multiple files efficiently using concurrent requests to Ollama.
Args:
file_contents: List of dicts with 'content' and optionally 'language' keys
max_workers: Maximum number of concurrent Ollama requests
Returns:
List of dicts with added 'embedding' key (preserves original order)
"""
if not file_contents:
return []
# For small batches, use sequential processing to avoid overhead
if len(file_contents) <= 2:
return self._batch_embed_sequential(file_contents)
return self._batch_embed_concurrent(file_contents, max_workers)
def _batch_embed_sequential(self, file_contents: List[dict]) -> List[dict]:
"""Sequential processing for small batches."""
results = []
for file_dict in file_contents:
content = file_dict['content']
language = file_dict.get('language', 'python')
embedding = self.embed_code(content, language)
result = file_dict.copy()
result['embedding'] = embedding
results.append(result)
return results
def _batch_embed_concurrent(self, file_contents: List[dict], max_workers: int) -> List[dict]:
"""Concurrent processing for larger batches."""
def embed_single(item_with_index):
index, file_dict = item_with_index
content = file_dict['content']
language = file_dict.get('language', 'python')
try:
embedding = self.embed_code(content, language)
result = file_dict.copy()
result['embedding'] = embedding
return index, result
except Exception as e:
logger.error(f"Failed to embed content at index {index}: {e}")
# Return with hash fallback
result = file_dict.copy()
result['embedding'] = self._hash_embedding(content)
return index, result
# Create indexed items to preserve order
indexed_items = list(enumerate(file_contents))
# Process concurrently
with ThreadPoolExecutor(max_workers=max_workers) as executor:
indexed_results = list(executor.map(embed_single, indexed_items))
# Sort by original index and extract results
indexed_results.sort(key=lambda x: x[0])
return [result for _, result in indexed_results]
def get_embedding_dim(self) -> int:
"""Return the dimension of embeddings produced by this model."""
return self.embedding_dim
def get_mode(self) -> str:
"""Return current embedding mode: 'ollama', 'fallback', or 'hash'."""
return self.mode
def get_status(self) -> Dict[str, Any]:
"""Get detailed status of the embedding system."""
return {
"mode": self.mode,
"ollama_available": self.ollama_available,
"fallback_available": FALLBACK_AVAILABLE and self.enable_fallback,
"fallback_model": getattr(self.fallback_embedder, 'model_type', None) if self.fallback_embedder else None,
"embedding_dim": self.embedding_dim,
"ollama_model": self.model_name if self.mode == "ollama" else None,
"ollama_url": self.base_url if self.mode == "ollama" else None
}
def warmup(self):
"""Warm up the embedding system with a dummy request."""
dummy_code = "def hello(): pass"
_ = self.embed_code(dummy_code)
logger.info(f"Hybrid embedder ready - Mode: {self.mode}")
return self.get_status()
# Convenience function for quick embedding
def embed_code(code: Union[str, List[str]], model_name: str = "nomic-embed-text:latest") -> np.ndarray:
"""
Quick function to embed code without managing embedder instance.
Args:
code: Code string(s) to embed
model_name: Ollama model name to use
Returns:
Embedding vector(s)
"""
embedder = OllamaEmbedder(model_name=model_name)
return embedder.embed_code(code)
# Compatibility alias for drop-in replacement
CodeEmbedder = OllamaEmbedder