Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

15.3. Fault Tolerance: The Art of Crash-Proofing

“In a distributed system, a failure is when a computer you didn’t even know existed renders your own computer unusable.” — Leslie Lamport

If you are training a model on a single GPU, a crash is an annoyance. You restart the script, maybe lose an hour of progress.

If you are training a 70B parameter LLM on 512 H100 GPUs for three months, a crash is a statistical certainty. At that scale, hardware failure is not an exception; it is the steady state.

  • Memory Errors: Cosmic rays flip bits in HBM3 memory (ECC errors).
  • Network Flaps: A single optical transceiver in a Top-of-Rack switch degrades, causing packet loss that times out the NCCL ring.
  • Preemption: The cloud provider reclaims your Spot capacity because a higher-paying customer just spun up a cluster.
  • Software Bugs: A gradient explosion produces a NaN which propagates through the AllReduce operation, poisoning the weights of every GPU in the cluster instantly.

Without a robust fault tolerance strategy, you will never finish training. You will be stuck in a “Sysiphus Loop,” rolling the rock up the hill only to have a node fail at 98%, forcing a restart from zero.

This section details the architecture of resilience: how to checkpoint state effectively, how to handle the ruthless economics of Spot instances, and how to build self-healing clusters on AWS and GCP.


9.3.1. The Thermodynamics of Failure

To architect for failure, we must first quantify it. The probability of a successful training run drops exponentially with the number of nodes.

$$ P(Success) = (1 - p_{daily_fail})^{N_{nodes} \times D_{days}} $$

Let’s assume a single GPU node has a Mean Time Between Failures (MTBF) that implies a 0.1% chance of failing on any given day ($p = 0.001$). This includes hardware issues, driver crashes, and maintenance events.

  • Single Node (1 GPU) running for 30 days: $$ 0.999^{30} \approx 97% \text{ chance of success without interruption.} $$
  • Cluster (1,000 Nodes) running for 30 days: $$ 0.999^{30,000} \approx 0.00000000000009% \text{ chance of success.} $$

It is mathematically impossible to train large models without interruptions. Therefore, the training system must be viewed not as a continuous process, but as a series of discrete, recoverable segments.

The Cost of Checkpointing (The Tax)

Fault tolerance is not free. It is a trade-off between Compute Time (lost progress after a crash) and I/O Overhead (time spent pausing training to write to disk).

If you checkpoint too often, you waste 20% of your GPU cycles writing to S3. If you checkpoint too rarely, a crash destroys 24 hours of compute (worth perhaps $50,000).

Young’s Approximation for Optimal Checkpoint Interval: $$ T_{opt} = \sqrt{2 \times T_{checkpoint} \times T_{MTBF}} $$

Where:

  • $T_{opt}$: The optimal time between checkpoints.
  • $T_{checkpoint}$: Time it takes to write the checkpoint.
  • $T_{MTBF}$: Mean Time Between Failures for the entire cluster.

Example:

  • Cluster MTBF is 12 hours (on average, something breaks twice a day).
  • Writing the checkpoint takes 5 minutes (0.083 hours).
  • $T_{opt} \approx \sqrt{2 \times 0.083 \times 12} \approx 1.41 \text{ hours}$.

You should checkpoint every ~90 minutes.


9.3.2. Checkpointing Mechanics: What to Save and How

A common misconception is that you only need to save the model weights (model.state_dict()). For a training run to resume exactly where it left off, ensuring bit-for-bit reproducibility (or at least statistical continuity), you must save much more.

The Anatomy of a Checkpoint

For a Large Language Model using AdamW optimizer and Mixed Precision:

  1. Model Weights (FP16/BF16): The active parameters.
  2. Master Weights (FP32): The high-precision copy kept by the optimizer to accumulate small gradient updates.
  3. Optimizer State (FP32):
    • Momentum (Beta1): Exponential moving average of gradients.
    • Variance (Beta2): Exponential moving average of squared gradients.
    • Step Count: For bias correction.
  4. Learning Rate Scheduler State: Current epoch, current LR, warmup counter.
  5. Data Loader State: Which epoch? Which batch index? Ideally, the RNG state of the shuffler.
  6. Random Number Generator (RNG) States: CUDA RNG, Python RNG, and Numpy RNG seeds for every rank.

The Storage Explosion: For a model with parameters $\Phi$, the checkpoint size is roughly $16 \times \Phi$ bytes.

  • Model (BF16): 2 bytes
  • Master Model (FP32): 4 bytes
  • Optimizer Momentum (FP32): 4 bytes
  • Optimizer Variance (FP32): 4 bytes
  • Gradients (Transient, usually not saved but exist in VRAM): 2 bytes

