Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

16.2 Async & Batch Inference: Handling the Long Tail

16.2.1 Introduction: The Asynchronous Paradigm Shift

Real-time inference is a sprint. Asynchronous and batch inference are marathons. They optimize for total throughput and cost efficiency rather than instantaneous response time. This paradigm shift is critical for use cases where:

  1. Processing time exceeds HTTP timeout limits (video analysis, large document processing)
  2. Results aren’t needed immediately (nightly analytics, batch labeling for training data)
  3. Cost optimization is paramount (processing millions of records at the lowest $ per unit)

This chapter explores the architecture, implementation patterns, and operational strategies for async and batch inference across AWS, GCP, and Kubernetes.

The Synchronous Problem

HTTP request-response is brittle for long-running operations:

sequenceDiagram
    participant Client
    participant LoadBalancer
    participant Server
    
    Client->>LoadBalancer: POST /analyze-video
    LoadBalancer->>Server: Forward (timeout: 60s)
    
    Note over Server: Processing... 45 seconds
    
    Server->>LoadBalancer: Response
    LoadBalancer->>Client: 200 OK
    
    Note over Client,LoadBalancer: BUT if processing > 60s?
    
    Server--xLoadBalancer: Connection closed (timeout)
    Client--xServer: Error, retry
    Note right of Client: Duplicate processing!

Failure Modes:

  • Client timeout: Mobile app loses network connection after 30s
  • Load balancer timeout: ALB/nginx default 60s idle timeout
  • Server thread exhaustion: Blocking threads waiting for model
  • Retry storms: Client retries create duplicate requests

16.2.2 Asynchronous Inference Architecture

The core pattern: decouple request submission from result retrieval.

graph LR
    Client[Client]
    API[API Gateway]
    Queue[Message Queue<br/>SQS/Pub-Sub]
    Worker[GPU Worker]
    Storage[S3/GCS]
    DB[(Result DB)]
    
    Client-->|1. POST /submit|API
    API-->|2. Enqueue Job|Queue
    API-->|3. Return JobID|Client
    Queue-->|4. Pull|Worker
    Worker-->|5. Process|Worker
    Worker-->|6. Upload Results|Storage
    Worker-->|7. Update Status|DB
    Client-->|8. Poll /status/JobID|API
    API-->|9. Query|DB
    DB-->|10. Return|API
    API-->|11. Result or S3 URI|Client

Key Components:

  1. Message Queue: Durable, distributed queue (SQS, Pub/Sub, Kafka)
  2. Worker Pool: Stateless processors that consume jobs from the queue
  3. Result Storage: S3/GCS for large outputs (images, videos)
  4. Status Tracking: Database (DynamoDB, Firestore) for job metadata

Job Lifecycle

PENDING → RUNNING → COMPLETED
                  ↘ FAILED

State Machine:

from enum import Enum
from dataclasses import dataclass
from datetime import datetime

class JobStatus(Enum):
    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"

@dataclass
class Job:
    job_id: str
    status: JobStatus
    input_uri: str
    output_uri: str = None
    error_message: str = None
    created_at: datetime = None
    started_at: datetime = None
    completed_at: datetime = None
    
    def duration_seconds(self) -> float:
        if self.completed_at and self.started_at:
            return (self.completed_at - self.started_at).total_seconds()
        return 0.0

16.2.3 AWS SageMaker Async Inference

SageMaker Async is a fully managed async inference service that handles the queue, workers, auto-scaling, and storage.

Architecture

graph TD
    Client[Client SDK]
    Endpoint[SageMaker Endpoint]
    InternalQueue[Internal SQS Queue<br/>Managed by SageMaker]
    EC2[ml.g4dn.xlarge Instances]
    S3Input[S3 Input Bucket]
    S3Output[S3 Output Bucket]
    SNS[SNS Topic]
    
    Client-->|InvokeEndpointAsync|Endpoint
    Endpoint-->|Enqueue|InternalQueue
    InternalQueue-->|Pull|EC2
    Client-->|1. Upload|S3Input
    EC2-->|2. Download|S3Input
    EC2-->|3. Process|EC2
    EC2-->|4. Upload|S3Output
    EC2-->|5. Notify|SNS
    SNS-->|6. Webhook/Email|Client

