15.3. Fault Tolerance: The Art of Crash-Proofing
“In a distributed system, a failure is when a computer you didn’t even know existed renders your own computer unusable.” — Leslie Lamport
If you are training a model on a single GPU, a crash is an annoyance. You restart the script, maybe lose an hour of progress.
If you are training a 70B parameter LLM on 512 H100 GPUs for three months, a crash is a statistical certainty. At that scale, hardware failure is not an exception; it is the steady state.
- Memory Errors: Cosmic rays flip bits in HBM3 memory (ECC errors).
- Network Flaps: A single optical transceiver in a Top-of-Rack switch degrades, causing packet loss that times out the NCCL ring.
- Preemption: The cloud provider reclaims your Spot capacity because a higher-paying customer just spun up a cluster.
- Software Bugs: A gradient explosion produces a
NaNwhich propagates through theAllReduceoperation, poisoning the weights of every GPU in the cluster instantly.
Without a robust fault tolerance strategy, you will never finish training. You will be stuck in a “Sysiphus Loop,” rolling the rock up the hill only to have a node fail at 98%, forcing a restart from zero.
This section details the architecture of resilience: how to checkpoint state effectively, how to handle the ruthless economics of Spot instances, and how to build self-healing clusters on AWS and GCP.
9.3.1. The Thermodynamics of Failure
To architect for failure, we must first quantify it. The probability of a successful training run drops exponentially with the number of nodes.
$$ P(Success) = (1 - p_{daily_fail})^{N_{nodes} \times D_{days}} $$
Let’s assume a single GPU node has a Mean Time Between Failures (MTBF) that implies a 0.1% chance of failing on any given day ($p = 0.001$). This includes hardware issues, driver crashes, and maintenance events.
- Single Node (1 GPU) running for 30 days: $$ 0.999^{30} \approx 97% \text{ chance of success without interruption.} $$
- Cluster (1,000 Nodes) running for 30 days: $$ 0.999^{30,000} \approx 0.00000000000009% \text{ chance of success.} $$
It is mathematically impossible to train large models without interruptions. Therefore, the training system must be viewed not as a continuous process, but as a series of discrete, recoverable segments.
The Cost of Checkpointing (The Tax)
Fault tolerance is not free. It is a trade-off between Compute Time (lost progress after a crash) and I/O Overhead (time spent pausing training to write to disk).
If you checkpoint too often, you waste 20% of your GPU cycles writing to S3. If you checkpoint too rarely, a crash destroys 24 hours of compute (worth perhaps $50,000).
Young’s Approximation for Optimal Checkpoint Interval: $$ T_{opt} = \sqrt{2 \times T_{checkpoint} \times T_{MTBF}} $$
Where:
- $T_{opt}$: The optimal time between checkpoints.
- $T_{checkpoint}$: Time it takes to write the checkpoint.
- $T_{MTBF}$: Mean Time Between Failures for the entire cluster.
Example:
- Cluster MTBF is 12 hours (on average, something breaks twice a day).
- Writing the checkpoint takes 5 minutes (0.083 hours).
- $T_{opt} \approx \sqrt{2 \times 0.083 \times 12} \approx 1.41 \text{ hours}$.
You should checkpoint every ~90 minutes.
9.3.2. Checkpointing Mechanics: What to Save and How
A common misconception is that you only need to save the model weights (model.state_dict()). For a training run to resume exactly where it left off, ensuring bit-for-bit reproducibility (or at least statistical continuity), you must save much more.
The Anatomy of a Checkpoint
For a Large Language Model using AdamW optimizer and Mixed Precision:
- Model Weights (FP16/BF16): The active parameters.
- Master Weights (FP32): The high-precision copy kept by the optimizer to accumulate small gradient updates.
- Optimizer State (FP32):
- Momentum (Beta1): Exponential moving average of gradients.
- Variance (Beta2): Exponential moving average of squared gradients.
- Step Count: For bias correction.
- Learning Rate Scheduler State: Current epoch, current LR, warmup counter.
- Data Loader State: Which epoch? Which batch index? Ideally, the RNG state of the shuffler.
- Random Number Generator (RNG) States: CUDA RNG, Python RNG, and Numpy RNG seeds for every rank.
The Storage Explosion: For a model with parameters $\Phi$, the checkpoint size is roughly $16 \times \Phi$ bytes.
- Model (BF16): 2 bytes
- Master Model (FP32): 4 bytes
- Optimizer Momentum (FP32): 4 bytes
- Optimizer Variance (FP32): 4 bytes
- Gradients (Transient, usually not saved but exist in VRAM): 2 bytes
A 70B parameter model requires: $$ 70 \times 10^9 \times 16 \text{ bytes} \approx 1.12 \text{ TB per checkpoint.} $$
If you retain the last 5 checkpoints for safety, you are storing 5.6 TB of data per run.
PyTorch Distributed Checkpointing (DCP)
In the old days (PyTorch < 1.13), rank 0 would gather all weights to CPU RAM and write a single .pt file. This causes OOM (Out of Memory) on rank 0 and network bottlenecks.
Modern training uses Sharded Checkpointing. Each GPU writes its own slice of the model and optimizer state directly to storage.
import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def save_checkpoint(model, optimizer, epoch, step, checkpoint_dir):
"""
Modern sharded checkpointing for FSDP models.
Each rank writes its own shard to storage in parallel.
"""
with FSDP.state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
):
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"step": step,
"rng_state": torch.get_rng_state(),
"cuda_rng_state": torch.cuda.get_rng_state(),
}
# Write sharded checkpoint
# Each rank writes to: checkpoint_dir/__0_0.distcp, __1_0.distcp, etc.
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(checkpoint_dir),
)
if dist.get_rank() == 0:
print(f"Checkpoint saved to {checkpoint_dir} at epoch {epoch}, step {step}")
def load_checkpoint(model, optimizer, checkpoint_dir):
"""
Load from sharded checkpoint.
Each rank loads only its shard.
"""
with FSDP.state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
):
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
# Load sharded checkpoint
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
)
model.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optimizer"])
# Restore RNG states
torch.set_rng_state(state_dict["rng_state"])
torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
return state_dict["epoch"], state_dict["step"]
The Storage Backend: S3 vs. EFS vs. FSx
Where do you write 1TB checkpoints? The choice matters for speed and cost.
1. Amazon S3 (Object Storage):
- Pros: Infinite scalability. Durability (99.999999999%). Cheap ($0.023/GB/month).
- Cons: High latency (~50-100ms per write). Eventual consistency issues for rapid updates.
- Use Case: Final checkpoints. Long-term archival.
- Throughput: With
s5cmdor parallel writes via PyTorch, you can achieve ~10GB/s writes from a multi-node cluster.
2. Amazon EFS (Elastic File System):
- Pros: POSIX-compatible. Can be mounted directly by all nodes. Lower latency than S3 (~1-5ms).
- Cons: Expensive ($0.30/GB/month). Performance depends on provisioned throughput.
- Use Case: Working checkpoints during active training.
- Architecture Note: Use EFS in “Max I/O” mode for distributed writes. Ensure the mount target is in the same Availability Zone as your cluster.
3. Amazon FSx for Lustre:
- Pros: Built for HPC. Backed by S3 but presents a high-speed POSIX filesystem (sub-millisecond latency). Can achieve 100s of GB/s throughput.
- Cons: Expensive ($0.14-0.60/GB/month depending on config). Requires explicit capacity planning.
- Use Case: The gold standard for massive-scale training. Used by AWS for training models like Olympus.
- Integration: FSx can be linked to an S3 bucket. Changes written to FSx are automatically synced to S3 in the background.
GCP Equivalents:
- Google Cloud Storage (GCS): Like S3. Use
gcsfusefor direct mounting (slower) orgcsfsPython library for programmatic access. - Filestore: Like EFS. NFS-based. Use Filestore High Scale for HPC.
- Parallelstore: Google’s new answer to FSx Lustre. Optimized for AI/ML workloads with tight integration to Vertex AI.
Terraform Example: FSx for Lustre Checkpoint Backend (AWS):
resource "aws_fsx_lustre_file_system" "checkpoint_fs" {
storage_capacity = 7200 # GB, scales in 1.2TB increments
subnet_ids = [aws_subnet.private.id]
deployment_type = "PERSISTENT_2" # High durability
per_unit_storage_throughput = 250 # MB/s per TB of storage
# Link to S3 bucket for automatic export
import_path = "s3://${aws_s3_bucket.checkpoints.bucket}/training-run-42/"
export_path = "s3://${aws_s3_bucket.checkpoints.bucket}/training-run-42/"
# Auto-import from S3: Any file added to S3 appears in FSx
auto_import_policy = "NEW_CHANGED"
tags = {
Name = "LLM-Checkpoint-Storage"
}
}
# Security Group: Allow Lustre traffic (988/tcp, 1021-1023/tcp)
resource "aws_security_group_rule" "lustre_ingress" {
type = "ingress"
from_port = 988
to_port = 988
protocol = "tcp"
security_group_id = aws_security_group.training_sg.id
self = true
}
9.3.3. Spot Instances: The Economics of Ephemeral Compute
Training an LLM on On-Demand instances can cost $500,000+. Using Spot instances can reduce this by 70%. But Spot capacity can be reclaimed with 2 minutes of notice.
The Trade-Off
- On-Demand: $32.77/hour per
p4d.24xlarge(8x A100). - Spot: ~$10-15/hour (varies by demand). But you might get interrupted.
If your training run spans 30 days and Spot saves you $200,000, but you lose 6 hours of compute to interruptions, you still win massively—as long as your checkpointing strategy is robust.
The Architecture for Spot Resilience
1. Mixed Fleet (Heterogeneous Spot + On-Demand): Do not run 100% Spot. Use a hybrid model.
- Core nodes (Rank 0, maybe 10% of cluster): On-Demand (guaranteed stability for orchestration).
- Worker nodes (Rank 1-N): Spot.
If a Spot node disappears, the training job doesn’t lose coordination.
2. Checkpoint Aggressively: On Spot, reduce checkpoint interval. If $T_{MTBF}$ is 6 hours for Spot, checkpoint every 30-60 minutes.
3. Spot Interruption Handler (AWS): AWS provides a metadata endpoint that signals 2 minutes before termination.
Python Daemon for Graceful Shutdown:
import requests
import time
import subprocess
SPOT_TERMINATION_ENDPOINT = "http://169.254.169.254/latest/meta-data/spot/instance-action"
def check_spot_termination():
"""
Poll the EC2 metadata endpoint.
If interruption is imminent, trigger emergency checkpoint.
"""
try:
response = requests.get(SPOT_TERMINATION_ENDPOINT, timeout=1)
if response.status_code == 200:
# Spot termination notice received
print("SPOT TERMINATION WARNING: Initiating emergency checkpoint.")
# Signal the training process (via file touch or signal)
subprocess.run(["touch", "/tmp/emergency_checkpoint"])
return True
except requests.exceptions.RequestException:
# No termination notice (404 = all clear)
pass
return False
if __name__ == "__main__":
while True:
if check_spot_termination():
break
time.sleep(5) # Poll every 5 seconds
In the training loop:
import os
for step, batch in enumerate(dataloader):
# Check for emergency signal
if os.path.exists("/tmp/emergency_checkpoint"):
print("Emergency checkpoint triggered. Saving state...")
save_checkpoint(model, optimizer, epoch, step, checkpoint_dir)
dist.barrier() # Ensure all ranks finish
sys.exit(0) # Graceful exit
# Normal training step
loss = train_step(model, batch)
GCP Equivalent: GCP Preemptible VMs provide a similar metadata endpoint at:
http://metadata.google.internal/computeMetadata/v1/instance/preempted
Terraform Auto Scaling with Spot (AWS)
resource "aws_autoscaling_group" "spot_workers" {
name = "llm-training-spot-workers"
max_size = 100
min_size = 0
desired_capacity = 50
vpc_zone_identifier = [aws_subnet.private.id]
mixed_instances_policy {
instances_distribution {
on_demand_base_capacity = 5 # 5 guaranteed On-Demand
on_demand_percentage_above_base_capacity = 10 # 10% more On-Demand
spot_allocation_strategy = "capacity-optimized"
}
launch_template {
launch_template_specification {
launch_template_id = aws_launch_template.gpu_node.id
version = "$Latest"
}
# Try multiple instance types to increase Spot availability
override {
instance_type = "p4d.24xlarge"
}
override {
instance_type = "p4de.24xlarge"
}
}
}
tag {
key = "Name"
value = "Spot-Worker"
propagate_at_launch = true
}
}
9.3.4. Failure Detection and Elastic Training
Modern distributed training frameworks support Elastic Training: the ability to dynamically add or remove nodes without restarting from scratch.
PyTorch Elastic (TorchElastic)
TorchElastic allows training jobs to survive node failures by shrinking the world size.
How It Works:
- You define a min/max number of nodes (e.g., min=8, max=16).
- If a node fails, TorchElastic detects the failure (via a rendezvous backend like
etcdorc10d). - The remaining nodes re-form the process group and continue training.
Launching with torchrun:
torchrun \
--nnodes=4:8 \ # Min 4 nodes, max 8 nodes
--nproc_per_node=8 \ # 8 GPUs per node
--rdzv_backend=c10d \ # Rendezvous backend
--rdzv_endpoint=$MASTER_ADDR:29500 \
--rdzv_id=unique_job_id \
--max_restarts=3 \ # Retry up to 3 times on failure
train.py --config config.yaml
The Rendezvous Service: For production, use AWS DynamoDB or etcd as the rendezvous backend. This stores the current membership of the cluster.
import torch.distributed as dist
def setup_elastic(rank, world_size):
# The rendezvous backend handles node discovery
dist.init_process_group(
backend="nccl",
init_method="env://", # TorchElastic sets the env vars
rank=rank,
world_size=world_size,
)
Caveats:
- Elastic training works best with Data Parallelism. Tensor Parallelism and Pipeline Parallelism are harder because they have rigid topologies (you can’t just remove rank 4 from a 3D layout).
- You must reload the checkpoint after the world size changes to reshard the optimizer states.
9.3.5. Health Checks and Automated Recovery
In a long-running training job, you need automated monitoring to detect silent failures (e.g., a GPU degrading but not crashing).
The Health Check Stack
1. NVIDIA DCGM (Data Center GPU Manager): DCGM is the canonical tool for monitoring GPU health.
- Metrics: Temperature, power draw, ECC errors, NVLink errors, PCIe throughput.
- Deployment: Run
dcgm-exporteras a DaemonSet.
Full Deployment Guide: For the complete Kubernetes DaemonSet configuration, ServiceMonitor setup, and Grafana dashboard JSON for DCGM, please refer to Chapter 18.2: GPU Observability. The configuration there is the canonical source of truth for this book.
2. Prometheus Alerting Rules: Define alerts for anomalous GPU behavior.
groups:
- name: gpu_health
rules:
- alert: GPUHighTemperature
expr: DCGM_FI_DEV_GPU_TEMP > 85
for: 5m
labels:
severity: warning
annotations:
summary: "GPU {{ $labels.gpu }} on {{ $labels.instance }} is overheating"
description: "Temperature is {{ $value }}C"
- alert: GPUMemoryErrors
expr: rate(DCGM_FI_DEV_ECC_DBE_VOL_TOTAL[5m]) > 0
labels:
severity: critical
annotations:
summary: "GPU {{ $labels.gpu }} has uncorrectable memory errors"
description: "This GPU should be drained and replaced"
- alert: TrainingStalled
expr: rate(training_steps_total[10m]) == 0
for: 15m
labels:
severity: critical
annotations:
summary: "Training job has stalled on {{ $labels.job_name }}"
3. Automated Remediation (Self-Healing): When an alert fires, trigger a Lambda (AWS) or Cloud Function (GCP) to:
- Drain the unhealthy node from the cluster (Kubernetes cordon + drain).
- Terminate the instance.
- The Auto Scaling Group automatically replaces it.
AWS Lambda for Node Replacement:
import boto3
def lambda_handler(event, context):
"""
Triggered by CloudWatch Alarm.
Terminates the unhealthy EC2 instance.
ASG will launch a replacement automatically.
"""
instance_id = event['detail']['instance-id']
ec2 = boto3.client('ec2')
print(f"Terminating unhealthy instance: {instance_id}")
ec2.terminate_instances(InstanceIds=[instance_id])
return {"status": "terminated", "instance": instance_id}
9.3.6. Gradient Anomaly Detection: Catching NaN Before It Spreads
A single NaN in a gradient can poison the entire model within one AllReduce operation. By the time you notice (loss becomes NaN), the damage is done.
The Solution: Gradient Clipping + NaN Detection
1. Global Gradient Norm Clipping: Standard practice is to clip the L2 norm of the gradient vector to prevent explosions.
from torch.nn.utils import clip_grad_norm_
# After loss.backward(), before optimizer.step()
total_norm = clip_grad_norm_(model.parameters(), max_norm=1.0)
if dist.get_rank() == 0:
# Log the gradient norm to detect anomalies
wandb.log({"grad_norm": total_norm.item()})
2. NaN Detection Hook: Install a hook to crash immediately if a NaN is detected, before it propagates.
def nan_hook(module, grad_input, grad_output):
"""
Debugging hook to catch NaN gradients.
"""
for i, grad in enumerate(grad_output):
if grad is not None and torch.isnan(grad).any():
rank = dist.get_rank()
print(f"RANK {rank}: NaN detected in {module.__class__.__name__}, output {i}")
# Trigger emergency checkpoint
save_checkpoint(model, optimizer, epoch, step, f"/checkpoints/emergency_nan_rank{rank}")
raise ValueError("NaN detected in gradients. Training halted.")
# Register the hook on all modules
for module in model.modules():
module.register_full_backward_hook(nan_hook)
3. Automatic Loss Scaling (Mixed Precision):
When using FP16/BF16, underflow can cause gradients to vanish. PyTorch’s GradScaler dynamically adjusts the loss scale to prevent this.
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
with autocast(): # Forward in FP16
loss = model(batch)
# Scale the loss to prevent underflow
scaler.scale(loss).backward()
# Unscale before clipping
scaler.unscale_(optimizer)
clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step with gradient scaling
scaler.step(optimizer)
scaler.update()
9.3.7. The “Checkpoint Zoo”: Retention Policies and Cost Optimization
If you checkpoint every hour for 30 days, you generate 720 checkpoints at 1TB each = 720TB of storage ($16,560/month on S3).
The Retention Strategy
1. Tiered Retention:
- Aggressive Tier (Last 6 hours): Keep every checkpoint (for rapid rollback).
- Daily Tier (Last 7 days): Keep 1 checkpoint per day.
- Weekly Tier (Last 3 months): Keep 1 checkpoint per week.
- Milestone Tier: Keep checkpoints at key milestones (e.g., “Best validation loss”).
2. Automated Cleanup (S3 Lifecycle Policies):
resource "aws_s3_bucket_lifecycle_configuration" "checkpoint_lifecycle" {
bucket = aws_s3_bucket.checkpoints.id
rule {
id = "cleanup_old_checkpoints"
status = "Enabled"
# Delete checkpoints older than 30 days
expiration {
days = 30
}
# Move to Glacier after 7 days (cheap archival)
transition {
days = 7
storage_class = "GLACIER"
}
}
rule {
id = "keep_best_model"
status = "Enabled"
filter {
prefix = "best-model/"
}
# Never delete the best model
expiration {
days = 0
}
}
}
3. Deduplicated Storage with Diff Checkpoints: Instead of saving the full state every time, save only the diff from the previous checkpoint (like Git).
Implementation Sketch:
import torch
def save_diff_checkpoint(prev_state, current_state, path):
diff = {}
for key in current_state:
if key in prev_state:
diff[key] = current_state[key] - prev_state[key]
else:
diff[key] = current_state[key]
torch.save(diff, path)
This can reduce checkpoint size by 10-50x for incremental updates.
9.3.8. Disaster Recovery: Multi-Region Checkpoint Replication
A single region failure (rare but not impossible) can destroy all checkpoints. For mission-critical training jobs, implement multi-region replication.
Cross-Region Replication Strategy
Async Replication: Write checkpoints to local storage (FSx), then asynchronously replicate to S3 in a different region.
Architecture:
- Primary Region (us-east-1): Training cluster + FSx for Lustre.
- Backup Region (us-west-2): S3 bucket with versioning enabled.
Terraform Implementation:
# Primary S3 bucket (us-east-1)
resource "aws_s3_bucket" "checkpoints_primary" {
bucket = "llm-checkpoints-primary"
provider = aws.us_east_1
versioning {
enabled = true
}
lifecycle_rule {
enabled = true
noncurrent_version_expiration {
days = 7 # Keep old versions for 7 days
}
}
}
# Replica S3 bucket (us-west-2)
resource "aws_s3_bucket" "checkpoints_replica" {
bucket = "llm-checkpoints-replica"
provider = aws.us_west_2
versioning {
enabled = true
}
}
# Replication configuration
resource "aws_s3_bucket_replication_configuration" "replication" {
bucket = aws_s3_bucket.checkpoints_primary.id
role = aws_iam_role.replication.arn
rule {
id = "replicate-all"
status = "Enabled"
destination {
bucket = aws_s3_bucket.checkpoints_replica.arn
storage_class = "GLACIER_IR" # Cheaper storage for backups
replication_time {
status = "Enabled"
time {
minutes = 15 # Replicate within 15 minutes (S3 RTC)
}
}
metrics {
status = "Enabled"
event_threshold {
minutes = 15
}
}
}
filter {} # Replicate all objects
}
}
Cost Analysis:
- Replication: $0.02/GB (one-time).
- Storage in Glacier IR: $0.004/GB/month (75% cheaper than Standard).
- Total for 10 TB: $200 (replication) + $40/month (storage).
Disaster Recovery Testing
Quarterly Drill: Simulate primary region failure.
# 1. Stop training in us-east-1
kubectl delete deployment training-job -n ml-training
# 2. Provision cluster in us-west-2
terraform apply -var="region=us-west-2"
# 3. Restore checkpoint from replica bucket
aws s3 sync s3://llm-checkpoints-replica/run-42/checkpoint-5000 /fsx/restore/
# 4. Resume training
python train.py --resume-from /fsx/restore/checkpoint-5000
Target Recovery Time Objective (RTO): 2 hours. Target Recovery Point Objective (RPO): 15 minutes (last checkpoint).
9.3.9. Incremental Checkpointing and Delta Compression
Saving full checkpoints every hour for a 70B model (1.1 TB each) is wasteful. Most parameters change very little between checkpoints.
Delta Checkpointing
Store only the difference between consecutive checkpoints.
Algorithm:
import torch
import numpy as np
def save_delta_checkpoint(prev_checkpoint, current_state, save_path):
"""
Save only the diff between current state and previous checkpoint.
"""
delta = {}
for key in current_state:
if key in prev_checkpoint:
# Compute difference
diff = current_state[key] - prev_checkpoint[key]
# Sparsify: Only store values > threshold
mask = torch.abs(diff) > 1e-6
sparse_diff = diff * mask
delta[key] = {
"sparse_values": sparse_diff[mask],
"indices": mask.nonzero(as_tuple=True),
"shape": diff.shape,
}
else:
# New parameter (e.g., added layer)
delta[key] = current_state[key]
torch.save(delta, save_path)
return delta
def load_delta_checkpoint(base_checkpoint, delta_path):
"""
Reconstruct checkpoint by applying delta to base.
"""
delta = torch.load(delta_path)
reconstructed = {}
for key in delta:
if isinstance(delta[key], dict) and "sparse_values" in delta[key]:
# Reconstruct sparse diff
base_tensor = base_checkpoint[key]
sparse_values = delta[key]["sparse_values"]
indices = delta[key]["indices"]
diff_tensor = torch.zeros_like(base_tensor)
diff_tensor[indices] = sparse_values
reconstructed[key] = base_tensor + diff_tensor
else:
# New parameter
reconstructed[key] = delta[key]
return reconstructed
Compression Ratio: For typical LLM training, parameters change by ~0.01-0.1% per step.
- Full checkpoint: 1.1 TB.
- Delta checkpoint: ~10-50 GB (95% reduction).
Trade-off: Reconstruction requires the base checkpoint. If the base is corrupted, all deltas are useless.
Hybrid Strategy:
- Every 10 steps: Save delta checkpoint.
- Every 100 steps: Save full checkpoint (new base).
9.3.10. Checkpoint Validation and Corruption Detection
Silent data corruption (e.g., bit flips in S3, filesystem bugs) can corrupt checkpoints without immediate detection.
Checksum Validation
Compute a cryptographic hash of each checkpoint and store it alongside the data.
Implementation:
import hashlib
import torch
def save_checkpoint_with_hash(state_dict, save_path):
"""
Save checkpoint with SHA256 checksum.
"""
# Save checkpoint
torch.save(state_dict, save_path)
# Compute hash
sha256 = hashlib.sha256()
with open(save_path, "rb") as f:
while chunk := f.read(8192):
sha256.update(chunk)
hash_value = sha256.hexdigest()
# Save hash to sidecar file
with open(f"{save_path}.sha256", "w") as f:
f.write(hash_value)
return hash_value
def verify_checkpoint(checkpoint_path):
"""
Verify checkpoint integrity using stored hash.
"""
# Compute current hash
sha256 = hashlib.sha256()
with open(checkpoint_path, "rb") as f:
while chunk := f.read(8192):
sha256.update(chunk)
current_hash = sha256.hexdigest()
# Load expected hash
with open(f"{checkpoint_path}.sha256", "r") as f:
expected_hash = f.read().strip()
if current_hash != expected_hash:
raise ValueError(f"Checkpoint corrupted! Expected {expected_hash}, got {current_hash}")
return True
S3 Object Lock: For compliance, use S3 Object Lock to make checkpoints immutable (cannot be deleted or modified for a retention period).
resource "aws_s3_bucket_object_lock_configuration" "checkpoints" {
bucket = aws_s3_bucket.checkpoints.id
rule {
default_retention {
mode = "GOVERNANCE" # Can be overridden by root user
days = 30
}
}
}
9.3.11. Training Resumption Testing: The “Resume Benchmark”
A checkpoint is useless if you can’t resume from it. Test resume functionality regularly.
Automated Resume Test
Goal: Verify that a resumed run produces identical results to a continuous run (within numerical tolerance).
Test Script:
import torch
import random
import numpy as np
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def test_deterministic_resume():
"""
Train for 100 steps, save checkpoint at step 50, resume, and verify results.
"""
# Run 1: Train 0-100 continuously
set_seed(42)
model1 = MyModel()
optimizer1 = AdamW(model1.parameters())
losses_continuous = []
for step in range(100):
loss = train_step(model1, get_batch(step))
losses_continuous.append(loss.item())
optimizer1.step()
# Run 2: Train 0-50, checkpoint, then resume 50-100
set_seed(42)
model2 = MyModel()
optimizer2 = AdamW(model2.parameters())
losses_resumed = []
for step in range(50):
loss = train_step(model2, get_batch(step))
losses_resumed.append(loss.item())
optimizer2.step()
# Checkpoint at step 50
checkpoint = {
"model": model2.state_dict(),
"optimizer": optimizer2.state_dict(),
"step": 50,
"rng_state": torch.get_rng_state(),
"cuda_rng_state": torch.cuda.get_rng_state(),
}
torch.save(checkpoint, "checkpoint_step50.pt")
# Resume from checkpoint
checkpoint = torch.load("checkpoint_step50.pt")
model2.load_state_dict(checkpoint["model"])
optimizer2.load_state_dict(checkpoint["optimizer"])
torch.set_rng_state(checkpoint["rng_state"])
torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
for step in range(50, 100):
loss = train_step(model2, get_batch(step))
losses_resumed.append(loss.item())
optimizer2.step()
# Compare
for i, (loss_cont, loss_res) in enumerate(zip(losses_continuous, losses_resumed)):
assert abs(loss_cont - loss_res) < 1e-6, f"Step {i}: {loss_cont} != {loss_res}"
print("Resume test PASSED: Resumed run is identical to continuous run.")
Run this test:
- Before every major training run.
- After upgrading PyTorch, CUDA, or NCCL.
9.3.12. Chaos Engineering for Distributed Training
Proactively inject failures to test resilience.
Chaos Experiment 1: Random Node Termination
Tool: Chaos Mesh (Kubernetes) or custom script.
Experiment:
apiVersion: chaos-mesh.org/v1alpha1
kind: PodChaos
metadata:
name: kill-random-worker
namespace: ml-training
spec:
action: pod-kill
mode: one
selector:
namespaces:
- ml-training
labelSelectors:
app: distributed-training
scheduler:
cron: "*/30 * * * *" # Kill one pod every 30 minutes
Expected Behavior: Training should pause, detect the failure, and resume from the last checkpoint without manual intervention.
Chaos Experiment 2: Network Partition
Simulate a network split where nodes can’t communicate.
apiVersion: chaos-mesh.org/v1alpha1
kind: NetworkChaos
metadata:
name: partition-network
spec:
action: partition
mode: all
selector:
namespaces:
- ml-training
labelSelectors:
app: distributed-training
direction: both
duration: "5m"
Expected Behavior: NCCL should timeout, training should crash, and watchdog should restart from checkpoint.
Chaos Experiment 3: Disk Corruption (Checkpoint Storage)
Corrupt a checkpoint file and verify detection.
#!/bin/bash
# Inject bit flip in checkpoint file
CHECKPOINT="/fsx/checkpoints/run-42/checkpoint-1000/model.pt"
dd if=/dev/urandom of=$CHECKPOINT bs=1 count=100 seek=$RANDOM conv=notrunc
Expected Behavior: On load, checksum validation should fail, and the system should fall back to the previous checkpoint.
9.3.13. Cost of Fault Tolerance: The Insurance Premium
Fault tolerance is not free. It’s an insurance policy. You pay upfront (in time and money) to reduce the risk of catastrophic loss.
Cost Breakdown for 30-Day Training Run (70B Model, 100 Nodes)
| Item | Without FT | With FT | Overhead |
|---|---|---|---|
| Compute (100 nodes × $32.77/hr × 720 hrs) | $2,359,440 | $2,359,440 | 0% |
| Checkpoint I/O (pause training to write) | $0 | ~5% time penalty | ~$118,000 |
| Storage (FSx + S3 replication) | $1,400 | $3,500 | $2,100 |
| Monitoring (DCGM, Prometheus) | $0 | $500 | $500 |
| Spot Interruption Losses (assume 3 interruptions) | $0 (N/A on On-Demand) | $2,000 | $2,000 |
| Expected Loss from Failure (10% chance of catastrophic failure) | $235,944 | $0 | -$235,944 |
| Total Expected Cost | $2,596,784 | $2,485,540 | -$111,244 |
Conclusion: Fault tolerance is not just a risk mitigation strategy; it’s also a cost optimization strategy. The expected cost is lower with FT due to reduced failure risk.
9.3.14. Checkpoint Formats: Trade-offs and Best Practices
Different checkpoint formats have different performance characteristics.
Format Comparison
| Format | Size | Write Speed | Read Speed | Compatibility | Use Case |
|---|---|---|---|---|---|
| PyTorch .pt (Pickle) | Medium | Fast | Fast | PyTorch only | Standard choice |
| Safetensors | Small | Very Fast | Very Fast | Multi-framework | Recommended for production |
| HDF5 | Medium | Medium | Medium | Universal | Legacy systems |
| NumPy .npz | Large | Slow | Slow | Universal | Debugging/inspection |
| TensorFlow Checkpoint | Large | Medium | Medium | TensorFlow | If using TF |
Safetensors: The Modern Standard
Safetensors is a new format developed by Hugging Face. It’s faster, safer, and more portable than Pickle.
Advantages:
- Security: No arbitrary code execution (Pickle can run malicious code).
- Speed: Zero-copy memory mapping (faster load).
- Lazy Loading: Load only needed tensors (useful for inference).
Installation:
pip install safetensors
Usage:
from safetensors.torch import save_file, load_file
# Save
state_dict = model.state_dict()
save_file(state_dict, "checkpoint.safetensors")
# Load
state_dict = load_file("checkpoint.safetensors")
model.load_state_dict(state_dict)
Migration from .pt to safetensors:
import torch
from safetensors.torch import save_file
# Load old checkpoint
old_checkpoint = torch.load("checkpoint.pt")
# Save in new format
save_file(old_checkpoint["model"], "checkpoint.safetensors")
Recommendation: Use Safetensors for all new projects. Convert existing .pt checkpoints during the next training run.
9.3.15. Final Checklist: Production-Ready Fault Tolerance
Before launching a multi-million dollar training run, verify:
1. Checkpointing:
- Checkpoints are sharded (DCP or FSDP state dict).
- Checkpoint interval is optimized (Young’s formula).
- Checkpoints are written to high-speed storage (FSx/Parallelstore).
- Checksums are computed and verified.
2. Backup and Replication:
- Checkpoints are replicated to S3 (or GCS).
- Multi-region replication is enabled for critical runs.
- Retention policy is configured (tiered storage).
3. Failure Detection:
- DCGM Exporter is deployed on all nodes.
- Prometheus alerts are configured for GPU health, network errors, and training stalls.
- Automated remediation (node replacement) is set up.
4. Resumption:
- Resume logic is tested (deterministic resume test passed).
- Dataloader state is saved and restored.
- RNG states are saved and restored.
5. Spot Resilience (if using Spot):
- Spot interruption handler is running.
- Emergency checkpoint on interruption is implemented.
- Mixed On-Demand + Spot fleet is configured.
6. Monitoring:
- Training metrics are logged (loss, throughput, GPU utilization).
- Dashboards are created (Grafana or CloudWatch).
- Alerts are routed to on-call engineers (PagerDuty, Slack).
7. Chaos Testing:
- Node termination chaos experiment passed.
- Network partition chaos experiment passed.
- Checkpoint corruption detection tested.
If all boxes are checked: You are ready for production.
9.3.16. Summary: The Resilience Checklist
When architecting fault tolerance for distributed training:
- Checkpoint Religiously: Use sharded checkpoints (DCP). Write to high-speed storage (FSx/Parallelstore).
- Optimize Checkpoint Interval: Use Young’s formula. Balance I/O cost vs. recompute cost.
- Embrace Spot: Use hybrid On-Demand + Spot. Implement interruption handlers.
- Monitor GPUs: Deploy DCGM. Alert on ECC errors, temperature, and training stalls.
- Detect NaN Early: Use gradient hooks and clipping. Don’t let poison spread.
- Automate Recovery: Use Elastic Training (TorchElastic) for node failures. Auto-replace unhealthy instances.
- Manage Checkpoint Bloat: Implement tiered retention. Use S3 lifecycle policies.
In the next chapter, we will discuss Model Serving and Inference Optimization, where the challenges shift from throughput (training) to latency (serving) and cost-per-token economics.