A 70B parameter model requires: $$ 70 \times 10^9 \times 16 \text{ bytes} \approx 1.12 \text{ TB per checkpoint.} $$

If you retain the last 5 checkpoints for safety, you are storing 5.6 TB of data per run.

PyTorch Distributed Checkpointing (DCP)

In the old days (PyTorch < 1.13), rank 0 would gather all weights to CPU RAM and write a single .pt file. This causes OOM (Out of Memory) on rank 0 and network bottlenecks.

Modern training uses Sharded Checkpointing. Each GPU writes its own slice of the model and optimizer state directly to storage.

import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

def save_checkpoint(model, optimizer, epoch, step, checkpoint_dir):
    """
    Modern sharded checkpointing for FSDP models.
    Each rank writes its own shard to storage in parallel.
    """
    with FSDP.state_dict_type(
        model,
        StateDictType.SHARDED_STATE_DICT,
    ):
        state_dict = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "step": step,
            "rng_state": torch.get_rng_state(),
            "cuda_rng_state": torch.cuda.get_rng_state(),
        }

        # Write sharded checkpoint
        # Each rank writes to: checkpoint_dir/__0_0.distcp, __1_0.distcp, etc.
        dist_cp.save_state_dict(
            state_dict=state_dict,
            storage_writer=dist_cp.FileSystemWriter(checkpoint_dir),
        )

    if dist.get_rank() == 0:
        print(f"Checkpoint saved to {checkpoint_dir} at epoch {epoch}, step {step}")

def load_checkpoint(model, optimizer, checkpoint_dir):
    """
    Load from sharded checkpoint.
    Each rank loads only its shard.
    """
    with FSDP.state_dict_type(
        model,
        StateDictType.SHARDED_STATE_DICT,
    ):
        state_dict = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }

        # Load sharded checkpoint
        dist_cp.load_state_dict(
            state_dict=state_dict,
            storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
        )

        model.load_state_dict(state_dict["model"])
        optimizer.load_state_dict(state_dict["optimizer"])

        # Restore RNG states
        torch.set_rng_state(state_dict["rng_state"])
        torch.cuda.set_rng_state(state_dict["cuda_rng_state"])

        return state_dict["epoch"], state_dict["step"]

The Storage Backend: S3 vs. EFS vs. FSx

Where do you write 1TB checkpoints? The choice matters for speed and cost.

1. Amazon S3 (Object Storage):

  • Pros: Infinite scalability. Durability (99.999999999%). Cheap ($0.023/GB/month).
  • Cons: High latency (~50-100ms per write). Eventual consistency issues for rapid updates.
  • Use Case: Final checkpoints. Long-term archival.
  • Throughput: With s5cmd or parallel writes via PyTorch, you can achieve ~10GB/s writes from a multi-node cluster.

2. Amazon EFS (Elastic File System):

  • Pros: POSIX-compatible. Can be mounted directly by all nodes. Lower latency than S3 (~1-5ms).
  • Cons: Expensive ($0.30/GB/month). Performance depends on provisioned throughput.
  • Use Case: Working checkpoints during active training.
  • Architecture Note: Use EFS in “Max I/O” mode for distributed writes. Ensure the mount target is in the same Availability Zone as your cluster.

3. Amazon FSx for Lustre:

  • Pros: Built for HPC. Backed by S3 but presents a high-speed POSIX filesystem (sub-millisecond latency). Can achieve 100s of GB/s throughput.
  • Cons: Expensive ($0.14-0.60/GB/month depending on config). Requires explicit capacity planning.
  • Use Case: The gold standard for massive-scale training. Used by AWS for training models like Olympus.
  • Integration: FSx can be linked to an S3 bucket. Changes written to FSx are automatically synced to S3 in the background.

GCP Equivalents:

  • Google Cloud Storage (GCS): Like S3. Use gcsfuse for direct mounting (slower) or gcsfs Python library for programmatic access.
  • Filestore: Like EFS. NFS-based. Use Filestore High Scale for HPC.
  • Parallelstore: Google’s new answer to FSx Lustre. Optimized for AI/ML workloads with tight integration to Vertex AI.

Terraform Example: FSx for Lustre Checkpoint Backend (AWS):

resource "aws_fsx_lustre_file_system" "checkpoint_fs" {
  storage_capacity            = 7200  # GB, scales in 1.2TB increments
  subnet_ids                  = [aws_subnet.private.id]
  deployment_type             = "PERSISTENT_2"  # High durability
  per_unit_storage_throughput = 250  # MB/s per TB of storage

  # Link to S3 bucket for automatic export
  import_path = "s3://${aws_s3_bucket.checkpoints.bucket}/training-run-42/"
  export_path = "s3://${aws_s3_bucket.checkpoints.bucket}/training-run-42/"

  # Auto-import from S3: Any file added to S3 appears in FSx
  auto_import_policy = "NEW_CHANGED"

  tags = {
    Name = "LLM-Checkpoint-Storage"
  }
}