Implementation

1. Infrastructure (Terraform):

# variables.tf
variable "model_name" {
  type    = string
  default = "video-classifier"
}

# s3.tf
resource "aws_s3_bucket" "async_input" {
  bucket = "${var.model_name}-async-input"
}

resource "aws_s3_bucket" "async_output" {
  bucket = "${var.model_name}-async-output"
}

# sns.tf
resource "aws_sns_topic" "success" {
  name = "${var.model_name}-success"
}

resource "aws_sns_topic" "error" {
  name = "${var.model_name}-error"
}

resource "aws_sns_topic_subscription" "success_webhook" {
  topic_arn = aws_sns_topic.success.arn
  protocol  = "https"
  endpoint  = "https://api.myapp.com/webhooks/sagemaker-success"
}

# sagemaker.tf
resource "aws_sagemaker_model" "model" {
  name               = var.model_name
  execution_role_arn = aws_iam_role.sagemaker_role.arn
  
  primary_container {
    image          = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.0-gpu-py310"
    model_data_url = "s3://my-models/${var.model_name}/model.tar.gz"
    
    environment = {
      "SAGEMAKER_PROGRAM" = "inference.py"
    }
  }
}

resource "aws_sagemaker_endpoint_configuration" "async_config" {
  name = "${var.model_name}-async-config"
  
  # Async-specific configuration
  async_inference_config {
    output_config {
      s3_output_path = "s3://${aws_s3_bucket.async_output.id}/"
      
      notification_config {
        success_topic = aws_sns_topic.success.arn
        error_topic   = aws_sns_topic.error.arn
      }
    }
    
    client_config {
      max_concurrent_invocations_per_instance = 4
    }
  }
  
  production_variants {
    variant_name           = "AllTraffic"
    model_name             = aws_sagemaker_model.model.name
    initial_instance_count = 1
    instance_type          = "ml.g4dn.xlarge"
  }
}

resource "aws_sagemaker_endpoint" "async_endpoint" {
  name                 = "${var.model_name}-async"
  endpoint_config_name = aws_sagemaker_endpoint_configuration.async_config.name
}

# Auto-scaling
resource "aws_appautoscaling_target" "async_scaling" {
  max_capacity       = 10
  min_capacity       = 0  # Scale to zero!
  resource_id        = "endpoint/${aws_sagemaker_endpoint.async_endpoint.name}/variant/AllTraffic"
  scalable_dimension = "sagemaker:variant:DesiredInstanceCount"
  service_namespace  = "sagemaker"
}

resource "aws_appautoscaling_policy" "async_scaling_policy" {
  name               = "${var.model_name}-scaling"
  policy_type        = "TargetTrackingScaling"
  resource_id        = aws_appautoscaling_target.async_scaling.resource_id
  scalable_dimension = aws_appautoscaling_target.async_scaling.scalable_dimension
  service_namespace  = aws_appautoscaling_target.async_scaling.service_namespace
  
  target_tracking_scaling_policy_configuration {
    customized_metric_specification {
      metric_name = "ApproximateBacklogSizePerInstance"
      namespace   = "AWS/SageMaker"
      statistic   = "Average"
      
      dimensions {
        name  = "EndpointName"
        value = aws_sagemaker_endpoint.async_endpoint.name
      }
    }
    
    target_value       = 5.0  # Target 5 jobs per instance
    scale_in_cooldown  = 600  # Wait 10 min before scaling down
    scale_out_cooldown = 60   # Scale up quickly
  }
}

2. Client Code (Python):

import boto3
import json
from datetime import datetime

s3_client = boto3.client('s3')
sagemaker_runtime = boto3.client('sagemaker-runtime')

