Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

32.5. Model Contracts: The API of AI

Tip

The Software Engineering View: Treat your ML Model exactly like a microservice. It must have a defined API, strict typing, and SLA guarantees. If the input format changes silently, the model breaks. If the output probability distribution shifts drastically, downstream systems break.

A Model Contract is a formal agreement between the Model Provider (Data Scientist) and the Model Consumer (Backend Engineer / Application). In mature MLOps organizations, you cannot deploy a model without a signed contract.


32.5.1. The Three Layers of Contracts

LayerWhat It ValidatesWhen It’s CheckedTooling
SchemaJSON structure, typesRequest timePydantic
SemanticData meaning, business rulesHandover/CIGreat Expectations
SLALatency, throughput, uptimeContinuous monitoringk6, Prometheus
graph TB
    A[API Request] --> B{Schema Contract}
    B -->|Invalid| C[400 Bad Request]
    B -->|Valid| D{Semantic Contract}
    D -->|Violated| E[422 Unprocessable]
    D -->|Valid| F[Model Inference]
    F --> G{SLA Contract}
    G -->|Breached| H[Alert + Fallback]
    G -->|Met| I[200 Response]

32.5.2. Schema Contracts with Pydantic

Pydantic is the standard for Python API contracts. FastAPI auto-generates OpenAPI specs from Pydantic models.

Complete Contract Example

from pydantic import BaseModel, Field, field_validator, model_validator
from typing import List, Optional, Literal
from datetime import datetime
from enum import Enum

class RiskCategory(str, Enum):
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"
    CRITICAL = "critical"

class CreditScoringInput(BaseModel):
    """Input contract for credit scoring model."""
    
    # Schema layer - types and constraints
    applicant_id: str = Field(..., min_length=8, max_length=32, pattern=r"^[A-Z0-9]+$")
    age: int = Field(..., ge=18, le=120, description="Applicant age in years")
    annual_income: float = Field(..., gt=0, description="Annual income in USD")
    loan_amount: float = Field(..., gt=0, le=10_000_000)
    employment_years: float = Field(..., ge=0, le=50)
    credit_history_months: int = Field(..., ge=0, le=600)
    existing_debt: float = Field(0, ge=0)
    loan_purpose: Literal["home", "auto", "education", "personal", "business"]
    
    # Semantic validators
    @field_validator("age")
    @classmethod
    def validate_age(cls, v):
        if v < 18:
            raise ValueError("Applicant must be 18 or older")
        return v
    
    @field_validator("annual_income")
    @classmethod
    def validate_income(cls, v):
        if v > 100_000_000:
            raise ValueError("Income seems unrealistic, please verify")
        return v
    
    @model_validator(mode="after")
    def validate_debt_ratio(self):
        debt_ratio = (self.existing_debt + self.loan_amount) / self.annual_income
        if debt_ratio > 10:
            raise ValueError(f"Debt-to-income ratio {debt_ratio:.1f} exceeds maximum 10")
        return self
    
    class Config:
        json_schema_extra = {
            "example": {
                "applicant_id": "APP12345678",
                "age": 35,
                "annual_income": 85000,
                "loan_amount": 250000,
                "employment_years": 8,
                "credit_history_months": 156,
                "existing_debt": 15000,
                "loan_purpose": "home"
            }
        }


class PredictionOutput(BaseModel):
    """Output contract for credit scoring model."""
    
    applicant_id: str
    default_probability: float = Field(..., ge=0.0, le=1.0)
    risk_category: RiskCategory
    confidence: float = Field(..., ge=0.0, le=1.0)
    model_version: str
    prediction_timestamp: datetime
    feature_contributions: Optional[dict] = None
    
    @field_validator("default_probability")
    @classmethod
    def validate_probability(cls, v):
        if v < 0 or v > 1:
            raise ValueError("Probability must be between 0 and 1")
        return round(v, 6)


class BatchInput(BaseModel):
    """Batch prediction input."""
    
    applications: List[CreditScoringInput] = Field(..., min_length=1, max_length=1000)
    correlation_id: Optional[str] = None
    priority: Literal["low", "normal", "high"] = "normal"


class BatchOutput(BaseModel):
    """Batch prediction output."""
    
    predictions: List[PredictionOutput]
    processed: int
    failed: int
    latency_ms: float
    correlation_id: Optional[str]