# Security Group: Allow Lustre traffic (988/tcp, 1021-1023/tcp)
resource "aws_security_group_rule" "lustre_ingress" {
  type              = "ingress"
  from_port         = 988
  to_port           = 988
  protocol          = "tcp"
  security_group_id = aws_security_group.training_sg.id
  self              = true
}

9.3.3. Spot Instances: The Economics of Ephemeral Compute

Training an LLM on On-Demand instances can cost $500,000+. Using Spot instances can reduce this by 70%. But Spot capacity can be reclaimed with 2 minutes of notice.

The Trade-Off

  • On-Demand: $32.77/hour per p4d.24xlarge (8x A100).
  • Spot: ~$10-15/hour (varies by demand). But you might get interrupted.

If your training run spans 30 days and Spot saves you $200,000, but you lose 6 hours of compute to interruptions, you still win massively—as long as your checkpointing strategy is robust.

The Architecture for Spot Resilience

1. Mixed Fleet (Heterogeneous Spot + On-Demand): Do not run 100% Spot. Use a hybrid model.

  • Core nodes (Rank 0, maybe 10% of cluster): On-Demand (guaranteed stability for orchestration).
  • Worker nodes (Rank 1-N): Spot.

If a Spot node disappears, the training job doesn’t lose coordination.

2. Checkpoint Aggressively: On Spot, reduce checkpoint interval. If $T_{MTBF}$ is 6 hours for Spot, checkpoint every 30-60 minutes.

3. Spot Interruption Handler (AWS): AWS provides a metadata endpoint that signals 2 minutes before termination.

Python Daemon for Graceful Shutdown:

import requests
import time
import subprocess

SPOT_TERMINATION_ENDPOINT = "http://169.254.169.254/latest/meta-data/spot/instance-action"

def check_spot_termination():
    """
    Poll the EC2 metadata endpoint.
    If interruption is imminent, trigger emergency checkpoint.
    """
    try:
        response = requests.get(SPOT_TERMINATION_ENDPOINT, timeout=1)
        if response.status_code == 200:
            # Spot termination notice received
            print("SPOT TERMINATION WARNING: Initiating emergency checkpoint.")
            # Signal the training process (via file touch or signal)
            subprocess.run(["touch", "/tmp/emergency_checkpoint"])
            return True
    except requests.exceptions.RequestException:
        # No termination notice (404 = all clear)
        pass
    return False

if __name__ == "__main__":
    while True:
        if check_spot_termination():
            break
        time.sleep(5)  # Poll every 5 seconds

In the training loop:

import os

for step, batch in enumerate(dataloader):
    # Check for emergency signal
    if os.path.exists("/tmp/emergency_checkpoint"):
        print("Emergency checkpoint triggered. Saving state...")
        save_checkpoint(model, optimizer, epoch, step, checkpoint_dir)
        dist.barrier()  # Ensure all ranks finish
        sys.exit(0)  # Graceful exit

    # Normal training step
    loss = train_step(model, batch)

GCP Equivalent: GCP Preemptible VMs provide a similar metadata endpoint at:

http://metadata.google.internal/computeMetadata/v1/instance/preempted

Terraform Auto Scaling with Spot (AWS)

resource "aws_autoscaling_group" "spot_workers" {
  name                = "llm-training-spot-workers"
  max_size            = 100
  min_size            = 0
  desired_capacity    = 50
  vpc_zone_identifier = [aws_subnet.private.id]

  mixed_instances_policy {
    instances_distribution {
      on_demand_base_capacity                  = 5  # 5 guaranteed On-Demand
      on_demand_percentage_above_base_capacity = 10  # 10% more On-Demand
      spot_allocation_strategy                 = "capacity-optimized"
    }

    launch_template {
      launch_template_specification {
        launch_template_id = aws_launch_template.gpu_node.id
        version            = "$Latest"
      }

      # Try multiple instance types to increase Spot availability
      override {
        instance_type = "p4d.24xlarge"
      }
      override {
        instance_type = "p4de.24xlarge"
      }
    }
  }

  tag {
    key                 = "Name"
    value               = "Spot-Worker"
    propagate_at_launch = true
  }
}

9.3.4. Failure Detection and Elastic Training

Modern distributed training frameworks support Elastic Training: the ability to dynamically add or remove nodes without restarting from scratch.

PyTorch Elastic (TorchElastic)

TorchElastic allows training jobs to survive node failures by shrinking the world size.