def submit_async_job(video_path: str, endpoint_name: str) -> str:
    """
    Submit an async inference job.
    
    Returns:
        output_location: S3 URI where results will be written
    """
    # Upload input to S3
    input_bucket = f"{endpoint_name}-async-input"
    input_key = f"inputs/{datetime.now().isoformat()}/video.mp4"
    
    s3_client.upload_file(
        Filename=video_path,
        Bucket=input_bucket,
        Key=input_key
    )
    
    input_location = f"s3://{input_bucket}/{input_key}"
    
    # Invoke async endpoint
    response = sagemaker_runtime.invoke_endpoint_async(
        EndpointName=endpoint_name,
        InputLocation=input_location,
        InferenceId=f"job-{datetime.now().timestamp()}"  # Optional correlation ID
    )
    
    output_location = response['OutputLocation']
    print(f"Job submitted. Results will be at: {output_location}")
    
    return output_location

def check_job_status(output_location: str) -> dict:
    """
    Check if the job is complete.
    
    Returns:
        {"status": "pending|completed|failed", "result": <data>}
    """
    # Parse S3 URI
    parts = output_location.replace("s3://", "").split("/", 1)
    bucket = parts[0]
    key = parts[1]
    
    try:
        obj = s3_client.get_object(Bucket=bucket, Key=key)
        result = json.loads(obj['Body'].read())
        return {"status": "completed", "result": result}
    except s3_client.exceptions.NoSuchKey:
        # Check for error file (SageMaker writes .error if failed)
        error_key = key.replace(".out", ".error")
        try:
            obj = s3_client.get_object(Bucket=bucket, Key=error_key)
            error = obj['Body'].read().decode('utf-8')
            return {"status": "failed", "error": error}
        except s3_client.exceptions.NoSuchKey:
            return {"status": "pending"}

# Usage
output_loc = submit_async_job("my_video.mp4", "video-classifier-async")

# Poll for result
import time
while True:
    status = check_job_status(output_loc)
    if status['status'] != 'pending':
        print(status)
        break
    time.sleep(5)

The Scale-to-Zero Advantage

With async inference:

  • Idle periods cost $0 (instances scale to 0)
  • Burst capacity (scale from 0 to 10 instances in minutes)
  • Pay only for processing time + small queue hosting cost

Cost Comparison (Sporadic Workload: 100 jobs/day, 5 minutes each):

DeploymentDaily CostMonthly Cost
Real-time (1x ml.g4dn.xlarge 24/7)$17.67$530
Async (scale-to-zero)100 × 5min × $0.736/hour = $6.13$184

Savings: 65%


16.2.4 Batch Transform: Offline Inference at Scale

Batch Transform is for “offline” workloads: label 10 million images, score all customers for churn risk, etc.

SageMaker Batch Transform

Key Features:

  • Massive Parallelism: Spin up 100 instances simultaneously
  • Automatic Data Splitting: SageMaker splits large files (CSV, JSON Lines) automatically
  • No Server Management: Instances start, process, then terminate

Workflow:

graph LR
    Input[S3 Input<br/>input.csv<br/>10 GB]
    SM[SageMaker<br/>Batch Transform]
    Workers[20x ml.p3.2xlarge<br/>Parallel Processing]
    Output[S3 Output<br/>output.csv<br/>Predictions]
    
    Input-->|Split into chunks|SM
    SM-->|Distribute|Workers
    Workers-->|Process|Workers
    Workers-->|Aggregate|Output

Implementation

Python SDK:

from sagemaker.pytorch import PyTorchModel
from sagemaker.transformer import Transformer

# Define model
model = PyTorchModel(
    model_data="s3://my-models/image-classifier/model.tar.gz",
    role=sagemaker_role,
    framework_version="2.0",
    py_version="py310",
    entry_point="inference.py"
)

