Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

18.3 Drift Detection Strategies

In traditional software, code behaves deterministically: if (a > b) always yields the same result for the same input. In Machine Learning, the “logic” is learned from data, and that logic is only valid as long as the world matches the data it was learned from.

Drift is the phenomenon where a model’s performance degrades over time not because the code changed (bugs), but because the world changed. It is the entropy of AI systems.


1. Taxonomy of Drift

Drift is often used as a catch-all term, but we must distinguish between three distinct failures to treat them correctly.

1.1. Data Drift (Covariate Shift)

$P(X)$ changes. The statistical distribution of the input variables changes, but the relationship between input and output $P(Y|X)$ remains the same.

  • Example: An autonomous car trained in Sunny California is deployed in Snowy Boston. The model has never seen snow visuals. The inputs (pixels) have drifted significantly.
  • Detection: Monitoring the statistics of input features (mean, variance, null rate). This requires no ground truth labels. It can be detected at the moment of inference.

1.2. Concept Drift (Label Shift)

$P(Y|X)$ changes. The fundamental relationship changes. The input data looks the same to the statistical monitor, but the “correct answer” is now different.

  • Example: A Spam classifier. “Free COVID Test” was a legitimate email in 2020. In 2023, it is likely spam. The text features (X) are similar, but the intent ($Y$) implies a different label.
  • Detection: Requires Ground Truth (labels). Because labels are often delayed (users mark spam days later), this is harder to detect in real-time.

1.3. Prior Probability Shift

$P(Y)$ changes. The distribution of the target variable changes.

  • Example: A fraud model where fraud is normally 0.1% of traffic. Suddenly, a bot attack makes fraud 5% of traffic.
  • Impact: The model might be calibrated to expect rare fraud. Even if accurate, the business impact (False Positives) scales linearly with the class imbalance shift.

2. Statistical Detection Methods

How do we mathematically prove “this dataset is different from that dataset”? We compare the Reference Distribution (Training Data) with the Current Distribution (Inference Data window).

2.1. Rolling Your Own Drfit Detector (Python)

While SageMaker/Vertex have tools, understanding the math is key. Here is a production-grade drift detector using scipy.

import numpy as np
from scipy.spatial.distance import jensenshannon
from scipy.stats import ks_2samp

class DriftDetector:
    def __init__(self, reference_data):
        self.reference = reference_data
        
    def check_drift_numerical(self, current_data, threshold=0.1):
        """
        Uses Kolmogorov-Smirnov Test (Non-parametric)
        Returns: True if drift detected (p_value < 0.05)
        """
        statistic, p_value = ks_2samp(self.reference, current_data)
        
        # If p-value is small, we reject Random Hypothesis (Datasets are different)
        is_drift = p_value < 0.05
        return {
            "method": "Kolmogorov-Smirnov",
            "statistic": statistic,
            "p_value": p_value,
            "drift_detected": is_drift
        }

    def check_drift_categorical(self, current_data, threshold=0.1):
        """
        Uses Jensen-Shannon Divergence on probability distributions
        """
        # 1. Calculate Probabilities (Histograms)
        ref_counts = np.unique(self.reference, return_counts=True)
        cur_counts = np.unique(current_data, return_counts=True)
        
        # Align domains (omitted for brevity)
        p = self._normalize(ref_counts)
        q = self._normalize(cur_counts)
        
        # 2. Calculate JS Distance
        js_distance = jensenshannon(p, q)
        
        return {
            "method": "Jensen-Shannon",
            "distance": js_distance,
            "drift_detected": js_distance > threshold
        }

    def _normalize(self, counts):
        return counts[1] / np.sum(counts[1])

# Usage
# detector = DriftDetector(training_df['age'])
# result = detector.check_drift_numerical(serving_df['age'])

2.2. Kullback-Leibler (KL) Divergence