class ErrorResponse(BaseModel):
    """Standard error response."""
    
    error_code: str
    message: str
    details: Optional[dict] = None
    request_id: str

FastAPI Implementation

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from datetime import datetime
import logging

app = FastAPI(
    title="Credit Scoring API",
    description="ML-powered credit risk assessment",
    version="2.0.0"
)

@app.exception_handler(ValueError)
async def validation_exception_handler(request: Request, exc: ValueError):
    return JSONResponse(
        status_code=422,
        content=ErrorResponse(
            error_code="VALIDATION_ERROR",
            message=str(exc),
            request_id=request.state.request_id
        ).model_dump()
    )


@app.post(
    "/v2/predict",
    response_model=PredictionOutput,
    responses={
        400: {"model": ErrorResponse},
        422: {"model": ErrorResponse},
        500: {"model": ErrorResponse}
    }
)
async def predict(request: CreditScoringInput) -> PredictionOutput:
    """Single prediction endpoint.
    
    The model contract guarantees:
    - Input validation per CreditScoringInput schema
    - Output format per PredictionOutput schema
    - Latency < 100ms for P95
    - Probability calibration within 5% of actual
    """
    # ... model inference ...
    
    return PredictionOutput(
        applicant_id=request.applicant_id,
        default_probability=0.15,
        risk_category=RiskCategory.MEDIUM,
        confidence=0.92,
        model_version="2.1.0",
        prediction_timestamp=datetime.utcnow()
    )


@app.post("/v2/batch", response_model=BatchOutput)
async def batch_predict(request: BatchInput) -> BatchOutput:
    """Batch prediction endpoint.
    
    Contract guarantees:
    - Maximum 1000 items per batch
    - Processing completes within 30 seconds
    - Partial failures return with individual error details
    """
    # ... batch processing ...
    pass

32.5.3. Semantic Contracts with Great Expectations

Schema validates syntax. Semantic contracts validate meaning and business rules.

Golden Dataset Testing

import great_expectations as gx
import pandas as pd
from typing import List, Dict
from dataclasses import dataclass

@dataclass
class ContractViolation:
    rule_name: str
    expectation_type: str
    column: str
    observed_value: any
    expected: str

class SemanticContractValidator:
    """Validate model outputs against semantic contracts."""
    
    def __init__(self, expectation_suite_name: str = "model_outputs"):
        self.context = gx.get_context()
        self.suite_name = expectation_suite_name
    
    def create_expectation_suite(self) -> gx.ExpectationSuite:
        """Define semantic expectations for model outputs."""
        suite = self.context.add_expectation_suite(self.suite_name)
        
        # Probability must be calibrated (between 0 and 1)
        suite.add_expectation(
            gx.expectations.ExpectColumnValuesToBeBetween(
                column="default_probability",
                min_value=0.0,
                max_value=1.0
            )
        )
        
        # Risk category must match probability ranges
        suite.add_expectation(
            gx.expectations.ExpectColumnPairValuesToBeInSet(
                column_A="risk_category",
                column_B="probability_bucket",
                value_pairs_set=[
                    ("low", "0.0-0.25"),
                    ("medium", "0.25-0.5"),
                    ("high", "0.5-0.75"),
                    ("critical", "0.75-1.0")
                ]
            )
        )
        
        # VIP rule: High income should rarely be high risk
        suite.add_expectation(
            gx.expectations.ExpectColumnValuesToMatchRegex(
                column="vip_check",
                regex=r"^(pass|exempt)$"
            )
        )
        
        # Distribution stability
        suite.add_expectation(
            gx.expectations.ExpectColumnMeanToBeBetween(
                column="default_probability",
                min_value=0.05,
                max_value=0.30  # Historical range
            )
        )
        
        return suite
    
    def validate_predictions(
        self, 
        predictions_df: pd.DataFrame,
        reference_df: pd.DataFrame = None
    ) -> Dict:
        """Validate predictions against contracts."""
        
        # Add derived columns for validation
        predictions_df = predictions_df.copy()
        predictions_df["probability_bucket"] = pd.cut(
            predictions_df["default_probability"],
            bins=[0, 0.25, 0.5, 0.75, 1.0],
            labels=["0.0-0.25", "0.25-0.5", "0.5-0.75", "0.75-1.0"]
        )
        
        # VIP check
        predictions_df["vip_check"] = predictions_df.apply(
            lambda r: "pass" if r["annual_income"] < 1_000_000 else (
                "pass" if r["default_probability"] < 0.5 else "fail"
            ),
            axis=1
        )
        
        # Run validation
        datasource = self.context.sources.add_pandas("predictions")
        data_asset = datasource.add_dataframe_asset("predictions_df")
        batch_request = data_asset.build_batch_request(dataframe=predictions_df)
        
        checkpoint = self.context.add_or_update_checkpoint(
            name="model_validation",
            validations=[{
                "batch_request": batch_request,
                "expectation_suite_name": self.suite_name
            }]
        )
        
        results = checkpoint.run()
        
        return self._parse_results(results)
    
    def _parse_results(self, results) -> Dict:
        """Parse validation results into actionable report."""
        violations = []
        
        for result in results.run_results.values():
            for expectation_result in result["validation_result"]["results"]:
                if not expectation_result["success"]:
                    violations.append(ContractViolation(
                        rule_name=expectation_result["expectation_config"]["expectation_type"],
                        expectation_type=expectation_result["expectation_config"]["expectation_type"],
                        column=expectation_result["expectation_config"].get("column", "N/A"),
                        observed_value=expectation_result["result"].get("observed_value"),
                        expected=str(expectation_result["expectation_config"])
                    ))
        
        return {
            "passed": len(violations) == 0,
            "violation_count": len(violations),
            "violations": [v.__dict__ for v in violations]
        }


