Model Integration Design Guide

Overview

This guide provides comprehensive instructions for integrating new models into Atlas. Whether you’re adding a machine learning model, statistical model, or external API, this guide will help you create robust, maintainable integrations that leverage the full power of the framework.

Model Integration Principles

Before diving into implementation, understand these core principles:

  1. Separation of Concerns: Models should focus on prediction, not data transformation

  2. Standardized Interfaces: All models implement the same abstract interface

  3. Comprehensive Validation: Input and output validation ensures reliability

  4. Performance Optimization: Consider caching and parallel execution

  5. Error Handling: Graceful degradation and informative error messages

Integration Approaches

Approach 1: Direct Python Integration

Best for models implemented in Python or easily callable from Python.

Step 1: Implement the AbstractModel Interface

from atlas.models import AbstractModel
import xarray as xr
from typing import Dict, List, Optional

class MyRevenueModel(AbstractModel):
    """
    Example revenue prediction model integration.
    """
    
    def __init__(self, model_path: str, config: Optional[Dict] = None):
        """
        Initialize your model.
        
        Args:
            model_path: Path to model artifacts
            config: Optional configuration dictionary
        """
        super().__init__()
        self.model = self._load_model(model_path)
        self.config = config or {}
        self._validate_config()
    
    def _load_model(self, model_path: str):
        """Load model from disk."""
        # Example: Load scikit-learn model
        import joblib
        return joblib.load(model_path)
    
    def _validate_config(self):
        """Validate configuration parameters."""
        required_keys = ['channels', 'time_period']
        for key in required_keys:
            if key not in self.config:
                raise ValueError(f"Missing required config key: {key}")
    
    @property
    def model_type(self) -> str:
        """Return model type identifier."""
        return "revenue_regression"
    
    @property
    def required_dimensions(self) -> List[str]:
        """Specify required data dimensions."""
        return ["time", "channel", "geography"]
    
    def predict(self, x: xr.Dataset) -> xr.DataArray:
        """
        Generate predictions from input data.
        
        Args:
            x: Input dataset with budget allocations
            
        Returns:
            Predictions as DataArray
        """
        # Validate input
        self._validate_input(x)
        
        # Transform xarray to model input format
        model_input = self._prepare_input(x)
        
        # Generate predictions
        predictions = self.model.predict(model_input)
        
        # Convert back to xarray
        return self._format_output(predictions, x)
    
    def contributions(self, x: xr.Dataset) -> xr.Dataset:
        """
        Calculate channel contributions.
        
        Args:
            x: Input dataset
            
        Returns:
            Dataset with contribution information
        """
        # Use SHAP or other attribution method
        import shap
        
        explainer = shap.Explainer(self.model)
        model_input = self._prepare_input(x)
        shap_values = explainer(model_input)
        
        return self._format_contributions(shap_values, x)
    
    def _validate_input(self, x: xr.Dataset):
        """Validate input data structure."""
        for dim in self.required_dimensions:
            if dim not in x.dims:
                raise ValueError(f"Missing required dimension: {dim}")
        
        # Check for required variables
        for channel in self.config['channels']:
            if channel not in x.data_vars:
                raise ValueError(f"Missing channel data: {channel}")
    
    def _prepare_input(self, x: xr.Dataset) -> np.ndarray:
        """Transform xarray Dataset to model input format."""
        # Example: Stack data into 2D array
        features = []
        for channel in self.config['channels']:
            features.append(x[channel].values.flatten())
        
        return np.column_stack(features)
    
    def _format_output(self, predictions: np.ndarray, 
                      template: xr.Dataset) -> xr.DataArray:
        """Format predictions as xarray DataArray."""
        # Reshape predictions to match input dimensions
        dims = template.dims
        coords = {dim: template.coords[dim] for dim in dims}
        
        reshaped = predictions.reshape([len(coords[d]) for d in dims])
        
        return xr.DataArray(
            data=reshaped,
            dims=dims,
            coords=coords,
            name="revenue_prediction",
            attrs={"units": "USD", "model": self.model_type}
        )

Step 2: Create Model Configuration

# model_config.py
from atlas.config import ModelConfiguration, LeverSpecification

