Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

37.4. Scaling Forecasting: Global vs Local Models

A retailer sells 100,000 SKUs across 5,000 stores = 500 million time series. How do you engineer a system to forecast them every night?


37.4.1. The Scale Challenge

graph TB
    A[500M Time Series] --> B{Nightly Forecast}
    B --> C[8 hour window]
    C --> D[17,361 forecasts/second]
    D --> E[Infrastructure Design]
ScaleTime SeriesCompute Strategy
Small1-1,000Single machine, sequential
Medium1K-100KMulti-core parallelism
Large100K-10MDistributed compute (Spark/Ray)
Massive10M-1BHybrid global + distributed local

Cost Reality Check

ApproachTime to Forecast 1M SeriesCloud Cost
Sequential Python28 hoursTimeout
Parallel (32 cores)52 minutes$15
Spark (100 workers)6 minutes$50
Global Transformer10 minutes$100 (GPU)
Hybrid Cascade15 minutes$30

37.4.2. Architectural Approaches

Comparison Matrix

ApproachDescriptionProsConsBest For
Local1 model per seriesTailored, interpretable, parallelCold start fails, no cross-learningHigh-signal series
Global1 model for allCross-learning, handles cold startExpensive inference, less interpretableLow-volume series
HybridClustered modelsBalancedCluster definition complexityMost real-world cases
graph TB
    A[500M Time Series] --> B{Approach Selection}
    B -->|Local| C[500M ARIMA/Prophet Models]
    B -->|Global| D[1 Transformer Model]
    B -->|Hybrid| E[50 Clustered Models]
    
    C --> F[Store model coefficients only]
    D --> G[Single GPU inference batch]
    E --> H[Group by category + volume tier]

When to Use Each Approach

If your data has…UseBecause…
Strong individual patternsLocalEach series is unique
Sparse history (<12 points)GlobalCross-series learning
New products constantlyGlobalCold start capability
Regulatory requirement for explainabilityLocalInterpretable coefficients
Similar products in categoriesHybridCluster-level patterns
Mixed volume (80/20 rule)HybridTier by importance

37.4.3. Local Models at Scale

Model Registry Pattern

Don’t store full model objects—store coefficients:

from dataclasses import dataclass, asdict
from typing import Dict, List, Optional
import json
import boto3
from datetime import datetime

@dataclass
class LocalModelMetadata:
    series_id: str
    algorithm: str  # "arima", "ets", "prophet"
    params: Dict  # Model coefficients/parameters
    metrics: Dict  # {"mape": 0.05, "rmse": 10.5}
    last_trained: str
    training_samples: int
    forecast_horizon: int
    version: int