# Create transformer
transformer = model.transformer(
    instance_count=20,  # Massive parallelism
    instance_type="ml.p3.2xlarge",
    strategy="MultiRecord",  # Process multiple records per request
    max_payload=10,  # Max 10 MB per request
    max_concurrent_transforms=8,  # Concurrent requests per instance
    output_path="s3://my-bucket/batch-output/",
    assemble_with="Line",  # Output format
    accept="application/json"
)

# Start batch job
transformer.transform(
    data="s3://my-bucket/batch-input/images.csv",
    data_type="S3Prefix",
    content_type="text/csv",
    split_type="Line",  # Split by line
    input_filter="$[1:]",  # Skip CSV header
    join_source="Input"  # Append prediction to input
)

# Wait for completion
transformer.wait()

# Results are now in s3://my-bucket/batch-output/

Advanced: The join_source Pattern

For ML validation, you often need: Input | Actual | Predicted

Input CSV:

customer_id,feature_1,feature_2,actual_churn
1001,25,50000,0
1002,45,75000,1

With join_source="Input", output becomes:

1001,25,50000,0,0.12
1002,45,75000,1,0.87

The prediction is appended to each line, preserving the input for validation scripts.

Handling Failures

Batch Transform writes failed records to a .out.failed file.

import boto3

s3 = boto3.resource('s3')
bucket = s3.Bucket('my-bucket')

# Check for failed records
for obj in bucket.objects.filter(Prefix='batch-output/'):
    if obj.key.endswith('.failed'):
        print(f"Found failures: {obj.key}")
        
        # Download and inspect
        obj.download_file('/tmp/failed.json')

Retry Strategy:

  1. Extract failed record IDs
  2. Create a new input file with only failed records
  3. Re-run Batch Transform with reduced instance_count (failures are often rate-limit issues)

16.2.5 Google Cloud Dataflow: ML Pipelines at Scale

Dataflow (Apache Beam) is Google’s alternative to Batch Transform. It’s more flexible but requires more code.

Apache Beam Primer

Beam is a unified stream and batch processing framework.

Core Concepts:

  • PCollection: An immutable distributed dataset
  • PTransform: A processing step (Map, Filter, GroupBy, etc.)
  • Pipeline: A DAG of PTransforms

RunInference API

Beam’s RunInference transform handles model loading, batching, and distribution.

Complete Example:

import apache_beam as beam
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
from apache_beam.ml.inference.base import RunInference, PredictionResult
import torch
import numpy as np

class ImagePreprocessor(beam.DoFn):
    def process(self, element):
        """
        element: {"image_uri": "gs://bucket/image.jpg", "id": "123"}
        """
        from PIL import Image
        from torchvision import transforms
        import io
        
        # Download image
        from google.cloud import storage
        client = storage.Client()
        bucket = client.bucket(element['image_uri'].split('/')[2])
        blob = bucket.blob('/'.join(element['image_uri'].split('/')[3:]))
        image_bytes = blob.download_as_bytes()
        
        # Preprocess
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        tensor = transform(image)
        yield {"id": element['id'], "tensor": tensor.numpy()}

class Postprocessor(beam.DoFn):
    def process(self, prediction_result: PredictionResult):
        """
        prediction_result: PredictionResult(example, inference)
        """
        class_idx = prediction_result.inference.argmax().item()
        confidence = prediction_result.inference.max().item()
        
        yield {
            "id": prediction_result.example['id'],
            "class": class_idx,
            "confidence": float(confidence)
        }