class RevenueModelConfig(ModelConfiguration):
    """Configuration for revenue prediction model."""
    
    def __init__(self):
        super().__init__(
            model_name="Revenue Predictor",
            model_version="2.0",
            model_type="regression"
        )
        
        # Define optimization levers
        self.levers = {
            "tv": LeverSpecification(
                name="tv",
                lever_type="spend",
                baseline_value=100_000,
                min_value=0,
                max_value=500_000,
                units="USD",
                description="Television advertising spend"
            ),
            "digital": LeverSpecification(
                name="digital",
                lever_type="spend",
                baseline_value=150_000,
                min_value=10_000,
                max_value=1_000_000,
                units="USD",
                description="Digital marketing spend"
            ),
            "radio": LeverSpecification(
                name="radio",
                lever_type="spend",
                baseline_value=50_000,
                min_value=0,
                max_value=200_000,
                units="USD",
                description="Radio advertising spend"
            )
        }
        
        # Define data mapping
        self.data_mapping = {
            "tv": {
                "variables": ["tv_grps", "tv_reach"],
                "transformation": "logarithmic",
                "scaling": "min_max"
            },
            "digital": {
                "variables": ["digital_impressions", "digital_clicks"],
                "transformation": "none",
                "scaling": "standard"
            }
        }
        
        # Define output specification
        self.outputs = {
            "revenue": {
                "type": "continuous",
                "dimensions": ["time", "geography"],
                "aggregation": "sum",
                "units": "USD"
            }
        }

Step 3: Register and Test the Model

# test_model_integration.py
import pytest
from atlas import ModelRegistry
from mymodels import MyRevenueModel, RevenueModelConfig

def test_model_integration():
    """Test model integration with framework."""
    
    # Initialize model
    model = MyRevenueModel(
        model_path="models/revenue_model.pkl",
        config={"channels": ["tv", "digital", "radio"], 
                "time_period": "weekly"}
    )
    
    # Create test data
    test_data = create_test_dataset()
    
    # Test prediction
    predictions = model.predict(test_data)
    assert predictions.dims == ("time", "channel", "geography")
    assert predictions.min() >= 0  # Revenue should be non-negative
    
    # Test contributions
    contributions = model.contributions(test_data)
    assert "tv" in contributions.data_vars
    
    # Register model
    registry = ModelRegistry()
    registry.register(
        model=model,
        config=RevenueModelConfig(),
        tags=["production", "revenue"]
    )

Approach 2: Docker Container Integration

Best for models in different languages, complex dependencies, or third-party models.

Step 1: Create Model Service

# model_service/app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, List
import numpy as np

app = FastAPI(title="Revenue Model Service")

class PredictionRequest(BaseModel):
    budget: Dict[str, float]
    options: Dict = {}

class PredictionResponse(BaseModel):
    prediction: float
    metadata: Dict = {}

# Load model at startup
model = load_your_model()

@app.get("/")
def info():
    """Model information endpoint."""
    return {
        "name": "Revenue Model",
        "version": "2.0",
        "type": "regression",
        "supported_channels": ["tv", "digital", "radio"]
    }

@app.get("/schema")
def schema():
    """Model input schema."""
    return {
        "required_variables": ["tv", "digital", "radio"],
        "constraints": {
            "tv": {"min": 0, "max": 500000},
            "digital": {"min": 10000, "max": 1000000},
            "radio": {"min": 0, "max": 200000}
        }
    }

@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
    """Generate prediction."""
    try:
        # Validate input
        validate_budget(request.budget)
        
        # Transform to model input
        model_input = prepare_input(request.budget)
        
        # Generate prediction
        prediction = model.predict(model_input)[0]
        
        return PredictionResponse(
            prediction=float(prediction),
            metadata={"model_version": "2.0"}
        )
    
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.post("/contributions")
def contributions(request: PredictionRequest):
    """Calculate feature contributions."""
    # Implementation here
    pass

Step 2: Create Dockerfile

FROM python:3.12-slim

WORKDIR /app

# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy model and code
COPY model/ ./model/
COPY model_service/ ./model_service/

# Expose port
EXPOSE 8000

# Run service
CMD ["uvicorn", "model_service.app:app", "--host", "0.0.0.0", "--port", "8000"]

Step 3: Create Docker Wrapper

# docker_wrapper.py
from atlas.models import DockerModelWrapper
import requests
import xarray as xr

class RevenueModelDocker(DockerModelWrapper):
    """Docker wrapper for revenue model."""
    
    def __init__(self, service_url: str = "http://revenue-model:8000"):
        super().__init__(
            name="Revenue Model",
            version="2.0",
            service_url=service_url
        )
    
    def predict(self, x: xr.Dataset) -> xr.DataArray:
        """Call Docker service for prediction."""
        # Convert xarray to budget dict
        budget = self._dataset_to_budget(x)
        
        # Call service
        response = requests.post(
            f"{self.service_url}/predict",
            json={"budget": budget}
        )
        response.raise_for_status()
        
        # Convert response to xarray
        result = response.json()
        return self._create_prediction_array(
            result["prediction"], x
        )