class ForecastRegistry:
    """Registry for millions of local forecast models."""
    
    def __init__(self, table_name: str, region: str = "us-east-1"):
        self.dynamodb = boto3.resource("dynamodb", region_name=region)
        self.table = self.dynamodb.Table(table_name)
    
    def save(self, model: LocalModelMetadata) -> None:
        """Save model metadata to registry."""
        item = asdict(model)
        item["pk"] = f"MODEL#{model.series_id}"
        item["sk"] = f"V#{model.version}"
        
        self.table.put_item(Item=item)
        
        # Also update "latest" pointer
        self.table.put_item(Item={
            "pk": f"MODEL#{model.series_id}",
            "sk": "LATEST",
            "version": model.version,
            "updated_at": datetime.utcnow().isoformat()
        })
    
    def load_latest(self, series_id: str) -> Optional[LocalModelMetadata]:
        """Load latest model version."""
        # Get latest version number
        response = self.table.get_item(
            Key={"pk": f"MODEL#{series_id}", "sk": "LATEST"}
        )
        
        if "Item" not in response:
            return None
        
        version = response["Item"]["version"]
        
        # Get actual model
        response = self.table.get_item(
            Key={"pk": f"MODEL#{series_id}", "sk": f"V#{version}"}
        )
        
        if "Item" not in response:
            return None
        
        item = response["Item"]
        return LocalModelMetadata(
            series_id=item["series_id"],
            algorithm=item["algorithm"],
            params=item["params"],
            metrics=item["metrics"],
            last_trained=item["last_trained"],
            training_samples=item["training_samples"],
            forecast_horizon=item["forecast_horizon"],
            version=item["version"]
        )
    
    def batch_load(self, series_ids: List[str]) -> Dict[str, LocalModelMetadata]:
        """Batch load multiple models."""
        # DynamoDB batch_get_item
        keys = [
            {"pk": f"MODEL#{sid}", "sk": "LATEST"}
            for sid in series_ids
        ]
        
        # Split into chunks of 100 (DynamoDB limit)
        results = {}
        for i in range(0, len(keys), 100):
            chunk = keys[i:i+100]
            response = self.dynamodb.batch_get_item(
                RequestItems={self.table.name: {"Keys": chunk}}
            )
            
            for item in response["Responses"][self.table.name]:
                series_id = item["pk"].replace("MODEL#", "")
                version = item["version"]
                
                # Fetch full model (could optimize with GSI)
                model = self.load_latest(series_id)
                if model:
                    results[series_id] = model
        
        return results
    
    def predict(self, series_id: str, horizon: int) -> Optional[List[float]]:
        """Generate forecast using stored coefficients."""
        model = self.load_latest(series_id)
        if not model:
            return None
        
        return self._inference(model, horizon)
    
    def _inference(self, model: LocalModelMetadata, horizon: int) -> List[float]:
        """Run inference using stored parameters."""
        if model.algorithm == "arima":
            return self._arima_forecast(model.params, horizon)
        elif model.algorithm == "ets":
            return self._ets_forecast(model.params, horizon)
        else:
            raise ValueError(f"Unknown algorithm: {model.algorithm}")
    
    def _arima_forecast(self, params: Dict, horizon: int) -> List[float]:
        """Reconstruct ARIMA and forecast."""
        import numpy as np
        
        ar_coeffs = np.array(params.get("ar_coeffs", []))
        ma_coeffs = np.array(params.get("ma_coeffs", []))
        diff_order = params.get("d", 0)
        last_values = np.array(params.get("last_values", []))
        residuals = np.array(params.get("residuals", []))
        
        # Simplified forecast (in production, use statsmodels)
        forecasts = []
        for h in range(horizon):
            # AR component
            ar_term = 0
            for i, coef in enumerate(ar_coeffs):
                if i < len(last_values):
                    ar_term += coef * last_values[-(i+1)]
            
            # MA component (assume residuals decay)
            ma_term = 0
            for i, coef in enumerate(ma_coeffs):
                if i < len(residuals):
                    ma_term += coef * residuals[-(i+1)] * (0.9 ** h)
            
            forecast = ar_term + ma_term
            forecasts.append(float(forecast))
            
            # Update for next step
            last_values = np.append(last_values, forecast)[-len(ar_coeffs):]
        
        return forecasts
    
    def _ets_forecast(self, params: Dict, horizon: int) -> List[float]:
        """ETS forecast from stored state."""
        level = params.get("level", 0)
        trend = params.get("trend", 0)
        seasonal = params.get("seasonal", [0] * 12)
        alpha = params.get("alpha", 0.2)
        beta = params.get("beta", 0.1)
        gamma = params.get("gamma", 0.1)
        
        forecasts = []
        for h in range(1, horizon + 1):
            # Holt-Winters forecast
            season_idx = (h - 1) % len(seasonal)
            forecast = (level + h * trend) * seasonal[season_idx]
            forecasts.append(float(forecast))
        
        return forecasts


# Terraform for DynamoDB
"""
resource "aws_dynamodb_table" "forecast_registry" {
  name         = "forecast-registry-${var.environment}"
  billing_mode = "PAY_PER_REQUEST"
  hash_key     = "pk"
  range_key    = "sk"
  
  attribute {
    name = "pk"
    type = "S"
  }
  
  attribute {
    name = "sk"
    type = "S"
  }
  
  ttl {
    attribute_name = "ttl"
    enabled        = true
  }
  
  tags = {
    Environment = var.environment
  }
}
"""

Kubernetes Indexed Jobs for Training

# forecast-training-job.yaml
apiVersion: batch/v1
kind: Job
metadata:
  name: forecast-training-batch
spec:
  completions: 1000
  parallelism: 100
  completionMode: Indexed
  backoffLimit: 3
  
  template:
    metadata:
      labels:
        app: forecast-trainer
    spec:
      restartPolicy: OnFailure
      
      containers:
      - name: trainer
        image: forecast-trainer:latest
        
        resources:
          requests:
            cpu: "2"
            memory: "4Gi"
          limits:
            cpu: "4"
            memory: "8Gi"
        
        env:
        - name: SHARD_ID
          valueFrom:
            fieldRef:
              fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
        - name: TOTAL_SHARDS
          value: "1000"
        - name: REGISTRY_TABLE
          valueFrom:
            configMapKeyRef:
              name: forecast-config
              key: registry_table
        
        command:
        - python
        - train.py
        - --shard
        - $(SHARD_ID)
        - --total-shards
        - $(TOTAL_SHARDS)
        
        volumeMounts:
        - name: data-cache
          mountPath: /data
      
      volumes:
      - name: data-cache
        emptyDir:
          sizeLimit: 10Gi
      
      nodeSelector:
        workload-type: batch
      
      tolerations:
      - key: "batch"
        operator: "Equal"
        value: "true"
        effect: "NoSchedule"

