diff --git a/src/calibration/enron_parser.py b/src/calibration/enron_parser.py new file mode 100644 index 0000000..7b00490 --- /dev/null +++ b/src/calibration/enron_parser.py @@ -0,0 +1,163 @@ +"""Parse Enron dataset for training.""" +import logging +import os +import email +from pathlib import Path +from typing import List, Optional +from datetime import datetime +from email.utils import parsedate_to_datetime + +from src.email_providers.base import Email + +logger = logging.getLogger(__name__) + + +class EnronParser: + """ + Parse Enron email dataset. + + The Enron dataset is in maildir format: + - enron_mail_20150507/ + - maildir/ + - lastname-f/ + - folder/ + - 1, 2, 3, ... (individual emails) + """ + + def __init__(self, dataset_path: str): + """Initialize Enron parser.""" + self.dataset_path = Path(dataset_path) + + if not self.dataset_path.exists(): + raise ValueError(f"Dataset path not found: {self.dataset_path}") + + self.maildir_path = self.dataset_path / "maildir" + + if not self.maildir_path.exists(): + raise ValueError(f"Maildir not found at {self.maildir_path}") + + logger.info(f"Initialized Enron parser: {self.dataset_path}") + + def parse_emails(self, limit: Optional[int] = None) -> List[Email]: + """ + Parse emails from Enron dataset. + + Args: + limit: Maximum number of emails to parse + + Returns: + List of Email objects + """ + emails = [] + email_count = 0 + + logger.info(f"Starting Enron parsing (limit: {limit})") + + # Iterate through users + for user_dir in sorted(self.maildir_path.iterdir()): + if not user_dir.is_dir(): + continue + + # Iterate through folders + for folder_dir in sorted(user_dir.iterdir()): + if not folder_dir.is_dir(): + continue + + # Parse emails in folder + for email_file in sorted(folder_dir.iterdir()): + if email_file.is_file(): + try: + parsed_email = self._parse_email_file(email_file) + if parsed_email: + emails.append(parsed_email) + email_count += 1 + + if limit and email_count >= limit: + logger.info(f"Reached limit: {email_count} emails parsed") + return emails + + if email_count % 1000 == 0: + logger.info(f"Progress: {email_count} emails parsed") + + except Exception as e: + logger.debug(f"Error parsing {email_file}: {e}") + + logger.info(f"Parsing complete: {email_count} emails") + return emails + + def _parse_email_file(self, filepath: Path) -> Optional[Email]: + """Parse single email file.""" + try: + with open(filepath, 'rb') as f: + msg = email.message_from_bytes(f.read()) + + # Extract basic info + msg_id = str(filepath).replace('/', '_').replace('\\', '_') + subject = msg.get('subject', 'No Subject') + sender = msg.get('from', '') + date_str = msg.get('date') + + # Parse date + date = None + if date_str: + try: + date = parsedate_to_datetime(date_str) + except Exception: + pass + + # Extract body + body = self._extract_body(msg) + body_snippet = body[:500] if body else "" + + return Email( + id=msg_id, + subject=subject, + sender=sender, + date=date, + body=body, + body_snippet=body_snippet, + has_attachments=self._has_attachments(msg), + provider='enron' + ) + + except Exception as e: + logger.debug(f"Error parsing email file: {e}") + return None + + def _extract_body(self, msg: email.message.Message) -> str: + """Extract email body.""" + body = "" + + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == 'text/plain': + try: + payload = part.get_payload(decode=True) + if payload: + body = payload.decode('utf-8', errors='ignore') + break + except Exception: + pass + else: + try: + payload = msg.get_payload(decode=True) + if payload: + body = payload.decode('utf-8', errors='ignore') + else: + body = msg.get_payload(decode=False) + if isinstance(body, str): + pass + else: + body = str(body) + except Exception: + pass + + return body.strip() if isinstance(body, str) else "" + + def _has_attachments(self, msg: email.message.Message) -> bool: + """Check if email has attachments.""" + if msg.is_multipart(): + for part in msg.iter_attachments(): + if part.get_filename(): + return True + return False diff --git a/src/calibration/llm_analyzer.py b/src/calibration/llm_analyzer.py new file mode 100644 index 0000000..aaa1a6c --- /dev/null +++ b/src/calibration/llm_analyzer.py @@ -0,0 +1,141 @@ +"""LLM-based calibration analysis.""" +import logging +import json +import re +from typing import List, Dict, Any, Optional, Tuple + +from src.email_providers.base import Email +from src.llm.base import BaseLLMProvider + +logger = logging.getLogger(__name__) + + +class CalibrationAnalyzer: + """ + Use LLM to discover natural categories in email sample. + + This runs ONCE during calibration to understand what categories + exist naturally in this inbox. + """ + + def __init__( + self, + llm_provider: BaseLLMProvider, + config: Dict[str, Any] + ): + """Initialize calibration analyzer.""" + self.llm_provider = llm_provider + self.config = config + self.llm_available = llm_provider.is_available() + + if not self.llm_available: + logger.warning("LLM not available for calibration analysis") + + def discover_categories( + self, + sample_emails: List[Email] + ) -> Tuple[Dict[str, Any], List[Tuple[str, str]]]: + """ + Discover natural categories in email sample. + + Args: + sample_emails: Stratified sample of emails + + Returns: + (category_map, email_labels) where: + - category_map: discovered categories with descriptions + - email_labels: list of (email_id, assigned_category) + """ + if not self.llm_available: + logger.warning("LLM unavailable, using default categories") + return self._default_categories(), [] + + logger.info(f"Starting LLM category discovery on {len(sample_emails)} emails") + + # Batch emails for analysis + batch_size = 20 + discovered_categories = {} + email_labels = [] + + for batch_idx in range(0, len(sample_emails), batch_size): + batch = sample_emails[batch_idx:batch_idx + batch_size] + + try: + batch_results = self._analyze_batch(batch) + + # Merge categories + for category, desc in batch_results.get('categories', {}).items(): + if category not in discovered_categories: + discovered_categories[category] = desc + + # Collect labels + for email_id, category in batch_results.get('labels', []): + email_labels.append((email_id, category)) + + except Exception as e: + logger.error(f"Error analyzing batch: {e}") + + logger.info(f"Discovery complete: {len(discovered_categories)} categories found") + + return discovered_categories, email_labels + + def _analyze_batch(self, batch: List[Email]) -> Dict[str, Any]: + """Analyze single batch of emails.""" + # Build email summary + email_summary = "\n".join([ + f"Email {i+1}:\n" + f" From: {e.sender}\n" + f" Subject: {e.subject}\n" + f" Preview: {e.body_snippet[:100]}...\n" + for i, e in enumerate(batch) + ]) + + prompt = f"""Analyze these emails and identify natural categories they belong to. +For each email, assign ONE category. Create new categories as needed based on the emails. + +EMAILS: +{email_summary} + +Respond with JSON only: +{{ + "categories": {{"category_name": "brief description", ...}}, + "labels": [["email_1_id", "category_name"], ["email_2_id", "category_name"], ...] +}} +""" + + try: + response = self.llm_provider.complete( + prompt, + temperature=0.1, + max_tokens=1000 + ) + + return self._parse_response(response) + + except Exception as e: + logger.error(f"LLM analysis failed: {e}") + return {'categories': {}, 'labels': []} + + def _parse_response(self, response: str) -> Dict[str, Any]: + """Parse LLM JSON response.""" + try: + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + return json.loads(json_match.group()) + except json.JSONDecodeError as e: + logger.debug(f"JSON parse error: {e}") + + return {'categories': {}, 'labels': []} + + def _default_categories(self) -> Dict[str, Any]: + """Return default categories.""" + return { + 'junk': 'Spam and unwanted emails', + 'transactional': 'Receipts and confirmations', + 'auth': 'Authentication and security', + 'newsletters': 'Newsletters and subscriptions', + 'work': 'Work correspondence', + 'personal': 'Personal emails', + 'finance': 'Financial documents', + 'unknown': 'Unclassified' + } diff --git a/src/calibration/sampler.py b/src/calibration/sampler.py new file mode 100644 index 0000000..a582803 --- /dev/null +++ b/src/calibration/sampler.py @@ -0,0 +1,114 @@ +"""Email sampling for calibration.""" +import logging +import random +from typing import List, Tuple +from src.email_providers.base import Email + +logger = logging.getLogger(__name__) + + +class EmailSampler: + """Sample emails for calibration phase.""" + + def __init__(self, random_seed: int = 42): + """Initialize sampler.""" + self.random_seed = random_seed + random.seed(random_seed) + + def stratified_sample( + self, + emails: List[Email], + sample_size: int, + field_getter = None + ) -> Tuple[List[Email], List[Email]]: + """ + Stratified sampling based on sender domain type. + + Args: + emails: All emails + sample_size: Target sample size + field_getter: Function to group by (default: sender_domain_type) + + Returns: + (sample, remaining) + """ + if sample_size >= len(emails): + logger.warning(f"Sample size ({sample_size}) >= total ({len(emails)}), returning all") + return emails, [] + + if field_getter is None: + # Default: group by sender domain type + field_getter = lambda e: self._extract_domain_type(e) + + # Group emails by field + groups = {} + for email in emails: + try: + key = field_getter(email) + if key not in groups: + groups[key] = [] + groups[key].append(email) + except Exception: + if 'unknown' not in groups: + groups['unknown'] = [] + groups['unknown'].append(email) + + logger.info(f"Stratified sampling: {len(groups)} groups, total {len(emails)} emails") + + # Proportional sampling from each group + sample = [] + remaining = [] + + for group_name, group_emails in groups.items(): + proportion = len(group_emails) / len(emails) + group_sample_size = max(1, int(sample_size * proportion)) + + group_sample = random.sample(group_emails, min(group_sample_size, len(group_emails))) + group_remaining = [e for e in group_emails if e not in group_sample] + + sample.extend(group_sample) + remaining.extend(group_remaining) + + logger.debug(f"Group '{group_name}': {len(group_sample)} sampled, {len(group_remaining)} remaining") + + # Adjust if we didn't get exactly the right size + if len(sample) < sample_size and remaining: + shortage = sample_size - len(sample) + additional = random.sample(remaining, min(shortage, len(remaining))) + sample.extend(additional) + remaining = [e for e in remaining if e not in additional] + + logger.info(f"Stratified sample complete: {len(sample)} sampled") + + return sample[:sample_size], remaining + + def random_sample( + self, + emails: List[Email], + sample_size: int + ) -> Tuple[List[Email], List[Email]]: + """Random sampling without stratification.""" + if sample_size >= len(emails): + return emails, [] + + sample = random.sample(emails, sample_size) + remaining = [e for e in emails if e not in sample] + + logger.info(f"Random sample: {len(sample)} sampled") + return sample, remaining + + def _extract_domain_type(self, email: Email) -> str: + """Extract domain type for stratification.""" + sender = email.sender or "" + if '@' in sender: + domain = sender.split('@')[1].lower() + freemail_domains = {'gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com'} + + if domain in freemail_domains: + return 'freemail' + elif 'noreply' in sender.lower(): + return 'noreply' + else: + return 'corporate' + + return 'unknown' diff --git a/src/export/exporter.py b/src/export/exporter.py new file mode 100644 index 0000000..1e77d85 --- /dev/null +++ b/src/export/exporter.py @@ -0,0 +1,218 @@ +"""Export classification results.""" +import logging +import json +import csv +from pathlib import Path +from typing import List, Dict, Any +from datetime import datetime + +from src.email_providers.base import ClassificationResult + +logger = logging.getLogger(__name__) + + +class ResultsExporter: + """Export classification results in multiple formats.""" + + def __init__(self, output_dir: str = "results"): + """Initialize exporter.""" + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + def export_json( + self, + results: List[ClassificationResult], + metadata: Dict[str, Any], + filename: str = "results.json" + ) -> Path: + """Export results as JSON.""" + output_file = self.output_dir / filename + + data = { + 'metadata': metadata, + 'classifications': [ + { + 'email_id': r.email_id, + 'category': r.category, + 'confidence': r.confidence, + 'method': r.method, + 'needs_review': r.needs_review, + 'probabilities': r.probabilities, + 'error': r.error + } + for r in results + ] + } + + try: + with open(output_file, 'w') as f: + json.dump(data, f, indent=2) + logger.info(f"Exported JSON to {output_file}") + return output_file + except Exception as e: + logger.error(f"Error exporting JSON: {e}") + raise + + def export_csv( + self, + results: List[ClassificationResult], + filename: str = "results.csv" + ) -> Path: + """Export results as CSV.""" + output_file = self.output_dir / filename + + try: + with open(output_file, 'w', newline='') as f: + writer = csv.DictWriter( + f, + fieldnames=['email_id', 'category', 'confidence', 'method', 'needs_review', 'error'] + ) + writer.writeheader() + + for result in results: + writer.writerow({ + 'email_id': result.email_id, + 'category': result.category, + 'confidence': round(result.confidence, 3), + 'method': result.method, + 'needs_review': result.needs_review, + 'error': result.error or '' + }) + + logger.info(f"Exported CSV to {output_file}") + return output_file + except Exception as e: + logger.error(f"Error exporting CSV: {e}") + raise + + def export_by_category( + self, + results: List[ClassificationResult], + dirname: str = "by_category" + ) -> Path: + """Export results organized by category.""" + output_dir = self.output_dir / dirname + output_dir.mkdir(parents=True, exist_ok=True) + + # Group by category + by_category = {} + for result in results: + if result.category not in by_category: + by_category[result.category] = [] + by_category[result.category].append(result) + + # Export each category + for category, category_results in by_category.items(): + filename = output_dir / f"{category}.json" + + data = { + 'category': category, + 'count': len(category_results), + 'emails': [ + { + 'email_id': r.email_id, + 'confidence': r.confidence, + 'method': r.method + } + for r in category_results + ] + } + + try: + with open(filename, 'w') as f: + json.dump(data, f, indent=2) + logger.debug(f"Exported {len(category_results)} emails to {filename}") + except Exception as e: + logger.error(f"Error exporting {category}: {e}") + + logger.info(f"Exported results organized by category to {output_dir}") + return output_dir + + +class ReportGenerator: + """Generate classification reports.""" + + def __init__(self, output_dir: str = "results"): + """Initialize report generator.""" + self.output_dir = Path(output_dir) + + def generate_report( + self, + results: List[ClassificationResult], + metadata: Dict[str, Any], + filename: str = "report.txt" + ) -> Path: + """Generate text report.""" + output_file = self.output_dir / filename + + # Calculate statistics + by_category = {} + by_method = {} + needs_review_count = 0 + + for result in results: + # By category + if result.category not in by_category: + by_category[result.category] = 0 + by_category[result.category] += 1 + + # By method + if result.method not in by_method: + by_method[result.method] = 0 + by_method[result.method] += 1 + + # Needs review + if result.needs_review: + needs_review_count += 1 + + # Generate report + report_lines = [ + "=" * 80, + "EMAIL SORTER CLASSIFICATION REPORT", + "=" * 80, + "", + f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + f"Total Emails: {len(results):,}", + f"Accuracy Estimate: {metadata.get('accuracy_estimate', 0):.1%}", + "", + "CATEGORY DISTRIBUTION:", + "-" * 80, + ] + + # Sort by count + for category in sorted(by_category.keys(), key=lambda c: by_category[c], reverse=True): + count = by_category[category] + pct = (count / len(results) * 100) if results else 0 + report_lines.append(f" {category:20s} {count:7,d} ({pct:5.1f}%)") + + report_lines.extend([ + "", + "CLASSIFICATION METHOD BREAKDOWN:", + "-" * 80, + ]) + + for method in sorted(by_method.keys()): + count = by_method[method] + pct = (count / len(results) * 100) if results else 0 + report_lines.append(f" {method:20s} {count:7,d} ({pct:5.1f}%)") + + report_lines.extend([ + "", + "QUALITY METRICS:", + "-" * 80, + f" Needs Review: {needs_review_count:,} ({(needs_review_count/len(results)*100) if results else 0:.1f}%)", + f" Processing Time: {metadata.get('processing_time', 'unknown')}s", + "", + "=" * 80, + ]) + + # Write report + try: + with open(output_file, 'w') as f: + f.write("\n".join(report_lines)) + + logger.info(f"Report generated: {output_file}") + return output_file + except Exception as e: + logger.error(f"Error generating report: {e}") + raise diff --git a/src/orchestration.py b/src/orchestration.py new file mode 100644 index 0000000..7eb5e5e --- /dev/null +++ b/src/orchestration.py @@ -0,0 +1,278 @@ +"""Main orchestration for complete pipeline.""" +import logging +import time +from typing import Dict, List, Any, Optional +from pathlib import Path + +from src.utils.config import Config, load_config, load_categories +from src.email_providers.base import Email, ClassificationResult +from src.classification.feature_extractor import FeatureExtractor +from src.classification.ml_classifier import MLClassifier +from src.classification.llm_classifier import LLMClassifier +from src.classification.adaptive_classifier import AdaptiveClassifier +from src.processing.bulk_processor import BulkProcessor +from src.calibration.sampler import EmailSampler +from src.calibration.llm_analyzer import CalibrationAnalyzer +from src.llm.base import BaseLLMProvider +from src.export.exporter import ResultsExporter, ReportGenerator + +logger = logging.getLogger(__name__) + + +class EmailSorterOrchestrator: + """ + Main orchestrator for complete email sorting pipeline. + + Pipeline phases: + 1. Calibration: Sample emails, discover categories, train ML + 2. Bulk processing: Classify all emails with checkpointing + 3. LLM review: Process uncertain cases + 4. Export: Results, reports, and sync + """ + + def __init__( + self, + config: Config, + llm_provider: Optional[BaseLLMProvider] = None + ): + """Initialize orchestrator.""" + self.config = config + self.llm_provider = llm_provider + self.categories = load_categories() + + # Components (lazy-initialized) + self.feature_extractor = None + self.ml_classifier = None + self.llm_classifier = None + self.adaptive_classifier = None + self.bulk_processor = None + self.exporter = None + self.report_generator = None + + self.start_time = None + self.stats = {} + + def initialize_components(self) -> None: + """Initialize all classifier components.""" + logger.info("Initializing classifier components...") + + # Feature extraction + self.feature_extractor = FeatureExtractor(self.config.features.dict()) + + # ML Classifier + self.ml_classifier = MLClassifier() + + # LLM Classifier + if self.llm_provider: + self.llm_classifier = LLMClassifier( + self.llm_provider, + self.categories, + self.config.dict() + ) + else: + logger.warning("No LLM provider, LLM classification disabled") + self.llm_classifier = None + + # Adaptive classifier + self.adaptive_classifier = AdaptiveClassifier( + self.feature_extractor, + self.ml_classifier, + self.llm_classifier, + self.categories, + self.config.dict() + ) + + # Bulk processor + self.bulk_processor = BulkProcessor( + self.adaptive_classifier, + batch_size=self.config.processing.batch_size, + checkpoint_dir=self.config.processing.checkpoint_dir, + checkpoint_interval=self.config.processing.checkpoint_interval + ) + + # Export + self.exporter = ResultsExporter(self.config.export.output_dir) + self.report_generator = ReportGenerator(self.config.export.output_dir) + + logger.info("Initialization complete") + + def run_calibration(self, sample_emails: List[Email]) -> None: + """Run calibration phase.""" + logger.info("=" * 80) + logger.info("PHASE 1: CALIBRATION") + logger.info("=" * 80) + + # Analyze with LLM if available + if self.llm_classifier and self.llm_classifier.llm_available: + logger.info("Analyzing sample with LLM...") + analyzer = CalibrationAnalyzer(self.llm_provider, self.config.dict()) + categories, labels = analyzer.discover_categories(sample_emails) + logger.info(f"Discovered {len(categories)} categories from LLM") + + # TODO: Use discovered categories to adjust thresholds + + logger.info("Calibration phase complete") + + def run_bulk_processing( + self, + emails: List[Email], + resume: bool = True, + progress_callback=None + ) -> List[ClassificationResult]: + """Run bulk processing phase.""" + logger.info("=" * 80) + logger.info("PHASE 2: BULK PROCESSING") + logger.info("=" * 80) + + results, llm_queue = self.bulk_processor.process( + emails, + resume=resume, + progress_callback=progress_callback + ) + + self.stats['results'] = results + self.stats['llm_queue'] = llm_queue + + return results + + def run_llm_review(self, emails_dict: Dict[str, Email]) -> None: + """Run LLM review phase for uncertain classifications.""" + if not self.llm_classifier or not self.llm_classifier.llm_available: + logger.warning("LLM unavailable, skipping LLM review phase") + return + + logger.info("=" * 80) + logger.info("PHASE 3: LLM REVIEW") + logger.info("=" * 80) + + llm_queue = self.stats.get('llm_queue', []) + if not llm_queue: + logger.info("No emails need LLM review") + return + + logger.info(f"Processing {len(llm_queue)} emails with LLM...") + + results = self.stats.get('results', []) + + # Build email dict for lookup + email_lookup = {e.id: e for e in emails_dict.values()} + + # Update results with LLM review + for result in results: + if result.email_id in llm_queue: + email = email_lookup.get(result.email_id) + if email: + updated_result = self.adaptive_classifier.classify_with_llm(result, email) + # Replace in results + idx = results.index(result) + results[idx] = updated_result + + logger.info("LLM review phase complete") + + def run_export(self, results: List[ClassificationResult]) -> Dict[str, Path]: + """Run export phase.""" + logger.info("=" * 80) + logger.info("PHASE 4: EXPORT & REPORTING") + logger.info("=" * 80) + + # Prepare metadata + metadata = { + 'total_emails': len(results), + 'accuracy_estimate': self.adaptive_classifier.get_stats().accuracy_estimate(), + 'processing_time': int(time.time() - self.start_time) if self.start_time else 0, + 'classification_stats': { + 'rule_matched': self.adaptive_classifier.get_stats().rule_matched, + 'ml_classified': self.adaptive_classifier.get_stats().ml_classified, + 'llm_classified': self.adaptive_classifier.get_stats().llm_classified, + 'needs_review': self.adaptive_classifier.get_stats().needs_review, + } + } + + # Export results + export_files = {} + + try: + export_files['json'] = self.exporter.export_json(results, metadata) + export_files['csv'] = self.exporter.export_csv(results) + export_files['by_category'] = self.exporter.export_by_category(results) + except Exception as e: + logger.error(f"Error during export: {e}") + + # Generate report + try: + export_files['report'] = self.report_generator.generate_report(results, metadata) + except Exception as e: + logger.error(f"Error generating report: {e}") + + logger.info("Export phase complete") + return export_files + + def run_full_pipeline( + self, + all_emails: List[Email], + sample_size: int = 1500, + resume: bool = True, + progress_callback=None + ) -> Dict[str, Any]: + """ + Run complete pipeline from start to finish. + + Args: + all_emails: All emails to process + sample_size: Size of calibration sample + resume: Resume from checkpoint if exists + progress_callback: Function for progress updates + + Returns: + Pipeline results + """ + self.start_time = time.time() + + logger.info("=" * 80) + logger.info("EMAIL SORTER v1.0 - FULL PIPELINE") + logger.info(f"Total emails: {len(all_emails):,}") + logger.info("=" * 80) + + # Initialize + self.initialize_components() + + # Phase 1: Calibration + sampler = EmailSampler() + sample_emails, remaining_emails = sampler.stratified_sample( + all_emails, + sample_size + ) + + self.run_calibration(sample_emails) + + # Phase 2: Bulk processing + results = self.run_bulk_processing( + all_emails, # Process all, not just remaining + resume=resume, + progress_callback=progress_callback + ) + + # Phase 3: LLM review + emails_dict = {e.id: e for e in all_emails} + self.run_llm_review(emails_dict) + + # Phase 4: Export + export_files = self.run_export(results) + + # Summary + elapsed = time.time() - self.start_time + logger.info("=" * 80) + logger.info("PIPELINE COMPLETE") + logger.info(f"Time: {elapsed:.1f}s ({elapsed/60:.1f}m)") + logger.info(f"Accuracy Estimate: {self.adaptive_classifier.get_stats().accuracy_estimate():.1%}") + logger.info(f"Results: {export_files}") + logger.info("=" * 80) + + return { + 'success': True, + 'elapsed_time': elapsed, + 'total_emails': len(all_emails), + 'results_processed': len(results), + 'export_files': export_files, + 'stats': self.stats + } diff --git a/src/processing/bulk_processor.py b/src/processing/bulk_processor.py new file mode 100644 index 0000000..567ae14 --- /dev/null +++ b/src/processing/bulk_processor.py @@ -0,0 +1,209 @@ +"""Bulk email processing pipeline.""" +import logging +from typing import List, Dict, Any, Optional, Callable +from dataclasses import dataclass, field +from pathlib import Path +import json +from datetime import datetime + +from src.email_providers.base import Email, ClassificationResult +from src.classification.adaptive_classifier import AdaptiveClassifier + +logger = logging.getLogger(__name__) + + +@dataclass +class ProcessingCheckpoint: + """Checkpoint for resumable processing.""" + total_emails: int + processed_count: int + completed_emails: List[str] = field(default_factory=list) + queued_for_llm: List[str] = field(default_factory=list) + failed_emails: List[Dict[str, Any]] = field(default_factory=list) + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + 'total_emails': self.total_emails, + 'processed_count': self.processed_count, + 'completed_emails': self.completed_emails, + 'queued_for_llm': self.queued_for_llm, + 'failed_emails': self.failed_emails, + 'timestamp': self.timestamp, + 'progress': f"{self.processed_count}/{self.total_emails}" + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ProcessingCheckpoint": + """Create from dictionary.""" + return cls( + total_emails=data.get('total_emails', 0), + processed_count=data.get('processed_count', 0), + completed_emails=data.get('completed_emails', []), + queued_for_llm=data.get('queued_for_llm', []), + failed_emails=data.get('failed_emails', []), + timestamp=data.get('timestamp', datetime.now().isoformat()) + ) + + +class BulkProcessor: + """Process large batches of emails with checkpointing.""" + + def __init__( + self, + classifier: AdaptiveClassifier, + batch_size: int = 100, + checkpoint_dir: str = "checkpoints", + checkpoint_interval: int = 1000 + ): + """Initialize bulk processor.""" + self.classifier = classifier + self.batch_size = batch_size + self.checkpoint_dir = checkpoint_dir + self.checkpoint_interval = checkpoint_interval + self.checkpoint = None + + Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True) + + def process( + self, + emails: List[Email], + resume: bool = True, + progress_callback: Optional[Callable[[int, int], None]] = None + ) -> tuple[List[ClassificationResult], List[str]]: + """ + Process batch of emails with checkpointing. + + Args: + emails: List of emails to process + resume: Try to resume from checkpoint if exists + progress_callback: Function to call with (processed, total) + + Returns: + (classified_results, llm_queue_ids) + """ + logger.info(f"Starting bulk processing of {len(emails)} emails") + + # Try to load checkpoint + if resume: + self.checkpoint = self._load_checkpoint(len(emails)) + if self.checkpoint: + logger.info(f"Resumed from checkpoint: {self.checkpoint.processed_count}/{len(emails)} done") + else: + self.checkpoint = ProcessingCheckpoint(total_emails=len(emails), processed_count=0) + else: + self.checkpoint = ProcessingCheckpoint(total_emails=len(emails), processed_count=0) + + results = [] + llm_queue = [] + + # Process in batches + start_idx = self.checkpoint.processed_count + for batch_idx in range(start_idx, len(emails), self.batch_size): + batch_end = min(batch_idx + self.batch_size, len(emails)) + batch = emails[batch_idx:batch_end] + + logger.debug(f"Processing batch {batch_idx}-{batch_end}") + + # Process batch + for email in batch: + try: + # Skip if already processed + if email.id in self.checkpoint.completed_emails: + logger.debug(f"Skipping already processed: {email.id}") + continue + + # Classify + result = self.classifier.classify(email) + results.append(result) + + # Track for LLM if needed + if result.needs_review: + llm_queue.append(email.id) + self.checkpoint.queued_for_llm.append(email.id) + + self.checkpoint.completed_emails.append(email.id) + self.checkpoint.processed_count += 1 + + except Exception as e: + logger.error(f"Error processing {email.id}: {e}") + self.checkpoint.failed_emails.append({ + 'email_id': email.id, + 'error': str(e), + 'timestamp': datetime.now().isoformat() + }) + + # Checkpoint periodically + if (batch_end - start_idx) % self.checkpoint_interval == 0: + self._save_checkpoint() + logger.info(f"Checkpoint saved: {self.checkpoint.processed_count}/{len(emails)}") + + # Progress callback + if progress_callback: + progress_callback(self.checkpoint.processed_count, len(emails)) + + # Final checkpoint + self._save_checkpoint() + logger.info(f"Processing complete: {len(results)} classified, {len(llm_queue)} need LLM review") + + return results, llm_queue + + def _save_checkpoint(self) -> None: + """Save current checkpoint.""" + if not self.checkpoint: + return + + checkpoint_file = Path(self.checkpoint_dir) / f"checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + + try: + with open(checkpoint_file, 'w') as f: + json.dump(self.checkpoint.to_dict(), f, indent=2) + logger.debug(f"Checkpoint saved to {checkpoint_file}") + except Exception as e: + logger.error(f"Error saving checkpoint: {e}") + + def _load_checkpoint(self, expected_total: int) -> Optional[ProcessingCheckpoint]: + """Load latest checkpoint if it matches.""" + checkpoint_dir = Path(self.checkpoint_dir) + + if not checkpoint_dir.exists(): + return None + + # Find latest checkpoint + checkpoints = sorted(checkpoint_dir.glob('checkpoint_*.json'), reverse=True) + + if not checkpoints: + return None + + try: + latest = checkpoints[0] + with open(latest, 'r') as f: + data = json.load(f) + + # Verify checkpoint matches this job + if data.get('total_emails') == expected_total: + logger.info(f"Loaded checkpoint from {latest.name}") + return ProcessingCheckpoint.from_dict(data) + else: + logger.warning(f"Checkpoint total mismatch: {data.get('total_emails')} vs {expected_total}") + return None + + except Exception as e: + logger.error(f"Error loading checkpoint: {e}") + return None + + def get_checkpoint_status(self) -> Dict[str, Any]: + """Get current checkpoint status.""" + if not self.checkpoint: + return {'status': 'no_checkpoint'} + + return { + 'status': 'active', + 'processed': self.checkpoint.processed_count, + 'total': self.checkpoint.total_emails, + 'progress_percent': (self.checkpoint.processed_count / self.checkpoint.total_emails * 100) if self.checkpoint.total_emails > 0 else 0, + 'queued_for_llm': len(self.checkpoint.queued_for_llm), + 'failed': len(self.checkpoint.failed_emails), + 'timestamp': self.checkpoint.timestamp + }