Approach 3: External API Integration

Best for cloud services, vendor APIs, or remote models.

Step 1: Create API Wrapper

# api_wrapper.py
from atlas.models import AbstractModel
import requests
from typing import Dict
import xarray as xr

class ExternalAPIModel(AbstractModel):
    """Wrapper for external prediction API."""
    
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.session = self._create_session()
    
    def _create_session(self):
        """Create authenticated session."""
        session = requests.Session()
        session.headers.update({
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        })
        return session
    
    def predict(self, x: xr.Dataset) -> xr.DataArray:
        """Call external API for predictions."""
        # Prepare request
        request_data = self._prepare_api_request(x)
        
        # Make API call with retry logic
        response = self._call_api_with_retry(
            self.endpoint + "/predict",
            json=request_data
        )
        
        # Process response
        return self._process_api_response(response, x)
    
    def _call_api_with_retry(self, url: str, **kwargs):
        """Call API with exponential backoff retry."""
        import time
        
        max_retries = 3
        retry_delay = 1
        
        for attempt in range(max_retries):
            try:
                response = self.session.post(url, **kwargs)
                response.raise_for_status()
                return response.json()
            
            except requests.exceptions.RequestException as e:
                if attempt == max_retries - 1:
                    raise
                
                time.sleep(retry_delay * (2 ** attempt))

Best Practices

1. Input Validation

Always validate inputs thoroughly:

def _validate_input(self, x: xr.Dataset):
    """Comprehensive input validation."""
    
    # Check dimensions
    required_dims = self.required_dimensions
    missing_dims = set(required_dims) - set(x.dims)
    if missing_dims:
        raise ValueError(f"Missing dimensions: {missing_dims}")
    
    # Check data types
    for var in x.data_vars:
        if not np.issubdtype(x[var].dtype, np.number):
            raise TypeError(f"Variable {var} must be numeric")
    
    # Check value ranges
    for var, bounds in self.variable_bounds.items():
        if var in x.data_vars:
            min_val, max_val = bounds
            actual_min = float(x[var].min())
            actual_max = float(x[var].max())
            
            if actual_min < min_val or actual_max > max_val:
                raise ValueError(
                    f"{var} values outside bounds [{min_val}, {max_val}]"
                )

2. Error Handling

Implement comprehensive error handling:

class ModelError(Exception):
    """Base exception for model errors."""
    pass

class ModelPredictionError(ModelError):
    """Error during prediction."""
    pass

class ModelValidationError(ModelError):
    """Input validation error."""
    pass

def predict(self, x: xr.Dataset) -> xr.DataArray:
    """Predict with error handling."""
    try:
        self._validate_input(x)
        result = self._internal_predict(x)
        self._validate_output(result)
        return result
    
    except ValidationError as e:
        raise ModelValidationError(f"Invalid input: {e}")
    
    except Exception as e:
        self.logger.error(f"Prediction failed: {e}")
        raise ModelPredictionError(f"Prediction failed: {e}")

3. Performance Optimization

Implement caching and batch processing:

from functools import lru_cache
import hashlib

class CachedModel(AbstractModel):
    """Model with prediction caching."""
    
    def __init__(self, base_model, cache_size: int = 128):
        self.base_model = base_model
        self.cache_size = cache_size
        self._cache = {}
    
    def predict(self, x: xr.Dataset) -> xr.DataArray:
        """Predict with caching."""
        # Create cache key
        cache_key = self._create_cache_key(x)
        
        # Check cache
        if cache_key in self._cache:
            self.logger.info("Cache hit")
            return self._cache[cache_key]
        
        # Generate prediction
        result = self.base_model.predict(x)
        
        # Update cache
        if len(self._cache) >= self.cache_size:
            # Remove oldest entry
            self._cache.pop(next(iter(self._cache)))
        
        self._cache[cache_key] = result
        return result
    
    def _create_cache_key(self, x: xr.Dataset) -> str:
        """Create unique cache key for dataset."""
        # Hash the data values
        data_bytes = x.to_netcdf()
        return hashlib.sha256(data_bytes).hexdigest()

4. Monitoring and Logging

Add comprehensive monitoring:

import logging
from datetime import datetime
import time