A measure of how one probability distribution $P$ diverges from a second, expected probability distribution $Q$. $$ D_{KL}(P || Q) = \sum P(x) \log( \frac{P(x)}{Q(x)} ) $$

  • Pros: Theoretically sound foundation for Information Theory.
  • Cons: Asymmetric. $D(P||Q) \neq D(Q||P)$. If $Q(x)$ is 0 where $P(x)$ is non-zero, it goes to infinity. Unstable for real-world monitoring.

2.3. Jensen-Shannon (JS) Divergence

A symmetric and smoothed version of KL divergence. $$ JSD(P || Q) = \frac{1}{2} D_{KL}(P || M) + \frac{1}{2} D_{KL}(Q || M) $$ where $M$ is the average of the two distributions.

  • Key Property: Always finite ($0 \le JSD \le 1$). Becomes the industry standard for cloud monitoring tools.
  • Threshold: Common alerting threshold is $JSD > 0.1$ (Noticeable drift) or $JSD > 0.2$ (Significant drift).

3. AWS SageMaker Model Monitor

SageMaker provides a fully managed solution to automate this.

3.1. The Mechanism

  1. Data Capture: The endpoint config is updated to capture Input/Output payloads to S3 (EnableCapture=True). This creates “jsonl” files in S3 buckets divided by Hour.
  2. Baseline Job: You run a processing job on the Training Data (e.g., train.csv). It calculates statistics (mean, discrete counts, quantiles) and saves a constraints.json and statistics.json.
  3. Monitoring Schedule: A recurring cron job (e.g., hourly) spins up a temporary container.
  4. Comparison: The container reads the captured S3 data for that hour, computes its stats, compares to the Baseline, and checks against constraints.

3.2. Pre-processing Scripts (The Power Move)

SageMaker’s default monitor handles Tabular data (CSV/JSON). But what if you send Images (Base64) or Text?

  • Feature Engineering: You can supply a custom Python script (preprocessing.py) to the monitor.
# preprocessing.py for SageMaker Model Monitor
import json

def preprocess_handler(inference_record):
    """
    Transforms raw input (e.g., Text Review) into features (Length, Sentiment)
    """
    input_data = inference_record.endpoint_input.data
    output_data = inference_record.endpoint_output.data
    
    payload = json.loads(input_data)
    prediction = json.loads(output_data)
    
    # Feature 1: Review Length (Numerical Drift)
    text_len = len(payload['review_text'])
    
    # Feature 2: Confidence Score (Model Uncertainty)
    confidence = prediction['score']
    
    # Return formatted validation map
    return {
        "text_length": text_len,
        "confidence": confidence
    }

4. GCP Vertex AI Model Monitoring

Google’s approach is similar but integrates deeply with their data platform (BigQuery).

4.1. Training-Serving Skew vs. Prediction Drift

Vertex distinguishes explicitly:

  • Skew: Is the data I’m serving now different from the data I trained on?
    • Requires: Access to Training Data (BigQuery/GCS).
    • Detects: Integration bugs. Use of different feature engineering versions.
  • Drift: Is the data I’m serving today different from the data I served yesterday?
    • Requires: Only serving logs.
    • Detects: World changes.

4.2. Feature Attribution Drift

Vertex AI adds a layer of identifying which feature caused the drift.

  • It runs an XAI (Explainable AI) attribution method (Shapley values) on the incoming predictions.
  • It detects drift in the Feature Importances.
  • Alert Example: “Prediction drift detected. Main contributor: user_age feature importance increased by 40%.”
  • Why it matters: If user_id drifts (Input Drift), it might not matter if the model ignores user_id. But if user_age (a top feature) drifts, the model’s output will swing wildly.

5. Unstructured Data Drift (Embedding Drift)

For NLP and Vision, monitoring pixel means is useless. We monitor Embeddings.

5.1. The Technique

  1. Reference: Pass your validation set through the model (e.g., ResNet50) and capture the vector from the penultimate layer (1x2048 float vector).
  2. Live: Capture the same vector for every inference request.
  3. Dimensionality Reduction: You cannot run JS Divergence on 2048 dimensions (Curse of Dimensionality).
    • Apply PCA or UMAP to reduce the vectors to 2D or 50D.
  4. Drift Check: Measure the drift in this lower-dimensional space.

