Root cause: Pre-trained model was loading successfully, causing CLI to skip
calibration entirely. System went straight to classification with 35% model.
Changes:
- config: Set calibration_model to qwen3:8b-q4_K_M (larger model for better instruction following)
- cli: Create separate calibration_llm provider with 8b model
- llm_analyzer: Improved prompt to force exact email ID copying
- workflow: Merge discovered categories with predefined ones
- workflow: Add detailed error logging for label mismatches
- ml_classifier: Fixed model path checking (was checking None parameter)
- ml_classifier: Add dual API support (sklearn predict_proba vs LightGBM predict)
- ollama: Fixed model list parsing (use m.model not m.get('name'))
- feature_extractor: Switch to Ollama embeddings (instant vs 90s load time)
Result: Calibration now runs and generates 16 categories + 50 labels correctly.
Next: Investigate calibration sampling to reduce overfitting on small samples.
141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
"""Ollama LLM provider for local model inference."""
|
|
import logging
|
|
import time
|
|
from typing import Optional, Dict, Any
|
|
|
|
from .base import BaseLLMProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OllamaProvider(BaseLLMProvider):
|
|
"""
|
|
Local LLM provider using Ollama.
|
|
|
|
Status: Requires Ollama running locally
|
|
- Default: http://localhost:11434
|
|
- Models: qwen3:4b (calibration), qwen3:1.7b (classification)
|
|
- If Ollama unavailable: Returns graceful error
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str = "http://localhost:11434",
|
|
model: str = "qwen3:1.7b",
|
|
temperature: float = 0.1,
|
|
max_tokens: int = 500,
|
|
timeout: int = 30,
|
|
retry_attempts: int = 3
|
|
):
|
|
"""Initialize Ollama provider."""
|
|
super().__init__(name="ollama")
|
|
self.base_url = base_url
|
|
self.model = model
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.timeout = timeout
|
|
self.retry_attempts = retry_attempts
|
|
self.client = None
|
|
self._available = False
|
|
|
|
self._initialize()
|
|
|
|
def _initialize(self) -> None:
|
|
"""Initialize Ollama client."""
|
|
try:
|
|
import ollama
|
|
self.client = ollama.Client(host=self.base_url)
|
|
self.logger.info(f"Ollama provider initialized: {self.base_url}")
|
|
|
|
# Test connection
|
|
if self.test_connection():
|
|
self._available = True
|
|
self.logger.info(f"Ollama connected, model: {self.model}")
|
|
else:
|
|
self.logger.warning("Ollama connection test failed")
|
|
self._available = False
|
|
|
|
except ImportError:
|
|
self.logger.error("ollama package not installed: pip install ollama")
|
|
self._available = False
|
|
except Exception as e:
|
|
self.logger.error(f"Ollama initialization failed: {e}")
|
|
self._available = False
|
|
|
|
def complete(self, prompt: str, **kwargs) -> str:
|
|
"""
|
|
Get completion from Ollama.
|
|
|
|
Args:
|
|
prompt: Input prompt
|
|
**kwargs: Override temperature, max_tokens, timeout
|
|
|
|
Returns:
|
|
LLM response
|
|
"""
|
|
if not self._available or not self.client:
|
|
self.logger.error("Ollama not available")
|
|
raise RuntimeError("Ollama provider not initialized")
|
|
|
|
temperature = kwargs.get('temperature', self.temperature)
|
|
max_tokens = kwargs.get('max_tokens', self.max_tokens)
|
|
timeout = kwargs.get('timeout', self.timeout)
|
|
|
|
attempt = 0
|
|
while attempt < self.retry_attempts:
|
|
try:
|
|
self.logger.debug(f"Ollama request: model={self.model}, tokens={max_tokens}")
|
|
|
|
response = self.client.generate(
|
|
model=self.model,
|
|
prompt=prompt,
|
|
options={
|
|
'temperature': temperature,
|
|
'num_predict': max_tokens,
|
|
'top_k': 40,
|
|
'top_p': 0.9,
|
|
}
|
|
)
|
|
|
|
text = response.get('response', '')
|
|
self.logger.debug(f"Ollama response: {len(text)} chars")
|
|
return text
|
|
|
|
except Exception as e:
|
|
attempt += 1
|
|
if attempt < self.retry_attempts:
|
|
wait_time = 2 ** attempt # Exponential backoff
|
|
self.logger.warning(f"Ollama request failed ({attempt}/{self.retry_attempts}), retrying in {wait_time}s: {e}")
|
|
time.sleep(wait_time)
|
|
else:
|
|
self.logger.error(f"Ollama request failed after {self.retry_attempts} attempts: {e}")
|
|
raise
|
|
|
|
def test_connection(self) -> bool:
|
|
"""Test if Ollama is running and accessible."""
|
|
if not self.client:
|
|
self.logger.warning("Ollama client not initialized")
|
|
return False
|
|
|
|
try:
|
|
# Try to list available models
|
|
response = self.client.list()
|
|
available_models = [m.model for m in response.models]
|
|
|
|
# Check if requested model is available
|
|
if any(self.model in m for m in available_models):
|
|
self.logger.info(f"Ollama test passed, model available: {self.model}")
|
|
return True
|
|
else:
|
|
self.logger.warning(f"Ollama running but model not found: {self.model}")
|
|
self.logger.warning(f"Available models: {available_models}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Ollama connection test failed: {e}")
|
|
return False
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if provider is available."""
|
|
return self._available
|