Sharded Training Script

import argparse
from typing import List, Tuple
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.holtwinters import ExponentialSmoothing

def get_shard_series(
    shard_id: int, 
    total_shards: int,
    all_series: List[str]
) -> List[str]:
    """Get series assigned to this shard."""
    return [
        s for i, s in enumerate(all_series) 
        if i % total_shards == shard_id
    ]

def train_arima(series: pd.Series, order: Tuple[int, int, int] = (1, 1, 1)) -> dict:
    """Train ARIMA and return coefficients."""
    try:
        model = ARIMA(series, order=order)
        fitted = model.fit()
        
        return {
            "ar_coeffs": fitted.arparams.tolist() if len(fitted.arparams) > 0 else [],
            "ma_coeffs": fitted.maparams.tolist() if len(fitted.maparams) > 0 else [],
            "d": order[1],
            "last_values": series.tail(max(order[0], 5)).tolist(),
            "residuals": fitted.resid.tail(max(order[2], 5)).tolist(),
            "sigma2": float(fitted.sigma2),
            "aic": float(fitted.aic)
        }
    except Exception as e:
        return {"error": str(e)}

def train_ets(series: pd.Series, seasonal_periods: int = 12) -> dict:
    """Train ETS and return state."""
    try:
        model = ExponentialSmoothing(
            series,
            trend="add",
            seasonal="mul",
            seasonal_periods=seasonal_periods
        )
        fitted = model.fit()
        
        return {
            "level": float(fitted.level.iloc[-1]),
            "trend": float(fitted.trend.iloc[-1]) if fitted.trend is not None else 0,
            "seasonal": fitted.season.tolist() if fitted.season is not None else [],
            "alpha": float(fitted.params.get("smoothing_level", 0.2)),
            "beta": float(fitted.params.get("smoothing_trend", 0.1)),
            "gamma": float(fitted.params.get("smoothing_seasonal", 0.1)),
            "aic": float(fitted.aic)
        }
    except Exception as e:
        return {"error": str(e)}

def select_best_model(series: pd.Series) -> Tuple[str, dict]:
    """Auto-select best model based on AIC."""
    candidates = []
    
    # Try ARIMA variants
    for order in [(1,1,1), (2,1,2), (1,1,0), (0,1,1)]:
        params = train_arima(series, order)
        if "error" not in params:
            candidates.append(("arima", order, params, params["aic"]))
    
    # Try ETS if enough data
    if len(series) >= 24:
        params = train_ets(series)
        if "error" not in params:
            candidates.append(("ets", None, params, params["aic"]))
    
    if not candidates:
        return "naive", {"last_value": float(series.iloc[-1])}
    
    # Select best by AIC
    best = min(candidates, key=lambda x: x[3])
    return best[0], best[2]

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--shard", type=int, required=True)
    parser.add_argument("--total-shards", type=int, required=True)
    args = parser.parse_args()
    
    # Load series list
    all_series = load_series_list()  # From S3/database
    my_series = get_shard_series(args.shard, args.total_shards, all_series)
    
    print(f"Shard {args.shard}: Processing {len(my_series)} series")
    
    registry = ForecastRegistry("forecast-registry-prod")
    
    for series_id in my_series:
        # Load data
        data = load_series_data(series_id)
        if len(data) < 10:
            continue
        
        # Train
        algorithm, params = select_best_model(data)
        
        # Calculate metrics on holdout
        train, test = data[:-7], data[-7:]
        _, train_params = select_best_model(train)
        
        # Save
        model = LocalModelMetadata(
            series_id=series_id,
            algorithm=algorithm,
            params=params,
            metrics={"mape": 0.0},  # Would compute properly
            last_trained=datetime.utcnow().isoformat(),
            training_samples=len(data),
            forecast_horizon=28,
            version=1
        )
        registry.save(model)
    
    print(f"Shard {args.shard}: Completed")

if __name__ == "__main__":
    main()

37.4.4. Cost Comparison

ServiceCost per 1M Model RunsStartup TimeMax Duration
Lambda$15Instant15 min
Fargate$51 minNone
EC2 Spot$0.502 minInterruption risk
EMR Serverless$330 secNone
GCP Dataflow$41 minNone

Recommendation: EC2 Spot Fleet with AWS Batch for large-scale batch forecasting.

AWS Batch Setup

