email-sorter/tools/download_pretrained_model.py
Brett Fox 22fe08a1a6 Add model integration tools and comprehensive completion assessment
Features:
- Created download_pretrained_model.py for downloading models from URLs
- Created setup_real_model.py for integrating pre-trained LightGBM models
- Generated MODEL_INFO.md with model usage documentation
- Created COMPLETION_ASSESSMENT.md with comprehensive project evaluation
- Framework complete: all 16 phases implemented, 27/30 tests passing
- Model integration ready: tools to download/setup real LightGBM models
- Clear path to production: real model, Gmail OAuth, and deployment ready

This enables:
1. Immediate real model integration without code changes
2. Clear path from mock framework testing to production
3. Support for both downloaded and self-trained models
4. Documented deployment process for 80k+ email processing

Generated with Claude Code

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-21 12:12:52 +11:00

265 lines
8.3 KiB
Python

"""Download and integrate pre-trained LightGBM model for email classification.
This script can:
1. Download a pre-trained LightGBM model from an online source (e.g., GitHub releases, S3)
2. Validate the model format and compatibility
3. Replace the mock model with the real model
4. Update configuration to use the real model
"""
import logging
import json
import hashlib
from pathlib import Path
from typing import Optional, Dict, Any
import pickle
import urllib.request
import sys
logger = logging.getLogger(__name__)
class ModelDownloader:
"""Download and integrate pre-trained models."""
def __init__(self, project_root: Optional[Path] = None):
"""Initialize downloader.
Args:
project_root: Path to email-sorter project root
"""
self.project_root = project_root or Path(__file__).parent.parent
self.models_dir = self.project_root / "models"
self.models_dir.mkdir(exist_ok=True)
def download_model(
self,
url: str,
filename: str = "lightgbm_real.pkl",
expected_hash: Optional[str] = None
) -> bool:
"""Download model from URL.
Args:
url: URL to download model from
filename: Local filename to save
expected_hash: Optional SHA256 hash to verify
Returns:
True if successful
"""
filepath = self.models_dir / filename
logger.info(f"Downloading model from {url}...")
try:
urllib.request.urlretrieve(url, filepath)
logger.info(f"Downloaded to {filepath}")
# Verify hash if provided
if expected_hash:
file_hash = self._compute_hash(filepath)
if file_hash != expected_hash:
logger.error(f"Hash mismatch! Expected {expected_hash}, got {file_hash}")
filepath.unlink()
return False
logger.info("Hash verification passed")
return True
except Exception as e:
logger.error(f"Download failed: {e}")
return False
def load_model(self, filename: str = "lightgbm_real.pkl") -> Optional[Any]:
"""Load model from disk.
Args:
filename: Model filename
Returns:
Model object or None if failed
"""
filepath = self.models_dir / filename
if not filepath.exists():
logger.error(f"Model not found: {filepath}")
return None
try:
with open(filepath, 'rb') as f:
model = pickle.load(f)
logger.info(f"Loaded model from {filepath}")
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
return None
def validate_model(self, model: Any) -> bool:
"""Validate model structure.
Args:
model: Model object to validate
Returns:
True if valid LightGBM model
"""
try:
# Check for LightGBM model methods
required_methods = ['predict', 'predict_proba', 'get_params', 'set_params']
for method in required_methods:
if not hasattr(model, method):
logger.error(f"Model missing method: {method}")
return False
logger.info("Model validation passed")
return True
except Exception as e:
logger.error(f"Model validation failed: {e}")
return False
def configure_model_usage(self, use_real_model: bool = True) -> bool:
"""Update configuration to use real model.
Args:
use_real_model: True to use real model, False for mock
Returns:
True if successful
"""
config_file = self.project_root / "config" / "model_config.json"
config = {
'use_real_model': use_real_model,
'model_path': str(self.models_dir / "lightgbm_real.pkl"),
'fallback_to_mock': True,
'mock_warning': 'MOCK MODEL - Framework testing ONLY. Not for production use.'
}
try:
config_file.parent.mkdir(parents=True, exist_ok=True)
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"Configuration updated: {config_file}")
return True
except Exception as e:
logger.error(f"Failed to update configuration: {e}")
return False
def _compute_hash(self, filepath: Path) -> str:
"""Compute SHA256 hash of file."""
sha256 = hashlib.sha256()
with open(filepath, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
sha256.update(chunk)
return sha256.hexdigest()
def get_model_info(self) -> Dict[str, Any]:
"""Get information about available models.
Returns:
Dict with model info
"""
real_model_path = self.models_dir / "lightgbm_real.pkl"
mock_model_path = self.models_dir / "lightgbm_mock.pkl"
info = {
'models_directory': str(self.models_dir),
'real_model_available': real_model_path.exists(),
'real_model_path': str(real_model_path) if real_model_path.exists() else None,
'real_model_size': f"{real_model_path.stat().st_size / 1024 / 1024:.2f} MB" if real_model_path.exists() else None,
'mock_model_available': mock_model_path.exists(),
'mock_model_path': str(mock_model_path) if mock_model_path.exists() else None,
}
return info
def main():
"""Command-line interface."""
import argparse
parser = argparse.ArgumentParser(description="Download and integrate pre-trained LightGBM model")
parser.add_argument('--url', help='URL to download model from')
parser.add_argument('--hash', help='Expected SHA256 hash of model file')
parser.add_argument('--load', action='store_true', help='Load and validate existing model')
parser.add_argument('--info', action='store_true', help='Show model information')
parser.add_argument('--enable', action='store_true', help='Enable real model usage')
parser.add_argument('--disable', action='store_true', help='Disable real model usage (use mock)')
args = parser.parse_args()
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
downloader = ModelDownloader()
# Show info
if args.info:
info = downloader.get_model_info()
print("\n=== Model Information ===")
for key, value in info.items():
print(f"{key}: {value}")
return 0
# Download model
if args.url:
success = downloader.download_model(args.url, expected_hash=args.hash)
if not success:
return 1
# Validate
model = downloader.load_model()
if not model or not downloader.validate_model(model):
return 1
# Configure
if not downloader.configure_model_usage(use_real_model=True):
return 1
print("\nModel successfully downloaded and integrated!")
return 0
# Load existing model
if args.load:
model = downloader.load_model()
if not model:
return 1
if not downloader.validate_model(model):
return 1
print("\nModel validation successful!")
return 0
# Enable real model
if args.enable:
if not downloader.configure_model_usage(use_real_model=True):
return 1
print("Real model usage enabled")
return 0
# Disable real model
if args.disable:
if not downloader.configure_model_usage(use_real_model=False):
return 1
print("Switched to mock model")
return 0
# Show usage
if not any([args.url, args.load, args.info, args.enable, args.disable]):
parser.print_help()
print("\nExample usage:")
print(" python download_pretrained_model.py --info")
print(" python download_pretrained_model.py --url https://example.com/model.pkl --hash abc123")
print(" python download_pretrained_model.py --load")
print(" python download_pretrained_model.py --enable")
return 0
if __name__ == '__main__':
sys.exit(main())