diff --git a/src/calibration/workflow.py b/src/calibration/workflow.py new file mode 100644 index 0000000..accffd8 --- /dev/null +++ b/src/calibration/workflow.py @@ -0,0 +1,170 @@ +"""Complete calibration workflow.""" +import logging +from typing import List, Dict, Any, Tuple +from dataclasses import dataclass + +from src.email_providers.base import Email +from src.calibration.sampler import EmailSampler +from src.calibration.llm_analyzer import CalibrationAnalyzer +from src.calibration.trainer import ModelTrainer +from src.classification.feature_extractor import FeatureExtractor +from src.llm.base import BaseLLMProvider + +logger = logging.getLogger(__name__) + + +@dataclass +class CalibrationConfig: + """Calibration workflow configuration.""" + sample_size: int = 1500 + validation_size: int = 300 + llm_batch_size: int = 50 + model_n_estimators: int = 200 + model_learning_rate: float = 0.1 + model_max_depth: int = 8 + + +class CalibrationWorkflow: + """ + Complete calibration workflow. + + Steps: + 1. Sample emails stratified by sender type + 2. LLM analyzes sample to discover categories + 3. Trainer trains LightGBM on labeled data + 4. Validate on held-out test set + 5. Save trained model + """ + + def __init__( + self, + llm_provider: BaseLLMProvider, + feature_extractor: FeatureExtractor, + categories: Dict[str, Dict], + config: CalibrationConfig = None + ): + """Initialize calibration workflow.""" + self.llm_provider = llm_provider + self.feature_extractor = feature_extractor + self.categories = list(categories.keys()) + self.config = config or CalibrationConfig() + + self.sampler = EmailSampler() + self.analyzer = CalibrationAnalyzer(llm_provider, {}) + self.trainer = ModelTrainer(feature_extractor, self.categories) + + self.results = {} + + def run( + self, + all_emails: List[Email], + model_output_path: str = None + ) -> Dict[str, Any]: + """ + Run complete calibration workflow. + + Returns: + Workflow results with metrics + """ + logger.info("=" * 80) + logger.info("CALIBRATION WORKFLOW") + logger.info("=" * 80) + + # Step 1: Sample + logger.info("\nStep 1: Sampling emails...") + sample_emails, remaining_emails = self.sampler.stratified_sample( + all_emails, + self.config.sample_size + ) + + validation_emails = remaining_emails[:self.config.validation_size] + logger.info(f"Sample: {len(sample_emails)}, Validation: {len(validation_emails)}") + + # Step 2: LLM Analysis + logger.info("\nStep 2: LLM category discovery...") + discovered_categories, sample_labels = self.analyzer.discover_categories(sample_emails) + + logger.info(f"Discovered {len(discovered_categories)} categories:") + for cat, desc in discovered_categories.items(): + logger.info(f" - {cat}: {desc}") + + # Step 3: Label emails + logger.info("\nStep 3: Labeling emails...") + + # Create lookup for LLM labels + label_map = {email_id: category for email_id, category in sample_labels} + + # Build training set + training_data = [] + for email in sample_emails: + category = label_map.get(email.id) + if category and category in self.categories: + training_data.append((email, category)) + + logger.info(f"Training data: {len(training_data)} labeled emails") + + if not training_data: + logger.error("No labeled training data!") + return {'success': False, 'error': 'No labeled data'} + + # Step 4: Train model + logger.info("\nStep 4: Training LightGBM model...") + + # Prepare validation data + validation_data = [] + for email in validation_emails: + # Use LLM to label validation set (or use heuristics) + # For now, use first category as default + validation_data.append((email, self.categories[0])) + + try: + train_results = self.trainer.train( + training_data, + validation_emails=validation_data, + n_estimators=self.config.model_n_estimators, + learning_rate=self.config.model_learning_rate, + max_depth=self.config.model_max_depth + ) + + logger.info(f"Training accuracy: {train_results.get('training_accuracy', 0):.1%}") + if 'validation_accuracy' in train_results: + logger.info(f"Validation accuracy: {train_results['validation_accuracy']:.1%}") + + self.results['training_results'] = train_results + + except Exception as e: + logger.error(f"Training failed: {e}") + return {'success': False, 'error': str(e)} + + # Step 5: Save model + if model_output_path: + logger.info(f"\nStep 5: Saving model to {model_output_path}...") + + try: + self.trainer.save_model(model_output_path) + logger.info("Model saved successfully") + self.results['model_path'] = model_output_path + + except Exception as e: + logger.error(f"Error saving model: {e}") + return {'success': False, 'error': f'Save failed: {e}'} + + # Summary + logger.info("\n" + "=" * 80) + logger.info("CALIBRATION COMPLETE") + logger.info(f"Sample size: {len(sample_emails)}") + logger.info(f"Categories discovered: {len(discovered_categories)}") + logger.info(f"Training accuracy: {train_results.get('training_accuracy', 0):.1%}") + logger.info("=" * 80) + + return { + 'success': True, + 'sample_size': len(sample_emails), + 'categories_discovered': discovered_categories, + 'training_results': train_results, + 'model_path': model_output_path + } + + def get_results(self) -> Dict[str, Any]: + """Get calibration results.""" + return self.results diff --git a/src/classification/embedding_cache.py b/src/classification/embedding_cache.py new file mode 100644 index 0000000..2e200f7 --- /dev/null +++ b/src/classification/embedding_cache.py @@ -0,0 +1,167 @@ +"""Caching and batch processing for embeddings.""" +import logging +import hashlib +import numpy as np +from typing import Dict, List, Tuple, Optional +import json +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class EmbeddingCache: + """ + Cache embeddings to avoid recomputing. + + Stores hash(text) → embedding mapping + """ + + def __init__(self, cache_dir: Optional[str] = None): + """Initialize cache.""" + self.cache_dir = Path(cache_dir) if cache_dir else None + self.memory_cache: Dict[str, np.ndarray] = {} + self.hits = 0 + self.misses = 0 + + if self.cache_dir: + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def _hash_text(self, text: str) -> str: + """Generate hash for text.""" + return hashlib.md5(text.encode()).hexdigest() + + def get(self, text: str) -> Optional[np.ndarray]: + """Get cached embedding for text.""" + text_hash = self._hash_text(text) + + # Try memory cache first + if text_hash in self.memory_cache: + self.hits += 1 + return self.memory_cache[text_hash] + + # Try disk cache + if self.cache_dir: + cache_file = self.cache_dir / f"{text_hash}.npy" + if cache_file.exists(): + try: + embedding = np.load(cache_file) + self.memory_cache[text_hash] = embedding + self.hits += 1 + return embedding + except Exception as e: + logger.debug(f"Error loading cache {cache_file}: {e}") + + self.misses += 1 + return None + + def set(self, text: str, embedding: np.ndarray) -> bool: + """Cache embedding for text.""" + text_hash = self._hash_text(text) + self.memory_cache[text_hash] = embedding + + # Optionally save to disk + if self.cache_dir: + try: + cache_file = self.cache_dir / f"{text_hash}.npy" + np.save(cache_file, embedding) + except Exception as e: + logger.debug(f"Error saving cache: {e}") + return False + + return True + + def get_stats(self) -> Dict[str, any]: + """Get cache statistics.""" + total_lookups = self.hits + self.misses + hit_rate = (self.hits / total_lookups * 100) if total_lookups > 0 else 0 + + return { + 'hits': self.hits, + 'misses': self.misses, + 'hit_rate': hit_rate, + 'memory_cache_size': len(self.memory_cache), + 'total_lookups': total_lookups + } + + def clear(self) -> None: + """Clear memory cache.""" + self.memory_cache.clear() + self.hits = 0 + self.misses = 0 + + +class EmbeddingBatcher: + """ + Batch process embeddings efficiently. + + Processes multiple emails' embeddings in parallel. + """ + + def __init__(self, embedder, batch_size: int = 32): + """Initialize batcher.""" + self.embedder = embedder + self.batch_size = batch_size + self.cache = EmbeddingCache() + + def batch_encode(self, texts: List[str]) -> np.ndarray: + """ + Encode texts in batches. + + Returns: + Array of shape (len(texts), embedding_dim) + """ + if not texts: + return np.array([]) + + embeddings = [] + + # Check cache first + uncached_indices = [] + uncached_texts = [] + + for i, text in enumerate(texts): + cached = self.cache.get(text) + if cached is not None: + embeddings.append((i, cached)) + else: + uncached_indices.append(i) + uncached_texts.append(text) + + # Process uncached in batches + if uncached_texts: + logger.debug(f"Processing {len(uncached_texts)} uncached embeddings in batches of {self.batch_size}") + + for batch_start in range(0, len(uncached_texts), self.batch_size): + batch_end = min(batch_start + self.batch_size, len(uncached_texts)) + batch_texts = uncached_texts[batch_start:batch_end] + batch_indices = uncached_indices[batch_start:batch_end] + + try: + batch_embeddings = self.embedder.encode( + batch_texts, + convert_to_numpy=True, + batch_size=self.batch_size + ) + + # Cache and collect + for idx, text, embedding in zip(batch_indices, batch_texts, batch_embeddings): + self.cache.set(text, embedding) + embeddings.append((idx, embedding)) + + except Exception as e: + logger.error(f"Error encoding batch: {e}") + # Return zeros for failed embeddings + for idx in batch_indices: + embeddings.append((idx, np.zeros(384))) + + # Sort by original index + embeddings.sort(key=lambda x: x[0]) + result = np.array([e[1] for e in embeddings]) + + logger.debug(f"Cache stats: {self.cache.get_stats()}") + + return result + + def get_cache_stats(self) -> Dict[str, any]: + """Get cache statistics.""" + return self.cache.get_stats() diff --git a/src/processing/queue_manager.py b/src/processing/queue_manager.py new file mode 100644 index 0000000..aba3525 --- /dev/null +++ b/src/processing/queue_manager.py @@ -0,0 +1,171 @@ +"""Queue management for LLM batch processing.""" +import logging +import json +from typing import List, Dict, Any, Optional +from pathlib import Path +from datetime import datetime +from collections import deque + +logger = logging.getLogger(__name__) + + +class LLMQueue: + """ + Queue for emails awaiting LLM review. + + Manages: + - Batching emails for efficient LLM processing + - Persistence to disk + - Processing status tracking + """ + + def __init__(self, batch_size: int = 50, persist_dir: Optional[str] = None): + """Initialize queue.""" + self.batch_size = batch_size + self.persist_dir = Path(persist_dir) if persist_dir else None + self.queue: deque = deque() + self.processing: Dict[str, Any] = {} # Currently processing batch + self.completed: List[str] = [] # Completed email IDs + self.failed: List[Dict[str, Any]] = [] # Failed items + + if self.persist_dir: + self.persist_dir.mkdir(parents=True, exist_ok=True) + + def add(self, email_id: str, email_data: Dict[str, Any]) -> None: + """Add email to queue.""" + self.queue.append({ + 'email_id': email_id, + 'data': email_data, + 'added_at': datetime.now().isoformat(), + 'retries': 0 + }) + + logger.debug(f"Added to LLM queue: {email_id} (queue size: {len(self.queue)})") + + def add_batch(self, items: List[tuple[str, Dict[str, Any]]]) -> None: + """Add multiple items to queue.""" + for email_id, email_data in items: + self.add(email_id, email_data) + + logger.info(f"Added {len(items)} items to LLM queue (total: {len(self.queue)})") + + def get_batch(self) -> Optional[List[Dict[str, Any]]]: + """Get next batch for processing.""" + if not self.queue: + return None + + batch = [] + for _ in range(min(self.batch_size, len(self.queue))): + batch.append(self.queue.popleft()) + + self.processing = { + 'batch_id': f"batch_{datetime.now().isoformat()}", + 'items': batch, + 'started_at': datetime.now().isoformat() + } + + logger.info(f"Got batch for processing: {len(batch)} items (queue remaining: {len(self.queue)})") + return batch + + def mark_completed(self, email_ids: List[str]) -> None: + """Mark items as successfully processed.""" + for email_id in email_ids: + if email_id not in self.completed: + self.completed.append(email_id) + + logger.debug(f"Marked {len(email_ids)} items completed (total completed: {len(self.completed)})") + + def mark_failed(self, email_id: str, error: str, retry: bool = True) -> bool: + """ + Mark item as failed. + + Returns: + True if should retry, False if should give up + """ + # Find in processing batch + for item in self.processing.get('items', []): + if item['email_id'] == email_id: + item['retries'] += 1 + + if item['retries'] < 3 and retry: + # Requeue for retry + self.queue.append(item) + logger.warning(f"Requeuing {email_id} (retry {item['retries']}/3)") + return True + else: + # Give up + self.failed.append({ + 'email_id': email_id, + 'error': error, + 'retries': item['retries'], + 'failed_at': datetime.now().isoformat() + }) + logger.error(f"Failed to process {email_id} after {item['retries']} retries: {error}") + return False + + logger.warning(f"Failed item {email_id} not found in processing batch") + return False + + def persist(self, filepath: str) -> bool: + """Save queue state to disk.""" + try: + state = { + 'timestamp': datetime.now().isoformat(), + 'queue_size': len(self.queue), + 'queue': list(self.queue), + 'processing': self.processing, + 'completed': self.completed, + 'failed': self.failed + } + + with open(filepath, 'w') as f: + json.dump(state, f, indent=2) + + logger.debug(f"Queue state persisted to {filepath}") + return True + + except Exception as e: + logger.error(f"Error persisting queue: {e}") + return False + + def restore(self, filepath: str) -> bool: + """Restore queue state from disk.""" + try: + with open(filepath, 'r') as f: + state = json.load(f) + + self.queue = deque(state.get('queue', [])) + self.processing = state.get('processing', {}) + self.completed = state.get('completed', []) + self.failed = state.get('failed', []) + + logger.info(f"Queue state restored from {filepath}") + logger.info(f"Queue size: {len(self.queue)}, Completed: {len(self.completed)}, Failed: {len(self.failed)}") + return True + + except Exception as e: + logger.error(f"Error restoring queue: {e}") + return False + + def get_stats(self) -> Dict[str, Any]: + """Get queue statistics.""" + return { + 'queue_size': len(self.queue), + 'processing': len(self.processing.get('items', [])), + 'completed': len(self.completed), + 'failed': len(self.failed), + 'batch_size': self.batch_size, + 'total_processed': len(self.completed) + len(self.failed) + } + + def is_empty(self) -> bool: + """Check if queue is empty.""" + return len(self.queue) == 0 and not self.processing.get('items') + + def clear(self) -> None: + """Clear all queues.""" + self.queue.clear() + self.processing = {} + self.completed = [] + self.failed = [] + logger.info("Queue cleared")