# batch_forecasting.tf

resource "aws_batch_compute_environment" "forecast" {
  compute_environment_name = "forecast-compute-${var.environment}"
  type                     = "MANAGED"
  
  compute_resources {
    type                = "SPOT"
    allocation_strategy = "SPOT_CAPACITY_OPTIMIZED"
    
    min_vcpus     = 0
    max_vcpus     = 1000
    desired_vcpus = 0
    
    instance_type = ["c6i.xlarge", "c6i.2xlarge", "c5.xlarge", "c5.2xlarge"]
    
    subnets            = var.subnet_ids
    security_group_ids = [aws_security_group.batch.id]
    instance_role      = aws_iam_instance_profile.batch.arn
    
    spot_iam_fleet_role = aws_iam_role.spot_fleet.arn
  }
  
  service_role = aws_iam_role.batch_service.arn
}

resource "aws_batch_job_queue" "forecast" {
  name     = "forecast-queue-${var.environment}"
  state    = "ENABLED"
  priority = 1
  
  compute_environments = [
    aws_batch_compute_environment.forecast.arn
  ]
}

resource "aws_batch_job_definition" "forecast_train" {
  name = "forecast-train-${var.environment}"
  type = "container"
  
  platform_capabilities = ["EC2"]
  
  container_properties = jsonencode({
    image   = "${aws_ecr_repository.forecast.repository_url}:latest"
    command = ["python", "train.py", "--shard", "Ref::shard", "--total-shards", "Ref::total_shards"]
    
    resourceRequirements = [
      { type = "VCPU", value = "2" },
      { type = "MEMORY", value = "4096" }
    ]
    
    environment = [
      { name = "REGISTRY_TABLE", value = aws_dynamodb_table.forecast_registry.name }
    ]
    
    jobRoleArn = aws_iam_role.batch_job.arn
  })
  
  retry_strategy {
    attempts = 3
  }
  
  timeout {
    attempt_duration_seconds = 3600
  }
}

37.4.5. Hierarchical Reconciliation

Forecasts must be coherent across hierarchy:

Total Sales
├── Region North
│   ├── Store 001
│   │   ├── SKU A
│   │   └── SKU B
│   └── Store 002
└── Region South
    └── Store 003

Constraint: Sum(children) == Parent

import numpy as np
from typing import Dict, List, Tuple
from scipy.optimize import minimize

def reconcile_forecasts_ols(
    base_forecasts: Dict[str, float],
    hierarchy: Dict[str, List[str]]
) -> Dict[str, float]:
    """OLS reconciliation: Ensure Sum(children) == Parent.
    
    Args:
        base_forecasts: {series_id: forecast_value}
        hierarchy: {parent: [children]}
    
    Returns:
        Reconciled forecasts
    """
    reconciled = base_forecasts.copy()
    
    # Bottom-up: scale children to match parent
    for parent, children in hierarchy.items():
        if parent not in base_forecasts:
            continue
        
        parent_forecast = base_forecasts[parent]
        children_sum = sum(base_forecasts.get(c, 0) for c in children)
        
        if children_sum == 0:
            # Distribute evenly
            equal_share = parent_forecast / len(children)
            for child in children:
                reconciled[child] = equal_share
        else:
            # Scale proportionally
            scale = parent_forecast / children_sum
            for child in children:
                if child in base_forecasts:
                    reconciled[child] = base_forecasts[child] * scale
    
    return reconciled


def reconcile_mint(
    base_forecasts: np.ndarray,
    S: np.ndarray,
    W: np.ndarray
) -> np.ndarray:
    """MinT (Minimum Trace) reconciliation.
    
    Args:
        base_forecasts: Base forecasts for all series (n,)
        S: Summing matrix (n, m) where m is bottom level
        W: Covariance matrix of base forecast errors (n, n)
    
    Returns:
        Reconciled forecasts
    """
    # G = (S'W^{-1}S)^{-1} S'W^{-1}
    W_inv = np.linalg.inv(W)
    G = np.linalg.inv(S.T @ W_inv @ S) @ S.T @ W_inv
    
    # Reconciled bottom level
    bottom_reconciled = G @ base_forecasts
    
    # Full reconciled
    reconciled = S @ bottom_reconciled
    
    return reconciled


