"""Learn sender-specific patterns and rules.""" import logging from typing import Dict, List, Any, Tuple from collections import defaultdict logger = logging.getLogger(__name__) class SenderPattern: """Pattern for a specific sender.""" def __init__(self, sender: str): """Initialize sender pattern.""" self.sender = sender self.categories = defaultdict(int) # category -> count self.total_emails = 0 self.confidence_avg = 0.0 def record_classification(self, category: str, confidence: float) -> None: """Record a classification for this sender.""" self.categories[category] += 1 self.total_emails += 1 # Update running average confidence self.confidence_avg = ( (self.confidence_avg * (self.total_emails - 1) + confidence) / self.total_emails ) def get_predicted_category(self) -> Tuple[str, float]: """ Get predicted category for this sender based on history. Returns: (category, confidence) where confidence is how confident we are """ if not self.categories: return None, 0.0 # Most common category top_category = max(self.categories.items(), key=lambda x: x[1]) category = top_category[0] count = top_category[1] # Confidence = proportion of emails in top category confidence = count / self.total_emails return category, confidence def is_confident(self, threshold: float = 0.8) -> bool: """Check if we're confident about this sender's category.""" _, confidence = self.get_predicted_category() return confidence >= threshold class PatternLearner: """ Learn sender-specific patterns to improve classification. Tracks: - What category emails from each sender typically belong to - Sender domain patterns - Special cases and exceptions """ def __init__(self, min_samples_per_sender: int = 3): """Initialize pattern learner.""" self.min_samples_per_sender = min_samples_per_sender self.sender_patterns: Dict[str, SenderPattern] = {} self.domain_patterns: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) def record_classification( self, sender: str, category: str, confidence: float ) -> None: """Record a classification result.""" # Track sender pattern if sender not in self.sender_patterns: self.sender_patterns[sender] = SenderPattern(sender) self.sender_patterns[sender].record_classification(category, confidence) # Track domain pattern if '@' in sender: domain = sender.split('@')[1].lower() self.domain_patterns[domain][category] += 1 def predict_category(self, sender: str) -> Tuple[str, float, str]: """ Predict category for email from sender based on learned patterns. Returns: (category, confidence, method) or (None, 0.0, 'no_pattern') """ if sender not in self.sender_patterns: return None, 0.0, 'no_pattern' pattern = self.sender_patterns[sender] # Need minimum samples to make prediction if pattern.total_emails < self.min_samples_per_sender: return None, 0.0, 'insufficient_samples' # Check if confident if not pattern.is_confident(threshold=0.7): return None, 0.0, 'low_confidence' category, confidence = pattern.get_predicted_category() return category, confidence, 'sender_pattern' def get_domain_category(self, domain: str) -> Tuple[str, float]: """Get most common category for emails from domain.""" if domain not in self.domain_patterns or not self.domain_patterns[domain]: return None, 0.0 categories = self.domain_patterns[domain] total = sum(categories.values()) # Most common category top_category = max(categories.items(), key=lambda x: x[1]) category = top_category[0] confidence = top_category[1] / total return category, confidence def get_learned_senders(self, min_emails: int = 3) -> Dict[str, Dict[str, Any]]: """Get senders with enough data to have learned patterns.""" learned = {} for sender, pattern in self.sender_patterns.items(): if pattern.total_emails >= min_emails: category, confidence = pattern.get_predicted_category() if confidence > 0.7: # Only confident patterns learned[sender] = { 'category': category, 'confidence': confidence, 'total_emails': pattern.total_emails, 'category_distribution': dict(pattern.categories) } return learned def get_domain_patterns(self, min_emails: int = 10) -> Dict[str, Dict[str, Any]]: """Get domain patterns with sufficient data.""" patterns = {} for domain, categories in self.domain_patterns.items(): total = sum(categories.values()) if total >= min_emails: top_category = max(categories.items(), key=lambda x: x[1]) category = top_category[0] confidence = top_category[1] / total if confidence > 0.6: # Only confident patterns patterns[domain] = { 'category': category, 'confidence': confidence, 'total_emails': total, 'distribution': dict(categories) } return patterns def suggest_hard_rule(self, sender: str) -> Dict[str, Any]: """ Suggest a hard rule for a sender. If a sender's emails are consistently in one category, we can add a hard rule to instantly classify future emails. """ if sender not in self.sender_patterns: return None pattern = self.sender_patterns[sender] # Need high confidence to suggest rule category, confidence = pattern.get_predicted_category() if confidence < 0.95: # Very high confidence required return None if pattern.total_emails < 10: # Need substantial data return None return { 'sender': sender, 'category': category, 'confidence': confidence, 'emails_seen': pattern.total_emails, 'recommendation': f'Add hard rule: emails from {sender} → {category}' } def get_stats(self) -> Dict[str, Any]: """Get learning statistics.""" learned_senders = self.get_learned_senders(min_emails=3) domain_patterns = self.get_domain_patterns(min_emails=10) return { 'total_senders': len(self.sender_patterns), 'learned_senders': len(learned_senders), 'learned_domains': len(domain_patterns), 'total_classifications': sum( p.total_emails for p in self.sender_patterns.values() ), 'suggested_hard_rules': sum( 1 for sender in self.sender_patterns if self.suggest_hard_rule(sender) is not None ) }