def run_pipeline():
    # Model handler
    model_handler = PytorchModelHandlerTensor(
        state_dict_path="gs://my-bucket/models/resnet50.pth",
        model_class=torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False).eval(),
        device="cuda"  # Will use GPU on Dataflow workers
    )
    
    pipeline_options = beam.options.pipeline_options.PipelineOptions(
        runner='DataflowRunner',
        project='my-gcp-project',
        region='us-central1',
        temp_location='gs://my-bucket/temp',
        staging_location='gs://my-bucket/staging',
        
        # Worker configuration
        machine_type='n1-standard-8',
        disk_size_gb=100,
        num_workers=10,
        max_num_workers=50,
        
        # GPU configuration
        worker_accelerator='type:nvidia-tesla-t4;count:1;install-nvidia-driver',
        
        # Dataflow specific
        dataflow_service_options=['worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver']
    )
    
    with beam.Pipeline(options=pipeline_options) as p:
        (
            p
            | "Read Input" >> beam.io.ReadFromText("gs://my-bucket/input.jsonl")
            | "Parse JSON" >> beam.Map(lambda x: json.loads(x))
            | "Preprocess" >> beam.ParDo(ImagePreprocessor())
            | "Extract Tensor" >> beam.Map(lambda x: (x['id'], x['tensor']))
            | "Run Inference" >> RunInference(model_handler)
            | "Postprocess" >> beam.ParDo(Postprocessor())
            | "Format Output" >> beam.Map(lambda x: json.dumps(x))
            | "Write Output" >> beam.io.WriteToText("gs://my-bucket/output.jsonl")
        )

if __name__ == "__main__":
    run_pipeline()

Execution:

python beam_pipeline.py \
  --runner DataflowRunner \
  --project my-gcp-project \
  --region us-central1 \
  --temp_location gs://my-bucket/temp

Auto-Scaling

Dataflow automatically scales workers based on backlog.

Monitoring:

from google.cloud import monitoring_v3

client = monitoring_v3.MetricServiceClient()

query = f'''
fetch dataflow_job
| metric 'dataflow.googleapis.com/job/current_num_vcpus'
| filter resource.job_name == "my-inference-job"
| align rate(1m)
'''

results = client.query_time_series(request={"name": f"projects/{PROJECT_ID}", "query": query})

16.2.6 DIY Async on Kubernetes

For ultimate control, build async inference on Kubernetes.

Architecture

graph TD
    API[FastAPI Service]
    Redis[(Redis Queue)]
    Worker1[Worker Pod 1<br/>GPU]
    Worker2[Worker Pod 2<br/>GPU]
    Worker3[Worker Pod 3<br/>GPU]
    PG[(PostgreSQL<br/>Job Status)]
    S3[S3/Minio<br/>Results]
    
    API-->|Enqueue Job|Redis
    Redis-->|Pop|Worker1
    Redis-->|Pop|Worker2
    Redis-->|Pop|Worker3
    Worker1-->|Update Status|PG
    Worker1-->|Upload Result|S3
    API-->|Query Status|PG

Tech Stack:

  • Queue: Redis (with persistence) or RabbitMQ
  • Workers: Kubernetes Job or Deployment
  • Status DB: PostgreSQL or DynamoDB
  • Storage: MinIO (self-hosted S3) or GCS

Implementation

1. API Server (FastAPI):

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import redis
import psycopg2
import uuid
from datetime import datetime

app = FastAPI()
redis_client = redis.Redis(host='redis-service', port=6379)
db_conn = psycopg2.connect("dbname=jobs user=postgres host=postgres-service")

class JobRequest(BaseModel):
    video_url: str

@app.post("/jobs")
def create_job(request: JobRequest):
    job_id = str(uuid.uuid4())
    
    # Insert into DB
    with db_conn.cursor() as cur:
        cur.execute(
            "INSERT INTO jobs (id, status, input_url, created_at) VALUES (%s, %s, %s, %s)",
            (job_id, "pending", request.video_url, datetime.now())
        )
    db_conn.commit()
    
    # Enqueue
    redis_client.rpush("job_queue", job_id)
    
    return {"job_id": job_id, "status": "pending"}

@app.get("/jobs/{job_id}")
def get_job(job_id: str):
    with db_conn.cursor() as cur:
        cur.execute("SELECT status, output_url, error FROM jobs WHERE id = %s", (job_id,))
        row = cur.fetchone()
        
        if not row:
            raise HTTPException(status_code=404, detail="Job not found")
        
        return {
            "job_id": job_id,
            "status": row[0],
            "output_url": row[1],
            "error": row[2]
        }