class HierarchicalReconciler:
    """Full hierarchical reconciliation system."""
    
    def __init__(self, hierarchy: Dict[str, List[str]]):
        self.hierarchy = hierarchy
        self.series_to_idx = {}
        self.idx_to_series = {}
        self._build_indices()
    
    def _build_indices(self):
        """Build series index mapping."""
        all_series = set(self.hierarchy.keys())
        for children in self.hierarchy.values():
            all_series.update(children)
        
        for i, series in enumerate(sorted(all_series)):
            self.series_to_idx[series] = i
            self.idx_to_series[i] = series
    
    def _build_summing_matrix(self) -> np.ndarray:
        """Build the S matrix for hierarchical structure."""
        n = len(self.series_to_idx)
        
        # Find bottom level (series that are not parents)
        parents = set(self.hierarchy.keys())
        all_children = set()
        for children in self.hierarchy.values():
            all_children.update(children)
        
        bottom_level = all_children - parents
        m = len(bottom_level)
        bottom_idx = {s: i for i, s in enumerate(sorted(bottom_level))}
        
        S = np.zeros((n, m))
        
        # Bottom level is identity
        for series, idx in bottom_idx.items():
            S[self.series_to_idx[series], idx] = 1
        
        # Parents sum children
        def get_bottom_descendants(series):
            if series in bottom_level:
                return [series]
            descendants = []
            for child in self.hierarchy.get(series, []):
                descendants.extend(get_bottom_descendants(child))
            return descendants
        
        for parent in parents:
            descendants = get_bottom_descendants(parent)
            for desc in descendants:
                if desc in bottom_idx:
                    S[self.series_to_idx[parent], bottom_idx[desc]] = 1
        
        return S
    
    def reconcile(
        self, 
        forecasts: Dict[str, float],
        method: str = "ols"
    ) -> Dict[str, float]:
        """Reconcile forecasts."""
        if method == "ols":
            return reconcile_forecasts_ols(forecasts, self.hierarchy)
        elif method == "mint":
            # Convert to array
            n = len(self.series_to_idx)
            base = np.zeros(n)
            for series, value in forecasts.items():
                if series in self.series_to_idx:
                    base[self.series_to_idx[series]] = value
            
            S = self._build_summing_matrix()
            W = np.eye(n)  # Simplified: use identity
            
            reconciled_arr = reconcile_mint(base, S, W)
            
            return {
                self.idx_to_series[i]: float(reconciled_arr[i])
                for i in range(n)
            }
        else:
            raise ValueError(f"Unknown method: {method}")


# Usage
hierarchy = {
    "total": ["north", "south"],
    "north": ["store_001", "store_002"],
    "south": ["store_003"]
}

base_forecasts = {
    "total": 1000,
    "north": 600,
    "south": 400,
    "store_001": 350,
    "store_002": 300,  # Sum = 650, but north = 600
    "store_003": 400
}

reconciler = HierarchicalReconciler(hierarchy)
reconciled = reconciler.reconcile(base_forecasts, method="ols")
# store_001: 323, store_002: 277 (scaled to match north=600)

37.4.6. Global Models (Transformer)

Single model for all series with cross-series learning:

Context Window Management

import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from typing import List, Optional, Tuple
import numpy as np

class TimeSeriesEmbedding(nn.Module):
    """Embed time series with metadata."""
    
    def __init__(
        self,
        d_model: int = 256,
        max_seq_len: int = 512,
        num_categories: int = 1000
    ):
        super().__init__()
        
        self.value_projection = nn.Linear(1, d_model)
        self.position_encoding = nn.Embedding(max_seq_len, d_model)
        self.category_embedding = nn.Embedding(num_categories, d_model)
        
        # Time features
        self.time_feature_projection = nn.Linear(7, d_model)  # dow, month, etc.
    
    def forward(
        self,
        values: torch.Tensor,  # (batch, seq_len)
        category_ids: torch.Tensor,  # (batch,)
        time_features: torch.Tensor  # (batch, seq_len, 7)
    ) -> torch.Tensor:
        batch_size, seq_len = values.shape
        
        # Value embedding
        value_emb = self.value_projection(values.unsqueeze(-1))
        
        # Position encoding
        positions = torch.arange(seq_len, device=values.device)
        pos_emb = self.position_encoding(positions)
        
        # Category embedding (broadcast)
        cat_emb = self.category_embedding(category_ids).unsqueeze(1)
        
        # Time features
        time_emb = self.time_feature_projection(time_features)
        
        # Combine
        return value_emb + pos_emb + cat_emb + time_emb


