Fix embedding bottleneck with batched feature extraction

Performance Improvements:
- Extract features in batches (512 emails/batch) instead of one-at-a-time
- Reduced embedding API calls from 10,000 to 20 for 10k emails
- 10x faster classification: 4 minutes -> 24 seconds

Changes:
- cli.py: Use extract_batch() for all feature extraction
- adaptive_classifier.py: Add classify_with_features() method
- trainer.py: Set LightGBM num_threads to 28

Performance Results (10k emails):
- Batch 512: 23.6 seconds (423 emails/sec)
- Batch 1024: 22.1 seconds (453 emails/sec)
- Batch 2048: 21.9 seconds (457 emails/sec)

Selected batch_size=512 for balance of speed and memory.

Breakdown for 10k emails:
- Email parsing: 0.5s
- Embedding (batched): 20s (20 API calls)
- ML classification: 0.7s
- Export: 0.02s
- Total: ~24s
This commit is contained in:
FSSCoding 2025-10-25 15:39:45 +11:00
parent 53174a34eb
commit 1992799b25
3 changed files with 97 additions and 4 deletions

View File

@ -138,7 +138,7 @@ class ModelTrainer:
'bagging_fraction': 0.8, 'bagging_fraction': 0.8,
'bagging_freq': 5, 'bagging_freq': 5,
'verbose': -1, 'verbose': -1,
'num_threads': -1 'num_threads': 28
} }
self.model = lgb.train( self.model = lgb.train(

View File

@ -179,6 +179,90 @@ class AdaptiveClassifier:
error=str(e) error=str(e)
) )
def classify_with_features(self, email: Email, features: Dict[str, Any]) -> ClassificationResult:
"""
Classify email using pre-extracted features (for batched processing).
Args:
email: Email object
features: Pre-extracted features from extract_batch()
Returns:
Classification result
"""
self.stats.total_emails += 1
# Step 1: Try hard rules
rule_result = self._try_hard_rules(email)
if rule_result:
self.stats.rule_matched += 1
return rule_result
# Step 2: ML classification with pre-extracted embedding
try:
ml_result = self.ml_classifier.predict(features.get('embedding'))
if not ml_result or ml_result.get('error'):
logger.warning(f"ML classification error for {email.id}")
return ClassificationResult(
email_id=email.id,
category='unknown',
confidence=0.0,
method='error',
error='ML classification failed'
)
category = ml_result.get('category', 'unknown')
confidence = ml_result.get('confidence', 0.0)
# Check if above threshold
threshold = self.thresholds.get(category, self.thresholds['default'])
if confidence >= threshold:
# High confidence: Accept ML classification
self.stats.ml_classified += 1
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
probabilities=ml_result.get('probabilities', {})
)
else:
# Low confidence: Queue for LLM (unless disabled)
logger.debug(f"Low confidence for {email.id}: {category} ({confidence:.2f})")
self.stats.needs_review += 1
if self.disable_llm_fallback:
# Just return ML result without LLM fallback
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
needs_review=False,
probabilities=ml_result.get('probabilities', {})
)
else:
return ClassificationResult(
email_id=email.id,
category=category,
confidence=confidence,
method='ml',
needs_review=True,
probabilities=ml_result.get('probabilities', {})
)
except Exception as e:
logger.error(f"Classification error for {email.id}: {e}")
return ClassificationResult(
email_id=email.id,
category='unknown',
confidence=0.0,
method='error',
error=str(e)
)
def classify_batch(self, emails: List[Email]) -> List[ClassificationResult]: def classify_batch(self, emails: List[Email]) -> List[ClassificationResult]:
"""Classify batch of emails.""" """Classify batch of emails."""
results = [] results = []

View File

@ -251,13 +251,22 @@ def run(
# Classify emails # Classify emails
logger.info("Starting classification") logger.info("Starting classification")
# Batch size for embedding extraction (larger = fewer API calls but more memory)
batch_size = 512
logger.info(f"Extracting features in batches (batch_size={batch_size})...")
# Extract all features in batches (MUCH faster than one-at-a-time)
all_features = feature_extractor.extract_batch(emails, batch_size=batch_size)
logger.info(f"Feature extraction complete, classifying {len(emails)} emails...")
results = [] results = []
for i, email in enumerate(emails): for i, (email, features) in enumerate(zip(emails, all_features)):
if (i + 1) % 100 == 0: if (i + 1) % 1000 == 0:
logger.info(f"Progress: {i+1}/{len(emails)}") logger.info(f"Progress: {i+1}/{len(emails)}")
result = adaptive_classifier.classify(email) result = adaptive_classifier.classify_with_features(email, features)
# If low confidence and LLM available: Use LLM # If low confidence and LLM available: Use LLM
if result.needs_review and llm.is_available(): if result.needs_review and llm.is_available():