# CI/CD integration
def verify_model_before_deploy(model_endpoint: str, test_data_path: str) -> bool:
    """Gate function for CI/CD pipeline."""
    import requests
    
    validator = SemanticContractValidator()
    
    # Load golden test dataset
    test_df = pd.read_parquet(test_data_path)
    
    # Get predictions from staging model
    predictions = []
    for _, row in test_df.iterrows():
        response = requests.post(
            f"{model_endpoint}/v2/predict",
            json=row.to_dict()
        )
        if response.status_code == 200:
            predictions.append(response.json())
    
    predictions_df = pd.DataFrame(predictions)
    predictions_df = predictions_df.merge(
        test_df[["applicant_id", "annual_income"]],
        on="applicant_id"
    )
    
    # Validate
    result = validator.validate_predictions(predictions_df)
    
    if not result["passed"]:
        print(f"❌ Contract violations: {result['violation_count']}")
        for v in result["violations"]:
            print(f"  - {v['rule_name']}: {v['observed_value']}")
        return False
    
    print("✅ All semantic contracts passed")
    return True

32.5.4. Service Level Contracts (SLA)

SLAs define operational guarantees: latency, throughput, availability.

SLA Definition

# sla.yaml
service: credit-scoring-model
version: 2.1.0

performance:
  latency:
    p50: 50ms
    p95: 100ms
    p99: 200ms
  throughput:
    sustained_rps: 500
    burst_rps: 1000
  cold_start: 2s

availability:
  uptime: 99.9%
  error_rate: 0.1%

quality:
  probability_calibration: 5%  # Within 5% of actual rate
  feature_drift_threshold: 0.1
  prediction_distribution_stability: 0.95  # PSI < 0.05

Load Testing with k6

// k6-sla-test.js
import http from 'k6/http';
import { check, sleep } from 'k6';
import { Rate, Trend } from 'k6/metrics';

// Custom metrics
const errorRate = new Rate('errors');
const latencyP95 = new Trend('latency_p95');

export const options = {
  stages: [
    { duration: '1m', target: 100 },   // Ramp up
    { duration: '5m', target: 500 },   // Sustained load
    { duration: '1m', target: 1000 },  // Burst test
    { duration: '2m', target: 500 },   // Back to sustained
    { duration: '1m', target: 0 },     // Ramp down
  ],
  thresholds: {
    // SLA enforcement
    'http_req_duration': ['p(95)<100', 'p(99)<200'],  // Latency
    'http_req_failed': ['rate<0.001'],                 // Error rate
    'errors': ['rate<0.001'],
  },
};

const payload = JSON.stringify({
  applicant_id: 'TEST12345678',
  age: 35,
  annual_income: 85000,
  loan_amount: 250000,
  employment_years: 8,
  credit_history_months: 156,
  existing_debt: 15000,
  loan_purpose: 'home'
});