5.2. Implementing Embedding Monitor

Using sklearn PCA for monitoring.

from sklearn.decomposition import PCA
from scipy.spatial.distance import euclidean

class EmbeddingMonitor:
    def __init__(self, ref_embeddings):
        # ref_embeddings: [N, 2048]
        self.pca = PCA(n_components=2)
        self.pca.fit(ref_embeddings)
        
        self.ref_reduced = self.pca.transform(ref_embeddings)
        self.ref_centroid = np.mean(self.ref_reduced, axis=0)
        
    def check_drift(self, batch_embeddings):
        # 1. Project new data to same PCA space
        curr_reduced = self.pca.transform(batch_embeddings)
        
        # 2. Calculate Centroid Shift
        curr_centroid = np.mean(curr_reduced, axis=0)
        
        shift = euclidean(self.ref_centroid, curr_centroid)
        
        return shift

6. Drift Response Playbook (Airflow)

What do you do when the pager goes off? You trigger a DAG.

from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator

def check_drift_severity(**context):
    drift_score = context['ti'].xcom_pull(task_ids='calculate_drift')
    if drift_score > 0.5:
        return 'retrain_model'
    elif drift_score > 0.2:
        return 'send_warning_email'
    else:
        return 'do_nothing'

with DAG('drift_response_pipeline', schedule_interval=None) as dag:
    
    analyze_drift = PythonOperator(
        task_id='analyze_drift_magnitude',
        python_callable=analyze_drift_logic
    )
    
    branch_task = BranchPythonOperator(
        task_id='decide_action',
        python_callable=check_drift_severity
    )
    
    retrain = TriggerDagRunOperator(
        task_id='retrain_model',
        trigger_dag_id='training_pipeline_v1'
    )
    
    warning = EmailOperator(
        task_id='send_warning_email',
        to='mlops-team@company.com',
        subject='Moderate Drift Detected'
    )
    
    analyze_drift >> branch_task >> [retrain, warning]

In this chapter, we have closed the loop on the MLOps lifecycle. From Strategy (Part I) to Monitoring (Part VII), you now possess the architectural blueprint to build systems that survive in the real world.


7. Complete Statistical Drift Detection Library

7.1. Production-Grade Drift Detector

# drift_detector.py
import numpy as np
from scipy import stats
from scipy.spatial.distance import jensenshannon
from dataclasses import dataclass
from typing import Dict, List, Optional
from enum import Enum

class DriftSeverity(Enum):
    NONE = "none"
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"
    CRITICAL = "critical"

@dataclass
class DriftReport:
    feature_name: str
    method: str
    statistic: float
    p_value: Optional[float]
    drift_detected: bool
    severity: DriftSeverity
    recommendation: str

