""" Adaptive classifier that orchestrates ML + LLM classification. Strategy: 1. Hard rules: Fast pattern matching -> instant classification (10% of emails) 2. ML classifier: LightGBM with confidence threshold (85% of emails) 3. LLM review: Low-confidence emails sent for LLM review (5% of emails) 4. Dynamic threshold adjustment: Learn from LLM feedback """ import logging from typing import Dict, List, Any, Optional from dataclasses import dataclass from src.email_providers.base import Email, ClassificationResult from src.classification.feature_extractor import FeatureExtractor from src.classification.ml_classifier import MLClassifier from src.classification.llm_classifier import LLMClassifier logger = logging.getLogger(__name__) @dataclass class ClassificationStats: """Track classification statistics.""" total_emails: int = 0 rule_matched: int = 0 ml_classified: int = 0 llm_classified: int = 0 needs_review: int = 0 def accuracy_estimate(self) -> float: """Estimate accuracy based on classification method distribution.""" if self.total_emails == 0: return 0.0 # Conservative estimate rule_accuracy = 0.99 # Hard rules are very accurate ml_accuracy = 0.92 # ML classifier baseline llm_accuracy = 0.95 # LLM is very accurate weighted = ( (self.rule_matched * rule_accuracy + self.ml_classified * ml_accuracy + self.llm_classified * llm_accuracy) / self.total_emails ) return weighted def __str__(self) -> str: return ( f"Stats: {self.total_emails} emails | " f"Rules: {self.rule_matched} | " f"ML: {self.ml_classified} | " f"LLM: {self.llm_classified} | " f"Est. Accuracy: {self.accuracy_estimate():.1%}" ) class AdaptiveClassifier: """ Hybrid classifier combining hard rules, ML, and LLM. This is the main classification orchestrator. """ def __init__( self, feature_extractor: FeatureExtractor, ml_classifier: MLClassifier, llm_classifier: Optional[LLMClassifier], categories: Dict[str, Dict], config: Dict[str, Any], disable_llm_fallback: bool = False ): """Initialize adaptive classifier.""" self.feature_extractor = feature_extractor self.ml_classifier = ml_classifier self.llm_classifier = llm_classifier self.categories = categories self.config = config self.disable_llm_fallback = disable_llm_fallback self.thresholds = self._init_thresholds() self.stats = ClassificationStats() def _init_thresholds(self) -> Dict[str, float]: """Initialize classification thresholds per category.""" thresholds = {} for category, cat_config in self.categories.items(): threshold = cat_config.get('threshold', 0.55) thresholds[category] = threshold default = self.config.get('classification', {}).get('default_threshold', 0.55) thresholds['default'] = default logger.info(f"Initialized thresholds: {thresholds}") return thresholds def classify(self, email: Email) -> ClassificationResult: """ Classify an email using adaptive strategy. Process: 1. Try hard rules first 2. If no rule match: ML classification 3. If low confidence: Queue for LLM """ self.stats.total_emails += 1 # Step 1: Try hard rules rule_result = self._try_hard_rules(email) if rule_result: self.stats.rule_matched += 1 return rule_result # Step 2: ML classification try: features = self.feature_extractor.extract(email) ml_result = self.ml_classifier.predict(features.get('embedding')) if not ml_result or ml_result.get('error'): logger.warning(f"ML classification error for {email.id}") return ClassificationResult( email_id=email.id, category='unknown', confidence=0.0, method='error', error='ML classification failed' ) category = ml_result.get('category', 'unknown') confidence = ml_result.get('confidence', 0.0) # Check if above threshold threshold = self.thresholds.get(category, self.thresholds['default']) if confidence >= threshold: # High confidence: Accept ML classification self.stats.ml_classified += 1 return ClassificationResult( email_id=email.id, category=category, confidence=confidence, method='ml', probabilities=ml_result.get('probabilities', {}) ) else: # Low confidence: Queue for LLM (unless disabled) logger.debug(f"Low confidence for {email.id}: {category} ({confidence:.2f})") self.stats.needs_review += 1 if self.disable_llm_fallback: # Just return ML result without LLM fallback return ClassificationResult( email_id=email.id, category=category, confidence=confidence, method='ml', needs_review=False, probabilities=ml_result.get('probabilities', {}) ) else: return ClassificationResult( email_id=email.id, category=category, confidence=confidence, method='ml', needs_review=True, probabilities=ml_result.get('probabilities', {}) ) except Exception as e: logger.error(f"Classification error for {email.id}: {e}") return ClassificationResult( email_id=email.id, category='unknown', confidence=0.0, method='error', error=str(e) ) def classify_with_features(self, email: Email, features: Dict[str, Any]) -> ClassificationResult: """ Classify email using pre-extracted features (for batched processing). Args: email: Email object features: Pre-extracted features from extract_batch() Returns: Classification result """ self.stats.total_emails += 1 # Step 1: Try hard rules rule_result = self._try_hard_rules(email) if rule_result: self.stats.rule_matched += 1 return rule_result # Step 2: ML classification with pre-extracted embedding try: ml_result = self.ml_classifier.predict(features.get('embedding')) if not ml_result or ml_result.get('error'): logger.warning(f"ML classification error for {email.id}") return ClassificationResult( email_id=email.id, category='unknown', confidence=0.0, method='error', error='ML classification failed' ) category = ml_result.get('category', 'unknown') confidence = ml_result.get('confidence', 0.0) # Check if above threshold threshold = self.thresholds.get(category, self.thresholds['default']) if confidence >= threshold: # High confidence: Accept ML classification self.stats.ml_classified += 1 return ClassificationResult( email_id=email.id, category=category, confidence=confidence, method='ml', probabilities=ml_result.get('probabilities', {}) ) else: # Low confidence: Queue for LLM (unless disabled) logger.debug(f"Low confidence for {email.id}: {category} ({confidence:.2f})") self.stats.needs_review += 1 if self.disable_llm_fallback: # Just return ML result without LLM fallback return ClassificationResult( email_id=email.id, category=category, confidence=confidence, method='ml', needs_review=False, probabilities=ml_result.get('probabilities', {}) ) else: return ClassificationResult( email_id=email.id, category=category, confidence=confidence, method='ml', needs_review=True, probabilities=ml_result.get('probabilities', {}) ) except Exception as e: logger.error(f"Classification error for {email.id}: {e}") return ClassificationResult( email_id=email.id, category='unknown', confidence=0.0, method='error', error=str(e) ) def classify_batch(self, emails: List[Email]) -> List[ClassificationResult]: """Classify batch of emails.""" results = [] for email in emails: result = self.classify(email) results.append(result) return results def classify_with_llm(self, ml_result: ClassificationResult, email: Email) -> ClassificationResult: """ Use LLM to review low-confidence classification. Args: ml_result: Result from ML classifier email: Original email Returns: Updated classification with LLM input """ if not self.llm_classifier or not self.llm_classifier.llm_available: logger.warning(f"LLM not available for {email.id}, keeping ML result") return ml_result try: email_dict = { 'email_id': email.id, 'subject': email.subject, 'sender': email.sender, 'body_snippet': email.body_snippet, 'has_attachments': email.has_attachments, 'ml_prediction': { 'category': ml_result.category, 'confidence': ml_result.confidence } } llm_result = self.llm_classifier.classify(email_dict) self.stats.llm_classified += 1 return ClassificationResult( email_id=email.id, category=llm_result.get('category', ml_result.category), confidence=llm_result.get('confidence', 0.8), method='llm', metadata={ 'llm_reasoning': llm_result.get('reasoning', ''), 'ml_original': ml_result.category, 'ml_confidence': ml_result.confidence } ) except Exception as e: logger.error(f"LLM review failed for {email.id}: {e}") # Fall back to ML result return ml_result def _try_hard_rules(self, email: Email) -> Optional[ClassificationResult]: """Apply hard pattern-matching rules.""" subject = (email.subject or "").lower() body = (email.body_snippet or "").lower() combined = f"{subject} {body}" # Auth emails - high priority if any(p in combined for p in ['verification code', 'otp', 'reset password', 'confirm identity']): return ClassificationResult( email_id=email.id, category='auth', confidence=0.99, method='rule' ) # Finance - high priority if email.sender_name and any(p in email.sender_name.lower() for p in ['bank', 'credit card', 'payment']): if any(p in combined for p in ['statement', 'balance', 'account', 'invoice']): return ClassificationResult( email_id=email.id, category='finance', confidence=0.98, method='rule' ) # Invoices/Receipts - transactional if any(p in combined for p in ['invoice #', 'receipt #', 'order #', 'tracking #']): return ClassificationResult( email_id=email.id, category='transactional', confidence=0.97, method='rule' ) # Obvious spam if any(p in combined for p in ['unsubscribe', 'click here now', 'limited time offer']): return ClassificationResult( email_id=email.id, category='junk', confidence=0.96, method='rule' ) # Meeting/Calendar if any(p in combined for p in ['meeting at', 'zoom link', 'teams meeting', 'calendar invite']): return ClassificationResult( email_id=email.id, category='work', confidence=0.95, method='rule' ) return None def get_stats(self) -> ClassificationStats: """Get classification statistics.""" return self.stats def adjust_threshold(self, category: str, adjustment: float) -> None: """Dynamically adjust classification threshold.""" current = self.thresholds.get(category, self.thresholds['default']) new_threshold = max( self.config['classification']['min_threshold'], min( self.config['classification']['max_threshold'], current + adjustment ) ) self.thresholds[category] = new_threshold logger.debug(f"Adjusted threshold for {category}: {current:.2f} -> {new_threshold:.2f}")