const headers = { 'Content-Type': 'application/json' };

export default function () {
  const res = http.post(
    `${__ENV.API_URL}/v2/predict`,
    payload,
    { headers }
  );
  
  const success = check(res, {
    'status is 200': (r) => r.status === 200,
    'response has prediction': (r) => r.json('default_probability') !== undefined,
    'response time < 100ms': (r) => r.timings.duration < 100,
  });
  
  errorRate.add(!success);
  latencyP95.add(res.timings.duration);
  
  sleep(0.1);
}

export function handleSummary(data) {
  // Generate SLA compliance report
  const p95Latency = data.metrics.http_req_duration.values['p(95)'];
  const errorPct = data.metrics.http_req_failed.values.rate * 100;
  
  const slaCompliance = {
    latency_p95: {
      target: 100,
      actual: p95Latency,
      passed: p95Latency < 100
    },
    error_rate: {
      target: 0.1,
      actual: errorPct,
      passed: errorPct < 0.1
    },
    overall_passed: p95Latency < 100 && errorPct < 0.1
  };
  
  return {
    'sla-report.json': JSON.stringify(slaCompliance, null, 2),
    stdout: textSummary(data, { indent: ' ', enableColors: true })
  };
}

Python SLA Monitor

from prometheus_client import Histogram, Counter, Gauge
from dataclasses import dataclass
from typing import Optional
from datetime import datetime, timedelta
import time

# Prometheus metrics
REQUEST_LATENCY = Histogram(
    "model_request_latency_seconds",
    "Request latency in seconds",
    ["endpoint", "model_version"],
    buckets=[.01, .025, .05, .075, .1, .25, .5, 1.0]
)

REQUEST_ERRORS = Counter(
    "model_request_errors_total",
    "Total request errors",
    ["endpoint", "model_version", "error_type"]
)

SLA_COMPLIANCE = Gauge(
    "model_sla_compliance",
    "SLA compliance status (1=compliant, 0=breached)",
    ["sla_type", "model_version"]
)

@dataclass
class SLAConfig:
    latency_p95_ms: float = 100
    latency_p99_ms: float = 200
    error_rate_threshold: float = 0.001
    throughput_rps: float = 500
    
class SLAMonitor:
    """Monitor and enforce SLA compliance."""
    
    def __init__(self, config: SLAConfig):
        self.config = config
        self.request_times = []
        self.error_count = 0
        self.total_count = 0
        self.window_start = datetime.utcnow()
        self.window_size = timedelta(minutes=5)
    
    def record_request(
        self, 
        latency_ms: float, 
        success: bool,
        endpoint: str,
        model_version: str
    ):
        """Record a request for SLA tracking."""
        self.total_count += 1
        self.request_times.append(latency_ms)
        
        if not success:
            self.error_count += 1
            REQUEST_ERRORS.labels(
                endpoint=endpoint,
                model_version=model_version,
                error_type="prediction_error"
            ).inc()
        
        REQUEST_LATENCY.labels(
            endpoint=endpoint,
            model_version=model_version
        ).observe(latency_ms / 1000)  # Convert to seconds
        
        # Check window
        self._maybe_reset_window()
    
    def _maybe_reset_window(self):
        """Reset metrics window if expired."""
        now = datetime.utcnow()
        if now - self.window_start > self.window_size:
            self._evaluate_sla()
            self.request_times = []
            self.error_count = 0
            self.total_count = 0
            self.window_start = now
    
    def _evaluate_sla(self):
        """Evaluate SLA compliance."""
        import numpy as np
        
        if not self.request_times:
            return
        
        times = np.array(self.request_times)
        p95 = np.percentile(times, 95)
        p99 = np.percentile(times, 99)
        error_rate = self.error_count / self.total_count if self.total_count > 0 else 0
        
        # Update Prometheus gauges
        latency_compliant = p95 <= self.config.latency_p95_ms
        error_compliant = error_rate <= self.config.error_rate_threshold
        
        SLA_COMPLIANCE.labels(sla_type="latency_p95", model_version="2.1.0").set(
            1 if latency_compliant else 0
        )
        SLA_COMPLIANCE.labels(sla_type="error_rate", model_version="2.1.0").set(
            1 if error_compliant else 0
        )
        
        # Log if breached
        if not latency_compliant:
            print(f"⚠️ SLA BREACH: P95 latency {p95:.1f}ms > {self.config.latency_p95_ms}ms")
        
        if not error_compliant:
            print(f"⚠️ SLA BREACH: Error rate {error_rate:.4f} > {self.config.error_rate_threshold}")
    
    def get_status(self) -> dict:
        """Get current SLA status."""
        import numpy as np
        
        if not self.request_times:
            return {"status": "no_data"}
        
        times = np.array(self.request_times)
        
        return {
            "window_start": self.window_start.isoformat(),
            "total_requests": self.total_count,
            "error_count": self.error_count,
            "error_rate": self.error_count / self.total_count,
            "latency_p50_ms": float(np.percentile(times, 50)),
            "latency_p95_ms": float(np.percentile(times, 95)),
            "latency_p99_ms": float(np.percentile(times, 99)),
            "p95_compliant": np.percentile(times, 95) <= self.config.latency_p95_ms,
            "error_rate_compliant": (self.error_count / self.total_count) <= self.config.error_rate_threshold
        }

