18.3 Drift Detection Strategies
In traditional software, code behaves deterministically: if (a > b) always yields the same result for the same input. In Machine Learning, the “logic” is learned from data, and that logic is only valid as long as the world matches the data it was learned from.
Drift is the phenomenon where a model’s performance degrades over time not because the code changed (bugs), but because the world changed. It is the entropy of AI systems.
1. Taxonomy of Drift
Drift is often used as a catch-all term, but we must distinguish between three distinct failures to treat them correctly.
1.1. Data Drift (Covariate Shift)
$P(X)$ changes. The statistical distribution of the input variables changes, but the relationship between input and output $P(Y|X)$ remains the same.
- Example: An autonomous car trained in Sunny California is deployed in Snowy Boston. The model has never seen snow visuals. The inputs (pixels) have drifted significantly.
- Detection: Monitoring the statistics of input features (mean, variance, null rate). This requires no ground truth labels. It can be detected at the moment of inference.
1.2. Concept Drift (Label Shift)
$P(Y|X)$ changes. The fundamental relationship changes. The input data looks the same to the statistical monitor, but the “correct answer” is now different.
- Example: A Spam classifier. “Free COVID Test” was a legitimate email in 2020. In 2023, it is likely spam. The text features (X) are similar, but the intent ($Y$) implies a different label.
- Detection: Requires Ground Truth (labels). Because labels are often delayed (users mark spam days later), this is harder to detect in real-time.
1.3. Prior Probability Shift
$P(Y)$ changes. The distribution of the target variable changes.
- Example: A fraud model where fraud is normally 0.1% of traffic. Suddenly, a bot attack makes fraud 5% of traffic.
- Impact: The model might be calibrated to expect rare fraud. Even if accurate, the business impact (False Positives) scales linearly with the class imbalance shift.
2. Statistical Detection Methods
How do we mathematically prove “this dataset is different from that dataset”? We compare the Reference Distribution (Training Data) with the Current Distribution (Inference Data window).
2.1. Rolling Your Own Drfit Detector (Python)
While SageMaker/Vertex have tools, understanding the math is key. Here is a production-grade drift detector using scipy.
import numpy as np
from scipy.spatial.distance import jensenshannon
from scipy.stats import ks_2samp
class DriftDetector:
def __init__(self, reference_data):
self.reference = reference_data
def check_drift_numerical(self, current_data, threshold=0.1):
"""
Uses Kolmogorov-Smirnov Test (Non-parametric)
Returns: True if drift detected (p_value < 0.05)
"""
statistic, p_value = ks_2samp(self.reference, current_data)
# If p-value is small, we reject Random Hypothesis (Datasets are different)
is_drift = p_value < 0.05
return {
"method": "Kolmogorov-Smirnov",
"statistic": statistic,
"p_value": p_value,
"drift_detected": is_drift
}
def check_drift_categorical(self, current_data, threshold=0.1):
"""
Uses Jensen-Shannon Divergence on probability distributions
"""
# 1. Calculate Probabilities (Histograms)
ref_counts = np.unique(self.reference, return_counts=True)
cur_counts = np.unique(current_data, return_counts=True)
# Align domains (omitted for brevity)
p = self._normalize(ref_counts)
q = self._normalize(cur_counts)
# 2. Calculate JS Distance
js_distance = jensenshannon(p, q)
return {
"method": "Jensen-Shannon",
"distance": js_distance,
"drift_detected": js_distance > threshold
}
def _normalize(self, counts):
return counts[1] / np.sum(counts[1])
# Usage
# detector = DriftDetector(training_df['age'])
# result = detector.check_drift_numerical(serving_df['age'])
2.2. Kullback-Leibler (KL) Divergence
A measure of how one probability distribution $P$ diverges from a second, expected probability distribution $Q$. $$ D_{KL}(P || Q) = \sum P(x) \log( \frac{P(x)}{Q(x)} ) $$
- Pros: Theoretically sound foundation for Information Theory.
- Cons: Asymmetric. $D(P||Q) \neq D(Q||P)$. If $Q(x)$ is 0 where $P(x)$ is non-zero, it goes to infinity. Unstable for real-world monitoring.
2.3. Jensen-Shannon (JS) Divergence
A symmetric and smoothed version of KL divergence. $$ JSD(P || Q) = \frac{1}{2} D_{KL}(P || M) + \frac{1}{2} D_{KL}(Q || M) $$ where $M$ is the average of the two distributions.
- Key Property: Always finite ($0 \le JSD \le 1$). Becomes the industry standard for cloud monitoring tools.
- Threshold: Common alerting threshold is $JSD > 0.1$ (Noticeable drift) or $JSD > 0.2$ (Significant drift).
3. AWS SageMaker Model Monitor
SageMaker provides a fully managed solution to automate this.
3.1. The Mechanism
- Data Capture: The endpoint config is updated to capture Input/Output payloads to S3 (
EnableCapture=True). This creates “jsonl” files in S3 buckets divided by Hour. - Baseline Job: You run a processing job on the Training Data (e.g.,
train.csv). It calculates statistics (mean, discrete counts, quantiles) and saves aconstraints.jsonandstatistics.json. - Monitoring Schedule: A recurring cron job (e.g., hourly) spins up a temporary container.
- Comparison: The container reads the captured S3 data for that hour, computes its stats, compares to the Baseline, and checks against constraints.
3.2. Pre-processing Scripts (The Power Move)
SageMaker’s default monitor handles Tabular data (CSV/JSON). But what if you send Images (Base64) or Text?
- Feature Engineering: You can supply a custom Python script (
preprocessing.py) to the monitor.
# preprocessing.py for SageMaker Model Monitor
import json
def preprocess_handler(inference_record):
"""
Transforms raw input (e.g., Text Review) into features (Length, Sentiment)
"""
input_data = inference_record.endpoint_input.data
output_data = inference_record.endpoint_output.data
payload = json.loads(input_data)
prediction = json.loads(output_data)
# Feature 1: Review Length (Numerical Drift)
text_len = len(payload['review_text'])
# Feature 2: Confidence Score (Model Uncertainty)
confidence = prediction['score']
# Return formatted validation map
return {
"text_length": text_len,
"confidence": confidence
}
4. GCP Vertex AI Model Monitoring
Google’s approach is similar but integrates deeply with their data platform (BigQuery).
4.1. Training-Serving Skew vs. Prediction Drift
Vertex distinguishes explicitly:
- Skew: Is the data I’m serving now different from the data I trained on?
- Requires: Access to Training Data (BigQuery/GCS).
- Detects: Integration bugs. Use of different feature engineering versions.
- Drift: Is the data I’m serving today different from the data I served yesterday?
- Requires: Only serving logs.
- Detects: World changes.
4.2. Feature Attribution Drift
Vertex AI adds a layer of identifying which feature caused the drift.
- It runs an XAI (Explainable AI) attribution method (Shapley values) on the incoming predictions.
- It detects drift in the Feature Importances.
- Alert Example: “Prediction drift detected. Main contributor:
user_agefeature importance increased by 40%.” - Why it matters: If
user_iddrifts (Input Drift), it might not matter if the model ignoresuser_id. But ifuser_age(a top feature) drifts, the model’s output will swing wildly.
5. Unstructured Data Drift (Embedding Drift)
For NLP and Vision, monitoring pixel means is useless. We monitor Embeddings.
5.1. The Technique
- Reference: Pass your validation set through the model (e.g., ResNet50) and capture the vector from the penultimate layer (1x2048 float vector).
- Live: Capture the same vector for every inference request.
- Dimensionality Reduction: You cannot run JS Divergence on 2048 dimensions (Curse of Dimensionality).
- Apply PCA or UMAP to reduce the vectors to 2D or 50D.
- Drift Check: Measure the drift in this lower-dimensional space.
5.2. Implementing Embedding Monitor
Using sklearn PCA for monitoring.
from sklearn.decomposition import PCA
from scipy.spatial.distance import euclidean
class EmbeddingMonitor:
def __init__(self, ref_embeddings):
# ref_embeddings: [N, 2048]
self.pca = PCA(n_components=2)
self.pca.fit(ref_embeddings)
self.ref_reduced = self.pca.transform(ref_embeddings)
self.ref_centroid = np.mean(self.ref_reduced, axis=0)
def check_drift(self, batch_embeddings):
# 1. Project new data to same PCA space
curr_reduced = self.pca.transform(batch_embeddings)
# 2. Calculate Centroid Shift
curr_centroid = np.mean(curr_reduced, axis=0)
shift = euclidean(self.ref_centroid, curr_centroid)
return shift
6. Drift Response Playbook (Airflow)
What do you do when the pager goes off? You trigger a DAG.
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
def check_drift_severity(**context):
drift_score = context['ti'].xcom_pull(task_ids='calculate_drift')
if drift_score > 0.5:
return 'retrain_model'
elif drift_score > 0.2:
return 'send_warning_email'
else:
return 'do_nothing'
with DAG('drift_response_pipeline', schedule_interval=None) as dag:
analyze_drift = PythonOperator(
task_id='analyze_drift_magnitude',
python_callable=analyze_drift_logic
)
branch_task = BranchPythonOperator(
task_id='decide_action',
python_callable=check_drift_severity
)
retrain = TriggerDagRunOperator(
task_id='retrain_model',
trigger_dag_id='training_pipeline_v1'
)
warning = EmailOperator(
task_id='send_warning_email',
to='mlops-team@company.com',
subject='Moderate Drift Detected'
)
analyze_drift >> branch_task >> [retrain, warning]
In this chapter, we have closed the loop on the MLOps lifecycle. From Strategy (Part I) to Monitoring (Part VII), you now possess the architectural blueprint to build systems that survive in the real world.
7. Complete Statistical Drift Detection Library
7.1. Production-Grade Drift Detector
# drift_detector.py
import numpy as np
from scipy import stats
from scipy.spatial.distance import jensenshannon
from dataclasses import dataclass
from typing import Dict, List, Optional
from enum import Enum
class DriftSeverity(Enum):
NONE = "none"
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
@dataclass
class DriftReport:
feature_name: str
method: str
statistic: float
p_value: Optional[float]
drift_detected: bool
severity: DriftSeverity
recommendation: str
class ComprehensiveDriftDetector:
def __init__(self, reference_data: Dict[str, np.ndarray]):
"""
reference_data: Dict mapping feature names to arrays
"""
self.reference = reference_data
self.feature_types = self._infer_types()
def _infer_types(self):
"""Automatically detect numerical vs categorical features."""
types = {}
for name, data in self.reference.items():
unique_ratio = len(np.unique(data)) / len(data)
if unique_ratio < 0.05 or data.dtype == object:
types[name] = 'categorical'
else:
types[name] = 'numerical'
return types
def detect_drift(self, current_data: Dict[str, np.ndarray]) -> List[DriftReport]:
"""
Run drift detection on all features.
"""
reports = []
for feature_name in self.reference.keys():
if feature_name not in current_data:
continue
ref = self.reference[feature_name]
curr = current_data[feature_name]
if self.feature_types[feature_name] == 'numerical':
report = self._detect_numerical_drift(feature_name, ref, curr)
else:
report = self._detect_categorical_drift(feature_name, ref, curr)
reports.append(report)
return reports
def _detect_numerical_drift(self, name, ref, curr) -> DriftReport:
"""
Multiple statistical tests for numerical features.
"""
# Test 1: Kolmogorov-Smirnov Test
ks_stat, p_value = stats.ks_2samp(ref, curr)
# Test 2: Population Stability Index (PSI)
psi = self._calculate_psi(ref, curr)
# Severity determination
if p_value < 0.001 or psi > 0.25:
severity = DriftSeverity.CRITICAL
recommendation = "Immediate retraining required"
elif p_value < 0.01 or psi > 0.1:
severity = DriftSeverity.HIGH
recommendation = "Schedule retraining within 24 hours"
elif p_value < 0.05:
severity = DriftSeverity.MEDIUM
recommendation = "Monitor closely, consider retraining"
else:
severity = DriftSeverity.NONE
recommendation = "No action needed"
return DriftReport(
feature_name=name,
method="KS Test + PSI",
statistic=ks_stat,
p_value=p_value,
drift_detected=(p_value < 0.05),
severity=severity,
recommendation=recommendation
)
def _detect_categorical_drift(self, name, ref, curr) -> DriftReport:
"""
Jensen-Shannon divergence for categorical features.
"""
# Create probability distributions
ref_unique, ref_counts = np.unique(ref, return_counts=True)
curr_unique, curr_counts = np.unique(curr, return_counts=True)
# Align categories
all_categories = np.union1d(ref_unique, curr_unique)
ref_probs = np.zeros(len(all_categories))
curr_probs = np.zeros(len(all_categories))
for i, cat in enumerate(all_categories):
ref_idx = np.where(ref_unique == cat)[0]
curr_idx = np.where(curr_unique == cat)[0]
if len(ref_idx) > 0:
ref_probs[i] = ref_counts[ref_idx[0]] / len(ref)
if len(curr_idx) > 0:
curr_probs[i] = curr_counts[curr_idx[0]] / len(curr)
# Calculate JS divergence
js_distance = jensenshannon(ref_probs, curr_probs)
# Severity
if js_distance > 0.5:
severity = DriftSeverity.CRITICAL
elif js_distance > 0.2:
severity = DriftSeverity.HIGH
elif js_distance > 0.1:
severity = DriftSeverity.MEDIUM
else:
severity = DriftSeverity.NONE
return DriftReport(
feature_name=name,
method="Jensen-Shannon Divergence",
statistic=js_distance,
p_value=None,
drift_detected=(js_distance > 0.1),
severity=severity,
recommendation=f"JS Distance: {js_distance:.3f}"
)
def _calculate_psi(self, ref, curr, bins=10):
"""
Population Stability Index - financial services standard.
"""
# Create bins based on reference distribution
_, bin_edges = np.histogram(ref, bins=bins)
ref_hist, _ = np.histogram(ref, bins=bin_edges)
curr_hist, _ = np.histogram(curr, bins=bin_edges)
# Avoid division by zero
ref_hist = (ref_hist + 0.0001) / len(ref)
curr_hist = (curr_hist + 0.0001) / len(curr)
psi = np.sum((curr_hist - ref_hist) * np.log(curr_hist / ref_hist))
return psi
# Usage example
reference = {
'age': np.random.normal(35, 10, 10000),
'income': np.random.lognormal(10, 1, 10000),
'category': np.random.choice(['A', 'B', 'C'], 10000)
}
current = {
'age': np.random.normal(38, 12, 1000), # Drifted
'income': np.random.lognormal(10, 1, 1000), # Not drifted
'category': np.random.choice(['A', 'B', 'C', 'D'], 1000, p=[0.2, 0.2, 0.2, 0.4]) # New category!
}
detector = ComprehensiveDriftDetector(reference)
reports = detector.detect_drift(current)
for report in reports:
if report.drift_detected:
print(f"⚠️ {report.feature_name}: {report.severity.value}")
print(f" {report.recommendation}")
8. AWS SageMaker Model Monitor Complete Setup
8.1. Enable Data Capture
# enable_data_capture.py
import boto3
from sagemaker import Session
session = Session()
sm_client = boto3.client('sagemaker')
# Update endpoint to capture data
endpoint_config = sm_client.create_endpoint_config(
EndpointConfigName='fraud-detector-monitored-config',
ProductionVariants=[{
'VariantName': 'AllTraffic',
'ModelName': 'fraud-detector-v2',
'InstanceType': 'ml.m5.xlarge',
'InitialInstanceCount': 1
}],
DataCaptureConfig={
'EnableCapture': True,
'InitialSamplingPercentage': 100, # Capture all requests
'DestinationS3Uri': 's3://my-bucket/model-monitor/data-capture',
'CaptureOptions': [
{'CaptureMode': 'Input'},
{'CaptureMode': 'Output'}
],
'CaptureContentTypeHeader': {
'CsvContentTypes': ['text/csv'],
'JsonContentTypes': ['application/json']
}
}
)
sm_client.update_endpoint(
EndpointName='fraud-detector-prod',
EndpointConfigName='fraud-detector-monitored-config'
)
8.2. Create Baseline
# create_baseline.py
from sagemaker.model_monitor import DefaultModelMonitor
from sagemaker.model_monitor.dataset_format import DatasetFormat
monitor = DefaultModelMonitor(
role='arn:aws:iam::123456789012:role/SageMakerRole',
instance_count=1,
instance_type='ml.m5.xlarge',
volume_size_in_gb=20,
max_runtime_in_seconds=3600
)
# Suggest baseline using training data
monitor.suggest_baseline(
baseline_dataset='s3://my-bucket/training-data/train.csv',
dataset_format=DatasetFormat.csv(header=True),
output_s3_uri='s3://my-bucket/model-monitor/baseline',
wait=True
)
print("✓ Baseline created")
print(f"Statistics: s3://my-bucket/model-monitor/baseline/statistics.json")
print(f"Constraints: s3://my-bucket/model-monitor/baseline/constraints.json")
8.3. Create Monitoring Schedule
# create_schedule.py
from sagemaker.model_monitor import CronExpressionGenerator
monitor.create_monitoring_schedule(
monitor_schedule_name='fraud-detector-hourly-monitor',
endpoint_input='fraud-detector-prod',
output_s3_uri='s3://my-bucket/model-monitor/reports',
statistics=monitor.baseline_statistics(),
constraints=monitor.suggested_constraints(),
schedule_cron_expression=CronExpressionGenerator.hourly(),
enable_cloudwatch_metrics=True
)
print("✓ Monitoring schedule created")
8.4. Query Violations
# check_violations.py
import boto3
import json
s3 = boto3.client('s3')
def get_latest_violations(bucket, prefix):
"""
Retrieve the most recent constraint violations.
"""
response = s3.list_objects_v2(
Bucket=bucket,
Prefix=f'{prefix}/constraint_violations.json',
MaxKeys=10
)
if 'Contents' not in response:
return []
# Get most recent
latest = sorted(response['Contents'], key=lambda x: x['LastModified'], reverse=True)[0]
obj = s3.get_object(Bucket=bucket, Key=latest['Key'])
violations = json.loads(obj['Body'].read())
return violations['violations']
violations = get_latest_violations('my-bucket', 'model-monitor/reports')
if violations:
print("⚠️ Drift detected!")
for v in violations:
print(f"Feature: {v['feature_name']}")
print(f"Violation: {v['violation_type']}")
print(f"Description: {v['description']}\n")
else:
print("✓ No violations")
9. GCP Vertex AI Model Monitoring Setup
9.1. Enable Monitoring (Python SDK)
# vertex_monitoring.py
from google.cloud import aiplatform
aiplatform.init(project='my-project', location='us-central1')
# Get existing endpoint
endpoint = aiplatform.Endpoint('projects/123/locations/us-central1/endpoints/456')
# Configure monitoring
from google.cloud.aiplatform_v1.types import ModelMonitoringObjectiveConfig
monitoring_config = ModelMonitoringObjectiveConfig(
training_dataset={
'data_format': 'csv',
'gcs_source': {'uris': ['gs://my-bucket/training-data/train.csv']},
'target_field': 'is_fraud'
},
training_prediction_skew_detection_config={
'skew_thresholds': {
'age': ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig.SkewThreshold(
value=0.1
),
'amount': ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig.SkewThreshold(
value=0.15
)
}
},
prediction_drift_detection_config={
'drift_thresholds': {
'age': ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig.DriftThreshold(
value=0.1
),
'amount': ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig.DriftThreshold(
value=0.15
)
}
}
)
# Create monitoring job
monitoring_job = aiplatform.ModelDeploymentMonitoringJob.create(
display_name='fraud-detector-monitor',
endpoint=endpoint,
logging_sampling_strategy=aiplatform.helpers.LoggingSamplingStrategy(1.0), # 100%
schedule_config=aiplatform.helpers.ScheduleConfig(cron_expression='0 * * * *'), # Hourly
model_monitoring_objective_configs=[monitoring_config],
model_monitoring_alert_config=aiplatform.helpers.EmailAlertConfig(
user_emails=['mlops-team@company.com']
)
)
print(f"Monitoring job created: {monitoring_job.resource_name}")
9.2. Query Monitoring Results (BigQuery)
-- query_drift_results.sql
-- Vertex AI writes monitoring results to BigQuery
SELECT
model_name,
feature_name,
training_stats.mean AS training_mean,
prediction_stats.mean AS serving_mean,
ABS(prediction_stats.mean - training_stats.mean) / training_stats.stddev AS drift_score
FROM
`my-project.model_monitoring.prediction_stats`
WHERE
DATE(prediction_time) = CURRENT_DATE()
AND drift_score > 2.0 -- More than 2 standard deviations
ORDER BY
drift_score DESC
LIMIT 10;
10. Real-Time Drift Detection in Inference Code
10.1. Lightweight In-Process Monitor
# realtime_drift_monitor.py
import numpy as np
from collections import deque
from threading import Lock
class RealTimeDriftMonitor:
"""
Embedding drift detection within the inference server.
Minimal overhead (<1ms per request).
"""
def __init__(self, window_size=1000):
self.window = deque(maxlen=window_size)
self.baseline_stats = None
self.lock = Lock()
def set_baseline(self, baseline_data):
"""
baseline_data: Dict[feature_name, np.array]
"""
self.baseline_stats = {
name: {
'mean': np.mean(data),
'std': np.std(data),
'min': np.min(data),
'max': np.max(data)
}
for name, data in baseline_data.items()
}
def observe(self, features: dict):
"""
Called on every inference request.
"""
with self.lock:
self.window.append(features)
def check_drift(self):
"""
Periodically called (e.g., every 1000 requests).
Returns drift score for each feature.
"""
if len(self.window) < 100:
return {}
with self.lock:
current_batch = list(self.window)
# Aggregate into dict of arrays
aggregated = {}
for features in current_batch:
for name, value in features.items():
if name not in aggregated:
aggregated[name] = []
aggregated[name].append(value)
# Calculate drift
drift_scores = {}
for name, values in aggregated.items():
if name not in self.baseline_stats:
continue
current_mean = np.mean(values)
baseline_mean = self.baseline_stats[name]['mean']
baseline_std = self.baseline_stats[name]['std']
# Z-score drift
drift = abs(current_mean - baseline_mean) / (baseline_std + 1e-6)
drift_scores[name] = drift
return drift_scores
# Integration with Flask inference server
monitor = RealTimeDriftMonitor()
@app.route('/predict', methods=['POST'])
def predict():
features = request.get_json()
# Record for drift monitoring
monitor.observe(features)
# Run inference
prediction = model.predict(features)
# Every 1000 requests, check drift
if request_count % 1000 == 0:
drift_scores = monitor.check_drift()
for feature, score in drift_scores.items():
if score > 3.0:
logger.warning(f"Drift detected in {feature}: {score:.2f} std devs")
return jsonify(prediction)
11. Automated Retraining Pipeline
11.1. Complete Airflow DAG
# drift_response_dag.py
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator
from datetime import datetime, timedelta
default_args = {
'owner': 'mlops-team',
'depends_on_past': False,
'start_date': datetime(2024, 1, 1),
'retries': 1,
'retry_delay': timedelta(minutes=5)
}
def analyze_drift_severity(**context):
"""
Download latest drift report and determine severity.
"""
import boto3
import json
s3 = boto3.client('s3')
obj = s3.get_object(Bucket='my-bucket', Key='model-monitor/reports/latest.json')
report = json.loads(obj['Body'].read())
violations = report.get('violations', [])
if not violations:
return 'no_action'
critical_count = sum(1 for v in violations if 'critical' in v.get('description', '').lower())
if critical_count > 0:
return 'emergency_retrain'
elif len(violations) > 5:
return 'scheduled_retrain'
else:
return 'send_alert'
def prepare_training_dataset(**context):
"""
Fetch recent data from production logs and prepare training set.
"""
import pandas as pd
# Query data warehouse for last 30 days of labeled data
query = """
SELECT * FROM production_predictions
WHERE labeled = true
AND timestamp > CURRENT_DATE - INTERVAL '30 days'
"""
df = pd.read_sql(query, connection_string)
# Save to S3
df.to_csv('s3://my-bucket/retraining-data/latest.csv', index=False)
return len(df)
with DAG(
'drift_response_pipeline',
default_args=default_args,
schedule_interval='@hourly',
catchup=False
) as dag:
# Wait for new monitoring report
wait_for_report = S3KeySensor(
task_id='wait_for_monitoring_report',
bucket_name='my-bucket',
bucket_key='model-monitor/reports/{{ ds }}/constraints_violations.json',
timeout=300,
poke_interval=30
)
# Analyze drift
analyze = PythonOperator(
task_id='analyze_drift',
python_callable=analyze_drift_severity
)
# Branch based on severity
branch = BranchPythonOperator(
task_id='determine_action',
python_callable=analyze_drift_severity
)
# Option 1: Emergency retrain (immediate)
emergency_retrain = TriggerDagRunOperator(
task_id='emergency_retrain',
trigger_dag_id='model_training_pipeline',
conf={'priority': 'high', 'notify': 'pagerduty'}
)
# Option 2: Scheduled retrain
prepare_data = PythonOperator(
task_id='prepare_training_data',
python_callable=prepare_training_dataset
)
scheduled_retrain = TriggerDagRunOperator(
task_id='scheduled_retrain',
trigger_dag_id='model_training_pipeline',
conf={'priority': 'normal'}
)
# Option 3: Just alert
send_alert = EmailOperator(
task_id='send_alert',
to=['mlops-team@company.com'],
subject='Drift Detected - Investigate',
html_content='Minor drift detected. Review monitoring dashboard.'
)
# Option 4: No action
no_action = PythonOperator(
task_id='no_action',
python_callable=lambda: print("No drift detected")
)
# Pipeline
wait_for_report >> analyze >> branch
branch >> emergency_retrain
branch >> prepare_data >> scheduled_retrain
branch >> send_alert
branch >> no_action
12. Champion/Challenger Pattern for Drift Mitigation
12.1. A/B Testing New Models
# champion_challenger.py
import boto3
import random
sm_client = boto3.client('sagemaker')
def create_ab_test_endpoint():
"""
Deploy champion and challenger models with traffic splitting.
"""
endpoint_config = sm_client.create_endpoint_config(
EndpointConfigName='fraud-detector-ab-test',
ProductionVariants=[
{
'VariantName': 'Champion',
'ModelName': 'fraud-model-v1',
'InstanceType': 'ml.m5.xlarge',
'InitialInstanceCount': 3,
'InitialVariantWeight': 0.9 # 90% traffic
},
{
'VariantName': 'Challenger',
'ModelName': 'fraud-model-v2-retrained',
'InstanceType': 'ml.m5.xlarge',
'InitialInstanceCount': 1,
'InitialVariantWeight': 0.1 # 10% traffic
}
]
)
sm_client.create_endpoint(
EndpointName='fraud-detector-ab',
EndpointConfigName='fraud-detector-ab-test'
)
# After 1 week, analyze metrics
def evaluate_challenger():
"""
Compare performance metrics between variants.
"""
cloudwatch = boto3.client('cloudwatch')
metrics = ['ModelLatency', 'Invocation4XXErrors', 'Invocation5XXErrors']
for metric in metrics:
champion_stats = cloudwatch.get_metric_statistics(
Namespace='AWS/SageMaker',
MetricName=metric,
Dimensions=[
{'Name': 'EndpointName', 'Value': 'fraud-detector-ab'},
{'Name': 'VariantName', 'Value': 'Champion'}
],
StartTime=datetime.utcnow() - timedelta(days=7),
EndTime=datetime.utcnow(),
Period=3600,
Statistics=['Average']
)
challenger_stats = cloudwatch.get_metric_statistics(
Namespace='AWS/SageMaker',
MetricName=metric,
Dimensions=[
{'Name': 'EndpointName', 'Value': 'fraud-detector-ab'},
{'Name': 'VariantName', 'Value': 'Challenger'}
],
StartTime=datetime.utcnow() - timedelta(days=7),
EndTime=datetime.utcnow(),
Period=3600,
Statistics=['Average']
)
print(f"{metric}:")
print(f" Champion: {np.mean([d['Average'] for d in champion_stats['Datapoints']]):.2f}")
print(f" Challenger: {np.mean([d['Average'] for d in challenger_stats['Datapoints']]):.2f}")
def promote_challenger():
"""
If challenger performs better, shift 100% traffic.
"""
sm_client.update_endpoint_weights_and_capacities(
EndpointName='fraud-detector-ab',
DesiredWeightsAndCapacities=[
{'VariantName': 'Champion', 'DesiredWeight': 0.0},
{'VariantName': 'Challenger', 'DesiredWeight': 1.0}
]
)
print("✓ Challenger promoted to production")
13. Conclusion
Drift is inevitable. The world changes, users change, adversaries adapt. The question is not “Will my model drift?” but “How quickly will I detect it, and how fast can I respond?”
Key principles:
- Monitor inputs AND outputs - Data drift is early warning, prediction drift is the fire
- Automate detection, not response - Humans decide to retrain, systems detect the need
- Design for rapid iteration - If retraining takes weeks, drift monitoring is pointless
- Use statistical rigor - “The model feels worse” is not a metric
With comprehensive monitoring in place—from infrastructure (18.1) to GPUs (18.2) to data (18.3)—you have closed the MLOps loop. Your system is no longer a static artifact deployed once, but a living system that observes itself, detects degradation, and triggers its own evolution.
This is production ML - systems that don’t just work today, but continue working tomorrow.