class ComprehensiveDriftDetector:
    def __init__(self, reference_data: Dict[str, np.ndarray]):
        """
        reference_data: Dict mapping feature names to arrays
        """
        self.reference = reference_data
        self.feature_types = self._infer_types()
    
    def _infer_types(self):
        """Automatically detect numerical vs categorical features."""
        types = {}
        for name, data in self.reference.items():
            unique_ratio = len(np.unique(data)) / len(data)
            if unique_ratio < 0.05 or data.dtype == object:
                types[name] = 'categorical'
            else:
                types[name] = 'numerical'
        return types
    
    def detect_drift(self, current_data: Dict[str, np.ndarray]) -> List[DriftReport]:
        """
        Run drift detection on all features.
        """
        reports = []
        
        for feature_name in self.reference.keys():
            if feature_name not in current_data:
                continue
            
            ref = self.reference[feature_name]
            curr = current_data[feature_name]
            
            if self.feature_types[feature_name] == 'numerical':
                report = self._detect_numerical_drift(feature_name, ref, curr)
            else:
                report = self._detect_categorical_drift(feature_name, ref, curr)
            
            reports.append(report)
        
        return reports
    
    def _detect_numerical_drift(self, name, ref, curr) -> DriftReport:
        """
        Multiple statistical tests for numerical features.
        """
        # Test 1: Kolmogorov-Smirnov Test
        ks_stat, p_value = stats.ks_2samp(ref, curr)
        
        # Test 2: Population Stability Index (PSI)
        psi = self._calculate_psi(ref, curr)
        
        # Severity determination
        if p_value < 0.001 or psi > 0.25:
            severity = DriftSeverity.CRITICAL
            recommendation = "Immediate retraining required"
        elif p_value < 0.01 or psi > 0.1:
            severity = DriftSeverity.HIGH
            recommendation = "Schedule retraining within 24 hours"
        elif p_value < 0.05:
            severity = DriftSeverity.MEDIUM
            recommendation = "Monitor closely, consider retraining"
        else:
            severity = DriftSeverity.NONE
            recommendation = "No action needed"
        
        return DriftReport(
            feature_name=name,
            method="KS Test + PSI",
            statistic=ks_stat,
            p_value=p_value,
            drift_detected=(p_value < 0.05),
            severity=severity,
            recommendation=recommendation
        )
    
    def _detect_categorical_drift(self, name, ref, curr) -> DriftReport:
        """
        Jensen-Shannon divergence for categorical features.
        """
        # Create probability distributions
        ref_unique, ref_counts = np.unique(ref, return_counts=True)
        curr_unique, curr_counts = np.unique(curr, return_counts=True)
        
        # Align categories
        all_categories = np.union1d(ref_unique, curr_unique)
        
        ref_probs = np.zeros(len(all_categories))
        curr_probs = np.zeros(len(all_categories))
        
        for i, cat in enumerate(all_categories):
            ref_idx = np.where(ref_unique == cat)[0]
            curr_idx = np.where(curr_unique == cat)[0]
            
            if len(ref_idx) > 0:
                ref_probs[i] = ref_counts[ref_idx[0]] / len(ref)
            if len(curr_idx) > 0:
                curr_probs[i] = curr_counts[curr_idx[0]] / len(curr)
        
        # Calculate JS divergence
        js_distance = jensenshannon(ref_probs, curr_probs)
        
        # Severity
        if js_distance > 0.5:
            severity = DriftSeverity.CRITICAL
        elif js_distance > 0.2:
            severity = DriftSeverity.HIGH
        elif js_distance > 0.1:
            severity = DriftSeverity.MEDIUM
        else:
            severity = DriftSeverity.NONE
        
        return DriftReport(
            feature_name=name,
            method="Jensen-Shannon Divergence",
            statistic=js_distance,
            p_value=None,
            drift_detected=(js_distance > 0.1),
            severity=severity,
            recommendation=f"JS Distance: {js_distance:.3f}"
        )
    
    def _calculate_psi(self, ref, curr, bins=10):
        """
        Population Stability Index - financial services standard.
        """
        # Create bins based on reference distribution
        _, bin_edges = np.histogram(ref, bins=bins)
        
        ref_hist, _ = np.histogram(ref, bins=bin_edges)
        curr_hist, _ = np.histogram(curr, bins=bin_edges)
        
        # Avoid division by zero
        ref_hist = (ref_hist + 0.0001) / len(ref)
        curr_hist = (curr_hist + 0.0001) / len(curr)
        
        psi = np.sum((curr_hist - ref_hist) * np.log(curr_hist / ref_hist))
        
        return psi

# Usage example
reference = {
    'age': np.random.normal(35, 10, 10000),
    'income': np.random.lognormal(10, 1, 10000),
    'category': np.random.choice(['A', 'B', 'C'], 10000)
}

current = {
    'age': np.random.normal(38, 12, 1000),  # Drifted
    'income': np.random.lognormal(10, 1, 1000),  # Not drifted
    'category': np.random.choice(['A', 'B', 'C', 'D'], 1000, p=[0.2, 0.2, 0.2, 0.4])  # New category!
}

detector = ComprehensiveDriftDetector(reference)
reports = detector.detect_drift(current)

for report in reports:
    if report.drift_detected:
        print(f"⚠️  {report.feature_name}: {report.severity.value}")
        print(f"   {report.recommendation}")