class MonitoredModel(AbstractModel):
    """Model with performance monitoring."""
    
    def __init__(self, base_model):
        self.base_model = base_model
        self.logger = logging.getLogger(__name__)
        self.metrics = {
            "prediction_count": 0,
            "total_time": 0,
            "errors": 0
        }
    
    def predict(self, x: xr.Dataset) -> xr.DataArray:
        """Predict with monitoring."""
        start_time = time.time()
        
        try:
            # Log input characteristics
            self.logger.info(
                f"Prediction request - shape: {x.dims}, "
                f"variables: {list(x.data_vars)}"
            )
            
            # Make prediction
            result = self.base_model.predict(x)
            
            # Update metrics
            elapsed = time.time() - start_time
            self.metrics["prediction_count"] += 1
            self.metrics["total_time"] += elapsed
            
            # Log success
            self.logger.info(
                f"Prediction successful - time: {elapsed:.2f}s"
            )
            
            return result
        
        except Exception as e:
            self.metrics["errors"] += 1
            self.logger.error(f"Prediction failed: {e}")
            raise

5. Testing Strategy

Implement comprehensive tests:

# test_model.py
import pytest
import numpy as np
from mymodel import MyRevenueModel

class TestRevenueModel:
    """Test suite for revenue model."""
    
    @pytest.fixture
    def model(self):
        """Create model instance."""
        return MyRevenueModel("test_model.pkl")
    
    @pytest.fixture
    def sample_data(self):
        """Create sample test data."""
        return create_sample_dataset(
            channels=["tv", "digital"],
            time_periods=52,
            geographies=["US", "EU"]
        )
    
    def test_prediction_shape(self, model, sample_data):
        """Test prediction output shape."""
        result = model.predict(sample_data)
        assert result.shape == sample_data["tv"].shape
    
    def test_prediction_bounds(self, model, sample_data):
        """Test prediction value bounds."""
        result = model.predict(sample_data)
        assert result.min() >= 0  # Revenue non-negative
        assert result.max() <= 1e9  # Reasonable upper bound
    
    def test_zero_budget(self, model, sample_data):
        """Test zero budget scenario."""
        # Set all budgets to zero
        zero_data = sample_data.copy()
        for var in zero_data.data_vars:
            zero_data[var] = 0
        
        result = model.predict(zero_data)
        assert result.sum() == 0  # No spend = no revenue
    
    @pytest.mark.parametrize("channel", ["tv", "digital"])
    def test_channel_contribution(self, model, sample_data, channel):
        """Test individual channel contributions."""
        # Isolate single channel
        single_channel = sample_data.copy()
        for var in single_channel.data_vars:
            if var != channel:
                single_channel[var] = 0
        
        result = model.predict(single_channel)
        assert result.sum() > 0  # Channel has positive impact

Model Lifecycle Management

Version Control

class VersionedModel(AbstractModel):
    """Model with version tracking."""
    
    def __init__(self, model_path: str, version: str):
        self.version = version
        self.model_path = model_path
        self.created_at = datetime.now()
        self.metadata = self._load_metadata()
    
    def get_version_info(self) -> Dict:
        """Get model version information."""
        return {
            "version": self.version,
            "created_at": self.created_at,
            "model_path": self.model_path,
            "performance_metrics": self.metadata.get("metrics", {}),
            "training_data": self.metadata.get("training_info", {})
        }

Model Registry Integration

# Register model with framework
from atlas import ModelRegistry

registry = ModelRegistry()

# Register model
model_id = registry.register(
    model=MyRevenueModel("model.pkl"),
    version="2.0.0",
    tags=["production", "revenue", "mmm"],
    metadata={
        "training_date": "2024-01-15",
        "performance": {"rmse": 0.05, "mape": 0.03}
    }
)

# Use registered model
model = registry.get_model(model_id)

Troubleshooting Guide

Common Issues and Solutions

  1. Dimension Mismatch

    • Ensure input data has all required dimensions

    • Check coordinate alignment

    • Verify dimension ordering

  2. Performance Issues

    • Implement caching for expensive operations

    • Use batch processing for multiple scenarios

    • Consider model simplification

  3. Integration Failures

    • Verify API endpoints and authentication

    • Check network connectivity

    • Review error logs

  4. Validation Errors

    • Double-check input data ranges

    • Verify data types

    • Ensure all required variables present

Next Steps

After integrating your model:

  1. Write comprehensive tests

  2. Document model assumptions and limitations

  3. Set up monitoring and alerting

  4. Create usage examples

  5. Submit for code review

For more information, see: