Performance Improvements: - Extract features in batches (512 emails/batch) instead of one-at-a-time - Reduced embedding API calls from 10,000 to 20 for 10k emails - 10x faster classification: 4 minutes -> 24 seconds Changes: - cli.py: Use extract_batch() for all feature extraction - adaptive_classifier.py: Add classify_with_features() method - trainer.py: Set LightGBM num_threads to 28 Performance Results (10k emails): - Batch 512: 23.6 seconds (423 emails/sec) - Batch 1024: 22.1 seconds (453 emails/sec) - Batch 2048: 21.9 seconds (457 emails/sec) Selected batch_size=512 for balance of speed and memory. Breakdown for 10k emails: - Email parsing: 0.5s - Embedding (batched): 20s (20 API calls) - ML classification: 0.7s - Export: 0.02s - Total: ~24s
393 lines
14 KiB
Python
393 lines
14 KiB
Python
"""
|
|
Adaptive classifier that orchestrates ML + LLM classification.
|
|
|
|
Strategy:
|
|
1. Hard rules: Fast pattern matching -> instant classification (10% of emails)
|
|
2. ML classifier: LightGBM with confidence threshold (85% of emails)
|
|
3. LLM review: Low-confidence emails sent for LLM review (5% of emails)
|
|
4. Dynamic threshold adjustment: Learn from LLM feedback
|
|
"""
|
|
import logging
|
|
from typing import Dict, List, Any, Optional
|
|
from dataclasses import dataclass
|
|
|
|
from src.email_providers.base import Email, ClassificationResult
|
|
from src.classification.feature_extractor import FeatureExtractor
|
|
from src.classification.ml_classifier import MLClassifier
|
|
from src.classification.llm_classifier import LLMClassifier
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ClassificationStats:
|
|
"""Track classification statistics."""
|
|
total_emails: int = 0
|
|
rule_matched: int = 0
|
|
ml_classified: int = 0
|
|
llm_classified: int = 0
|
|
needs_review: int = 0
|
|
|
|
def accuracy_estimate(self) -> float:
|
|
"""Estimate accuracy based on classification method distribution."""
|
|
if self.total_emails == 0:
|
|
return 0.0
|
|
|
|
# Conservative estimate
|
|
rule_accuracy = 0.99 # Hard rules are very accurate
|
|
ml_accuracy = 0.92 # ML classifier baseline
|
|
llm_accuracy = 0.95 # LLM is very accurate
|
|
|
|
weighted = (
|
|
(self.rule_matched * rule_accuracy +
|
|
self.ml_classified * ml_accuracy +
|
|
self.llm_classified * llm_accuracy) / self.total_emails
|
|
)
|
|
return weighted
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f"Stats: {self.total_emails} emails | "
|
|
f"Rules: {self.rule_matched} | "
|
|
f"ML: {self.ml_classified} | "
|
|
f"LLM: {self.llm_classified} | "
|
|
f"Est. Accuracy: {self.accuracy_estimate():.1%}"
|
|
)
|
|
|
|
|
|
class AdaptiveClassifier:
|
|
"""
|
|
Hybrid classifier combining hard rules, ML, and LLM.
|
|
|
|
This is the main classification orchestrator.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
feature_extractor: FeatureExtractor,
|
|
ml_classifier: MLClassifier,
|
|
llm_classifier: Optional[LLMClassifier],
|
|
categories: Dict[str, Dict],
|
|
config: Dict[str, Any],
|
|
disable_llm_fallback: bool = False
|
|
):
|
|
"""Initialize adaptive classifier."""
|
|
self.feature_extractor = feature_extractor
|
|
self.ml_classifier = ml_classifier
|
|
self.llm_classifier = llm_classifier
|
|
self.categories = categories
|
|
self.config = config
|
|
self.disable_llm_fallback = disable_llm_fallback
|
|
|
|
self.thresholds = self._init_thresholds()
|
|
self.stats = ClassificationStats()
|
|
|
|
def _init_thresholds(self) -> Dict[str, float]:
|
|
"""Initialize classification thresholds per category."""
|
|
thresholds = {}
|
|
|
|
for category, cat_config in self.categories.items():
|
|
threshold = cat_config.get('threshold', 0.55)
|
|
thresholds[category] = threshold
|
|
|
|
default = self.config.get('classification', {}).get('default_threshold', 0.55)
|
|
thresholds['default'] = default
|
|
|
|
logger.info(f"Initialized thresholds: {thresholds}")
|
|
return thresholds
|
|
|
|
def classify(self, email: Email) -> ClassificationResult:
|
|
"""
|
|
Classify an email using adaptive strategy.
|
|
|
|
Process:
|
|
1. Try hard rules first
|
|
2. If no rule match: ML classification
|
|
3. If low confidence: Queue for LLM
|
|
"""
|
|
self.stats.total_emails += 1
|
|
|
|
# Step 1: Try hard rules
|
|
rule_result = self._try_hard_rules(email)
|
|
if rule_result:
|
|
self.stats.rule_matched += 1
|
|
return rule_result
|
|
|
|
# Step 2: ML classification
|
|
try:
|
|
features = self.feature_extractor.extract(email)
|
|
ml_result = self.ml_classifier.predict(features.get('embedding'))
|
|
|
|
if not ml_result or ml_result.get('error'):
|
|
logger.warning(f"ML classification error for {email.id}")
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='unknown',
|
|
confidence=0.0,
|
|
method='error',
|
|
error='ML classification failed'
|
|
)
|
|
|
|
category = ml_result.get('category', 'unknown')
|
|
confidence = ml_result.get('confidence', 0.0)
|
|
|
|
# Check if above threshold
|
|
threshold = self.thresholds.get(category, self.thresholds['default'])
|
|
|
|
if confidence >= threshold:
|
|
# High confidence: Accept ML classification
|
|
self.stats.ml_classified += 1
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category=category,
|
|
confidence=confidence,
|
|
method='ml',
|
|
probabilities=ml_result.get('probabilities', {})
|
|
)
|
|
else:
|
|
# Low confidence: Queue for LLM (unless disabled)
|
|
logger.debug(f"Low confidence for {email.id}: {category} ({confidence:.2f})")
|
|
self.stats.needs_review += 1
|
|
|
|
if self.disable_llm_fallback:
|
|
# Just return ML result without LLM fallback
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category=category,
|
|
confidence=confidence,
|
|
method='ml',
|
|
needs_review=False,
|
|
probabilities=ml_result.get('probabilities', {})
|
|
)
|
|
else:
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category=category,
|
|
confidence=confidence,
|
|
method='ml',
|
|
needs_review=True,
|
|
probabilities=ml_result.get('probabilities', {})
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Classification error for {email.id}: {e}")
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='unknown',
|
|
confidence=0.0,
|
|
method='error',
|
|
error=str(e)
|
|
)
|
|
|
|
def classify_with_features(self, email: Email, features: Dict[str, Any]) -> ClassificationResult:
|
|
"""
|
|
Classify email using pre-extracted features (for batched processing).
|
|
|
|
Args:
|
|
email: Email object
|
|
features: Pre-extracted features from extract_batch()
|
|
|
|
Returns:
|
|
Classification result
|
|
"""
|
|
self.stats.total_emails += 1
|
|
|
|
# Step 1: Try hard rules
|
|
rule_result = self._try_hard_rules(email)
|
|
if rule_result:
|
|
self.stats.rule_matched += 1
|
|
return rule_result
|
|
|
|
# Step 2: ML classification with pre-extracted embedding
|
|
try:
|
|
ml_result = self.ml_classifier.predict(features.get('embedding'))
|
|
|
|
if not ml_result or ml_result.get('error'):
|
|
logger.warning(f"ML classification error for {email.id}")
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='unknown',
|
|
confidence=0.0,
|
|
method='error',
|
|
error='ML classification failed'
|
|
)
|
|
|
|
category = ml_result.get('category', 'unknown')
|
|
confidence = ml_result.get('confidence', 0.0)
|
|
|
|
# Check if above threshold
|
|
threshold = self.thresholds.get(category, self.thresholds['default'])
|
|
|
|
if confidence >= threshold:
|
|
# High confidence: Accept ML classification
|
|
self.stats.ml_classified += 1
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category=category,
|
|
confidence=confidence,
|
|
method='ml',
|
|
probabilities=ml_result.get('probabilities', {})
|
|
)
|
|
else:
|
|
# Low confidence: Queue for LLM (unless disabled)
|
|
logger.debug(f"Low confidence for {email.id}: {category} ({confidence:.2f})")
|
|
self.stats.needs_review += 1
|
|
|
|
if self.disable_llm_fallback:
|
|
# Just return ML result without LLM fallback
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category=category,
|
|
confidence=confidence,
|
|
method='ml',
|
|
needs_review=False,
|
|
probabilities=ml_result.get('probabilities', {})
|
|
)
|
|
else:
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category=category,
|
|
confidence=confidence,
|
|
method='ml',
|
|
needs_review=True,
|
|
probabilities=ml_result.get('probabilities', {})
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Classification error for {email.id}: {e}")
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='unknown',
|
|
confidence=0.0,
|
|
method='error',
|
|
error=str(e)
|
|
)
|
|
|
|
def classify_batch(self, emails: List[Email]) -> List[ClassificationResult]:
|
|
"""Classify batch of emails."""
|
|
results = []
|
|
for email in emails:
|
|
result = self.classify(email)
|
|
results.append(result)
|
|
return results
|
|
|
|
def classify_with_llm(self, ml_result: ClassificationResult, email: Email) -> ClassificationResult:
|
|
"""
|
|
Use LLM to review low-confidence classification.
|
|
|
|
Args:
|
|
ml_result: Result from ML classifier
|
|
email: Original email
|
|
|
|
Returns:
|
|
Updated classification with LLM input
|
|
"""
|
|
if not self.llm_classifier or not self.llm_classifier.llm_available:
|
|
logger.warning(f"LLM not available for {email.id}, keeping ML result")
|
|
return ml_result
|
|
|
|
try:
|
|
email_dict = {
|
|
'email_id': email.id,
|
|
'subject': email.subject,
|
|
'sender': email.sender,
|
|
'body_snippet': email.body_snippet,
|
|
'has_attachments': email.has_attachments,
|
|
'ml_prediction': {
|
|
'category': ml_result.category,
|
|
'confidence': ml_result.confidence
|
|
}
|
|
}
|
|
|
|
llm_result = self.llm_classifier.classify(email_dict)
|
|
|
|
self.stats.llm_classified += 1
|
|
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category=llm_result.get('category', ml_result.category),
|
|
confidence=llm_result.get('confidence', 0.8),
|
|
method='llm',
|
|
metadata={
|
|
'llm_reasoning': llm_result.get('reasoning', ''),
|
|
'ml_original': ml_result.category,
|
|
'ml_confidence': ml_result.confidence
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"LLM review failed for {email.id}: {e}")
|
|
# Fall back to ML result
|
|
return ml_result
|
|
|
|
def _try_hard_rules(self, email: Email) -> Optional[ClassificationResult]:
|
|
"""Apply hard pattern-matching rules."""
|
|
subject = (email.subject or "").lower()
|
|
body = (email.body_snippet or "").lower()
|
|
combined = f"{subject} {body}"
|
|
|
|
# Auth emails - high priority
|
|
if any(p in combined for p in ['verification code', 'otp', 'reset password', 'confirm identity']):
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='auth',
|
|
confidence=0.99,
|
|
method='rule'
|
|
)
|
|
|
|
# Finance - high priority
|
|
if email.sender_name and any(p in email.sender_name.lower() for p in ['bank', 'credit card', 'payment']):
|
|
if any(p in combined for p in ['statement', 'balance', 'account', 'invoice']):
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='finance',
|
|
confidence=0.98,
|
|
method='rule'
|
|
)
|
|
|
|
# Invoices/Receipts - transactional
|
|
if any(p in combined for p in ['invoice #', 'receipt #', 'order #', 'tracking #']):
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='transactional',
|
|
confidence=0.97,
|
|
method='rule'
|
|
)
|
|
|
|
# Obvious spam
|
|
if any(p in combined for p in ['unsubscribe', 'click here now', 'limited time offer']):
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='junk',
|
|
confidence=0.96,
|
|
method='rule'
|
|
)
|
|
|
|
# Meeting/Calendar
|
|
if any(p in combined for p in ['meeting at', 'zoom link', 'teams meeting', 'calendar invite']):
|
|
return ClassificationResult(
|
|
email_id=email.id,
|
|
category='work',
|
|
confidence=0.95,
|
|
method='rule'
|
|
)
|
|
|
|
return None
|
|
|
|
def get_stats(self) -> ClassificationStats:
|
|
"""Get classification statistics."""
|
|
return self.stats
|
|
|
|
def adjust_threshold(self, category: str, adjustment: float) -> None:
|
|
"""Dynamically adjust classification threshold."""
|
|
current = self.thresholds.get(category, self.thresholds['default'])
|
|
new_threshold = max(
|
|
self.config['classification']['min_threshold'],
|
|
min(
|
|
self.config['classification']['max_threshold'],
|
|
current + adjustment
|
|
)
|
|
)
|
|
self.thresholds[category] = new_threshold
|
|
logger.debug(f"Adjusted threshold for {category}: {current:.2f} -> {new_threshold:.2f}")
|