8. AWS SageMaker Model Monitor Complete Setup

8.1. Enable Data Capture

# enable_data_capture.py
import boto3
from sagemaker import Session

session = Session()
sm_client = boto3.client('sagemaker')

# Update endpoint to capture data
endpoint_config = sm_client.create_endpoint_config(
    EndpointConfigName='fraud-detector-monitored-config',
    ProductionVariants=[{
        'VariantName': 'AllTraffic',
        'ModelName': 'fraud-detector-v2',
        'InstanceType': 'ml.m5.xlarge',
        'InitialInstanceCount': 1
    }],
    DataCaptureConfig={
        'EnableCapture': True,
        'InitialSamplingPercentage': 100,  # Capture all requests
        'DestinationS3Uri': 's3://my-bucket/model-monitor/data-capture',
        'CaptureOptions': [
            {'CaptureMode': 'Input'},
            {'CaptureMode': 'Output'}
        ],
        'CaptureContentTypeHeader': {
            'CsvContentTypes': ['text/csv'],
            'JsonContentTypes': ['application/json']
        }
    }
)

sm_client.update_endpoint(
    EndpointName='fraud-detector-prod',
    EndpointConfigName='fraud-detector-monitored-config'
)

8.2. Create Baseline

# create_baseline.py
from sagemaker.model_monitor import DefaultModelMonitor
from sagemaker.model_monitor.dataset_format import DatasetFormat

monitor = DefaultModelMonitor(
    role='arn:aws:iam::123456789012:role/SageMakerRole',
    instance_count=1,
    instance_type='ml.m5.xlarge',
    volume_size_in_gb=20,
    max_runtime_in_seconds=3600
)

# Suggest baseline using training data
monitor.suggest_baseline(
    baseline_dataset='s3://my-bucket/training-data/train.csv',
    dataset_format=DatasetFormat.csv(header=True),
    output_s3_uri='s3://my-bucket/model-monitor/baseline',
    wait=True
)

print("✓ Baseline created")
print(f"Statistics: s3://my-bucket/model-monitor/baseline/statistics.json")
print(f"Constraints: s3://my-bucket/model-monitor/baseline/constraints.json")

8.3. Create Monitoring Schedule

# create_schedule.py
from sagemaker.model_monitor import CronExpressionGenerator

monitor.create_monitoring_schedule(
    monitor_schedule_name='fraud-detector-hourly-monitor',
    endpoint_input='fraud-detector-prod',
    output_s3_uri='s3://my-bucket/model-monitor/reports',
    statistics=monitor.baseline_statistics(),
    constraints=monitor.suggested_constraints(),
    schedule_cron_expression=CronExpressionGenerator.hourly(),
    enable_cloudwatch_metrics=True
)

print("✓ Monitoring schedule created")

8.4. Query Violations

# check_violations.py
import boto3
import json

s3 = boto3.client('s3')

def get_latest_violations(bucket, prefix):
    """
    Retrieve the most recent constraint violations.
    """
    response = s3.list_objects_v2(
        Bucket=bucket,
        Prefix=f'{prefix}/constraint_violations.json',
        MaxKeys=10
    )
    
    if 'Contents' not in response:
        return []
    
    # Get most recent
    latest = sorted(response['Contents'], key=lambda x: x['LastModified'], reverse=True)[0]
    
    obj = s3.get_object(Bucket=bucket, Key=latest['Key'])
    violations = json.loads(obj['Body'].read())
    
    return violations['violations']

violations = get_latest_violations('my-bucket', 'model-monitor/reports')

if violations:
    print("⚠️  Drift detected!")
    for v in violations:
        print(f"Feature: {v['feature_name']}")
        print(f"Violation: {v['violation_type']}")
        print(f"Description: {v['description']}\n")
else:
    print("✓ No violations")

9. GCP Vertex AI Model Monitoring Setup

9.1. Enable Monitoring (Python SDK)

# vertex_monitoring.py
from google.cloud import aiplatform

