"""Download and integrate pre-trained LightGBM model for email classification. This script can: 1. Download a pre-trained LightGBM model from an online source (e.g., GitHub releases, S3) 2. Validate the model format and compatibility 3. Replace the mock model with the real model 4. Update configuration to use the real model """ import logging import json import hashlib from pathlib import Path from typing import Optional, Dict, Any import pickle import urllib.request import sys logger = logging.getLogger(__name__) class ModelDownloader: """Download and integrate pre-trained models.""" def __init__(self, project_root: Optional[Path] = None): """Initialize downloader. Args: project_root: Path to email-sorter project root """ self.project_root = project_root or Path(__file__).parent.parent self.models_dir = self.project_root / "models" self.models_dir.mkdir(exist_ok=True) def download_model( self, url: str, filename: str = "lightgbm_real.pkl", expected_hash: Optional[str] = None ) -> bool: """Download model from URL. Args: url: URL to download model from filename: Local filename to save expected_hash: Optional SHA256 hash to verify Returns: True if successful """ filepath = self.models_dir / filename logger.info(f"Downloading model from {url}...") try: urllib.request.urlretrieve(url, filepath) logger.info(f"Downloaded to {filepath}") # Verify hash if provided if expected_hash: file_hash = self._compute_hash(filepath) if file_hash != expected_hash: logger.error(f"Hash mismatch! Expected {expected_hash}, got {file_hash}") filepath.unlink() return False logger.info("Hash verification passed") return True except Exception as e: logger.error(f"Download failed: {e}") return False def load_model(self, filename: str = "lightgbm_real.pkl") -> Optional[Any]: """Load model from disk. Args: filename: Model filename Returns: Model object or None if failed """ filepath = self.models_dir / filename if not filepath.exists(): logger.error(f"Model not found: {filepath}") return None try: with open(filepath, 'rb') as f: model = pickle.load(f) logger.info(f"Loaded model from {filepath}") return model except Exception as e: logger.error(f"Failed to load model: {e}") return None def validate_model(self, model: Any) -> bool: """Validate model structure. Args: model: Model object to validate Returns: True if valid LightGBM model """ try: # Check for LightGBM model methods required_methods = ['predict', 'predict_proba', 'get_params', 'set_params'] for method in required_methods: if not hasattr(model, method): logger.error(f"Model missing method: {method}") return False logger.info("Model validation passed") return True except Exception as e: logger.error(f"Model validation failed: {e}") return False def configure_model_usage(self, use_real_model: bool = True) -> bool: """Update configuration to use real model. Args: use_real_model: True to use real model, False for mock Returns: True if successful """ config_file = self.project_root / "config" / "model_config.json" config = { 'use_real_model': use_real_model, 'model_path': str(self.models_dir / "lightgbm_real.pkl"), 'fallback_to_mock': True, 'mock_warning': 'MOCK MODEL - Framework testing ONLY. Not for production use.' } try: config_file.parent.mkdir(parents=True, exist_ok=True) with open(config_file, 'w') as f: json.dump(config, f, indent=2) logger.info(f"Configuration updated: {config_file}") return True except Exception as e: logger.error(f"Failed to update configuration: {e}") return False def _compute_hash(self, filepath: Path) -> str: """Compute SHA256 hash of file.""" sha256 = hashlib.sha256() with open(filepath, 'rb') as f: for chunk in iter(lambda: f.read(4096), b''): sha256.update(chunk) return sha256.hexdigest() def get_model_info(self) -> Dict[str, Any]: """Get information about available models. Returns: Dict with model info """ real_model_path = self.models_dir / "lightgbm_real.pkl" mock_model_path = self.models_dir / "lightgbm_mock.pkl" info = { 'models_directory': str(self.models_dir), 'real_model_available': real_model_path.exists(), 'real_model_path': str(real_model_path) if real_model_path.exists() else None, 'real_model_size': f"{real_model_path.stat().st_size / 1024 / 1024:.2f} MB" if real_model_path.exists() else None, 'mock_model_available': mock_model_path.exists(), 'mock_model_path': str(mock_model_path) if mock_model_path.exists() else None, } return info def main(): """Command-line interface.""" import argparse parser = argparse.ArgumentParser(description="Download and integrate pre-trained LightGBM model") parser.add_argument('--url', help='URL to download model from') parser.add_argument('--hash', help='Expected SHA256 hash of model file') parser.add_argument('--load', action='store_true', help='Load and validate existing model') parser.add_argument('--info', action='store_true', help='Show model information') parser.add_argument('--enable', action='store_true', help='Enable real model usage') parser.add_argument('--disable', action='store_true', help='Disable real model usage (use mock)') args = parser.parse_args() # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) downloader = ModelDownloader() # Show info if args.info: info = downloader.get_model_info() print("\n=== Model Information ===") for key, value in info.items(): print(f"{key}: {value}") return 0 # Download model if args.url: success = downloader.download_model(args.url, expected_hash=args.hash) if not success: return 1 # Validate model = downloader.load_model() if not model or not downloader.validate_model(model): return 1 # Configure if not downloader.configure_model_usage(use_real_model=True): return 1 print("\nModel successfully downloaded and integrated!") return 0 # Load existing model if args.load: model = downloader.load_model() if not model: return 1 if not downloader.validate_model(model): return 1 print("\nModel validation successful!") return 0 # Enable real model if args.enable: if not downloader.configure_model_usage(use_real_model=True): return 1 print("Real model usage enabled") return 0 # Disable real model if args.disable: if not downloader.configure_model_usage(use_real_model=False): return 1 print("Switched to mock model") return 0 # Show usage if not any([args.url, args.load, args.info, args.enable, args.disable]): parser.print_help() print("\nExample usage:") print(" python download_pretrained_model.py --info") print(" python download_pretrained_model.py --url https://example.com/model.pkl --hash abc123") print(" python download_pretrained_model.py --load") print(" python download_pretrained_model.py --enable") return 0 if __name__ == '__main__': sys.exit(main())