class GlobalForecaster(nn.Module):
    """Single Transformer model for all series."""
    
    def __init__(
        self,
        d_model: int = 256,
        nhead: int = 8,
        num_layers: int = 6,
        max_seq_len: int = 512,
        num_categories: int = 1000,
        forecast_horizon: int = 28
    ):
        super().__init__()
        
        self.embedding = TimeSeriesEmbedding(d_model, max_seq_len, num_categories)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.forecast_head = nn.Linear(d_model, forecast_horizon)
        self.quantile_heads = nn.ModuleDict({
            "q10": nn.Linear(d_model, forecast_horizon),
            "q50": nn.Linear(d_model, forecast_horizon),
            "q90": nn.Linear(d_model, forecast_horizon)
        })
    
    def forward(
        self,
        values: torch.Tensor,
        category_ids: torch.Tensor,
        time_features: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, dict]:
        # Embed
        x = self.embedding(values, category_ids, time_features)
        
        # Transform
        if attention_mask is not None:
            x = self.transformer(x, src_key_padding_mask=attention_mask)
        else:
            x = self.transformer(x)
        
        # Use last token for prediction
        last_hidden = x[:, -1, :]
        
        # Point forecast
        point_forecast = self.forecast_head(last_hidden)
        
        # Quantile forecasts
        quantiles = {
            name: head(last_hidden)
            for name, head in self.quantile_heads.items()
        }
        
        return point_forecast, quantiles


