16.2 Async & Batch Inference: Handling the Long Tail
16.2.1 Introduction: The Asynchronous Paradigm Shift
Real-time inference is a sprint. Asynchronous and batch inference are marathons. They optimize for total throughput and cost efficiency rather than instantaneous response time. This paradigm shift is critical for use cases where:
- Processing time exceeds HTTP timeout limits (video analysis, large document processing)
- Results aren’t needed immediately (nightly analytics, batch labeling for training data)
- Cost optimization is paramount (processing millions of records at the lowest $ per unit)
This chapter explores the architecture, implementation patterns, and operational strategies for async and batch inference across AWS, GCP, and Kubernetes.
The Synchronous Problem
HTTP request-response is brittle for long-running operations:
sequenceDiagram
participant Client
participant LoadBalancer
participant Server
Client->>LoadBalancer: POST /analyze-video
LoadBalancer->>Server: Forward (timeout: 60s)
Note over Server: Processing... 45 seconds
Server->>LoadBalancer: Response
LoadBalancer->>Client: 200 OK
Note over Client,LoadBalancer: BUT if processing > 60s?
Server--xLoadBalancer: Connection closed (timeout)
Client--xServer: Error, retry
Note right of Client: Duplicate processing!
Failure Modes:
- Client timeout: Mobile app loses network connection after 30s
- Load balancer timeout: ALB/nginx default 60s idle timeout
- Server thread exhaustion: Blocking threads waiting for model
- Retry storms: Client retries create duplicate requests
16.2.2 Asynchronous Inference Architecture
The core pattern: decouple request submission from result retrieval.
graph LR
Client[Client]
API[API Gateway]
Queue[Message Queue<br/>SQS/Pub-Sub]
Worker[GPU Worker]
Storage[S3/GCS]
DB[(Result DB)]
Client-->|1. POST /submit|API
API-->|2. Enqueue Job|Queue
API-->|3. Return JobID|Client
Queue-->|4. Pull|Worker
Worker-->|5. Process|Worker
Worker-->|6. Upload Results|Storage
Worker-->|7. Update Status|DB
Client-->|8. Poll /status/JobID|API
API-->|9. Query|DB
DB-->|10. Return|API
API-->|11. Result or S3 URI|Client
Key Components:
- Message Queue: Durable, distributed queue (SQS, Pub/Sub, Kafka)
- Worker Pool: Stateless processors that consume jobs from the queue
- Result Storage: S3/GCS for large outputs (images, videos)
- Status Tracking: Database (DynamoDB, Firestore) for job metadata
Job Lifecycle
PENDING → RUNNING → COMPLETED
↘ FAILED
State Machine:
from enum import Enum
from dataclasses import dataclass
from datetime import datetime
class JobStatus(Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class Job:
job_id: str
status: JobStatus
input_uri: str
output_uri: str = None
error_message: str = None
created_at: datetime = None
started_at: datetime = None
completed_at: datetime = None
def duration_seconds(self) -> float:
if self.completed_at and self.started_at:
return (self.completed_at - self.started_at).total_seconds()
return 0.0
16.2.3 AWS SageMaker Async Inference
SageMaker Async is a fully managed async inference service that handles the queue, workers, auto-scaling, and storage.
Architecture
graph TD
Client[Client SDK]
Endpoint[SageMaker Endpoint]
InternalQueue[Internal SQS Queue<br/>Managed by SageMaker]
EC2[ml.g4dn.xlarge Instances]
S3Input[S3 Input Bucket]
S3Output[S3 Output Bucket]
SNS[SNS Topic]
Client-->|InvokeEndpointAsync|Endpoint
Endpoint-->|Enqueue|InternalQueue
InternalQueue-->|Pull|EC2
Client-->|1. Upload|S3Input
EC2-->|2. Download|S3Input
EC2-->|3. Process|EC2
EC2-->|4. Upload|S3Output
EC2-->|5. Notify|SNS
SNS-->|6. Webhook/Email|Client
Implementation
1. Infrastructure (Terraform):
# variables.tf
variable "model_name" {
type = string
default = "video-classifier"
}
# s3.tf
resource "aws_s3_bucket" "async_input" {
bucket = "${var.model_name}-async-input"
}
resource "aws_s3_bucket" "async_output" {
bucket = "${var.model_name}-async-output"
}
# sns.tf
resource "aws_sns_topic" "success" {
name = "${var.model_name}-success"
}
resource "aws_sns_topic" "error" {
name = "${var.model_name}-error"
}
resource "aws_sns_topic_subscription" "success_webhook" {
topic_arn = aws_sns_topic.success.arn
protocol = "https"
endpoint = "https://api.myapp.com/webhooks/sagemaker-success"
}
# sagemaker.tf
resource "aws_sagemaker_model" "model" {
name = var.model_name
execution_role_arn = aws_iam_role.sagemaker_role.arn
primary_container {
image = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.0-gpu-py310"
model_data_url = "s3://my-models/${var.model_name}/model.tar.gz"
environment = {
"SAGEMAKER_PROGRAM" = "inference.py"
}
}
}
resource "aws_sagemaker_endpoint_configuration" "async_config" {
name = "${var.model_name}-async-config"
# Async-specific configuration
async_inference_config {
output_config {
s3_output_path = "s3://${aws_s3_bucket.async_output.id}/"
notification_config {
success_topic = aws_sns_topic.success.arn
error_topic = aws_sns_topic.error.arn
}
}
client_config {
max_concurrent_invocations_per_instance = 4
}
}
production_variants {
variant_name = "AllTraffic"
model_name = aws_sagemaker_model.model.name
initial_instance_count = 1
instance_type = "ml.g4dn.xlarge"
}
}
resource "aws_sagemaker_endpoint" "async_endpoint" {
name = "${var.model_name}-async"
endpoint_config_name = aws_sagemaker_endpoint_configuration.async_config.name
}
# Auto-scaling
resource "aws_appautoscaling_target" "async_scaling" {
max_capacity = 10
min_capacity = 0 # Scale to zero!
resource_id = "endpoint/${aws_sagemaker_endpoint.async_endpoint.name}/variant/AllTraffic"
scalable_dimension = "sagemaker:variant:DesiredInstanceCount"
service_namespace = "sagemaker"
}
resource "aws_appautoscaling_policy" "async_scaling_policy" {
name = "${var.model_name}-scaling"
policy_type = "TargetTrackingScaling"
resource_id = aws_appautoscaling_target.async_scaling.resource_id
scalable_dimension = aws_appautoscaling_target.async_scaling.scalable_dimension
service_namespace = aws_appautoscaling_target.async_scaling.service_namespace
target_tracking_scaling_policy_configuration {
customized_metric_specification {
metric_name = "ApproximateBacklogSizePerInstance"
namespace = "AWS/SageMaker"
statistic = "Average"
dimensions {
name = "EndpointName"
value = aws_sagemaker_endpoint.async_endpoint.name
}
}
target_value = 5.0 # Target 5 jobs per instance
scale_in_cooldown = 600 # Wait 10 min before scaling down
scale_out_cooldown = 60 # Scale up quickly
}
}
2. Client Code (Python):
import boto3
import json
from datetime import datetime
s3_client = boto3.client('s3')
sagemaker_runtime = boto3.client('sagemaker-runtime')
def submit_async_job(video_path: str, endpoint_name: str) -> str:
"""
Submit an async inference job.
Returns:
output_location: S3 URI where results will be written
"""
# Upload input to S3
input_bucket = f"{endpoint_name}-async-input"
input_key = f"inputs/{datetime.now().isoformat()}/video.mp4"
s3_client.upload_file(
Filename=video_path,
Bucket=input_bucket,
Key=input_key
)
input_location = f"s3://{input_bucket}/{input_key}"
# Invoke async endpoint
response = sagemaker_runtime.invoke_endpoint_async(
EndpointName=endpoint_name,
InputLocation=input_location,
InferenceId=f"job-{datetime.now().timestamp()}" # Optional correlation ID
)
output_location = response['OutputLocation']
print(f"Job submitted. Results will be at: {output_location}")
return output_location
def check_job_status(output_location: str) -> dict:
"""
Check if the job is complete.
Returns:
{"status": "pending|completed|failed", "result": <data>}
"""
# Parse S3 URI
parts = output_location.replace("s3://", "").split("/", 1)
bucket = parts[0]
key = parts[1]
try:
obj = s3_client.get_object(Bucket=bucket, Key=key)
result = json.loads(obj['Body'].read())
return {"status": "completed", "result": result}
except s3_client.exceptions.NoSuchKey:
# Check for error file (SageMaker writes .error if failed)
error_key = key.replace(".out", ".error")
try:
obj = s3_client.get_object(Bucket=bucket, Key=error_key)
error = obj['Body'].read().decode('utf-8')
return {"status": "failed", "error": error}
except s3_client.exceptions.NoSuchKey:
return {"status": "pending"}
# Usage
output_loc = submit_async_job("my_video.mp4", "video-classifier-async")
# Poll for result
import time
while True:
status = check_job_status(output_loc)
if status['status'] != 'pending':
print(status)
break
time.sleep(5)
The Scale-to-Zero Advantage
With async inference:
- Idle periods cost $0 (instances scale to 0)
- Burst capacity (scale from 0 to 10 instances in minutes)
- Pay only for processing time + small queue hosting cost
Cost Comparison (Sporadic Workload: 100 jobs/day, 5 minutes each):
| Deployment | Daily Cost | Monthly Cost |
|---|---|---|
| Real-time (1x ml.g4dn.xlarge 24/7) | $17.67 | $530 |
| Async (scale-to-zero) | 100 × 5min × $0.736/hour = $6.13 | $184 |
Savings: 65%
16.2.4 Batch Transform: Offline Inference at Scale
Batch Transform is for “offline” workloads: label 10 million images, score all customers for churn risk, etc.
SageMaker Batch Transform
Key Features:
- Massive Parallelism: Spin up 100 instances simultaneously
- Automatic Data Splitting: SageMaker splits large files (CSV, JSON Lines) automatically
- No Server Management: Instances start, process, then terminate
Workflow:
graph LR
Input[S3 Input<br/>input.csv<br/>10 GB]
SM[SageMaker<br/>Batch Transform]
Workers[20x ml.p3.2xlarge<br/>Parallel Processing]
Output[S3 Output<br/>output.csv<br/>Predictions]
Input-->|Split into chunks|SM
SM-->|Distribute|Workers
Workers-->|Process|Workers
Workers-->|Aggregate|Output
Implementation
Python SDK:
from sagemaker.pytorch import PyTorchModel
from sagemaker.transformer import Transformer
# Define model
model = PyTorchModel(
model_data="s3://my-models/image-classifier/model.tar.gz",
role=sagemaker_role,
framework_version="2.0",
py_version="py310",
entry_point="inference.py"
)
# Create transformer
transformer = model.transformer(
instance_count=20, # Massive parallelism
instance_type="ml.p3.2xlarge",
strategy="MultiRecord", # Process multiple records per request
max_payload=10, # Max 10 MB per request
max_concurrent_transforms=8, # Concurrent requests per instance
output_path="s3://my-bucket/batch-output/",
assemble_with="Line", # Output format
accept="application/json"
)
# Start batch job
transformer.transform(
data="s3://my-bucket/batch-input/images.csv",
data_type="S3Prefix",
content_type="text/csv",
split_type="Line", # Split by line
input_filter="$[1:]", # Skip CSV header
join_source="Input" # Append prediction to input
)
# Wait for completion
transformer.wait()
# Results are now in s3://my-bucket/batch-output/
Advanced: The join_source Pattern
For ML validation, you often need: Input | Actual | Predicted
Input CSV:
customer_id,feature_1,feature_2,actual_churn
1001,25,50000,0
1002,45,75000,1
With join_source="Input", output becomes:
1001,25,50000,0,0.12
1002,45,75000,1,0.87
The prediction is appended to each line, preserving the input for validation scripts.
Handling Failures
Batch Transform writes failed records to a .out.failed file.
import boto3
s3 = boto3.resource('s3')
bucket = s3.Bucket('my-bucket')
# Check for failed records
for obj in bucket.objects.filter(Prefix='batch-output/'):
if obj.key.endswith('.failed'):
print(f"Found failures: {obj.key}")
# Download and inspect
obj.download_file('/tmp/failed.json')
Retry Strategy:
- Extract failed record IDs
- Create a new input file with only failed records
- Re-run Batch Transform with reduced
instance_count(failures are often rate-limit issues)
16.2.5 Google Cloud Dataflow: ML Pipelines at Scale
Dataflow (Apache Beam) is Google’s alternative to Batch Transform. It’s more flexible but requires more code.
Apache Beam Primer
Beam is a unified stream and batch processing framework.
Core Concepts:
- PCollection: An immutable distributed dataset
- PTransform: A processing step (Map, Filter, GroupBy, etc.)
- Pipeline: A DAG of PTransforms
RunInference API
Beam’s RunInference transform handles model loading, batching, and distribution.
Complete Example:
import apache_beam as beam
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
from apache_beam.ml.inference.base import RunInference, PredictionResult
import torch
import numpy as np
class ImagePreprocessor(beam.DoFn):
def process(self, element):
"""
element: {"image_uri": "gs://bucket/image.jpg", "id": "123"}
"""
from PIL import Image
from torchvision import transforms
import io
# Download image
from google.cloud import storage
client = storage.Client()
bucket = client.bucket(element['image_uri'].split('/')[2])
blob = bucket.blob('/'.join(element['image_uri'].split('/')[3:]))
image_bytes = blob.download_as_bytes()
# Preprocess
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
tensor = transform(image)
yield {"id": element['id'], "tensor": tensor.numpy()}
class Postprocessor(beam.DoFn):
def process(self, prediction_result: PredictionResult):
"""
prediction_result: PredictionResult(example, inference)
"""
class_idx = prediction_result.inference.argmax().item()
confidence = prediction_result.inference.max().item()
yield {
"id": prediction_result.example['id'],
"class": class_idx,
"confidence": float(confidence)
}
def run_pipeline():
# Model handler
model_handler = PytorchModelHandlerTensor(
state_dict_path="gs://my-bucket/models/resnet50.pth",
model_class=torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False).eval(),
device="cuda" # Will use GPU on Dataflow workers
)
pipeline_options = beam.options.pipeline_options.PipelineOptions(
runner='DataflowRunner',
project='my-gcp-project',
region='us-central1',
temp_location='gs://my-bucket/temp',
staging_location='gs://my-bucket/staging',
# Worker configuration
machine_type='n1-standard-8',
disk_size_gb=100,
num_workers=10,
max_num_workers=50,
# GPU configuration
worker_accelerator='type:nvidia-tesla-t4;count:1;install-nvidia-driver',
# Dataflow specific
dataflow_service_options=['worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver']
)
with beam.Pipeline(options=pipeline_options) as p:
(
p
| "Read Input" >> beam.io.ReadFromText("gs://my-bucket/input.jsonl")
| "Parse JSON" >> beam.Map(lambda x: json.loads(x))
| "Preprocess" >> beam.ParDo(ImagePreprocessor())
| "Extract Tensor" >> beam.Map(lambda x: (x['id'], x['tensor']))
| "Run Inference" >> RunInference(model_handler)
| "Postprocess" >> beam.ParDo(Postprocessor())
| "Format Output" >> beam.Map(lambda x: json.dumps(x))
| "Write Output" >> beam.io.WriteToText("gs://my-bucket/output.jsonl")
)
if __name__ == "__main__":
run_pipeline()
Execution:
python beam_pipeline.py \
--runner DataflowRunner \
--project my-gcp-project \
--region us-central1 \
--temp_location gs://my-bucket/temp
Auto-Scaling
Dataflow automatically scales workers based on backlog.
Monitoring:
from google.cloud import monitoring_v3
client = monitoring_v3.MetricServiceClient()
query = f'''
fetch dataflow_job
| metric 'dataflow.googleapis.com/job/current_num_vcpus'
| filter resource.job_name == "my-inference-job"
| align rate(1m)
'''
results = client.query_time_series(request={"name": f"projects/{PROJECT_ID}", "query": query})
16.2.6 DIY Async on Kubernetes
For ultimate control, build async inference on Kubernetes.
Architecture
graph TD
API[FastAPI Service]
Redis[(Redis Queue)]
Worker1[Worker Pod 1<br/>GPU]
Worker2[Worker Pod 2<br/>GPU]
Worker3[Worker Pod 3<br/>GPU]
PG[(PostgreSQL<br/>Job Status)]
S3[S3/Minio<br/>Results]
API-->|Enqueue Job|Redis
Redis-->|Pop|Worker1
Redis-->|Pop|Worker2
Redis-->|Pop|Worker3
Worker1-->|Update Status|PG
Worker1-->|Upload Result|S3
API-->|Query Status|PG
Tech Stack:
- Queue: Redis (with persistence) or RabbitMQ
- Workers: Kubernetes Job or Deployment
- Status DB: PostgreSQL or DynamoDB
- Storage: MinIO (self-hosted S3) or GCS
Implementation
1. API Server (FastAPI):
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import redis
import psycopg2
import uuid
from datetime import datetime
app = FastAPI()
redis_client = redis.Redis(host='redis-service', port=6379)
db_conn = psycopg2.connect("dbname=jobs user=postgres host=postgres-service")
class JobRequest(BaseModel):
video_url: str
@app.post("/jobs")
def create_job(request: JobRequest):
job_id = str(uuid.uuid4())
# Insert into DB
with db_conn.cursor() as cur:
cur.execute(
"INSERT INTO jobs (id, status, input_url, created_at) VALUES (%s, %s, %s, %s)",
(job_id, "pending", request.video_url, datetime.now())
)
db_conn.commit()
# Enqueue
redis_client.rpush("job_queue", job_id)
return {"job_id": job_id, "status": "pending"}
@app.get("/jobs/{job_id}")
def get_job(job_id: str):
with db_conn.cursor() as cur:
cur.execute("SELECT status, output_url, error FROM jobs WHERE id = %s", (job_id,))
row = cur.fetchone()
if not row:
raise HTTPException(status_code=404, detail="Job not found")
return {
"job_id": job_id,
"status": row[0],
"output_url": row[1],
"error": row[2]
}
2. Worker (Python):
import redis
import psycopg2
import boto3
from datetime import datetime
redis_client = redis.Redis(host='redis-service', port=6379)
db_conn = psycopg2.connect("dbname=jobs user=postgres host=postgres-service")
s3_client = boto3.client('s3', endpoint_url='http://minio-service:9000')
def update_job_status(job_id, status, output_url=None, error=None):
with db_conn.cursor() as cur:
cur.execute(
"""
UPDATE jobs
SET status = %s, output_url = %s, error = %s, updated_at = %s
WHERE id = %s
""",
(status, output_url, error, datetime.now(), job_id)
)
db_conn.commit()
def process_job(job_id):
# Fetch job details
with db_conn.cursor() as cur:
cur.execute("SELECT input_url FROM jobs WHERE id = %s", (job_id,))
input_url = cur.fetchone()[0]
try:
update_job_status(job_id, "running")
# Download input
video_path = f"/tmp/{job_id}.mp4"
s3_client.download_file("input-bucket", input_url, video_path)
# Run inference (placeholder)
result = run_model(video_path)
# Upload output
output_key = f"output/{job_id}/result.json"
s3_client.put_object(
Bucket="output-bucket",
Key=output_key,
Body=json.dumps(result)
)
update_job_status(job_id, "completed", output_url=f"s3://output-bucket/{output_key}")
except Exception as e:
update_job_status(job_id, "failed", error=str(e))
# Main loop
while True:
# Blocking pop (timeout 60s)
job_data = redis_client.blpop("job_queue", timeout=60)
if job_data:
job_id = job_data[1].decode('utf-8')
process_job(job_id)
3. Kubernetes Deployment:
# worker-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: inference-worker
spec:
replicas: 3
selector:
matchLabels:
app: inference-worker
template:
metadata:
labels:
app: inference-worker
spec:
containers:
- name: worker
image: gcr.io/my-project/inference-worker:v1
resources:
limits:
nvidia.com/gpu: "1"
env:
- name: REDIS_HOST
value: "redis-service"
- name: DB_HOST
value: "postgres-service"
Horizontal Pod Autoscaler (HPA)
Scale workers based on queue depth.
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: worker-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: inference-worker
minReplicas: 1
maxReplicas: 20
metrics:
- type: External
external:
metric:
name: redis_queue_depth
selector:
matchLabels:
queue: "job_queue"
target:
type: AverageValue
averageValue: "5"
This requires a custom metrics adapter that queries Redis and exposes the queue depth as a Kubernetes metric.
16.2.7 Comparison Matrix
| Feature | SageMaker Async | SageMaker Batch | Dataflow | DIY Kubernetes |
|---|---|---|---|---|
| Latency | Seconds to minutes | Minutes to hours | Minutes to hours | Configurable |
| Scale-to-Zero | Yes | N/A (ephemeral jobs) | N/A | Manual |
| Max Parallelism | 10-100 instances | 1000+ instances | 10,000+ workers | Limited by cluster |
| Cost (per hour) | Instance cost | Instance cost | vCPU + memory | Instance cost |
| Data Splitting | No | Yes (automatic) | Yes (manual) | Manual |
| Best For | Real-time with bursts | Large batch jobs | Complex ETL + ML | Full control |
16.2.8 Conclusion
Asynchronous and batch inference unlock cost optimization and scale beyond what real-time endpoints can achieve. The trade-off is latency, but for non-interactive workloads, this is acceptable.
Decision Framework:
- User waiting for result → Real-time or Async (< 1 min)
- Webhook/Email notification → Async (1-10 min)
- Nightly batch → Batch Transform / Dataflow (hours)
- Maximum control → DIY on Kubernetes
Master these patterns, and you’ll build systems that process billions of predictions at a fraction of the cost of real-time infrastructure.