32.5.5. Consumer-Driven Contract Testing (Pact)

Integration tests are slow and flaky. Contract tests are fast and deterministic.

Workflow

sequenceDiagram
    participant Frontend
    participant Pact Broker
    participant ML API
    
    Frontend->>Frontend: Write consumer test
    Frontend->>Pact Broker: Publish pact.json
    ML API->>Pact Broker: Fetch pact.json
    ML API->>ML API: Verify against local server
    ML API-->>Pact Broker: Report verification
    Pact Broker-->>Frontend: Contract verified ✓

Consumer Side (Frontend)

// consumer.pact.spec.js
const { Pact } = require('@pact-foundation/pact');
const path = require('path');
const axios = require('axios');

describe('Credit Scoring API Contract', () => {
  const provider = new Pact({
    consumer: 'frontend-app',
    provider: 'credit-scoring-api',
    port: 1234,
    log: path.resolve(process.cwd(), 'logs', 'pact.log'),
    dir: path.resolve(process.cwd(), 'pacts'),
  });

  beforeAll(() => provider.setup());
  afterAll(() => provider.finalize());
  afterEach(() => provider.verify());

  describe('predict endpoint', () => {
    it('returns prediction for valid input', async () => {
      // Arrange
      const expectedResponse = {
        applicant_id: 'APP12345678',
        default_probability: 0.15,
        risk_category: 'medium',
        confidence: 0.92,
        model_version: '2.1.0'
      };

      await provider.addInteraction({
        state: 'model is healthy',
        uponReceiving: 'a prediction request',
        withRequest: {
          method: 'POST',
          path: '/v2/predict',
          headers: { 'Content-Type': 'application/json' },
          body: {
            applicant_id: 'APP12345678',
            age: 35,
            annual_income: 85000,
            loan_amount: 250000,
            employment_years: 8,
            credit_history_months: 156,
            existing_debt: 15000,
            loan_purpose: 'home'
          }
        },
        willRespondWith: {
          status: 200,
          headers: { 'Content-Type': 'application/json' },
          body: {
            applicant_id: Matchers.string('APP12345678'),
            default_probability: Matchers.decimal(0.15),
            risk_category: Matchers.regex('(low|medium|high|critical)', 'medium'),
            confidence: Matchers.decimal(0.92),
            model_version: Matchers.string('2.1.0')
          }
        }
      });

      // Act
      const response = await axios.post(
        'http://localhost:1234/v2/predict',
        { /* input */ },
        { headers: { 'Content-Type': 'application/json' } }
      );

      // Assert
      expect(response.status).toBe(200);
      expect(response.data.risk_category).toMatch(/low|medium|high|critical/);
    });
  });
});

Provider Side (ML API)

# test_pact_provider.py
import pytest
from pact import Verifier
import subprocess
import time
import os