class GlobalForecasterPipeline:
    """Full pipeline for global model inference."""
    
    def __init__(
        self,
        model_path: str,
        device: str = "cuda",
        batch_size: int = 128
    ):
        self.device = torch.device(device)
        self.batch_size = batch_size
        
        # Load model
        self.model = GlobalForecaster()
        self.model.load_state_dict(torch.load(model_path))
        self.model.to(self.device)
        self.model.eval()
    
    def preprocess(
        self,
        histories: List[np.ndarray],
        category_ids: List[int],
        max_len: int = 365
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Preprocess batch of series."""
        batch_size = len(histories)
        
        # Pad/truncate to max_len
        padded = np.zeros((batch_size, max_len))
        mask = np.ones((batch_size, max_len), dtype=bool)
        
        for i, hist in enumerate(histories):
            length = min(len(hist), max_len)
            padded[i, -length:] = hist[-length:]
            mask[i, -length:] = False
        
        # Normalize
        means = np.mean(padded, axis=1, keepdims=True)
        stds = np.std(padded, axis=1, keepdims=True) + 1e-8
        normalized = (padded - means) / stds
        
        # Time features (simplified)
        time_features = np.zeros((batch_size, max_len, 7))
        
        return (
            torch.tensor(normalized, dtype=torch.float32),
            torch.tensor(category_ids, dtype=torch.long),
            torch.tensor(time_features, dtype=torch.float32),
            torch.tensor(mask, dtype=torch.bool),
            means,
            stds
        )
    
    def predict_batch(
        self,
        histories: List[np.ndarray],
        category_ids: List[int]
    ) -> List[dict]:
        """Predict for a batch of series."""
        values, cats, time_feat, mask, means, stds = self.preprocess(
            histories, category_ids
        )
        
        values = values.to(self.device)
        cats = cats.to(self.device)
        time_feat = time_feat.to(self.device)
        mask = mask.to(self.device)
        
        with torch.no_grad():
            point, quantiles = self.model(values, cats, time_feat, mask)
        
        # Denormalize
        point = point.cpu().numpy()
        point = point * stds + means
        
        results = []
        for i in range(len(histories)):
            results.append({
                "point": point[i].tolist(),
                "q10": (quantiles["q10"][i].cpu().numpy() * stds[i] + means[i]).tolist(),
                "q50": (quantiles["q50"][i].cpu().numpy() * stds[i] + means[i]).tolist(),
                "q90": (quantiles["q90"][i].cpu().numpy() * stds[i] + means[i]).tolist()
            })
        
        return results
    
    def predict_all(
        self,
        all_histories: List[np.ndarray],
        all_categories: List[int]
    ) -> List[dict]:
        """Predict all series in batches."""
        results = []
        
        for i in range(0, len(all_histories), self.batch_size):
            batch_hist = all_histories[i:i+self.batch_size]
            batch_cats = all_categories[i:i+self.batch_size]
            
            batch_results = self.predict_batch(batch_hist, batch_cats)
            results.extend(batch_results)
        
        return results

37.4.7. Cold Start Solution

New products with no history need special handling:

StrategyWhen to UseData Required
Metadata similaritySimilar products existProduct attributes
Category averageNew categoryCategory mapping
Expert judgmentNovel productDomain knowledge
Analogous productReplacement/upgradeLinking table
import numpy as np
from typing import List, Dict, Optional
from dataclasses import dataclass
from sklearn.metrics.pairwise import cosine_similarity

@dataclass
class ProductMetadata:
    product_id: str
    category: str
    subcategory: str
    price: float
    attributes: Dict[str, str]

class ColdStartForecaster:
    """Handle forecasting for new products."""
    
    def __init__(
        self,
        embedding_model,
        product_db: Dict[str, ProductMetadata],
        forecast_db: Dict[str, np.ndarray]
    ):
        self.embedder = embedding_model
        self.products = product_db
        self.forecasts = forecast_db
        
        # Pre-compute embeddings for existing products
        self.embeddings = {}
        for pid, meta in product_db.items():
            self.embeddings[pid] = self._compute_embedding(meta)
    
    def _compute_embedding(self, meta: ProductMetadata) -> np.ndarray:
        """Compute embedding from metadata."""
        # Create text representation
        text = f"{meta.category} {meta.subcategory} price:{meta.price}"
        for k, v in meta.attributes.items():
            text += f" {k}:{v}"
        
        return self.embedder.encode(text)
    
    def find_similar_products(
        self,
        new_meta: ProductMetadata,
        top_k: int = 5,
        same_category: bool = True
    ) -> List[tuple]:
        """Find most similar existing products."""
        new_embedding = self._compute_embedding(new_meta)
        
        similarities = []
        for pid, emb in self.embeddings.items():
            # Optionally filter by category
            if same_category and self.products[pid].category != new_meta.category:
                continue
            
            sim = cosine_similarity([new_embedding], [emb])[0][0]
            similarities.append((pid, sim))
        
        # Sort by similarity
        similarities.sort(key=lambda x: -x[1])
        
        return similarities[:top_k]
    
    def forecast_new_product(
        self,
        new_meta: ProductMetadata,
        horizon: int = 28,
        method: str = "weighted_average"
    ) -> dict:
        """Generate forecast for new product."""
        similar = self.find_similar_products(new_meta)
        
        if not similar:
            # Fallback to category average
            return self._category_average(new_meta.category, horizon)
        
        if method == "weighted_average":
            return self._weighted_average_forecast(similar, horizon)
        elif method == "top_1":
            return self._top_1_forecast(similar, horizon)
        else:
            raise ValueError(f"Unknown method: {method}")
    
    def _weighted_average_forecast(
        self,
        similar: List[tuple],
        horizon: int
    ) -> dict:
        """Weighted average of similar products' forecasts."""
        weights = []
        forecasts = []
        
        for pid, sim in similar:
            if pid in self.forecasts:
                weights.append(sim)
                forecasts.append(self.forecasts[pid][:horizon])
        
        if not forecasts:
            return {"point": [0] * horizon, "method": "fallback"}
        
        # Normalize weights
        weights = np.array(weights) / sum(weights)
        
        # Weighted average
        weighted = np.zeros(horizon)
        for w, f in zip(weights, forecasts):
            weighted += w * f
        
        return {
            "point": weighted.tolist(),
            "method": "weighted_average",
            "similar_products": [p[0] for p in similar],
            "weights": weights.tolist()
        }
    
    def _top_1_forecast(
        self,
        similar: List[tuple],
        horizon: int
    ) -> dict:
        """Use top similar product's forecast."""
        for pid, sim in similar:
            if pid in self.forecasts:
                return {
                    "point": self.forecasts[pid][:horizon].tolist(),
                    "method": "top_1",
                    "analog_product": pid,
                    "similarity": sim
                }
        
        return {"point": [0] * horizon, "method": "fallback"}
    
    def _category_average(
        self,
        category: str,
        horizon: int
    ) -> dict:
        """Average forecast for category."""
        category_forecasts = [
            self.forecasts[pid][:horizon]
            for pid, meta in self.products.items()
            if meta.category == category and pid in self.forecasts
        ]
        
        if not category_forecasts:
            return {"point": [0] * horizon, "method": "no_data"}
        
        avg = np.mean(category_forecasts, axis=0)
        
        return {
            "point": avg.tolist(),
            "method": "category_average",
            "category": category,
            "n_products": len(category_forecasts)
        }


# Usage
cold_start = ColdStartForecaster(
    embedding_model=SentenceTransformer("all-MiniLM-L6-v2"),
    product_db=load_product_metadata(),
    forecast_db=load_existing_forecasts()
)

new_product = ProductMetadata(
    product_id="NEW-001",
    category="Electronics",
    subcategory="Headphones",
    price=149.99,
    attributes={"wireless": "true", "brand": "Premium"}
)

forecast = cold_start.forecast_new_product(new_product)
# {'point': [...], 'method': 'weighted_average', 'similar_products': ['B001', 'B002']}

37.4.8. Monitoring Forecast Quality

import numpy as np
from typing import Dict, List
from datetime import datetime, timedelta
from prometheus_client import Gauge, Histogram

# Metrics
FORECAST_MAPE = Gauge(
    "forecast_mape",
    "Mean Absolute Percentage Error",
    ["category", "model_type"]
)

FORECAST_BIAS = Gauge(
    "forecast_bias",
    "Forecast Bias (positive = over-forecast)",
    ["category", "model_type"]
)

FORECAST_COVERAGE = Gauge(
    "forecast_coverage",
    "Prediction interval coverage",
    ["category", "quantile"]
)

class ForecastMonitor:
    """Monitor forecast accuracy over time."""
    
    def __init__(self, forecast_db, actuals_db):
        self.forecasts = forecast_db
        self.actuals = actuals_db
    
    def calculate_metrics(
        self,
        series_id: str,
        forecast_date: datetime,
        horizon: int = 7
    ) -> dict:
        """Calculate accuracy metrics for a forecast."""
        forecast = self.forecasts.get(series_id, forecast_date)
        actuals = self.actuals.get(
            series_id,
            forecast_date,
            forecast_date + timedelta(days=horizon)
        )
        
        if forecast is None or actuals is None:
            return {}
        
        forecast = np.array(forecast["point"][:horizon])
        actuals = np.array(actuals[:horizon])
        
        # MAPE
        mape = np.mean(np.abs(forecast - actuals) / (actuals + 1)) * 100
        
        # Bias
        bias = np.mean(forecast - actuals)
        bias_pct = np.mean((forecast - actuals) / (actuals + 1)) * 100
        
        # RMSE
        rmse = np.sqrt(np.mean((forecast - actuals) ** 2))
        
        # Coverage (if quantile forecasts available)
        coverage = {}
        if "q10" in forecast and "q90" in forecast:
            q10 = np.array(forecast["q10"][:horizon])
            q90 = np.array(forecast["q90"][:horizon])
            
            in_interval = (actuals >= q10) & (actuals <= q90)
            coverage["80"] = np.mean(in_interval).item()
        
        return {
            "mape": float(mape),
            "bias": float(bias),
            "bias_pct": float(bias_pct),
            "rmse": float(rmse),
            "coverage": coverage
        }
    
    def aggregate_metrics(
        self,
        category: str,
        date_range: tuple
    ) -> dict:
        """Aggregate metrics across category."""
        series_in_category = self._get_series_by_category(category)
        
        all_metrics = []
        for series_id in series_in_category:
            for date in self._date_range(date_range):
                metrics = self.calculate_metrics(series_id, date)
                if metrics:
                    all_metrics.append(metrics)
        
        if not all_metrics:
            return {}
        
        return {
            "mape_mean": np.mean([m["mape"] for m in all_metrics]),
            "mape_median": np.median([m["mape"] for m in all_metrics]),
            "bias_mean": np.mean([m["bias_pct"] for m in all_metrics]),
            "n_forecasts": len(all_metrics)
        }
    
    def update_prometheus_metrics(self, category: str, model_type: str):
        """Push metrics to Prometheus."""
        metrics = self.aggregate_metrics(category, (last_week, today))
        
        FORECAST_MAPE.labels(category=category, model_type=model_type).set(
            metrics.get("mape_mean", 0)
        )
        FORECAST_BIAS.labels(category=category, model_type=model_type).set(
            metrics.get("bias_mean", 0)
        )

37.4.9. Strategy Summary

TierSKU VolumeModel TypeReasonUpdate Frequency
Tier 1Top 20% by valueLocal ARIMA/ProphetHigh signal, explainableWeekly
Tier 2Middle 60%Hybrid (clustered)Balanced accuracy/costWeekly
Tier 3Bottom 20%Global TransformerSparse data, cold startDaily (batch)
New Products0 historyCold start methodsNo data availableOn-demand

Migration Path

graph LR
    A[Start: 100% Local] --> B[Step 1: Add Global for cold start]
    B --> C[Step 2: Tier by volume]
    C --> D[Step 3: Cluster medium tier]
    D --> E[Hybrid System]
    
    F[Measure MAPE at each step]
    G[Validate with A/B test]
    
    A --> F
    B --> F
    C --> F
    D --> F
    
    F --> G

37.4.10. Summary Checklist

StepActionPriority
1Tier series by volume/valueCritical
2Implement local model registryCritical
3Set up distributed training (K8s Jobs/Batch)High
4Add global model for cold startHigh
5Implement hierarchical reconciliationHigh
6Set up forecast monitoringHigh
7Cluster medium tier for hybridMedium
8Optimize inference batchingMedium
9Add quantile forecastsMedium
10A/B test model typesMedium

[End of Section 37.4]