Fix embedding bottleneck with batched feature extraction
Performance Improvements: - Extract features in batches (512 emails/batch) instead of one-at-a-time - Reduced embedding API calls from 10,000 to 20 for 10k emails - 10x faster classification: 4 minutes -> 24 seconds Changes: - cli.py: Use extract_batch() for all feature extraction - adaptive_classifier.py: Add classify_with_features() method - trainer.py: Set LightGBM num_threads to 28 Performance Results (10k emails): - Batch 512: 23.6 seconds (423 emails/sec) - Batch 1024: 22.1 seconds (453 emails/sec) - Batch 2048: 21.9 seconds (457 emails/sec) Selected batch_size=512 for balance of speed and memory. Breakdown for 10k emails: - Email parsing: 0.5s - Embedding (batched): 20s (20 API calls) - ML classification: 0.7s - Export: 0.02s - Total: ~24s
This commit is contained in:
parent
53174a34eb
commit
1992799b25
@ -138,7 +138,7 @@ class ModelTrainer:
|
||||
'bagging_fraction': 0.8,
|
||||
'bagging_freq': 5,
|
||||
'verbose': -1,
|
||||
'num_threads': -1
|
||||
'num_threads': 28
|
||||
}
|
||||
|
||||
self.model = lgb.train(
|
||||
|
||||
@ -179,6 +179,90 @@ class AdaptiveClassifier:
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def classify_with_features(self, email: Email, features: Dict[str, Any]) -> ClassificationResult:
|
||||
"""
|
||||
Classify email using pre-extracted features (for batched processing).
|
||||
|
||||
Args:
|
||||
email: Email object
|
||||
features: Pre-extracted features from extract_batch()
|
||||
|
||||
Returns:
|
||||
Classification result
|
||||
"""
|
||||
self.stats.total_emails += 1
|
||||
|
||||
# Step 1: Try hard rules
|
||||
rule_result = self._try_hard_rules(email)
|
||||
if rule_result:
|
||||
self.stats.rule_matched += 1
|
||||
return rule_result
|
||||
|
||||
# Step 2: ML classification with pre-extracted embedding
|
||||
try:
|
||||
ml_result = self.ml_classifier.predict(features.get('embedding'))
|
||||
|
||||
if not ml_result or ml_result.get('error'):
|
||||
logger.warning(f"ML classification error for {email.id}")
|
||||
return ClassificationResult(
|
||||
email_id=email.id,
|
||||
category='unknown',
|
||||
confidence=0.0,
|
||||
method='error',
|
||||
error='ML classification failed'
|
||||
)
|
||||
|
||||
category = ml_result.get('category', 'unknown')
|
||||
confidence = ml_result.get('confidence', 0.0)
|
||||
|
||||
# Check if above threshold
|
||||
threshold = self.thresholds.get(category, self.thresholds['default'])
|
||||
|
||||
if confidence >= threshold:
|
||||
# High confidence: Accept ML classification
|
||||
self.stats.ml_classified += 1
|
||||
return ClassificationResult(
|
||||
email_id=email.id,
|
||||
category=category,
|
||||
confidence=confidence,
|
||||
method='ml',
|
||||
probabilities=ml_result.get('probabilities', {})
|
||||
)
|
||||
else:
|
||||
# Low confidence: Queue for LLM (unless disabled)
|
||||
logger.debug(f"Low confidence for {email.id}: {category} ({confidence:.2f})")
|
||||
self.stats.needs_review += 1
|
||||
|
||||
if self.disable_llm_fallback:
|
||||
# Just return ML result without LLM fallback
|
||||
return ClassificationResult(
|
||||
email_id=email.id,
|
||||
category=category,
|
||||
confidence=confidence,
|
||||
method='ml',
|
||||
needs_review=False,
|
||||
probabilities=ml_result.get('probabilities', {})
|
||||
)
|
||||
else:
|
||||
return ClassificationResult(
|
||||
email_id=email.id,
|
||||
category=category,
|
||||
confidence=confidence,
|
||||
method='ml',
|
||||
needs_review=True,
|
||||
probabilities=ml_result.get('probabilities', {})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Classification error for {email.id}: {e}")
|
||||
return ClassificationResult(
|
||||
email_id=email.id,
|
||||
category='unknown',
|
||||
confidence=0.0,
|
||||
method='error',
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def classify_batch(self, emails: List[Email]) -> List[ClassificationResult]:
|
||||
"""Classify batch of emails."""
|
||||
results = []
|
||||
|
||||
15
src/cli.py
15
src/cli.py
@ -251,13 +251,22 @@ def run(
|
||||
|
||||
# Classify emails
|
||||
logger.info("Starting classification")
|
||||
|
||||
# Batch size for embedding extraction (larger = fewer API calls but more memory)
|
||||
batch_size = 512
|
||||
logger.info(f"Extracting features in batches (batch_size={batch_size})...")
|
||||
|
||||
# Extract all features in batches (MUCH faster than one-at-a-time)
|
||||
all_features = feature_extractor.extract_batch(emails, batch_size=batch_size)
|
||||
|
||||
logger.info(f"Feature extraction complete, classifying {len(emails)} emails...")
|
||||
results = []
|
||||
|
||||
for i, email in enumerate(emails):
|
||||
if (i + 1) % 100 == 0:
|
||||
for i, (email, features) in enumerate(zip(emails, all_features)):
|
||||
if (i + 1) % 1000 == 0:
|
||||
logger.info(f"Progress: {i+1}/{len(emails)}")
|
||||
|
||||
result = adaptive_classifier.classify(email)
|
||||
result = adaptive_classifier.classify_with_features(email, features)
|
||||
|
||||
# If low confidence and LLM available: Use LLM
|
||||
if result.needs_review and llm.is_available():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user