32.5. Model Contracts: The API of AI
Tip
The Software Engineering View: Treat your ML Model exactly like a microservice. It must have a defined API, strict typing, and SLA guarantees. If the input format changes silently, the model breaks. If the output probability distribution shifts drastically, downstream systems break.
A Model Contract is a formal agreement between the Model Provider (Data Scientist) and the Model Consumer (Backend Engineer / Application). In mature MLOps organizations, you cannot deploy a model without a signed contract.
32.5.1. The Three Layers of Contracts
| Layer | What It Validates | When It’s Checked | Tooling |
|---|---|---|---|
| Schema | JSON structure, types | Request time | Pydantic |
| Semantic | Data meaning, business rules | Handover/CI | Great Expectations |
| SLA | Latency, throughput, uptime | Continuous monitoring | k6, Prometheus |
graph TB
A[API Request] --> B{Schema Contract}
B -->|Invalid| C[400 Bad Request]
B -->|Valid| D{Semantic Contract}
D -->|Violated| E[422 Unprocessable]
D -->|Valid| F[Model Inference]
F --> G{SLA Contract}
G -->|Breached| H[Alert + Fallback]
G -->|Met| I[200 Response]
32.5.2. Schema Contracts with Pydantic
Pydantic is the standard for Python API contracts. FastAPI auto-generates OpenAPI specs from Pydantic models.
Complete Contract Example
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import List, Optional, Literal
from datetime import datetime
from enum import Enum
class RiskCategory(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class CreditScoringInput(BaseModel):
"""Input contract for credit scoring model."""
# Schema layer - types and constraints
applicant_id: str = Field(..., min_length=8, max_length=32, pattern=r"^[A-Z0-9]+$")
age: int = Field(..., ge=18, le=120, description="Applicant age in years")
annual_income: float = Field(..., gt=0, description="Annual income in USD")
loan_amount: float = Field(..., gt=0, le=10_000_000)
employment_years: float = Field(..., ge=0, le=50)
credit_history_months: int = Field(..., ge=0, le=600)
existing_debt: float = Field(0, ge=0)
loan_purpose: Literal["home", "auto", "education", "personal", "business"]
# Semantic validators
@field_validator("age")
@classmethod
def validate_age(cls, v):
if v < 18:
raise ValueError("Applicant must be 18 or older")
return v
@field_validator("annual_income")
@classmethod
def validate_income(cls, v):
if v > 100_000_000:
raise ValueError("Income seems unrealistic, please verify")
return v
@model_validator(mode="after")
def validate_debt_ratio(self):
debt_ratio = (self.existing_debt + self.loan_amount) / self.annual_income
if debt_ratio > 10:
raise ValueError(f"Debt-to-income ratio {debt_ratio:.1f} exceeds maximum 10")
return self
class Config:
json_schema_extra = {
"example": {
"applicant_id": "APP12345678",
"age": 35,
"annual_income": 85000,
"loan_amount": 250000,
"employment_years": 8,
"credit_history_months": 156,
"existing_debt": 15000,
"loan_purpose": "home"
}
}
class PredictionOutput(BaseModel):
"""Output contract for credit scoring model."""
applicant_id: str
default_probability: float = Field(..., ge=0.0, le=1.0)
risk_category: RiskCategory
confidence: float = Field(..., ge=0.0, le=1.0)
model_version: str
prediction_timestamp: datetime
feature_contributions: Optional[dict] = None
@field_validator("default_probability")
@classmethod
def validate_probability(cls, v):
if v < 0 or v > 1:
raise ValueError("Probability must be between 0 and 1")
return round(v, 6)
class BatchInput(BaseModel):
"""Batch prediction input."""
applications: List[CreditScoringInput] = Field(..., min_length=1, max_length=1000)
correlation_id: Optional[str] = None
priority: Literal["low", "normal", "high"] = "normal"
class BatchOutput(BaseModel):
"""Batch prediction output."""
predictions: List[PredictionOutput]
processed: int
failed: int
latency_ms: float
correlation_id: Optional[str]
class ErrorResponse(BaseModel):
"""Standard error response."""
error_code: str
message: str
details: Optional[dict] = None
request_id: str
FastAPI Implementation
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from datetime import datetime
import logging
app = FastAPI(
title="Credit Scoring API",
description="ML-powered credit risk assessment",
version="2.0.0"
)
@app.exception_handler(ValueError)
async def validation_exception_handler(request: Request, exc: ValueError):
return JSONResponse(
status_code=422,
content=ErrorResponse(
error_code="VALIDATION_ERROR",
message=str(exc),
request_id=request.state.request_id
).model_dump()
)
@app.post(
"/v2/predict",
response_model=PredictionOutput,
responses={
400: {"model": ErrorResponse},
422: {"model": ErrorResponse},
500: {"model": ErrorResponse}
}
)
async def predict(request: CreditScoringInput) -> PredictionOutput:
"""Single prediction endpoint.
The model contract guarantees:
- Input validation per CreditScoringInput schema
- Output format per PredictionOutput schema
- Latency < 100ms for P95
- Probability calibration within 5% of actual
"""
# ... model inference ...
return PredictionOutput(
applicant_id=request.applicant_id,
default_probability=0.15,
risk_category=RiskCategory.MEDIUM,
confidence=0.92,
model_version="2.1.0",
prediction_timestamp=datetime.utcnow()
)
@app.post("/v2/batch", response_model=BatchOutput)
async def batch_predict(request: BatchInput) -> BatchOutput:
"""Batch prediction endpoint.
Contract guarantees:
- Maximum 1000 items per batch
- Processing completes within 30 seconds
- Partial failures return with individual error details
"""
# ... batch processing ...
pass
32.5.3. Semantic Contracts with Great Expectations
Schema validates syntax. Semantic contracts validate meaning and business rules.
Golden Dataset Testing
import great_expectations as gx
import pandas as pd
from typing import List, Dict
from dataclasses import dataclass
@dataclass
class ContractViolation:
rule_name: str
expectation_type: str
column: str
observed_value: any
expected: str
class SemanticContractValidator:
"""Validate model outputs against semantic contracts."""
def __init__(self, expectation_suite_name: str = "model_outputs"):
self.context = gx.get_context()
self.suite_name = expectation_suite_name
def create_expectation_suite(self) -> gx.ExpectationSuite:
"""Define semantic expectations for model outputs."""
suite = self.context.add_expectation_suite(self.suite_name)
# Probability must be calibrated (between 0 and 1)
suite.add_expectation(
gx.expectations.ExpectColumnValuesToBeBetween(
column="default_probability",
min_value=0.0,
max_value=1.0
)
)
# Risk category must match probability ranges
suite.add_expectation(
gx.expectations.ExpectColumnPairValuesToBeInSet(
column_A="risk_category",
column_B="probability_bucket",
value_pairs_set=[
("low", "0.0-0.25"),
("medium", "0.25-0.5"),
("high", "0.5-0.75"),
("critical", "0.75-1.0")
]
)
)
# VIP rule: High income should rarely be high risk
suite.add_expectation(
gx.expectations.ExpectColumnValuesToMatchRegex(
column="vip_check",
regex=r"^(pass|exempt)$"
)
)
# Distribution stability
suite.add_expectation(
gx.expectations.ExpectColumnMeanToBeBetween(
column="default_probability",
min_value=0.05,
max_value=0.30 # Historical range
)
)
return suite
def validate_predictions(
self,
predictions_df: pd.DataFrame,
reference_df: pd.DataFrame = None
) -> Dict:
"""Validate predictions against contracts."""
# Add derived columns for validation
predictions_df = predictions_df.copy()
predictions_df["probability_bucket"] = pd.cut(
predictions_df["default_probability"],
bins=[0, 0.25, 0.5, 0.75, 1.0],
labels=["0.0-0.25", "0.25-0.5", "0.5-0.75", "0.75-1.0"]
)
# VIP check
predictions_df["vip_check"] = predictions_df.apply(
lambda r: "pass" if r["annual_income"] < 1_000_000 else (
"pass" if r["default_probability"] < 0.5 else "fail"
),
axis=1
)
# Run validation
datasource = self.context.sources.add_pandas("predictions")
data_asset = datasource.add_dataframe_asset("predictions_df")
batch_request = data_asset.build_batch_request(dataframe=predictions_df)
checkpoint = self.context.add_or_update_checkpoint(
name="model_validation",
validations=[{
"batch_request": batch_request,
"expectation_suite_name": self.suite_name
}]
)
results = checkpoint.run()
return self._parse_results(results)
def _parse_results(self, results) -> Dict:
"""Parse validation results into actionable report."""
violations = []
for result in results.run_results.values():
for expectation_result in result["validation_result"]["results"]:
if not expectation_result["success"]:
violations.append(ContractViolation(
rule_name=expectation_result["expectation_config"]["expectation_type"],
expectation_type=expectation_result["expectation_config"]["expectation_type"],
column=expectation_result["expectation_config"].get("column", "N/A"),
observed_value=expectation_result["result"].get("observed_value"),
expected=str(expectation_result["expectation_config"])
))
return {
"passed": len(violations) == 0,
"violation_count": len(violations),
"violations": [v.__dict__ for v in violations]
}
# CI/CD integration
def verify_model_before_deploy(model_endpoint: str, test_data_path: str) -> bool:
"""Gate function for CI/CD pipeline."""
import requests
validator = SemanticContractValidator()
# Load golden test dataset
test_df = pd.read_parquet(test_data_path)
# Get predictions from staging model
predictions = []
for _, row in test_df.iterrows():
response = requests.post(
f"{model_endpoint}/v2/predict",
json=row.to_dict()
)
if response.status_code == 200:
predictions.append(response.json())
predictions_df = pd.DataFrame(predictions)
predictions_df = predictions_df.merge(
test_df[["applicant_id", "annual_income"]],
on="applicant_id"
)
# Validate
result = validator.validate_predictions(predictions_df)
if not result["passed"]:
print(f"❌ Contract violations: {result['violation_count']}")
for v in result["violations"]:
print(f" - {v['rule_name']}: {v['observed_value']}")
return False
print("✅ All semantic contracts passed")
return True
32.5.4. Service Level Contracts (SLA)
SLAs define operational guarantees: latency, throughput, availability.
SLA Definition
# sla.yaml
service: credit-scoring-model
version: 2.1.0
performance:
latency:
p50: 50ms
p95: 100ms
p99: 200ms
throughput:
sustained_rps: 500
burst_rps: 1000
cold_start: 2s
availability:
uptime: 99.9%
error_rate: 0.1%
quality:
probability_calibration: 5% # Within 5% of actual rate
feature_drift_threshold: 0.1
prediction_distribution_stability: 0.95 # PSI < 0.05
Load Testing with k6
// k6-sla-test.js
import http from 'k6/http';
import { check, sleep } from 'k6';
import { Rate, Trend } from 'k6/metrics';
// Custom metrics
const errorRate = new Rate('errors');
const latencyP95 = new Trend('latency_p95');
export const options = {
stages: [
{ duration: '1m', target: 100 }, // Ramp up
{ duration: '5m', target: 500 }, // Sustained load
{ duration: '1m', target: 1000 }, // Burst test
{ duration: '2m', target: 500 }, // Back to sustained
{ duration: '1m', target: 0 }, // Ramp down
],
thresholds: {
// SLA enforcement
'http_req_duration': ['p(95)<100', 'p(99)<200'], // Latency
'http_req_failed': ['rate<0.001'], // Error rate
'errors': ['rate<0.001'],
},
};
const payload = JSON.stringify({
applicant_id: 'TEST12345678',
age: 35,
annual_income: 85000,
loan_amount: 250000,
employment_years: 8,
credit_history_months: 156,
existing_debt: 15000,
loan_purpose: 'home'
});
const headers = { 'Content-Type': 'application/json' };
export default function () {
const res = http.post(
`${__ENV.API_URL}/v2/predict`,
payload,
{ headers }
);
const success = check(res, {
'status is 200': (r) => r.status === 200,
'response has prediction': (r) => r.json('default_probability') !== undefined,
'response time < 100ms': (r) => r.timings.duration < 100,
});
errorRate.add(!success);
latencyP95.add(res.timings.duration);
sleep(0.1);
}
export function handleSummary(data) {
// Generate SLA compliance report
const p95Latency = data.metrics.http_req_duration.values['p(95)'];
const errorPct = data.metrics.http_req_failed.values.rate * 100;
const slaCompliance = {
latency_p95: {
target: 100,
actual: p95Latency,
passed: p95Latency < 100
},
error_rate: {
target: 0.1,
actual: errorPct,
passed: errorPct < 0.1
},
overall_passed: p95Latency < 100 && errorPct < 0.1
};
return {
'sla-report.json': JSON.stringify(slaCompliance, null, 2),
stdout: textSummary(data, { indent: ' ', enableColors: true })
};
}
Python SLA Monitor
from prometheus_client import Histogram, Counter, Gauge
from dataclasses import dataclass
from typing import Optional
from datetime import datetime, timedelta
import time
# Prometheus metrics
REQUEST_LATENCY = Histogram(
"model_request_latency_seconds",
"Request latency in seconds",
["endpoint", "model_version"],
buckets=[.01, .025, .05, .075, .1, .25, .5, 1.0]
)
REQUEST_ERRORS = Counter(
"model_request_errors_total",
"Total request errors",
["endpoint", "model_version", "error_type"]
)
SLA_COMPLIANCE = Gauge(
"model_sla_compliance",
"SLA compliance status (1=compliant, 0=breached)",
["sla_type", "model_version"]
)
@dataclass
class SLAConfig:
latency_p95_ms: float = 100
latency_p99_ms: float = 200
error_rate_threshold: float = 0.001
throughput_rps: float = 500
class SLAMonitor:
"""Monitor and enforce SLA compliance."""
def __init__(self, config: SLAConfig):
self.config = config
self.request_times = []
self.error_count = 0
self.total_count = 0
self.window_start = datetime.utcnow()
self.window_size = timedelta(minutes=5)
def record_request(
self,
latency_ms: float,
success: bool,
endpoint: str,
model_version: str
):
"""Record a request for SLA tracking."""
self.total_count += 1
self.request_times.append(latency_ms)
if not success:
self.error_count += 1
REQUEST_ERRORS.labels(
endpoint=endpoint,
model_version=model_version,
error_type="prediction_error"
).inc()
REQUEST_LATENCY.labels(
endpoint=endpoint,
model_version=model_version
).observe(latency_ms / 1000) # Convert to seconds
# Check window
self._maybe_reset_window()
def _maybe_reset_window(self):
"""Reset metrics window if expired."""
now = datetime.utcnow()
if now - self.window_start > self.window_size:
self._evaluate_sla()
self.request_times = []
self.error_count = 0
self.total_count = 0
self.window_start = now
def _evaluate_sla(self):
"""Evaluate SLA compliance."""
import numpy as np
if not self.request_times:
return
times = np.array(self.request_times)
p95 = np.percentile(times, 95)
p99 = np.percentile(times, 99)
error_rate = self.error_count / self.total_count if self.total_count > 0 else 0
# Update Prometheus gauges
latency_compliant = p95 <= self.config.latency_p95_ms
error_compliant = error_rate <= self.config.error_rate_threshold
SLA_COMPLIANCE.labels(sla_type="latency_p95", model_version="2.1.0").set(
1 if latency_compliant else 0
)
SLA_COMPLIANCE.labels(sla_type="error_rate", model_version="2.1.0").set(
1 if error_compliant else 0
)
# Log if breached
if not latency_compliant:
print(f"⚠️ SLA BREACH: P95 latency {p95:.1f}ms > {self.config.latency_p95_ms}ms")
if not error_compliant:
print(f"⚠️ SLA BREACH: Error rate {error_rate:.4f} > {self.config.error_rate_threshold}")
def get_status(self) -> dict:
"""Get current SLA status."""
import numpy as np
if not self.request_times:
return {"status": "no_data"}
times = np.array(self.request_times)
return {
"window_start": self.window_start.isoformat(),
"total_requests": self.total_count,
"error_count": self.error_count,
"error_rate": self.error_count / self.total_count,
"latency_p50_ms": float(np.percentile(times, 50)),
"latency_p95_ms": float(np.percentile(times, 95)),
"latency_p99_ms": float(np.percentile(times, 99)),
"p95_compliant": np.percentile(times, 95) <= self.config.latency_p95_ms,
"error_rate_compliant": (self.error_count / self.total_count) <= self.config.error_rate_threshold
}
32.5.5. Consumer-Driven Contract Testing (Pact)
Integration tests are slow and flaky. Contract tests are fast and deterministic.
Workflow
sequenceDiagram
participant Frontend
participant Pact Broker
participant ML API
Frontend->>Frontend: Write consumer test
Frontend->>Pact Broker: Publish pact.json
ML API->>Pact Broker: Fetch pact.json
ML API->>ML API: Verify against local server
ML API-->>Pact Broker: Report verification
Pact Broker-->>Frontend: Contract verified ✓
Consumer Side (Frontend)
// consumer.pact.spec.js
const { Pact } = require('@pact-foundation/pact');
const path = require('path');
const axios = require('axios');
describe('Credit Scoring API Contract', () => {
const provider = new Pact({
consumer: 'frontend-app',
provider: 'credit-scoring-api',
port: 1234,
log: path.resolve(process.cwd(), 'logs', 'pact.log'),
dir: path.resolve(process.cwd(), 'pacts'),
});
beforeAll(() => provider.setup());
afterAll(() => provider.finalize());
afterEach(() => provider.verify());
describe('predict endpoint', () => {
it('returns prediction for valid input', async () => {
// Arrange
const expectedResponse = {
applicant_id: 'APP12345678',
default_probability: 0.15,
risk_category: 'medium',
confidence: 0.92,
model_version: '2.1.0'
};
await provider.addInteraction({
state: 'model is healthy',
uponReceiving: 'a prediction request',
withRequest: {
method: 'POST',
path: '/v2/predict',
headers: { 'Content-Type': 'application/json' },
body: {
applicant_id: 'APP12345678',
age: 35,
annual_income: 85000,
loan_amount: 250000,
employment_years: 8,
credit_history_months: 156,
existing_debt: 15000,
loan_purpose: 'home'
}
},
willRespondWith: {
status: 200,
headers: { 'Content-Type': 'application/json' },
body: {
applicant_id: Matchers.string('APP12345678'),
default_probability: Matchers.decimal(0.15),
risk_category: Matchers.regex('(low|medium|high|critical)', 'medium'),
confidence: Matchers.decimal(0.92),
model_version: Matchers.string('2.1.0')
}
}
});
// Act
const response = await axios.post(
'http://localhost:1234/v2/predict',
{ /* input */ },
{ headers: { 'Content-Type': 'application/json' } }
);
// Assert
expect(response.status).toBe(200);
expect(response.data.risk_category).toMatch(/low|medium|high|critical/);
});
});
});
Provider Side (ML API)
# test_pact_provider.py
import pytest
from pact import Verifier
import subprocess
import time
import os
@pytest.fixture(scope="module")
def provider_server():
"""Start the FastAPI server."""
process = subprocess.Popen(
["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
time.sleep(3) # Wait for server to start
yield "http://localhost:8000"
process.terminate()
def test_verify_contract(provider_server):
"""Verify that we honor the consumer's contract."""
verifier = Verifier(
provider="credit-scoring-api",
provider_base_url=provider_server
)
# State handler for different test scenarios
verifier.set_state_handler(
"model is healthy",
lambda: True # Could set up specific model state
)
# Verify against pact files
success, logs = verifier.verify_pacts(
# From local pact file
"./pacts/frontend-app-credit-scoring-api.json",
# Or from Pact Broker
# broker_url="https://pact.company.com",
# publish_version="1.0.0"
)
assert success, f"Pact verification failed:\n{logs}"
def test_verify_from_broker(provider_server):
"""Verify against Pact Broker."""
verifier = Verifier(
provider="credit-scoring-api",
provider_base_url=provider_server
)
success, logs = verifier.verify_with_broker(
broker_url=os.environ.get("PACT_BROKER_URL"),
broker_token=os.environ.get("PACT_BROKER_TOKEN"),
publish_version=os.environ.get("GIT_SHA", "dev"),
provider_version_tag=os.environ.get("GIT_BRANCH", "main")
)
if not success:
print("Contract violations detected:")
print(logs)
pytest.fail("Pact verification failed")
32.5.6. High-Performance Contracts: gRPC + Protobuf
For high-throughput systems, JSON is too slow. Use Protocol Buffers.
Proto Definition
// credit_scoring.proto
syntax = "proto3";
package mlops.credit;
option go_package = "github.com/company/ml-api/proto/credit";
option java_package = "com.company.ml.credit";
// Service definition
service CreditScoring {
rpc Predict(PredictRequest) returns (PredictResponse);
rpc BatchPredict(BatchPredictRequest) returns (BatchPredictResponse);
rpc StreamPredict(stream PredictRequest) returns (stream PredictResponse);
}
// Request message
message PredictRequest {
string applicant_id = 1;
int32 age = 2;
double annual_income = 3;
double loan_amount = 4;
double employment_years = 5;
int32 credit_history_months = 6;
double existing_debt = 7;
LoanPurpose loan_purpose = 8;
// Reserved for future use
reserved 9, 10;
reserved "deprecated_field";
}
enum LoanPurpose {
LOAN_PURPOSE_UNSPECIFIED = 0;
LOAN_PURPOSE_HOME = 1;
LOAN_PURPOSE_AUTO = 2;
LOAN_PURPOSE_EDUCATION = 3;
LOAN_PURPOSE_PERSONAL = 4;
LOAN_PURPOSE_BUSINESS = 5;
}
// Response message
message PredictResponse {
string applicant_id = 1;
double default_probability = 2;
RiskCategory risk_category = 3;
double confidence = 4;
string model_version = 5;
google.protobuf.Timestamp prediction_timestamp = 6;
map<string, double> feature_contributions = 7;
}
enum RiskCategory {
RISK_CATEGORY_UNSPECIFIED = 0;
RISK_CATEGORY_LOW = 1;
RISK_CATEGORY_MEDIUM = 2;
RISK_CATEGORY_HIGH = 3;
RISK_CATEGORY_CRITICAL = 4;
}
message BatchPredictRequest {
repeated PredictRequest requests = 1;
string correlation_id = 2;
}
message BatchPredictResponse {
repeated PredictResponse predictions = 1;
int32 processed = 2;
int32 failed = 3;
double latency_ms = 4;
}
Python gRPC Server
# grpc_server.py
import grpc
from concurrent import futures
import credit_scoring_pb2
import credit_scoring_pb2_grpc
from google.protobuf.timestamp_pb2 import Timestamp
from datetime import datetime
class CreditScoringServicer(credit_scoring_pb2_grpc.CreditScoringServicer):
"""gRPC implementation of credit scoring service."""
def __init__(self, model):
self.model = model
self.model_version = "2.1.0"
def Predict(self, request, context):
"""Single prediction."""
# Validate contract
if request.age < 18 or request.age > 120:
context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
"Age must be between 18 and 120"
)
if request.annual_income <= 0:
context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
"Annual income must be positive"
)
# Run inference
probability = self._predict(request)
# Build response
timestamp = Timestamp()
timestamp.FromDatetime(datetime.utcnow())
return credit_scoring_pb2.PredictResponse(
applicant_id=request.applicant_id,
default_probability=probability,
risk_category=self._categorize_risk(probability),
confidence=0.92,
model_version=self.model_version,
prediction_timestamp=timestamp
)
def BatchPredict(self, request, context):
"""Batch prediction."""
import time
start = time.perf_counter()
predictions = []
failed = 0
for req in request.requests:
try:
pred = self.Predict(req, context)
predictions.append(pred)
except Exception:
failed += 1
latency = (time.perf_counter() - start) * 1000
return credit_scoring_pb2.BatchPredictResponse(
predictions=predictions,
processed=len(predictions),
failed=failed,
latency_ms=latency
)
def StreamPredict(self, request_iterator, context):
"""Streaming prediction for real-time processing."""
for request in request_iterator:
yield self.Predict(request, context)
def _predict(self, request) -> float:
# Model inference
return 0.15
def _categorize_risk(self, probability: float) -> int:
if probability < 0.25:
return credit_scoring_pb2.RISK_CATEGORY_LOW
elif probability < 0.5:
return credit_scoring_pb2.RISK_CATEGORY_MEDIUM
elif probability < 0.75:
return credit_scoring_pb2.RISK_CATEGORY_HIGH
else:
return credit_scoring_pb2.RISK_CATEGORY_CRITICAL
def serve():
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
]
)
credit_scoring_pb2_grpc.add_CreditScoringServicer_to_server(
CreditScoringServicer(model=None),
server
)
server.add_insecure_port('[::]:50051')
server.start()
server.wait_for_termination()
if __name__ == '__main__':
serve()
32.5.7. Schema Registry for Event-Driven Systems
In Kafka-based architectures, use a Schema Registry to enforce contracts.
Architecture
graph LR
A[Model Producer] --> B{Schema Registry}
B -->|Valid| C[Kafka Topic]
B -->|Invalid| D[Reject]
C --> E[Consumer A]
C --> F[Consumer B]
B --> G[Schema Evolution Check]
G -->|Compatible| H[Allow]
G -->|Breaking| D
Avro Schema Definition
{
"type": "record",
"name": "PredictionEvent",
"namespace": "com.company.ml.events",
"doc": "Credit scoring prediction event",
"fields": [
{
"name": "event_id",
"type": "string",
"doc": "Unique event identifier"
},
{
"name": "applicant_id",
"type": "string"
},
{
"name": "default_probability",
"type": "double",
"doc": "Probability of default (0.0-1.0)"
},
{
"name": "risk_category",
"type": {
"type": "enum",
"name": "RiskCategory",
"symbols": ["LOW", "MEDIUM", "HIGH", "CRITICAL"]
}
},
{
"name": "model_version",
"type": "string"
},
{
"name": "timestamp",
"type": "long",
"logicalType": "timestamp-millis"
},
{
"name": "feature_contributions",
"type": ["null", {"type": "map", "values": "double"}],
"default": null,
"doc": "Optional SHAP values"
}
]
}
Python Producer with Schema Registry
from confluent_kafka import SerializingProducer
from confluent_kafka.schema_registry import SchemaRegistryClient
from confluent_kafka.schema_registry.avro import AvroSerializer
from dataclasses import dataclass
from typing import Optional, Dict
import uuid
import time
@dataclass
class PredictionEvent:
applicant_id: str
default_probability: float
risk_category: str
model_version: str
feature_contributions: Optional[Dict[str, float]] = None
def to_dict(self) -> dict:
return {
"event_id": str(uuid.uuid4()),
"applicant_id": self.applicant_id,
"default_probability": self.default_probability,
"risk_category": self.risk_category,
"model_version": self.model_version,
"timestamp": int(time.time() * 1000),
"feature_contributions": self.feature_contributions
}
class PredictionEventProducer:
"""Produce prediction events with schema validation."""
def __init__(
self,
bootstrap_servers: str,
schema_registry_url: str,
topic: str
):
self.topic = topic
# Schema Registry client
schema_registry = SchemaRegistryClient({
"url": schema_registry_url
})
# Load Avro schema
with open("schemas/prediction_event.avsc") as f:
schema_str = f.read()
# Serializer with schema validation
avro_serializer = AvroSerializer(
schema_registry,
schema_str,
lambda event, ctx: event.to_dict()
)
# Producer config
self.producer = SerializingProducer({
"bootstrap.servers": bootstrap_servers,
"value.serializer": avro_serializer,
"acks": "all"
})
def produce(self, event: PredictionEvent) -> None:
"""Produce event to Kafka with schema validation."""
self.producer.produce(
topic=self.topic,
value=event,
key=event.applicant_id,
on_delivery=self._delivery_report
)
self.producer.flush()
def _delivery_report(self, err, msg):
if err:
print(f"Failed to deliver: {err}")
else:
print(f"Delivered to {msg.topic()}[{msg.partition()}]")
# Usage
producer = PredictionEventProducer(
bootstrap_servers="kafka:9092",
schema_registry_url="http://schema-registry:8081",
topic="predictions"
)
event = PredictionEvent(
applicant_id="APP12345678",
default_probability=0.15,
risk_category="MEDIUM",
model_version="2.1.0"
)
producer.produce(event)
32.5.8. Versioning and Breaking Changes
Semantic Versioning for ML
| Version Change | Type | Example | Action |
|---|---|---|---|
| MAJOR (v2.0.0) | Breaking API | Remove input field | New endpoint URL |
| MINOR (v1.2.0) | Backward compatible | Add optional field | Deploy in place |
| PATCH (v1.2.1) | Bug fix | Fix memory leak | Deploy in place |
Non-Breaking Changes
✅ Safe to deploy in place:
- Adding optional input fields
- Adding new output fields
- Improving model accuracy
- Adding new endpoints
- Relaxing validation (accepting more formats)
Breaking Changes
❌ Require new version:
- Removing or renaming input fields
- Changing output field types
- Tightening validation
- Changing probability distribution significantly
- Removing endpoints
Migration Pattern
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import warnings
app = FastAPI()
# V1 - Deprecated
@app.post("/v1/predict")
async def predict_v1(request: CreditScoringInputV1):
warnings.warn("v1 is deprecated, migrate to v2", DeprecationWarning)
# Transform to v2 format
v2_input = transform_v1_to_v2(request)
# Use v2 logic
result = await predict_v2(v2_input)
# Transform back to v1 format
return transform_v2_to_v1(result)
# V2 - Current
@app.post("/v2/predict")
async def predict_v2(request: CreditScoringInputV2):
# ... implementation
pass
# Deprecation header middleware
@app.middleware("http")
async def deprecation_header(request: Request, call_next):
response = await call_next(request)
if "/v1/" in request.url.path:
response.headers["Deprecation"] = "true"
response.headers["Sunset"] = "2024-06-01T00:00:00Z"
response.headers["Link"] = '</v2/predict>; rel="successor-version"'
return response
32.5.9. Summary Checklist
| Layer | What to Define | Tool | When to Check |
|---|---|---|---|
| Schema | Types, constraints | Pydantic | Every request |
| Semantic | Business rules | Great Expectations | CI/CD |
| SLA | Latency, error rate | k6, Prometheus | Continuous |
| Consumer | Cross-team contracts | Pact | CI before deploy |
| Events | Message format | Schema Registry | Produce time |
Golden Rules
- Schema first: Define Pydantic/Protobuf before writing code
- Test semantics: Run Great Expectations on golden datasets
- Enforce SLAs: k6 load tests in CI/CD
- Consumer contracts: Pact verification before merge
- Version everything: Never break v1
[End of Section 32.5]