@pytest.fixture(scope="module")
def provider_server():
    """Start the FastAPI server."""
    process = subprocess.Popen(
        ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
    time.sleep(3)  # Wait for server to start
    yield "http://localhost:8000"
    process.terminate()


def test_verify_contract(provider_server):
    """Verify that we honor the consumer's contract."""
    verifier = Verifier(
        provider="credit-scoring-api",
        provider_base_url=provider_server
    )
    
    # State handler for different test scenarios
    verifier.set_state_handler(
        "model is healthy",
        lambda: True  # Could set up specific model state
    )
    
    # Verify against pact files
    success, logs = verifier.verify_pacts(
        # From local pact file
        "./pacts/frontend-app-credit-scoring-api.json",
        # Or from Pact Broker
        # broker_url="https://pact.company.com",
        # publish_version="1.0.0"
    )
    
    assert success, f"Pact verification failed:\n{logs}"


def test_verify_from_broker(provider_server):
    """Verify against Pact Broker."""
    verifier = Verifier(
        provider="credit-scoring-api",
        provider_base_url=provider_server
    )
    
    success, logs = verifier.verify_with_broker(
        broker_url=os.environ.get("PACT_BROKER_URL"),
        broker_token=os.environ.get("PACT_BROKER_TOKEN"),
        publish_version=os.environ.get("GIT_SHA", "dev"),
        provider_version_tag=os.environ.get("GIT_BRANCH", "main")
    )
    
    if not success:
        print("Contract violations detected:")
        print(logs)
        pytest.fail("Pact verification failed")

32.5.6. High-Performance Contracts: gRPC + Protobuf

For high-throughput systems, JSON is too slow. Use Protocol Buffers.

Proto Definition

// credit_scoring.proto
syntax = "proto3";

package mlops.credit;

option go_package = "github.com/company/ml-api/proto/credit";
option java_package = "com.company.ml.credit";

// Service definition
service CreditScoring {
  rpc Predict(PredictRequest) returns (PredictResponse);
  rpc BatchPredict(BatchPredictRequest) returns (BatchPredictResponse);
  rpc StreamPredict(stream PredictRequest) returns (stream PredictResponse);
}

// Request message
message PredictRequest {
  string applicant_id = 1;
  int32 age = 2;
  double annual_income = 3;
  double loan_amount = 4;
  double employment_years = 5;
  int32 credit_history_months = 6;
  double existing_debt = 7;
  LoanPurpose loan_purpose = 8;
  
  // Reserved for future use
  reserved 9, 10;
  reserved "deprecated_field";
}

enum LoanPurpose {
  LOAN_PURPOSE_UNSPECIFIED = 0;
  LOAN_PURPOSE_HOME = 1;
  LOAN_PURPOSE_AUTO = 2;
  LOAN_PURPOSE_EDUCATION = 3;
  LOAN_PURPOSE_PERSONAL = 4;
  LOAN_PURPOSE_BUSINESS = 5;
}

// Response message
message PredictResponse {
  string applicant_id = 1;
  double default_probability = 2;
  RiskCategory risk_category = 3;
  double confidence = 4;
  string model_version = 5;
  google.protobuf.Timestamp prediction_timestamp = 6;
  map<string, double> feature_contributions = 7;
}

enum RiskCategory {
  RISK_CATEGORY_UNSPECIFIED = 0;
  RISK_CATEGORY_LOW = 1;
  RISK_CATEGORY_MEDIUM = 2;
  RISK_CATEGORY_HIGH = 3;
  RISK_CATEGORY_CRITICAL = 4;
}

message BatchPredictRequest {
  repeated PredictRequest requests = 1;
  string correlation_id = 2;
}

message BatchPredictResponse {
  repeated PredictResponse predictions = 1;
  int32 processed = 2;
  int32 failed = 3;
  double latency_ms = 4;
}

Python gRPC Server

# grpc_server.py
import grpc
from concurrent import futures
import credit_scoring_pb2
import credit_scoring_pb2_grpc
from google.protobuf.timestamp_pb2 import Timestamp
from datetime import datetime

class CreditScoringServicer(credit_scoring_pb2_grpc.CreditScoringServicer):
    """gRPC implementation of credit scoring service."""
    
    def __init__(self, model):
        self.model = model
        self.model_version = "2.1.0"
    
    def Predict(self, request, context):
        """Single prediction."""
        # Validate contract
        if request.age < 18 or request.age > 120:
            context.abort(
                grpc.StatusCode.INVALID_ARGUMENT,
                "Age must be between 18 and 120"
            )
        
        if request.annual_income <= 0:
            context.abort(
                grpc.StatusCode.INVALID_ARGUMENT,
                "Annual income must be positive"
            )
        
        # Run inference
        probability = self._predict(request)
        
        # Build response
        timestamp = Timestamp()
        timestamp.FromDatetime(datetime.utcnow())
        
        return credit_scoring_pb2.PredictResponse(
            applicant_id=request.applicant_id,
            default_probability=probability,
            risk_category=self._categorize_risk(probability),
            confidence=0.92,
            model_version=self.model_version,
            prediction_timestamp=timestamp
        )
    
    def BatchPredict(self, request, context):
        """Batch prediction."""
        import time
        start = time.perf_counter()
        
        predictions = []
        failed = 0
        
        for req in request.requests:
            try:
                pred = self.Predict(req, context)
                predictions.append(pred)
            except Exception:
                failed += 1
        
        latency = (time.perf_counter() - start) * 1000
        
        return credit_scoring_pb2.BatchPredictResponse(
            predictions=predictions,
            processed=len(predictions),
            failed=failed,
            latency_ms=latency
        )
    
    def StreamPredict(self, request_iterator, context):
        """Streaming prediction for real-time processing."""
        for request in request_iterator:
            yield self.Predict(request, context)
    
    def _predict(self, request) -> float:
        # Model inference
        return 0.15
    
    def _categorize_risk(self, probability: float) -> int:
        if probability < 0.25:
            return credit_scoring_pb2.RISK_CATEGORY_LOW
        elif probability < 0.5:
            return credit_scoring_pb2.RISK_CATEGORY_MEDIUM
        elif probability < 0.75:
            return credit_scoring_pb2.RISK_CATEGORY_HIGH
        else:
            return credit_scoring_pb2.RISK_CATEGORY_CRITICAL


def serve():
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=10),
        options=[
            ('grpc.max_send_message_length', 50 * 1024 * 1024),
            ('grpc.max_receive_message_length', 50 * 1024 * 1024),
        ]
    )
    
    credit_scoring_pb2_grpc.add_CreditScoringServicer_to_server(
        CreditScoringServicer(model=None),
        server
    )
    
    server.add_insecure_port('[::]:50051')
    server.start()
    server.wait_for_termination()