How It Works:

  1. You define a min/max number of nodes (e.g., min=8, max=16).
  2. If a node fails, TorchElastic detects the failure (via a rendezvous backend like etcd or c10d).
  3. The remaining nodes re-form the process group and continue training.

Launching with torchrun:

torchrun \
    --nnodes=4:8 \              # Min 4 nodes, max 8 nodes
    --nproc_per_node=8 \        # 8 GPUs per node
    --rdzv_backend=c10d \       # Rendezvous backend
    --rdzv_endpoint=$MASTER_ADDR:29500 \
    --rdzv_id=unique_job_id \
    --max_restarts=3 \          # Retry up to 3 times on failure
    train.py --config config.yaml

The Rendezvous Service: For production, use AWS DynamoDB or etcd as the rendezvous backend. This stores the current membership of the cluster.

import torch.distributed as dist

def setup_elastic(rank, world_size):
    # The rendezvous backend handles node discovery
    dist.init_process_group(
        backend="nccl",
        init_method="env://",  # TorchElastic sets the env vars
        rank=rank,
        world_size=world_size,
    )

Caveats:

  • Elastic training works best with Data Parallelism. Tensor Parallelism and Pipeline Parallelism are harder because they have rigid topologies (you can’t just remove rank 4 from a 3D layout).
  • You must reload the checkpoint after the world size changes to reshard the optimizer states.

9.3.5. Health Checks and Automated Recovery

In a long-running training job, you need automated monitoring to detect silent failures (e.g., a GPU degrading but not crashing).

The Health Check Stack

1. NVIDIA DCGM (Data Center GPU Manager): DCGM is the canonical tool for monitoring GPU health.

  • Metrics: Temperature, power draw, ECC errors, NVLink errors, PCIe throughput.
  • Deployment: Run dcgm-exporter as a DaemonSet.

Full Deployment Guide: For the complete Kubernetes DaemonSet configuration, ServiceMonitor setup, and Grafana dashboard JSON for DCGM, please refer to Chapter 18.2: GPU Observability. The configuration there is the canonical source of truth for this book.

2. Prometheus Alerting Rules: Define alerts for anomalous GPU behavior.

groups:
- name: gpu_health
  rules:
  - alert: GPUHighTemperature
    expr: DCGM_FI_DEV_GPU_TEMP > 85
    for: 5m
    labels:
      severity: warning
    annotations:
      summary: "GPU {{ $labels.gpu }} on {{ $labels.instance }} is overheating"
      description: "Temperature is {{ $value }}C"

  - alert: GPUMemoryErrors
    expr: rate(DCGM_FI_DEV_ECC_DBE_VOL_TOTAL[5m]) > 0
    labels:
      severity: critical
    annotations:
      summary: "GPU {{ $labels.gpu }} has uncorrectable memory errors"
      description: "This GPU should be drained and replaced"

  - alert: TrainingStalled
    expr: rate(training_steps_total[10m]) == 0
    for: 15m
    labels:
      severity: critical
    annotations:
      summary: "Training job has stalled on {{ $labels.job_name }}"

3. Automated Remediation (Self-Healing): When an alert fires, trigger a Lambda (AWS) or Cloud Function (GCP) to:

  • Drain the unhealthy node from the cluster (Kubernetes cordon + drain).
  • Terminate the instance.
  • The Auto Scaling Group automatically replaces it.

AWS Lambda for Node Replacement:

import boto3

def lambda_handler(event, context):
    """
    Triggered by CloudWatch Alarm.
    Terminates the unhealthy EC2 instance.
    ASG will launch a replacement automatically.
    """
    instance_id = event['detail']['instance-id']
    ec2 = boto3.client('ec2')

    print(f"Terminating unhealthy instance: {instance_id}")
    ec2.terminate_instances(InstanceIds=[instance_id])

    return {"status": "terminated", "instance": instance_id}

9.3.6. Gradient Anomaly Detection: Catching NaN Before It Spreads

A single NaN in a gradient can poison the entire model within one AllReduce operation. By the time you notice (loss becomes NaN), the damage is done.

The Solution: Gradient Clipping + NaN Detection

1. Global Gradient Norm Clipping: Standard practice is to clip the L2 norm of the gradient vector to prevent explosions.

from torch.nn.utils import clip_grad_norm_

# After loss.backward(), before optimizer.step()
total_norm = clip_grad_norm_(model.parameters(), max_norm=1.0)

if dist.get_rank() == 0:
    # Log the gradient norm to detect anomalies
    wandb.log({"grad_norm": total_norm.item()})

2. NaN Detection Hook: Install a hook to crash immediately if a NaN is detected, before it propagates.