2. Worker (Python):

import redis
import psycopg2
import boto3
from datetime import datetime

redis_client = redis.Redis(host='redis-service', port=6379)
db_conn = psycopg2.connect("dbname=jobs user=postgres host=postgres-service")
s3_client = boto3.client('s3', endpoint_url='http://minio-service:9000')

def update_job_status(job_id, status, output_url=None, error=None):
    with db_conn.cursor() as cur:
        cur.execute(
            """
            UPDATE jobs 
            SET status = %s, output_url = %s, error = %s, updated_at = %s 
            WHERE id = %s
            """,
            (status, output_url, error, datetime.now(), job_id)
        )
    db_conn.commit()

def process_job(job_id):
    # Fetch job details
    with db_conn.cursor() as cur:
        cur.execute("SELECT input_url FROM jobs WHERE id = %s", (job_id,))
        input_url = cur.fetchone()[0]
    
    try:
        update_job_status(job_id, "running")
        
        # Download input
        video_path = f"/tmp/{job_id}.mp4"
        s3_client.download_file("input-bucket", input_url, video_path)
        
        # Run inference (placeholder)
        result = run_model(video_path)
        
        # Upload output
        output_key = f"output/{job_id}/result.json"
        s3_client.put_object(
            Bucket="output-bucket",
            Key=output_key,
            Body=json.dumps(result)
        )
        
        update_job_status(job_id, "completed", output_url=f"s3://output-bucket/{output_key}")
        
    except Exception as e:
        update_job_status(job_id, "failed", error=str(e))

# Main loop
while True:
    # Blocking pop (timeout 60s)
    job_data = redis_client.blpop("job_queue", timeout=60)
    
    if job_data:
        job_id = job_data[1].decode('utf-8')
        process_job(job_id)

3. Kubernetes Deployment:

# worker-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: inference-worker
spec:
  replicas: 3
  selector:
    matchLabels:
      app: inference-worker
  template:
    metadata:
      labels:
        app: inference-worker
    spec:
      containers:
        - name: worker
          image: gcr.io/my-project/inference-worker:v1
          resources:
            limits:
              nvidia.com/gpu: "1"
          env:
            - name: REDIS_HOST
              value: "redis-service"
            - name: DB_HOST
              value: "postgres-service"

Horizontal Pod Autoscaler (HPA)

Scale workers based on queue depth.

apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: worker-hpa
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: inference-worker
  minReplicas: 1
  maxReplicas: 20
  metrics:
    - type: External
      external:
        metric:
          name: redis_queue_depth
          selector:
            matchLabels:
              queue: "job_queue"
        target:
          type: AverageValue
          averageValue: "5"

This requires a custom metrics adapter that queries Redis and exposes the queue depth as a Kubernetes metric.


16.2.7 Comparison Matrix

FeatureSageMaker AsyncSageMaker BatchDataflowDIY Kubernetes
LatencySeconds to minutesMinutes to hoursMinutes to hoursConfigurable
Scale-to-ZeroYesN/A (ephemeral jobs)N/AManual
Max Parallelism10-100 instances1000+ instances10,000+ workersLimited by cluster
Cost (per hour)Instance costInstance costvCPU + memoryInstance cost
Data SplittingNoYes (automatic)Yes (manual)Manual
Best ForReal-time with burstsLarge batch jobsComplex ETL + MLFull control

16.2.8 Conclusion

Asynchronous and batch inference unlock cost optimization and scale beyond what real-time endpoints can achieve. The trade-off is latency, but for non-interactive workloads, this is acceptable.

Decision Framework:

  • User waiting for result → Real-time or Async (< 1 min)
  • Webhook/Email notification → Async (1-10 min)
  • Nightly batch → Batch Transform / Dataflow (hours)
  • Maximum control → DIY on Kubernetes

Master these patterns, and you’ll build systems that process billions of predictions at a fraction of the cost of real-time infrastructure.