aiplatform.init(project='my-project', location='us-central1')

# Get existing endpoint
endpoint = aiplatform.Endpoint('projects/123/locations/us-central1/endpoints/456')

# Configure monitoring
from google.cloud.aiplatform_v1.types import ModelMonitoringObjectiveConfig

monitoring_config = ModelMonitoringObjectiveConfig(
    training_dataset={
        'data_format': 'csv',
        'gcs_source': {'uris': ['gs://my-bucket/training-data/train.csv']},
        'target_field': 'is_fraud'
    },
    training_prediction_skew_detection_config={
        'skew_thresholds': {
            'age': ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig.SkewThreshold(
                value=0.1
            ),
            'amount': ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig.SkewThreshold(
                value=0.15
            )
        }
    },
    prediction_drift_detection_config={
        'drift_thresholds': {
            'age': ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig.DriftThreshold(
                value=0.1
            ),
            'amount': ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig.DriftThreshold(
                value=0.15
            )
        }
    }
)

# Create monitoring job
monitoring_job = aiplatform.ModelDeploymentMonitoringJob.create(
    display_name='fraud-detector-monitor',
    endpoint=endpoint,
    logging_sampling_strategy=aiplatform.helpers.LoggingSamplingStrategy(1.0),  # 100%
    schedule_config=aiplatform.helpers.ScheduleConfig(cron_expression='0 * * * *'),  # Hourly
    model_monitoring_objective_configs=[monitoring_config],
    model_monitoring_alert_config=aiplatform.helpers.EmailAlertConfig(
        user_emails=['mlops-team@company.com']
    )
)

print(f"Monitoring job created: {monitoring_job.resource_name}")

9.2. Query Monitoring Results (BigQuery)

-- query_drift_results.sql
-- Vertex AI writes monitoring results to BigQuery

SELECT
  model_name,
  feature_name,
  training_stats.mean AS training_mean,
  prediction_stats.mean AS serving_mean,
  ABS(prediction_stats.mean - training_stats.mean) / training_stats.stddev AS drift_score
FROM
  `my-project.model_monitoring.prediction_stats`
WHERE
  DATE(prediction_time) = CURRENT_DATE()
  AND drift_score > 2.0  -- More than 2 standard deviations
ORDER BY
  drift_score DESC
LIMIT 10;

10. Real-Time Drift Detection in Inference Code

10.1. Lightweight In-Process Monitor

# realtime_drift_monitor.py
import numpy as np
from collections import deque
from threading import Lock

class RealTimeDriftMonitor:
    """
    Embedding drift detection within the inference server.
    Minimal overhead (<1ms per request).
    """
    def __init__(self, window_size=1000):
        self.window = deque(maxlen=window_size)
        self.baseline_stats = None
        self.lock = Lock()
    
    def set_baseline(self, baseline_data):
        """
        baseline_data: Dict[feature_name, np.array]
        """
        self.baseline_stats = {
            name: {
                'mean': np.mean(data),
                'std': np.std(data),
                'min': np.min(data),
                'max': np.max(data)
            }
            for name, data in baseline_data.items()
        }
    
    def observe(self, features: dict):
        """
        Called on every inference request.
        """
        with self.lock:
            self.window.append(features)
    
    def check_drift(self):
        """
        Periodically called (e.g., every 1000 requests).
        Returns drift score for each feature.
        """
        if len(self.window) < 100:
            return {}
        
        with self.lock:
            current_batch = list(self.window)
        
        # Aggregate into dict of arrays
        aggregated = {}
        for features in current_batch:
            for name, value in features.items():
                if name not in aggregated:
                    aggregated[name] = []
                aggregated[name].append(value)
        
        # Calculate drift
        drift_scores = {}
        for name, values in aggregated.items():
            if name not in self.baseline_stats:
                continue
            
            current_mean = np.mean(values)
            baseline_mean = self.baseline_stats[name]['mean']
            baseline_std = self.baseline_stats[name]['std']
            
            # Z-score drift
            drift = abs(current_mean - baseline_mean) / (baseline_std + 1e-6)
            drift_scores[name] = drift
        
        return drift_scores