def nan_hook(module, grad_input, grad_output):
    """
    Debugging hook to catch NaN gradients.
    """
    for i, grad in enumerate(grad_output):
        if grad is not None and torch.isnan(grad).any():
            rank = dist.get_rank()
            print(f"RANK {rank}: NaN detected in {module.__class__.__name__}, output {i}")
            # Trigger emergency checkpoint
            save_checkpoint(model, optimizer, epoch, step, f"/checkpoints/emergency_nan_rank{rank}")
            raise ValueError("NaN detected in gradients. Training halted.")

# Register the hook on all modules
for module in model.modules():
    module.register_full_backward_hook(nan_hook)

3. Automatic Loss Scaling (Mixed Precision): When using FP16/BF16, underflow can cause gradients to vanish. PyTorch’s GradScaler dynamically adjusts the loss scale to prevent this.

from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()

    with autocast():  # Forward in FP16
        loss = model(batch)

    # Scale the loss to prevent underflow
    scaler.scale(loss).backward()

    # Unscale before clipping
    scaler.unscale_(optimizer)
    clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Step with gradient scaling
    scaler.step(optimizer)
    scaler.update()

9.3.7. The “Checkpoint Zoo”: Retention Policies and Cost Optimization

If you checkpoint every hour for 30 days, you generate 720 checkpoints at 1TB each = 720TB of storage ($16,560/month on S3).

The Retention Strategy

1. Tiered Retention:

  • Aggressive Tier (Last 6 hours): Keep every checkpoint (for rapid rollback).
  • Daily Tier (Last 7 days): Keep 1 checkpoint per day.
  • Weekly Tier (Last 3 months): Keep 1 checkpoint per week.
  • Milestone Tier: Keep checkpoints at key milestones (e.g., “Best validation loss”).

2. Automated Cleanup (S3 Lifecycle Policies):

resource "aws_s3_bucket_lifecycle_configuration" "checkpoint_lifecycle" {
  bucket = aws_s3_bucket.checkpoints.id

  rule {
    id     = "cleanup_old_checkpoints"
    status = "Enabled"

    # Delete checkpoints older than 30 days
    expiration {
      days = 30
    }

    # Move to Glacier after 7 days (cheap archival)
    transition {
      days          = 7
      storage_class = "GLACIER"
    }
  }

  rule {
    id     = "keep_best_model"
    status = "Enabled"

    filter {
      prefix = "best-model/"
    }

    # Never delete the best model
    expiration {
      days = 0
    }
  }
}

3. Deduplicated Storage with Diff Checkpoints: Instead of saving the full state every time, save only the diff from the previous checkpoint (like Git).

Implementation Sketch:

import torch

def save_diff_checkpoint(prev_state, current_state, path):
    diff = {}
    for key in current_state:
        if key in prev_state:
            diff[key] = current_state[key] - prev_state[key]
        else:
            diff[key] = current_state[key]
    torch.save(diff, path)

This can reduce checkpoint size by 10-50x for incremental updates.


9.3.8. Disaster Recovery: Multi-Region Checkpoint Replication

A single region failure (rare but not impossible) can destroy all checkpoints. For mission-critical training jobs, implement multi-region replication.

Cross-Region Replication Strategy

Async Replication: Write checkpoints to local storage (FSx), then asynchronously replicate to S3 in a different region.

Architecture:

  1. Primary Region (us-east-1): Training cluster + FSx for Lustre.
  2. Backup Region (us-west-2): S3 bucket with versioning enabled.

Terraform Implementation:

# Primary S3 bucket (us-east-1)
resource "aws_s3_bucket" "checkpoints_primary" {
  bucket = "llm-checkpoints-primary"
  provider = aws.us_east_1

  versioning {
    enabled = true
  }

  lifecycle_rule {
    enabled = true

    noncurrent_version_expiration {
      days = 7  # Keep old versions for 7 days
    }
  }
}

# Replica S3 bucket (us-west-2)
resource "aws_s3_bucket" "checkpoints_replica" {
  bucket = "llm-checkpoints-replica"
  provider = aws.us_west_2

  versioning {
    enabled = true
  }
}

# Replication configuration
resource "aws_s3_bucket_replication_configuration" "replication" {
  bucket = aws_s3_bucket.checkpoints_primary.id
  role   = aws_iam_role.replication.arn

  rule {
    id     = "replicate-all"
    status = "Enabled"

    destination {
      bucket        = aws_s3_bucket.checkpoints_replica.arn
      storage_class = "GLACIER_IR"  # Cheaper storage for backups

      replication_time {
        status = "Enabled"
        time {
          minutes = 15  # Replicate within 15 minutes (S3 RTC)
        }
      }

      metrics {
        status = "Enabled"
        event_threshold {
          minutes = 15
        }
      }
    }

    filter {}  # Replicate all objects
  }
}