if __name__ == '__main__':
    serve()

32.5.7. Schema Registry for Event-Driven Systems

In Kafka-based architectures, use a Schema Registry to enforce contracts.

Architecture

graph LR
    A[Model Producer] --> B{Schema Registry}
    B -->|Valid| C[Kafka Topic]
    B -->|Invalid| D[Reject]
    C --> E[Consumer A]
    C --> F[Consumer B]
    
    B --> G[Schema Evolution Check]
    G -->|Compatible| H[Allow]
    G -->|Breaking| D

Avro Schema Definition

{
  "type": "record",
  "name": "PredictionEvent",
  "namespace": "com.company.ml.events",
  "doc": "Credit scoring prediction event",
  "fields": [
    {
      "name": "event_id",
      "type": "string",
      "doc": "Unique event identifier"
    },
    {
      "name": "applicant_id",
      "type": "string"
    },
    {
      "name": "default_probability",
      "type": "double",
      "doc": "Probability of default (0.0-1.0)"
    },
    {
      "name": "risk_category",
      "type": {
        "type": "enum",
        "name": "RiskCategory",
        "symbols": ["LOW", "MEDIUM", "HIGH", "CRITICAL"]
      }
    },
    {
      "name": "model_version",
      "type": "string"
    },
    {
      "name": "timestamp",
      "type": "long",
      "logicalType": "timestamp-millis"
    },
    {
      "name": "feature_contributions",
      "type": ["null", {"type": "map", "values": "double"}],
      "default": null,
      "doc": "Optional SHAP values"
    }
  ]
}

Python Producer with Schema Registry

from confluent_kafka import SerializingProducer
from confluent_kafka.schema_registry import SchemaRegistryClient
from confluent_kafka.schema_registry.avro import AvroSerializer
from dataclasses import dataclass
from typing import Optional, Dict
import uuid
import time

@dataclass
class PredictionEvent:
    applicant_id: str
    default_probability: float
    risk_category: str
    model_version: str
    feature_contributions: Optional[Dict[str, float]] = None
    
    def to_dict(self) -> dict:
        return {
            "event_id": str(uuid.uuid4()),
            "applicant_id": self.applicant_id,
            "default_probability": self.default_probability,
            "risk_category": self.risk_category,
            "model_version": self.model_version,
            "timestamp": int(time.time() * 1000),
            "feature_contributions": self.feature_contributions
        }