# Integration with Flask inference server
monitor = RealTimeDriftMonitor()

@app.route('/predict', methods=['POST'])
def predict():
    features = request.get_json()
    
    # Record for drift monitoring
    monitor.observe(features)
    
    # Run inference
    prediction = model.predict(features)
    
    # Every 1000 requests, check drift
    if request_count % 1000 == 0:
        drift_scores = monitor.check_drift()
        for feature, score in drift_scores.items():
            if score > 3.0:
                logger.warning(f"Drift detected in {feature}: {score:.2f} std devs")
    
    return jsonify(prediction)

11. Automated Retraining Pipeline

11.1. Complete Airflow DAG

# drift_response_dag.py
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator
from datetime import datetime, timedelta

default_args = {
    'owner': 'mlops-team',
    'depends_on_past': False,
    'start_date': datetime(2024, 1, 1),
    'retries': 1,
    'retry_delay': timedelta(minutes=5)
}

def analyze_drift_severity(**context):
    """
    Download latest drift report and determine severity.
    """
    import boto3
    import json
    
    s3 = boto3.client('s3')
    obj = s3.get_object(Bucket='my-bucket', Key='model-monitor/reports/latest.json')
    report = json.loads(obj['Body'].read())
    
    violations = report.get('violations', [])
    
    if not violations:
        return 'no_action'
    
    critical_count = sum(1 for v in violations if 'critical' in v.get('description', '').lower())
    
    if critical_count > 0:
        return 'emergency_retrain'
    elif len(violations) > 5:
        return 'scheduled_retrain'
    else:
        return 'send_alert'

def prepare_training_dataset(**context):
    """
    Fetch recent data from production logs and prepare training set.
    """
    import pandas as pd
    
    # Query data warehouse for last 30 days of labeled data
    query = """
        SELECT * FROM production_predictions
        WHERE labeled = true
        AND timestamp > CURRENT_DATE - INTERVAL '30 days'
    """
    
    df = pd.read_sql(query, connection_string)
    
    # Save to S3
    df.to_csv('s3://my-bucket/retraining-data/latest.csv', index=False)
    
    return len(df)

with DAG(
    'drift_response_pipeline',
    default_args=default_args,
    schedule_interval='@hourly',
    catchup=False
) as dag:
    
    # Wait for new monitoring report
    wait_for_report = S3KeySensor(
        task_id='wait_for_monitoring_report',
        bucket_name='my-bucket',
        bucket_key='model-monitor/reports/{{ ds }}/constraints_violations.json',
        timeout=300,
        poke_interval=30
    )
    
    # Analyze drift
    analyze = PythonOperator(
        task_id='analyze_drift',
        python_callable=analyze_drift_severity
    )
    
    # Branch based on severity
    branch = BranchPythonOperator(
        task_id='determine_action',
        python_callable=analyze_drift_severity
    )
    
    # Option 1: Emergency retrain (immediate)
    emergency_retrain = TriggerDagRunOperator(
        task_id='emergency_retrain',
        trigger_dag_id='model_training_pipeline',
        conf={'priority': 'high', 'notify': 'pagerduty'}
    )
    
    # Option 2: Scheduled retrain
    prepare_data = PythonOperator(
        task_id='prepare_training_data',
        python_callable=prepare_training_dataset
    )
    
    scheduled_retrain = TriggerDagRunOperator(
        task_id='scheduled_retrain',
        trigger_dag_id='model_training_pipeline',
        conf={'priority': 'normal'}
    )
    
    # Option 3: Just alert
    send_alert = EmailOperator(
        task_id='send_alert',
        to=['mlops-team@company.com'],
        subject='Drift Detected - Investigate',
        html_content='Minor drift detected. Review monitoring dashboard.'
    )
    
    # Option 4: No action
    no_action = PythonOperator(
        task_id='no_action',
        python_callable=lambda: print("No drift detected")
    )
    
    # Pipeline
    wait_for_report >> analyze >> branch
    branch >> emergency_retrain
    branch >> prepare_data >> scheduled_retrain
    branch >> send_alert
    branch >> no_action