Cost Analysis:

  • Replication: $0.02/GB (one-time).
  • Storage in Glacier IR: $0.004/GB/month (75% cheaper than Standard).
  • Total for 10 TB: $200 (replication) + $40/month (storage).

Disaster Recovery Testing

Quarterly Drill: Simulate primary region failure.

# 1. Stop training in us-east-1
kubectl delete deployment training-job -n ml-training

# 2. Provision cluster in us-west-2
terraform apply -var="region=us-west-2"

# 3. Restore checkpoint from replica bucket
aws s3 sync s3://llm-checkpoints-replica/run-42/checkpoint-5000 /fsx/restore/

# 4. Resume training
python train.py --resume-from /fsx/restore/checkpoint-5000

Target Recovery Time Objective (RTO): 2 hours. Target Recovery Point Objective (RPO): 15 minutes (last checkpoint).


9.3.9. Incremental Checkpointing and Delta Compression

Saving full checkpoints every hour for a 70B model (1.1 TB each) is wasteful. Most parameters change very little between checkpoints.

Delta Checkpointing

Store only the difference between consecutive checkpoints.

Algorithm:

import torch
import numpy as np

def save_delta_checkpoint(prev_checkpoint, current_state, save_path):
    """
    Save only the diff between current state and previous checkpoint.
    """
    delta = {}

    for key in current_state:
        if key in prev_checkpoint:
            # Compute difference
            diff = current_state[key] - prev_checkpoint[key]

            # Sparsify: Only store values > threshold
            mask = torch.abs(diff) > 1e-6
            sparse_diff = diff * mask

            delta[key] = {
                "sparse_values": sparse_diff[mask],
                "indices": mask.nonzero(as_tuple=True),
                "shape": diff.shape,
            }
        else:
            # New parameter (e.g., added layer)
            delta[key] = current_state[key]

    torch.save(delta, save_path)
    return delta

def load_delta_checkpoint(base_checkpoint, delta_path):
    """
    Reconstruct checkpoint by applying delta to base.
    """
    delta = torch.load(delta_path)
    reconstructed = {}

    for key in delta:
        if isinstance(delta[key], dict) and "sparse_values" in delta[key]:
            # Reconstruct sparse diff
            base_tensor = base_checkpoint[key]
            sparse_values = delta[key]["sparse_values"]
            indices = delta[key]["indices"]

            diff_tensor = torch.zeros_like(base_tensor)
            diff_tensor[indices] = sparse_values

            reconstructed[key] = base_tensor + diff_tensor
        else:
            # New parameter
            reconstructed[key] = delta[key]

    return reconstructed

Compression Ratio: For typical LLM training, parameters change by ~0.01-0.1% per step.

  • Full checkpoint: 1.1 TB.
  • Delta checkpoint: ~10-50 GB (95% reduction).

Trade-off: Reconstruction requires the base checkpoint. If the base is corrupted, all deltas are useless.

Hybrid Strategy:

  • Every 10 steps: Save delta checkpoint.
  • Every 100 steps: Save full checkpoint (new base).

9.3.10. Checkpoint Validation and Corruption Detection

Silent data corruption (e.g., bit flips in S3, filesystem bugs) can corrupt checkpoints without immediate detection.

Checksum Validation

Compute a cryptographic hash of each checkpoint and store it alongside the data.

Implementation:

import hashlib
import torch

def save_checkpoint_with_hash(state_dict, save_path):
    """
    Save checkpoint with SHA256 checksum.
    """
    # Save checkpoint
    torch.save(state_dict, save_path)

    # Compute hash
    sha256 = hashlib.sha256()
    with open(save_path, "rb") as f:
        while chunk := f.read(8192):
            sha256.update(chunk)

    hash_value = sha256.hexdigest()

    # Save hash to sidecar file
    with open(f"{save_path}.sha256", "w") as f:
        f.write(hash_value)

    return hash_value

def verify_checkpoint(checkpoint_path):
    """
    Verify checkpoint integrity using stored hash.
    """
    # Compute current hash
    sha256 = hashlib.sha256()
    with open(checkpoint_path, "rb") as f:
        while chunk := f.read(8192):
            sha256.update(chunk)
    current_hash = sha256.hexdigest()

    # Load expected hash
    with open(f"{checkpoint_path}.sha256", "r") as f:
        expected_hash = f.read().strip()

    if current_hash != expected_hash:
        raise ValueError(f"Checkpoint corrupted! Expected {expected_hash}, got {current_hash}")

    return True

S3 Object Lock: For compliance, use S3 Object Lock to make checkpoints immutable (cannot be deleted or modified for a retention period).

