email-sorter/src/classification/adaptive_classifier.py
FSSCoding 1992799b25 Fix embedding bottleneck with batched feature extraction
Performance Improvements:
- Extract features in batches (512 emails/batch) instead of one-at-a-time
- Reduced embedding API calls from 10,000 to 20 for 10k emails
- 10x faster classification: 4 minutes -> 24 seconds

Changes:
- cli.py: Use extract_batch() for all feature extraction
- adaptive_classifier.py: Add classify_with_features() method
- trainer.py: Set LightGBM num_threads to 28

Performance Results (10k emails):
- Batch 512: 23.6 seconds (423 emails/sec)
- Batch 1024: 22.1 seconds (453 emails/sec)
- Batch 2048: 21.9 seconds (457 emails/sec)

Selected batch_size=512 for balance of speed and memory.

Breakdown for 10k emails:
- Email parsing: 0.5s
- Embedding (batched): 20s (20 API calls)
- ML classification: 0.7s
- Export: 0.02s
- Total: ~24s
2025-10-25 15:39:45 +11:00

393 lines
14 KiB
Python

"""
Adaptive classifier that orchestrates ML + LLM classification.
Strategy:
1. Hard rules: Fast pattern matching -> instant classification (10% of emails)
2. ML classifier: LightGBM with confidence threshold (85% of emails)
3. LLM review: Low-confidence emails sent for LLM review (5% of emails)
4. Dynamic threshold adjustment: Learn from LLM feedback
"""
import logging
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from src.email_providers.base import Email, ClassificationResult
from src.classification.feature_extractor import FeatureExtractor
from src.classification.ml_classifier import MLClassifier
from src.classification.llm_classifier import LLMClassifier
logger = logging.getLogger(__name__)
@dataclass
class ClassificationStats:
"""Track classification statistics."""
total_emails: int = 0
rule_matched: int = 0
ml_classified: int = 0
llm_classified: int = 0
needs_review: int = 0
def accuracy_estimate(self) -> float:
"""Estimate accuracy based on classification method distribution."""
if self.total_emails == 0:
return 0.0
# Conservative estimate
rule_accuracy = 0.99 # Hard rules are very accurate
ml_accuracy = 0.92 # ML classifier baseline
llm_accuracy = 0.95 # LLM is very accurate
weighted = (
(self.rule_matched * rule_accuracy +
self.ml_classified * ml_accuracy +
self.llm_classified * llm_accuracy) / self.total_emails
)
return weighted
def __str__(self) -> str:
return (
f"Stats: {self.total_emails} emails | "
f"Rules: {self.rule_matched} | "
f"ML: {self.ml_classified} | "
f"LLM: {self.llm_classified} | "
f"Est. Accuracy: {self.accuracy_estimate():.1%}"
)
class AdaptiveClassifier:
"""
Hybrid classifier combining hard rules, ML, and LLM.
This is the main classification orchestrator.
"""
def __init__(
self,
feature_extractor: FeatureExtractor,
ml_classifier: MLClassifier,
llm_classifier: Optional[LLMClassifier],
categories: Dict[str, Dict],
config: Dict[str, Any],
disable_llm_fallback: bool = False
):
"""Initialize adaptive classifier."""
self.feature_extractor = feature_extractor
self.ml_classifier = ml_classifier
self.llm_classifier = llm_classifier
self.categories = categories
self.config = config
self.disable_llm_fallback = disable_llm_fallback
self.thresholds = self._init_thresholds()
self.stats = ClassificationStats()
def _init_thresholds(self) -> Dict[str, float]:
"""Initialize classification thresholds per category."""
thresholds = {}
for category, cat_config in self.categories.items():
threshold = cat_config.get('threshold', 0.55)
thresholds[category] = threshold
default = self.config.get('classification', {}).get('default_threshold', 0.55)
thresholds['default'] = default
logger.info(f"Initialized thresholds: {thresholds}")
return thresholds
def classify(self, email: Email) -> ClassificationResult:
"""
Classify an email using adaptive strategy.
Process:
1. Try hard rules first
2. If no rule match: ML classification
3. If low confidence: Queue for LLM
"""
self.stats.total_emails += 1
# Step 1: Try hard rules
rule_result = self._try_hard_rules(email)
if rule_result:
self.stats.rule_matched += 1
return rule_result
# Step 2: ML classification
try:
features = self.feature_extractor.extract(email)
ml_result = self.ml_classifier.predict(features.get('embedding'))
if not ml_result or ml_result.get('error'):
logger.warning(f"ML classification error for {email.id}")
return ClassificationResult(
email_id=email.id,
category='unknown',
confidence=0.0,
method='error',
error='ML classification failed'
)
category = ml_result.get('category', 'unknown')
confidence = ml_result.get('confidence', 0.0)
# Check if above threshold
threshold = self.thresholds.get(category, self.thresholds['default'])
if confidence >= threshold:
# High confidence: Accept ML classification
self.stats.ml_classified += 1
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
probabilities=ml_result.get('probabilities', {})
)
else:
# Low confidence: Queue for LLM (unless disabled)
logger.debug(f"Low confidence for {email.id}: {category} ({confidence:.2f})")
self.stats.needs_review += 1
if self.disable_llm_fallback:
# Just return ML result without LLM fallback
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
needs_review=False,
probabilities=ml_result.get('probabilities', {})
)
else:
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
needs_review=True,
probabilities=ml_result.get('probabilities', {})
)
except Exception as e:
logger.error(f"Classification error for {email.id}: {e}")
return ClassificationResult(
email_id=email.id,
category='unknown',
confidence=0.0,
method='error',
error=str(e)
)
def classify_with_features(self, email: Email, features: Dict[str, Any]) -> ClassificationResult:
"""
Classify email using pre-extracted features (for batched processing).
Args:
email: Email object
features: Pre-extracted features from extract_batch()
Returns:
Classification result
"""
self.stats.total_emails += 1
# Step 1: Try hard rules
rule_result = self._try_hard_rules(email)
if rule_result:
self.stats.rule_matched += 1
return rule_result
# Step 2: ML classification with pre-extracted embedding
try:
ml_result = self.ml_classifier.predict(features.get('embedding'))
if not ml_result or ml_result.get('error'):
logger.warning(f"ML classification error for {email.id}")
return ClassificationResult(
email_id=email.id,
category='unknown',
confidence=0.0,
method='error',
error='ML classification failed'
)
category = ml_result.get('category', 'unknown')
confidence = ml_result.get('confidence', 0.0)
# Check if above threshold
threshold = self.thresholds.get(category, self.thresholds['default'])
if confidence >= threshold:
# High confidence: Accept ML classification
self.stats.ml_classified += 1
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
probabilities=ml_result.get('probabilities', {})
)
else:
# Low confidence: Queue for LLM (unless disabled)
logger.debug(f"Low confidence for {email.id}: {category} ({confidence:.2f})")
self.stats.needs_review += 1
if self.disable_llm_fallback:
# Just return ML result without LLM fallback
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
needs_review=False,
probabilities=ml_result.get('probabilities', {})
)
else:
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
needs_review=True,
probabilities=ml_result.get('probabilities', {})
)
except Exception as e:
logger.error(f"Classification error for {email.id}: {e}")
return ClassificationResult(
email_id=email.id,
category='unknown',
confidence=0.0,
method='error',
error=str(e)
)
def classify_batch(self, emails: List[Email]) -> List[ClassificationResult]:
"""Classify batch of emails."""
results = []
for email in emails:
result = self.classify(email)
results.append(result)
return results
def classify_with_llm(self, ml_result: ClassificationResult, email: Email) -> ClassificationResult:
"""
Use LLM to review low-confidence classification.
Args:
ml_result: Result from ML classifier
email: Original email
Returns:
Updated classification with LLM input
"""
if not self.llm_classifier or not self.llm_classifier.llm_available:
logger.warning(f"LLM not available for {email.id}, keeping ML result")
return ml_result
try:
email_dict = {
'email_id': email.id,
'subject': email.subject,
'sender': email.sender,
'body_snippet': email.body_snippet,
'has_attachments': email.has_attachments,
'ml_prediction': {
'category': ml_result.category,
'confidence': ml_result.confidence
}
}
llm_result = self.llm_classifier.classify(email_dict)
self.stats.llm_classified += 1
return ClassificationResult(
email_id=email.id,
category=llm_result.get('category', ml_result.category),
confidence=llm_result.get('confidence', 0.8),
method='llm',
metadata={
'llm_reasoning': llm_result.get('reasoning', ''),
'ml_original': ml_result.category,
'ml_confidence': ml_result.confidence
}
)
except Exception as e:
logger.error(f"LLM review failed for {email.id}: {e}")
# Fall back to ML result
return ml_result
def _try_hard_rules(self, email: Email) -> Optional[ClassificationResult]:
"""Apply hard pattern-matching rules."""
subject = (email.subject or "").lower()
body = (email.body_snippet or "").lower()
combined = f"{subject} {body}"
# Auth emails - high priority
if any(p in combined for p in ['verification code', 'otp', 'reset password', 'confirm identity']):
return ClassificationResult(
email_id=email.id,
category='auth',
confidence=0.99,
method='rule'
)
# Finance - high priority
if email.sender_name and any(p in email.sender_name.lower() for p in ['bank', 'credit card', 'payment']):
if any(p in combined for p in ['statement', 'balance', 'account', 'invoice']):
return ClassificationResult(
email_id=email.id,
category='finance',
confidence=0.98,
method='rule'
)
# Invoices/Receipts - transactional
if any(p in combined for p in ['invoice #', 'receipt #', 'order #', 'tracking #']):
return ClassificationResult(
email_id=email.id,
category='transactional',
confidence=0.97,
method='rule'
)
# Obvious spam
if any(p in combined for p in ['unsubscribe', 'click here now', 'limited time offer']):
return ClassificationResult(
email_id=email.id,
category='junk',
confidence=0.96,
method='rule'
)
# Meeting/Calendar
if any(p in combined for p in ['meeting at', 'zoom link', 'teams meeting', 'calendar invite']):
return ClassificationResult(
email_id=email.id,
category='work',
confidence=0.95,
method='rule'
)
return None
def get_stats(self) -> ClassificationStats:
"""Get classification statistics."""
return self.stats
def adjust_threshold(self, category: str, adjustment: float) -> None:
"""Dynamically adjust classification threshold."""
current = self.thresholds.get(category, self.thresholds['default'])
new_threshold = max(
self.config['classification']['min_threshold'],
min(
self.config['classification']['max_threshold'],
current + adjustment
)
)
self.thresholds[category] = new_threshold
logger.debug(f"Adjusted threshold for {category}: {current:.2f} -> {new_threshold:.2f}")