class PredictionEventProducer:
    """Produce prediction events with schema validation."""
    
    def __init__(
        self,
        bootstrap_servers: str,
        schema_registry_url: str,
        topic: str
    ):
        self.topic = topic
        
        # Schema Registry client
        schema_registry = SchemaRegistryClient({
            "url": schema_registry_url
        })
        
        # Load Avro schema
        with open("schemas/prediction_event.avsc") as f:
            schema_str = f.read()
        
        # Serializer with schema validation
        avro_serializer = AvroSerializer(
            schema_registry,
            schema_str,
            lambda event, ctx: event.to_dict()
        )
        
        # Producer config
        self.producer = SerializingProducer({
            "bootstrap.servers": bootstrap_servers,
            "value.serializer": avro_serializer,
            "acks": "all"
        })
    
    def produce(self, event: PredictionEvent) -> None:
        """Produce event to Kafka with schema validation."""
        self.producer.produce(
            topic=self.topic,
            value=event,
            key=event.applicant_id,
            on_delivery=self._delivery_report
        )
        self.producer.flush()
    
    def _delivery_report(self, err, msg):
        if err:
            print(f"Failed to deliver: {err}")
        else:
            print(f"Delivered to {msg.topic()}[{msg.partition()}]")


# Usage
producer = PredictionEventProducer(
    bootstrap_servers="kafka:9092",
    schema_registry_url="http://schema-registry:8081",
    topic="predictions"
)

event = PredictionEvent(
    applicant_id="APP12345678",
    default_probability=0.15,
    risk_category="MEDIUM",
    model_version="2.1.0"
)

producer.produce(event)

32.5.8. Versioning and Breaking Changes

Semantic Versioning for ML

Version ChangeTypeExampleAction
MAJOR (v2.0.0)Breaking APIRemove input fieldNew endpoint URL
MINOR (v1.2.0)Backward compatibleAdd optional fieldDeploy in place
PATCH (v1.2.1)Bug fixFix memory leakDeploy in place

Non-Breaking Changes

✅ Safe to deploy in place:

  • Adding optional input fields
  • Adding new output fields
  • Improving model accuracy
  • Adding new endpoints
  • Relaxing validation (accepting more formats)

Breaking Changes

❌ Require new version:

  • Removing or renaming input fields
  • Changing output field types
  • Tightening validation
  • Changing probability distribution significantly
  • Removing endpoints

Migration Pattern

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import warnings

app = FastAPI()

# V1 - Deprecated
@app.post("/v1/predict")
async def predict_v1(request: CreditScoringInputV1):
    warnings.warn("v1 is deprecated, migrate to v2", DeprecationWarning)
    
    # Transform to v2 format
    v2_input = transform_v1_to_v2(request)
    
    # Use v2 logic
    result = await predict_v2(v2_input)
    
    # Transform back to v1 format
    return transform_v2_to_v1(result)


# V2 - Current
@app.post("/v2/predict")
async def predict_v2(request: CreditScoringInputV2):
    # ... implementation
    pass


# Deprecation header middleware
@app.middleware("http")
async def deprecation_header(request: Request, call_next):
    response = await call_next(request)
    
    if "/v1/" in request.url.path:
        response.headers["Deprecation"] = "true"
        response.headers["Sunset"] = "2024-06-01T00:00:00Z"
        response.headers["Link"] = '</v2/predict>; rel="successor-version"'
    
    return response

32.5.9. Summary Checklist

LayerWhat to DefineToolWhen to Check
SchemaTypes, constraintsPydanticEvery request
SemanticBusiness rulesGreat ExpectationsCI/CD
SLALatency, error ratek6, PrometheusContinuous
ConsumerCross-team contractsPactCI before deploy
EventsMessage formatSchema RegistryProduce time

Golden Rules

  1. Schema first: Define Pydantic/Protobuf before writing code
  2. Test semantics: Run Great Expectations on golden datasets
  3. Enforce SLAs: k6 load tests in CI/CD
  4. Consumer contracts: Pact verification before merge
  5. Version everything: Never break v1

[End of Section 32.5]