resource "aws_s3_bucket_object_lock_configuration" "checkpoints" {
  bucket = aws_s3_bucket.checkpoints.id

  rule {
    default_retention {
      mode = "GOVERNANCE"  # Can be overridden by root user
      days = 30
    }
  }
}

9.3.11. Training Resumption Testing: The “Resume Benchmark”

A checkpoint is useless if you can’t resume from it. Test resume functionality regularly.

Automated Resume Test

Goal: Verify that a resumed run produces identical results to a continuous run (within numerical tolerance).

Test Script:

import torch
import random
import numpy as np

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def test_deterministic_resume():
    """
    Train for 100 steps, save checkpoint at step 50, resume, and verify results.
    """
    # Run 1: Train 0-100 continuously
    set_seed(42)
    model1 = MyModel()
    optimizer1 = AdamW(model1.parameters())

    losses_continuous = []
    for step in range(100):
        loss = train_step(model1, get_batch(step))
        losses_continuous.append(loss.item())
        optimizer1.step()

    # Run 2: Train 0-50, checkpoint, then resume 50-100
    set_seed(42)
    model2 = MyModel()
    optimizer2 = AdamW(model2.parameters())

    losses_resumed = []
    for step in range(50):
        loss = train_step(model2, get_batch(step))
        losses_resumed.append(loss.item())
        optimizer2.step()

    # Checkpoint at step 50
    checkpoint = {
        "model": model2.state_dict(),
        "optimizer": optimizer2.state_dict(),
        "step": 50,
        "rng_state": torch.get_rng_state(),
        "cuda_rng_state": torch.cuda.get_rng_state(),
    }
    torch.save(checkpoint, "checkpoint_step50.pt")

    # Resume from checkpoint
    checkpoint = torch.load("checkpoint_step50.pt")
    model2.load_state_dict(checkpoint["model"])
    optimizer2.load_state_dict(checkpoint["optimizer"])
    torch.set_rng_state(checkpoint["rng_state"])
    torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])

    for step in range(50, 100):
        loss = train_step(model2, get_batch(step))
        losses_resumed.append(loss.item())
        optimizer2.step()

    # Compare
    for i, (loss_cont, loss_res) in enumerate(zip(losses_continuous, losses_resumed)):
        assert abs(loss_cont - loss_res) < 1e-6, f"Step {i}: {loss_cont} != {loss_res}"

    print("Resume test PASSED: Resumed run is identical to continuous run.")

Run this test:

  • Before every major training run.
  • After upgrading PyTorch, CUDA, or NCCL.

9.3.12. Chaos Engineering for Distributed Training

Proactively inject failures to test resilience.

Chaos Experiment 1: Random Node Termination

Tool: Chaos Mesh (Kubernetes) or custom script.

Experiment:

apiVersion: chaos-mesh.org/v1alpha1
kind: PodChaos
metadata:
  name: kill-random-worker
  namespace: ml-training
spec:
  action: pod-kill
  mode: one
  selector:
    namespaces:
      - ml-training
    labelSelectors:
      app: distributed-training
  scheduler:
    cron: "*/30 * * * *"  # Kill one pod every 30 minutes

Expected Behavior: Training should pause, detect the failure, and resume from the last checkpoint without manual intervention.

Chaos Experiment 2: Network Partition

Simulate a network split where nodes can’t communicate.

apiVersion: chaos-mesh.org/v1alpha1
kind: NetworkChaos
metadata:
  name: partition-network
spec:
  action: partition
  mode: all
  selector:
    namespaces:
      - ml-training
    labelSelectors:
      app: distributed-training
  direction: both
  duration: "5m"

Expected Behavior: NCCL should timeout, training should crash, and watchdog should restart from checkpoint.

Chaos Experiment 3: Disk Corruption (Checkpoint Storage)

Corrupt a checkpoint file and verify detection.

#!/bin/bash
# Inject bit flip in checkpoint file
CHECKPOINT="/fsx/checkpoints/run-42/checkpoint-1000/model.pt"
dd if=/dev/urandom of=$CHECKPOINT bs=1 count=100 seek=$RANDOM conv=notrunc

Expected Behavior: On load, checksum validation should fail, and the system should fall back to the previous checkpoint.


9.3.13. Cost of Fault Tolerance: The Insurance Premium

Fault tolerance is not free. It’s an insurance policy. You pay upfront (in time and money) to reduce the risk of catastrophic loss.

Cost Breakdown for 30-Day Training Run (70B Model, 100 Nodes)