12. Champion/Challenger Pattern for Drift Mitigation

12.1. A/B Testing New Models

# champion_challenger.py
import boto3
import random

sm_client = boto3.client('sagemaker')

def create_ab_test_endpoint():
    """
    Deploy champion and challenger models with traffic splitting.
    """
    endpoint_config = sm_client.create_endpoint_config(
        EndpointConfigName='fraud-detector-ab-test',
        ProductionVariants=[
            {
                'VariantName': 'Champion',
                'ModelName': 'fraud-model-v1',
                'InstanceType': 'ml.m5.xlarge',
                'InitialInstanceCount': 3,
                'InitialVariantWeight': 0.9  # 90% traffic
            },
            {
                'VariantName': 'Challenger',
                'ModelName': 'fraud-model-v2-retrained',
                'InstanceType': 'ml.m5.xlarge',
                'InitialInstanceCount': 1,
                'InitialVariantWeight': 0.1  # 10% traffic
            }
        ]
    )
    
    sm_client.create_endpoint(
        EndpointName='fraud-detector-ab',
        EndpointConfigName='fraud-detector-ab-test'
    )

# After 1 week, analyze metrics
def evaluate_challenger():
    """
    Compare performance metrics between variants.
    """
    cloudwatch = boto3.client('cloudwatch')
    
    metrics = ['ModelLatency', 'Invocation4XXErrors', 'Invocation5XXErrors']
    
    for metric in metrics:
        champion_stats = cloudwatch.get_metric_statistics(
            Namespace='AWS/SageMaker',
            MetricName=metric,
            Dimensions=[
                {'Name': 'EndpointName', 'Value': 'fraud-detector-ab'},
                {'Name': 'VariantName', 'Value': 'Champion'}
            ],
            StartTime=datetime.utcnow() - timedelta(days=7),
            EndTime=datetime.utcnow(),
            Period=3600,
            Statistics=['Average']
        )
        
        challenger_stats = cloudwatch.get_metric_statistics(
            Namespace='AWS/SageMaker',
            MetricName=metric,
            Dimensions=[
                {'Name': 'EndpointName', 'Value': 'fraud-detector-ab'},
                {'Name': 'VariantName', 'Value': 'Challenger'}
            ],
            StartTime=datetime.utcnow() - timedelta(days=7),
            EndTime=datetime.utcnow(),
            Period=3600,
            Statistics=['Average']
        )
        
        print(f"{metric}:")
        print(f"  Champion: {np.mean([d['Average'] for d in champion_stats['Datapoints']]):.2f}")
        print(f"  Challenger: {np.mean([d['Average'] for d in challenger_stats['Datapoints']]):.2f}")

def promote_challenger():
    """
    If challenger performs better, shift 100% traffic.
    """
    sm_client.update_endpoint_weights_and_capacities(
        EndpointName='fraud-detector-ab',
        DesiredWeightsAndCapacities=[
            {'VariantName': 'Champion', 'DesiredWeight': 0.0},
            {'VariantName': 'Challenger', 'DesiredWeight': 1.0}
        ]
    )
    print("✓ Challenger promoted to production")

13. Conclusion

Drift is inevitable. The world changes, users change, adversaries adapt. The question is not “Will my model drift?” but “How quickly will I detect it, and how fast can I respond?”

Key principles:

  1. Monitor inputs AND outputs - Data drift is early warning, prediction drift is the fire
  2. Automate detection, not response - Humans decide to retrain, systems detect the need
  3. Design for rapid iteration - If retraining takes weeks, drift monitoring is pointless
  4. Use statistical rigor - “The model feels worse” is not a metric

With comprehensive monitoring in place—from infrastructure (18.1) to GPUs (18.2) to data (18.3)—you have closed the MLOps loop. Your system is no longer a static artifact deployed once, but a living system that observes itself, detects degradation, and triggers its own evolution.

This is production ML - systems that don’t just work today, but continue working tomorrow.