ItemWithout FTWith FTOverhead
Compute (100 nodes × $32.77/hr × 720 hrs)$2,359,440$2,359,4400%
Checkpoint I/O (pause training to write)$0~5% time penalty~$118,000
Storage (FSx + S3 replication)$1,400$3,500$2,100
Monitoring (DCGM, Prometheus)$0$500$500
Spot Interruption Losses (assume 3 interruptions)$0 (N/A on On-Demand)$2,000$2,000
Expected Loss from Failure (10% chance of catastrophic failure)$235,944$0-$235,944
Total Expected Cost$2,596,784$2,485,540-$111,244

Conclusion: Fault tolerance is not just a risk mitigation strategy; it’s also a cost optimization strategy. The expected cost is lower with FT due to reduced failure risk.


9.3.14. Checkpoint Formats: Trade-offs and Best Practices

Different checkpoint formats have different performance characteristics.

Format Comparison

FormatSizeWrite SpeedRead SpeedCompatibilityUse Case
PyTorch .pt (Pickle)MediumFastFastPyTorch onlyStandard choice
SafetensorsSmallVery FastVery FastMulti-frameworkRecommended for production
HDF5MediumMediumMediumUniversalLegacy systems
NumPy .npzLargeSlowSlowUniversalDebugging/inspection
TensorFlow CheckpointLargeMediumMediumTensorFlowIf using TF

Safetensors: The Modern Standard

Safetensors is a new format developed by Hugging Face. It’s faster, safer, and more portable than Pickle.

Advantages:

  • Security: No arbitrary code execution (Pickle can run malicious code).
  • Speed: Zero-copy memory mapping (faster load).
  • Lazy Loading: Load only needed tensors (useful for inference).

Installation:

pip install safetensors

Usage:

from safetensors.torch import save_file, load_file

# Save
state_dict = model.state_dict()
save_file(state_dict, "checkpoint.safetensors")

# Load
state_dict = load_file("checkpoint.safetensors")
model.load_state_dict(state_dict)

Migration from .pt to safetensors:

import torch
from safetensors.torch import save_file

# Load old checkpoint
old_checkpoint = torch.load("checkpoint.pt")

# Save in new format
save_file(old_checkpoint["model"], "checkpoint.safetensors")

Recommendation: Use Safetensors for all new projects. Convert existing .pt checkpoints during the next training run.


9.3.15. Final Checklist: Production-Ready Fault Tolerance

Before launching a multi-million dollar training run, verify:

1. Checkpointing:

  • Checkpoints are sharded (DCP or FSDP state dict).
  • Checkpoint interval is optimized (Young’s formula).
  • Checkpoints are written to high-speed storage (FSx/Parallelstore).
  • Checksums are computed and verified.

2. Backup and Replication:

  • Checkpoints are replicated to S3 (or GCS).
  • Multi-region replication is enabled for critical runs.
  • Retention policy is configured (tiered storage).

3. Failure Detection:

  • DCGM Exporter is deployed on all nodes.
  • Prometheus alerts are configured for GPU health, network errors, and training stalls.
  • Automated remediation (node replacement) is set up.

4. Resumption:

  • Resume logic is tested (deterministic resume test passed).
  • Dataloader state is saved and restored.
  • RNG states are saved and restored.

5. Spot Resilience (if using Spot):

  • Spot interruption handler is running.
  • Emergency checkpoint on interruption is implemented.
  • Mixed On-Demand + Spot fleet is configured.

6. Monitoring:

  • Training metrics are logged (loss, throughput, GPU utilization).
  • Dashboards are created (Grafana or CloudWatch).
  • Alerts are routed to on-call engineers (PagerDuty, Slack).

7. Chaos Testing:

  • Node termination chaos experiment passed.
  • Network partition chaos experiment passed.
  • Checkpoint corruption detection tested.

If all boxes are checked: You are ready for production.


9.3.16. Summary: The Resilience Checklist

When architecting fault tolerance for distributed training:

  1. Checkpoint Religiously: Use sharded checkpoints (DCP). Write to high-speed storage (FSx/Parallelstore).
  2. Optimize Checkpoint Interval: Use Young’s formula. Balance I/O cost vs. recompute cost.
  3. Embrace Spot: Use hybrid On-Demand + Spot. Implement interruption handlers.
  4. Monitor GPUs: Deploy DCGM. Alert on ECC errors, temperature, and training stalls.
  5. Detect NaN Early: Use gradient hooks and clipping. Don’t let poison spread.
  6. Automate Recovery: Use Elastic Training (TorchElastic) for node failures. Auto-replace unhealthy instances.
  7. Manage Checkpoint Bloat: Implement tiered retention. Use S3 lifecycle policies.

In the next chapter, we will discuss Model Serving and Inference Optimization, where the challenges shift from throughput (training) to latency (serving) and cost-per-token economics.