Chapter 15: Distributed Training Strategies
15.1. Parallelism: The Physics of Scale
“The quantity of meaning typically created by a neural network is roughly proportional to the square root of the number of floating-point operations used to train it… assuming you can fit it in memory.” — The Scaling Hypothesis (paraphrased)
In the early days of Deep Learning (circa 2012-2016), a “large” model fit comfortably onto a single NVIDIA K80 GPU. The primary engineering challenge was algorithmic: vanishing gradients, initialization schemes, and hyperparameter tuning.
Today, the primary challenge is physics.
Modern Foundation Models (LLMs) and large Computer Vision models have sizes that physically exceed the VRAM capacity of any single piece of silicon in existence.
- Llama-3-70B: In FP16 precision, the parameters alone require ~140GB of memory. An NVIDIA H100 has 80GB. You literally cannot load the model to print its summary, let alone train it.
- The Training Multiplier: To train a model, you need not just the parameters, but the gradients (same size), the optimizer states (often double the size), and the activations (intermediate outputs).
For the Architect and Principal Engineer, distributed training is no longer an optional optimization for speed; it is a hard requirement for existence.
This chapter dissects the taxonomy of parallelism. We will move beyond the high-level buzzwords (“Data Parallel”, “Model Parallel”) and examine the exact memory layouts, communication primitives, and bandwidth requirements that dictate whether your training run finishes in 3 weeks or crashes in 3 milliseconds.
9.1.0. The Memory Equation: Why We Go Parallel
Before selecting a strategy, we must quantify the enemy. Why exactly do we run out of memory (OOM)?
The total memory $M_{total}$ required to train a model with $\Phi$ parameters using the Adam optimizer can be approximated as:
$$ M_{total} = M_{model} + M_{grad} + M_{opt} + M_{act} + M_{frag} $$
Where:
- $M_{model}$ (Parameters):
- In 16-bit precision (FP16/BF16): $2 \times \Phi$ bytes.
- Example: 7B model $\approx$ 14 GB.
- $M_{grad}$ (Gradients):
- Stores the gradient with respect to every parameter. Same precision.
- Size: $2 \times \Phi$ bytes.
- Example: 7B model $\approx$ 14 GB.
- $M_{opt}$ (Optimizer States):
- Standard Adam maintains the momentum ($m$) and variance ($v$) for every parameter.
- These are typically stored in FP32 (Single Precision) for numerical stability, even if weights are FP16 (Mixed Precision training).
- Size: $4 \text{ bytes (FP32)} \times 2 \text{ states} \times \Phi = 8 \times \Phi$ bytes.
- Example: 7B model $\approx$ 56 GB.
- $M_{act}$ (Activations):
- The intermediate outputs of every layer, needed for the backward pass (chain rule).
- Scales linearly with Batch Size ($B$) and Sequence Length ($S$).
- $M_{act} \propto B \times S \times \text{HiddenDim} \times \text{Layers}$.
- Note: This can often exceed model size for long contexts.
- $M_{frag}$ (Fragmentation):
- Inefficiencies in the CUDA memory allocator (caching overhead).
The 7B Parameter Reality Check: Summing up the static requirement (Weights + Gradients + Optimizer) for a 7B model: $$ 14 + 14 + 56 = 84 \text{ GB} $$
Verdict: You cannot fine-tune a 7B model on a single A100 (80GB) using standard Adam without advanced techniques. You are OOM before the first forward pass begins.
To solve this, we split the problem. How we split it determines the “Parallelism Strategy.”
9.1.1. Data Parallelism (DP) & Distributed Data Parallel (DDP)
Data Parallelism is the simplest, most robust, and most common form of distributed training. It assumes the model fits entirely on a single device, but the data is too large to process quickly enough.
The Architecture
- Replication: The full model is copied to every GPU in the cluster (Rank 0 to Rank $N$).
- Scatter: The global batch of data (e.g., 1024 images) is split into mini-batches (e.g., 32 images per GPU).
- Forward/Backward: Each GPU computes gradients on its local slice of data independently.
- Synchronization (AllReduce): Before the optimizer step, all GPUs must agree on the average gradient.
- Update: Every GPU updates its local weights identically. They remain synchronized bit-for-bit.
The Communication Primitive: Ring AllReduce
The naive approach to synchronization is a “Parameter Server” (all GPUs send gradients to a central node, which averages them and sends them back). This creates a massive bandwidth bottleneck at the central node.
Modern DDP uses Ring AllReduce.
- Topology: GPUs are logically arranged in a ring.
- Step 1 (Scatter-Reduce): GPU $k$ sends a chunk of its gradients to GPU $k+1$ while receiving from $k-1$. After $N-1$ steps, every GPU has a chunk of the summed gradients.
- Step 2 (AllGather): GPUs circulate the summed chunks until everyone has the full summed gradient vector.
- Bandwidth Efficiency: The bandwidth required is constant regardless of the number of GPUs.
The Python Implementation (PyTorch DDP)
In modern PyTorch, DistributedDataParallel moves the gradient synchronization into the backward pass buckets. As layers finish backprop, their gradients are immediately transmitted, overlapping computation with communication.
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def train(rank, world_size):
# 1. Initialize Process Group (NCCL backend is mandatory for GPUs)
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# 2. Bind model to local GPU
model = MyTransformer().to(rank)
# 3. Wrap with DDP
# This registers hooks to trigger AllReduce during .backward()
ddp_model = DDP(model, device_ids=[rank])
# 4. Use DistributedSampler to ensure each GPU gets different data
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = criterion(outputs, labels)
# 5. Magic happens here:
# Gradients are computed locally.
# As buckets fill up, they are AllReduced across the cluster asynchronously.
loss.backward()
# By the time we hit step(), gradients are synced.
optimizer.step()
dist.destroy_process_group()
The Bottleneck
DDP is Network Bound.
- The amount of data transmitted per step is proportional to the Model Size, not the Batch Size.
- If you have a slow interconnect (e.g., standard 10Gbps Ethernet), the GPUs will spend more time waiting for gradients to arrive than computing math.
- AWS Implication: For large models, use instances with EFA (Elastic Fabric Adapter) like
p4d.24xlarge(400 Gbps). - GCP Implication: Use Fast Socket and compact placement policies.
9.1.2. Breaking the Memory Wall: ZeRO and FSDP
DDP has a fatal flaw: Memory Redundancy. If you have 16 GPUs, you store 16 identical copies of the weights, 16 identical copies of the optimizer states, and 16 identical copies of the gradients. For large models, this is a colossal waste of VRAM.
ZeRO (Zero Redundancy Optimizer), popularized by Microsoft DeepSpeed and implemented natively in PyTorch as FSDP (Fully Sharded Data Parallel), solves this by sharding the model states across GPUs.
The Three Stages of ZeRO
ZeRO trades Communication for Memory.
-
ZeRO-1 (Optimizer Sharding):
- Concept: Every GPU holds the full parameters and gradients, but only updates a subset (1/N) of the optimizer states.
- Mechanism: At the end of the step, gradients are reduced to the specific GPU responsible for that slice of the optimizer. That GPU updates its slice of weights, then broadcasts the updated weights to everyone.
- Memory Savings: $4\times$ reduction (removes the massive optimizer state redundancy).
-
ZeRO-2 (Gradient Sharding):
- Concept: Shard the gradients as well. Each GPU only holds gradients for the slice of parameters it updates.
- Memory Savings: $8\times$ reduction combined with Stage 1.
-
ZeRO-3 (Parameter Sharding) / FSDP:
- Concept: Shard everything. The full model does not exist on any single GPU.
- Mechanism:
- Forward Pass: When GPU 1 needs Layer 3 to compute, it fetches the weights for Layer 3 from GPUs 2…N. It computes the output, then immediately discards the weights to free memory.
- Backward Pass: Same fetch-compute-discard pattern.
- Memory Savings: Linear reduction with $N$ GPUs. You can train a 1T parameter model if you just add enough GPUs.
- Cost: Massive communication overhead. Every forward/backward pass requires reconstructing the model over the network.
PyTorch FSDP Implementation
FSDP is the de-facto standard for fine-tuning LLMs (e.g., Llama 2/3) on AWS/GCP today.
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Define a policy to wrap each Transformer Block individually
# This allows FSDP to clear memory for Block 1 while computing Block 2
llama_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
# Mixed Precision Policy (Weights in FP32, Compute in BF16)
bf16_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
model = FSDP(
model,
auto_wrap_policy=llama_auto_wrap_policy,
# FULL_SHARD = ZeRO-3 (Shard params, grads, opt)
# SHARD_GRAD_OP = ZeRO-2 (Shard grads, opt; keep params replicated)
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=bf16_policy,
device_id=torch.cuda.current_device(),
)
# The rest of the training loop looks identical to DDP
FSDP vs. DDP Decision Matrix
- Model < 2GB: Use DDP. It’s faster (less communication).
- Model fits in VRAM but tight: Use ZeRO-2 (FSDP
SHARD_GRAD_OP). - Model > VRAM: Use ZeRO-3 (FSDP
FULL_SHARD). This enables training 70B models on A100s.
9.1.3. Tensor Parallelism (TP): “Slicing the Brain”
ZeRO/FSDP shards data and states, but the computation of a single layer is still monolithic. What if a single matrix multiplication is so large it takes too long, or the weight matrix itself is larger than VRAM?
Tensor Parallelism (pioneered by NVIDIA’s Megatron-LM) splits the individual tensors (matrices) across GPUs. This is “Intra-Layer” parallelism.
The Mathematics of Splitting
Consider a standard Linear Layer: $Y = XA$.
- $X$: Input vector ($1 \times D_{in}$).
- $A$: Weight matrix ($D_{in} \times D_{out}$).
If we have 2 GPUs, we can split $A$ in two ways:
1. Column Parallelism Split $A$ vertically into $A_1$ and $A_2$. $$ A = [A_1 | A_2] $$ GPU 1 computes $Y_1 = X A_1$. GPU 2 computes $Y_2 = X A_2$. The output $Y$ is the concatenation $[Y_1, Y_2]$.
- Communication: Each GPU needs the full input $X$ (Broadcast). At the end, we need to gather parts of $Y$ (AllGather).
2. Row Parallelism Split $A$ horizontally. $$ A = \begin{bmatrix} A_1 \ A_2 \end{bmatrix} $$ Split input $X$ horizontally into $X_1, X_2$. GPU 1 computes $Y_1 = X_1 A_1$. GPU 2 computes $Y_2 = X_2 A_2$. The output $Y = Y_1 + Y_2$.
- Communication: We need to sum the results (AllReduce).
The Megatron-LM Transformer Block
Megatron efficiently combines these to minimize communication.
- Attention Layer: Uses Column Parallelism for $Q, K, V$ projections. The heads are split across GPUs.
- Output Projection: Uses Row Parallelism.
- MLP Layer: Uses Column Parallelism for the first expansion layer ($4h$) and Row Parallelism for the reduction layer.
The “f” and “g” Operators: In TP code, you will see special identity operators that trigger communication during backprop.
- $f$: Forward = Identity (Pass); Backward = AllReduce.
- $g$: Forward = AllReduce; Backward = Identity.
The Cost of TP
TP requires blocking communication in the middle of the forward pass.
- Layer 1 Part A cannot finish until Layer 1 Part B sends its partial sum.
- This requires extremely low latency.
- Architectural Rule: TP should ONLY be used within a single node (NVLink). Never span TP across Ethernet/Infiniband. The latency penalty will destroy performance.
9.1.4. Pipeline Parallelism (PP): “The Assembly Line”
If a model is too deep (too many layers) to fit on one GPU, we can stack the GPUs vertically.
- GPU 0: Layers 1-8
- GPU 1: Layers 9-16
- …
- GPU 3: Layers 25-32
This is Pipeline Parallelism.
The Bubble Problem
The naive implementation is synchronous:
- GPU 0 processes Batch A. GPU 1, 2, 3 are idle.
- GPU 0 sends activations to GPU 1.
- GPU 1 processes Batch A. GPU 0, 2, 3 are idle.
This results in huge “Bubbles” (idle time). In a naive setup with 4 GPUs, utilization is only ~25%.
Solution: Micro-Batching (GPipe / 1F1B)
To reduce bubbles, we split the global batch into “Micro-Batches” (e.g., Global Batch 1024 -> 4 micro-batches of 256).
1F1B (One Forward, One Backward) Schedule: Instead of waiting for all forward passes to finish, a GPU starts the backward pass for Micro-Batch 1 as soon as possible, interleaving forward/backward steps to keep the pipeline full.
- Memory Impact: PP reduces memory per GPU because each GPU only holds parameters for $1/N$ layers.
- Communication: Only happens at the boundaries (GPU 0 sends to GPU 1). This is low bandwidth compared to TP or DDP.
- Architectural Rule: PP is excellent for Inter-Node parallelism because it tolerates higher latency (Ethernet) better than TP.
9.1.5. Sequence Parallelism (SP) and the “Long Context” Era
With the advent of RAG (Retrieval Augmented Generation) and “Context Windows” of 128k+ tokens (e.g., Claude 3, GPT-4 Turbo), the activations ($M_{act}$) become the dominant memory consumer, surpassing parameters.
Standard TP splits the hidden dimension. Sequence Parallelism splits the Sequence Length dimension ($S$).
Ring Attention
If Sequence Length = 100k, we cannot compute the $Attention(Q, K, V)$ matrix ($S \times S$) on one GPU. Ring Attention allows computing attention by passing blocks of Key/Value tensors around a ring of GPUs, computing partial attention scores, and updating the maximums (using the FlashAttention trick) without ever materializing the full $S \times S$ matrix.
This is critical for “Infinite Context” architectures.
9.1.6. 3D Parallelism: The Grand Unified Theory
To train a state-of-the-art model (e.g., 175B+ parameters) on a cluster of thousands of GPUs (e.g., AWS P4d or P5 instances), we combine all three methods. This is 3D Parallelism.
The goal is to map the parallelism type to the hardware topology to minimize communication cost.
The Mapping Strategy
Imagine a cluster of 100 nodes, each with 8 GPUs (800 GPUs total).
- Intra-Node (Fastest, NVLink 600GB/s): Use Tensor Parallelism (TP).
- Set $TP_Degree = 8$. Ideally, the entire model width fits on one node.
- Inter-Node (Fast, EFA/Infiniband 400Gbps): Use Pipeline Parallelism (PP).
- Split the model depth across nodes. $PP_Degree = 4$.
- Outer Loop (Slowest, but robust): Use Data Parallelism (DP).
- Replicate the entire TP+PP pipeline.
- $DP_Degree = \frac{Total GPUs}{TP \times PP} = \frac{800}{8 \times 4} = 25$.
The ds_config.json (DeepSpeed) Example
DeepSpeed allows configuring this 3D layout via JSON.
{
"train_batch_size": 2048,
"train_micro_batch_size_per_gpu": 4,
"steps_per_print": 10,
"zero_optimization": {
"stage": 1, // Usually Stage 1 is enough if using 3D Parallelism
"reduce_bucket_size": 5e8
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000
},
"gradient_clipping": 1.0,
"prescale_gradients": false,
"wall_clock_breakdown": false
}
Note: The actual TP/PP degrees are usually flags passed to the launch script (e.g., Megatron-DeepSpeed launcher), not just the JSON config.
9.1.7. Hardware Specifics: AWS vs. GCP
The choice of cloud provider dictates your parallelism constraints.
AWS: The “Explicit Network” Approach
- Instance:
p4d.24xlarge(8x A100) orp5.48xlarge(8x H100). - Fabric: AWS uses EFA (Elastic Fabric Adapter). It bypasses the OS kernel (Libfabric) to provide low-latency communication.
- Optimization: You must install the AWS OFI NCCL plugin. Without this, PyTorch Distributed will try to use TCP sockets over Ethernet, and your AllReduce performance will drop by 10-50x.
- Topology: AWS clusters are often built in “Placement Groups” (Cluster strategy) to ensure physical proximity.
GCP: The “Transparent Fabric” Approach
- Instance:
a3-highgpu(8x H100) or TPU Pods. - Fabric: GCP uses Jupiter networking and specialized “Rail-aligned” designs for H100 clusters.
- TPU Interconnect: If using TPUs (v4/v5), the interconnect is a 3D Torus mesh that is significantly faster than Ethernet. TPUs support “GSPMD” (General and Scalable Parallelization for ML), which allows writing code as if it were single-device, and the XLA compiler handles the sharding automatically.
- Optimization: On GCP GPUs, use GPUDirect RDMA (via NCCL fast socket) to allow GPUs to talk to NICs directly without CPU involvement.
9.1.7. Activation Checkpointing: Trading Compute for Memory
Even with FSDP, the activations ($M_{act}$) can dominate memory usage, especially for large batch sizes or long sequences. Activation Checkpointing (also called Gradient Checkpointing) is a technique to dramatically reduce activation memory at the cost of recomputation.
The Mechanism
During the forward pass, instead of storing all intermediate activations, we only store activations at specific “checkpoint” layers.
During the backward pass, when we need the activations of a non-checkpointed layer, we recompute them by running a mini forward pass from the last checkpoint.
Memory-Compute Trade-off:
- Without Checkpointing: Store all $L$ layers of activations. Memory: $O(L)$.
- With Checkpointing: Store every $\sqrt{L}$ layers. Memory: $O(\sqrt{L})$. Compute: $1.5\times$ (50% overhead).
For a 32-layer Transformer:
- Normal: Store 32 sets of activations.
- Checkpointed: Store ~6 checkpoint boundaries. Save ~80% of activation memory.
PyTorch Implementation
import torch.utils.checkpoint as checkpoint
class CheckpointedTransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.mlp = MLP(config)
self.ln1 = LayerNorm(config.hidden_size)
self.ln2 = LayerNorm(config.hidden_size)
def forward(self, x):
# Use gradient checkpointing for this block
# PyTorch will not store intermediate activations
# They will be recomputed during backward pass
return checkpoint.checkpoint(self._forward_impl, x, use_reentrant=False)
def _forward_impl(self, x):
x = x + self.attention(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
FSDP Integration: When using FSDP, activation checkpointing is applied per wrapped block.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
model = MyTransformer()
# Wrap each block with FSDP
model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrap_policy,
)
# Apply activation checkpointing to specific layers
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=lambda submodule: checkpoint_wrapper(
submodule,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
check_fn=lambda submodule: isinstance(submodule, TransformerBlock),
)
When to Use:
- Your model fits in VRAM, but you want to increase batch size.
- Training very long sequences (>4k tokens) where activations explode.
- You have abundant compute but limited memory (older GPUs like V100 16GB).
When NOT to Use:
- If your training is already bottlenecked by GPU compute (low utilization). Adding 50% recompute overhead will make it worse.
- If you’re using Tensor Parallelism, activation checkpointing can interact poorly with communication patterns.
9.1.8. Gradient Accumulation: Simulating Larger Batches
Modern LLM training often requires enormous batch sizes (e.g., 4 million tokens per batch for Llama 2). No GPU cluster can fit this in memory in a single step.
Gradient Accumulation solves this by splitting the logical batch into micro-batches, accumulating gradients across multiple forward/backward passes, then stepping the optimizer once.
The Algorithm
optimizer.zero_grad()
# Logical batch size = 1024, but VRAM only fits 32
micro_batch_size = 32
accumulation_steps = 1024 // (micro_batch_size * world_size)
for i, batch in enumerate(dataloader):
# Forward pass
outputs = model(batch)
loss = outputs.loss / accumulation_steps # Scale loss
# Backward pass (gradients accumulate)
loss.backward()
# Only step optimizer every N accumulation steps
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
The Gradient Scaling Trap
A common bug: forgetting to scale the loss by 1/accumulation_steps. If you don’t scale:
- Gradients become
accumulation_stepstimes larger. - Learning rate effectively becomes
lr * accumulation_steps. - Training diverges or converges to suboptimal solution.
DDP and Gradient Accumulation
In standard DDP, gradients are synchronized on every .backward() call, even if you’re accumulating. This wastes bandwidth.
Solution: Use no_sync() context:
from torch.nn.parallel import DistributedDataParallel as DDP
ddp_model = DDP(model)
for i, batch in enumerate(dataloader):
# Disable gradient synchronization for accumulation steps
if (i + 1) % accumulation_steps != 0:
with ddp_model.no_sync():
loss = ddp_model(batch).loss / accumulation_steps
loss.backward()
else:
# Final step: allow synchronization
loss = ddp_model(batch).loss / accumulation_steps
loss.backward() # AllReduce happens here
optimizer.step()
optimizer.zero_grad()
FSDP and Gradient Accumulation
FSDP handles this more elegantly. You simply wrap the accumulation logic, and FSDP will only synchronize on the final step.
Memory Implication: Gradient accumulation does not reduce peak memory significantly. You still need to store activations for each micro-batch during backward. It’s primarily a tool for achieving large effective batch sizes, not for fitting larger models.
9.1.9. CPU Offloading: The Last Resort
When a model is so large that even FSDP with full sharding cannot fit it, you can offload parameters and optimizer states to CPU RAM.
The Hierarchy of Memory
| Memory Type | Capacity | Bandwidth | Latency |
|---|---|---|---|
| GPU HBM (A100) | 80 GB | 2 TB/s | ~100 ns |
| CPU RAM | 1-2 TB | 200 GB/s | ~1 μs |
| NVMe SSD | 4-8 TB | 7 GB/s | ~100 μs |
CPU offloading moves data from GPU to CPU between forward/backward passes.
DeepSpeed ZeRO-Infinity (Offload to CPU/NVMe)
DeepSpeed ZeRO-Infinity extends ZeRO-3 to use CPU RAM and even NVMe SSDs.
Configuration (ds_config.json):
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true // Use pinned memory for faster PCIe transfers
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6
},
"train_batch_size": 16,
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 16,
"fp16": {
"enabled": true
}
}
The Performance Penalty:
- PCIe Bandwidth: The link between CPU and GPU is typically PCIe Gen4 x16 (~32 GB/s). This is 60x slower than HBM.
- Implication: Training slows down by 5-10x compared to pure GPU training.
When to Use:
- Fine-tuning massive models (70B+) on a single node with large CPU RAM.
- Prototyping before committing to multi-node infrastructure.
- Budget constraints (cheaper to use large CPU RAM than rent 8x H100s).
When NOT to Use:
- Production training at scale. Multi-node FSDP without offloading is faster and more cost-effective.
QLoRA: Quantization + Offloading
An alternative to full-precision offloading is QLoRA (Quantized Low-Rank Adaptation).
Instead of offloading FP16/FP32 weights to CPU, you:
- Load the base model in 4-bit or 8-bit quantization (reduces memory by 4-8x).
- Freeze the base model.
- Train small “adapter” layers (LoRA) in FP16.
Memory Savings: A 70B model in 4-bit requires ~35 GB (fits on a single A100). The adapter layers are tiny (<1 GB).
Use Case: Fine-tuning Llama-2-70B on a single A100 for domain adaptation.
Library: Hugging Face bitsandbytes + peft.
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
# Load model in 4-bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
quantization_config=bnb_config,
device_map="auto",
)
# Add LoRA adapters
lora_config = LoraConfig(
r=16, # Rank of the adaptation matrix
lora_alpha=32,
target_modules=["q_proj", "v_proj"], # Which layers to adapt
lora_dropout=0.1,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters() # Only ~0.5% of params are trainable
9.1.10. Debugging Distributed Training: Common Failure Modes
Distributed training introduces failure modes that don’t exist in single-GPU training.
1. Deadlock: Mismatched Collectives
Symptom: Training hangs indefinitely. No error message. All GPUs at 0% utilization.
Cause: One rank hits an AllReduce, but another rank doesn’t (e.g., due to a conditional).
# BAD CODE (Will Deadlock)
if rank == 0:
loss = model(batch)
loss.backward() # Triggers AllReduce
# Rank 1 never calls backward, so AllReduce never completes.
Fix: Ensure all ranks execute collective operations (AllReduce, Broadcast, Barrier) together.
2. Gradient Divergence: Non-Deterministic Ops
Symptom: Loss diverges or fluctuates wildly. Different ranks produce different losses for the same input.
Cause: Non-deterministic operations (e.g., torch.nn.functional.dropout without a fixed seed).
Fix: Set seeds on all ranks.
def set_seed(seed, rank):
torch.manual_seed(seed + rank)
torch.cuda.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
3. NCCL Timeout
Symptom: RuntimeError: NCCL error: unhandled system error. Training crashes after several minutes.
Cause: Network packet loss or a straggler node.
Debug:
- Set
export NCCL_DEBUG=INFOto see detailed logs. - Check for network errors:
dmesg | grep -i error. - Run
nccl-teststo isolate the bad node.
Fix: Replace the faulty node or increase timeout.
import os
os.environ["NCCL_TIMEOUT"] = "7200" # 2 hours
4. OOM on One Rank Only
Symptom: Rank 3 crashes with OOM, but ranks 0, 1, 2 are fine.
Cause: Imbalanced data (e.g., Rank 3 gets the longest sequences).
Fix: Use padding and bucketing in the dataloader to equalize sequence lengths per batch.
5. Slow Startup (Rank 0 Initialization Bottleneck)
Symptom: Rank 0 takes 10 minutes to initialize, while ranks 1-7 wait idle.
Cause: Rank 0 is downloading the model from Hugging Face Hub, while others wait.
Fix: Pre-download the model to shared storage (EFS/FSx), or use torch.distributed.barrier() strategically.
if rank == 0:
# Download model
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
model.save_pretrained("/shared/models/llama-2-7b")
dist.barrier() # Wait for rank 0 to finish
# All ranks load from shared storage
model = AutoModel.from_pretrained("/shared/models/llama-2-7b")
9.1.11. Mixed Precision Training: The BF16 vs. FP16 Debate
Mixed precision training (using 16-bit floats for speed, 32-bit for accuracy) is standard practice. But choosing between BFloat16 (BF16) and Float16 (FP16) has profound implications.
The Numeric Formats
FP32 (Single Precision):
- Sign: 1 bit, Exponent: 8 bits, Mantissa: 23 bits.
- Range: ~$10^{-38}$ to $10^{38}$.
- Precision: ~7 decimal digits.
FP16 (Half Precision):
- Sign: 1 bit, Exponent: 5 bits, Mantissa: 10 bits.
- Range: ~$10^{-4}$ to $6.5 \times 10^{4}$.
- Precision: ~3 decimal digits.
- Problem: Narrow range. Gradients smaller than $10^{-4}$ underflow to zero.
BF16 (Brain Float16):
- Sign: 1 bit, Exponent: 8 bits, Mantissa: 7 bits.
- Range: Same as FP32 (~$10^{-38}$ to $10^{38}$).
- Precision: ~2 decimal digits.
- Advantage: Same exponent range as FP32, so no underflow issues.
When to Use Which
Use FP16 if:
- Training CNNs (Computer Vision models). Activations are well-behaved.
- Using older GPUs (V100, P100) that have fast FP16 Tensor Cores but no BF16 support.
- You are willing to use loss scaling (see below).
Use BF16 if:
- Training Transformers (LLMs). Attention scores can have extreme ranges.
- Using modern GPUs (A100, H100) with native BF16 Tensor Core support.
- You want simplicity (no loss scaling required).
Automatic Mixed Precision (AMP) in PyTorch
PyTorch’s torch.cuda.amp module automates mixed precision.
Basic Usage:
from torch.cuda.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler() # For FP16 only; BF16 doesn't need scaling
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast(dtype=torch.bfloat16): # or torch.float16
outputs = model(batch)
loss = outputs.loss
# Backward pass (gradients in FP32)
scaler.scale(loss).backward()
# Unscale gradients before clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step with gradient scaling
scaler.step(optimizer)
scaler.update()
The Loss Scaling Trick (FP16 only):
To prevent gradient underflow, we multiply the loss by a large constant (e.g., 1024) before .backward(). This shifts small gradients into the representable range. Before the optimizer step, we unscale.
BF16 Simplification:
If using BF16, skip the GradScaler entirely:
with autocast(dtype=torch.bfloat16):
loss = model(batch).loss
loss.backward()
optimizer.step()
FSDP Mixed Precision Policy
When using FSDP, you specify precision per tensor type.
from torch.distributed.fsdp import MixedPrecision
# Compute in BF16, reduce (AllReduce) in BF16, store params in FP32
mp_policy = MixedPrecision(
param_dtype=torch.float32, # Master weights
reduce_dtype=torch.bfloat16, # Gradient communication
buffer_dtype=torch.bfloat16, # Buffers (e.g., LayerNorm running stats)
)
model = FSDP(model, mixed_precision=mp_policy)
Performance Impact:
- A100 BF16 Tensor Cores: 312 TFLOPS.
- A100 FP32 Tensor Cores: 19.5 TFLOPS.
- Speedup: ~16x for matrix operations.
9.1.12. Flash Attention: The Memory Breakthrough
Standard attention has a memory complexity of $O(N^2)$ where $N$ is sequence length. For a 128k token context, this requires 64 GB just for the attention matrix.
Flash Attention (by Dao et al., 2022) reduces memory to $O(N)$ while maintaining exact correctness.
The Standard Attention Bottleneck
# Standard Attention (Simplified)
Q = linear_q(x) # (batch, seq_len, head_dim)
K = linear_k(x)
V = linear_v(x)
# Problem: This matrix is seq_len x seq_len
scores = Q @ K.T / sqrt(head_dim) # (batch, seq_len, seq_len)
attn = softmax(scores, dim=-1)
out = attn @ V
For $N = 100,000$ tokens:
scoresmatrix: $100k \times 100k = 10^{10}$ elements.- At FP16: $10^{10} \times 2 \text{ bytes} = 20 \text{ GB}$.
This is stored in GPU HBM during the forward pass and needed again during backward.
Flash Attention: Tiling and Recomputation
Flash Attention never materializes the full $N \times N$ matrix. It:
- Splits $Q, K, V$ into tiles (e.g., 128 tokens per tile).
- Computes attention for one tile at a time, keeping only the output.
- During backward, recomputes the attention scores on-the-fly.
Trade-off: More FLOPs (recomputation), but drastically less memory.
Memory Savings:
- Standard Attention: $O(N^2)$ memory.
- Flash Attention: $O(N)$ memory.
For 100k tokens: Reduction from 20 GB to ~200 MB.
PyTorch Integration (Flash Attention 2)
As of PyTorch 2.0+, Flash Attention is integrated via F.scaled_dot_product_attention.
import torch.nn.functional as F
# Enable Flash Attention automatically (if supported by hardware)
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False, # Disable fallback to standard math
enable_mem_efficient=False,
):
output = F.scaled_dot_product_attention(Q, K, V)
Requirements:
- NVIDIA A100 or H100 (Ampere/Hopper architecture).
- CUDA 11.6+.
Fallback: On older GPUs (V100), PyTorch uses a memory-efficient attention variant (slower but still better than naive).
Flash Attention in Transformers
Hugging Face Transformers supports Flash Attention 2 natively.
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # <--- Enable Flash Attention
device_map="auto",
)
Performance Benchmark (Llama-2-7B, 8k context, A100):
- Standard Attention: 45 tokens/sec, 72 GB VRAM.
- Flash Attention 2: 120 tokens/sec, 38 GB VRAM.
9.1.13. Performance Profiling: Finding the Bottleneck
Training a model is an optimization problem. But you can’t optimize what you don’t measure.
PyTorch Profiler
The PyTorch Profiler captures detailed traces of GPU operations.
from torch.profiler import profile, ProfilerActivity, schedule
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=1, warmup=1, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profiler'),
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
for step, batch in enumerate(dataloader):
if step >= 5:
break
outputs = model(batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
prof.step() # Notify profiler of step boundary
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Output (example):
--------------------------------- ------------ ------------ ------------
Name Self CPU % Self CPU CPU total %
--------------------------------- ------------ ------------ ------------
aten::addmm 2.50% 10.234ms 15.20%
aten::_scaled_dot_product... 1.20% 4.912ms 45.30%
aten::copy_ 5.10% 20.891ms 5.10%
Memcpy DtoH (Device -> Host) 8.30% 33.981ms 8.30%
Interpretation:
- High
Memcpy DtoH: Data is being copied from GPU to CPU unnecessarily. Check if you’re calling.cpu()or.item()in the training loop. - High
aten::copy_: Likely a datatype mismatch or inefficient tensor operations.
TensorBoard Profiler Visualization
Load the trace in TensorBoard:
tensorboard --logdir=./log/profiler
Navigate to the “Profiler” tab. You’ll see:
- Timeline: GPU kernel execution over time. Look for gaps (idle time).
- Operator View: Which operations consume the most time.
- Memory View: Peak memory usage per operation.
Red Flags:
- Long gaps between kernels: Data loading bottleneck. Use
num_workers > 0andpin_memory=True. - AllReduce consuming >50% of time: Network bottleneck. Verify EFA is working.
NVIDIA Nsight Systems
For deeper profiling (CPU, GPU, NCCL), use Nsight Systems.
nsys profile -t cuda,nvtx,osrt,cudnn,cublas \
-o training_profile \
python train.py
Open the .nsys-rep file in the Nsight Systems GUI. You can see:
- NCCL communication timelines.
- Kernel launch overhead.
- CPU-GPU synchronization points.
9.1.14. Real-World Case Study: Training Llama-3-70B on AWS
Let’s walk through a production deployment.
Goal: Fine-tune Llama-3-70B on a custom dataset (500M tokens) using AWS.
Cluster Configuration:
- Instances: 8x
p4d.24xlarge(64 A100 GPUs total). - Network: EFA, cluster placement group, single AZ.
- Storage: FSx for Lustre (10 TB, linked to S3).
Step 1: Cost Estimation
Training time estimate: 7 days.
- Compute: 8 nodes × $32.77/hr × 168 hrs = $44,054.
- FSx: 10 TB × $0.14/GB × 7 days = $98.
- Total: ~$44,152.
Step 2: Parallelism Strategy
70B parameters in BF16:
- Model: 140 GB.
- Gradients: 140 GB.
- Optimizer (Adam): 560 GB.
- Total: 840 GB.
Single A100: 80 GB VRAM. We need aggressive sharding.
Choice: 3D Parallelism.
- TP = 8 (intra-node, use NVLink).
- PP = 2 (split 80 layers across 2 nodes).
- DP = 4 (replicate the TP+PP pipeline 4 times).
Verification: $8 \text{ nodes} \times 8 \text{ GPUs/node} = 64 \text{ GPUs}$. $TP \times PP \times DP = 8 \times 2 \times 4 = 64$. ✓
Step 3: Launcher Script (Megatron-DeepSpeed)
#!/bin/bash
# Nodes
NNODES=8
GPUS_PER_NODE=8
WORLD_SIZE=$((NNODES * GPUS_PER_NODE))
# Parallelism config
TP=8
PP=2
# Master node (rank 0)
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
deepspeed --num_nodes=$NNODES \
--num_gpus=$GPUS_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
pretrain_llama.py \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--num-layers 80 \
--hidden-size 8192 \
--num-attention-heads 64 \
--seq-length 4096 \
--max-position-embeddings 4096 \
--micro-batch-size 1 \
--global-batch-size 512 \
--train-iters 100000 \
--lr 3e-4 \
--lr-decay-style cosine \
--min-lr 3e-5 \
--weight-decay 0.1 \
--clip-grad 1.0 \
--bf16 \
--zero-stage 1 \
--checkpoint-activations \
--save-interval 1000 \
--save /fsx/checkpoints/llama3-70b \
--load /fsx/checkpoints/llama3-70b \
--data-path /fsx/data/my_dataset_text_document \
--vocab-file /fsx/models/tokenizer.model \
--tensorboard-dir /fsx/logs
Step 4: Monitoring
Deploy Prometheus + Grafana + DCGM Exporter. Watch:
- GPU utilization (target: >90%).
- Network throughput (expect ~40 GB/s during AllReduce).
- Loss curve (should decrease smoothly).
Step 5: Checkpointing
Checkpoint every 1000 steps (~2 hours). Each checkpoint: 1.1 TB. Retain last 5 checkpoints (5.5 TB total).
Step 6: Failure Handling
On day 3, node 7 has an ECC error. GPU 7.3 is marked unhealthy.
- CloudWatch alarm triggers Lambda.
- Lambda terminates node 7.
- Auto Scaling Group launches replacement.
- Training resumes from latest checkpoint (lost ~30 minutes of compute).
Final Result:
- Training completed in 6.8 days.
- Final model uploaded to S3.
- Total cost: $43,200 (under budget).
9.1.15. Advanced Optimization Techniques
1. Gradient Checkpointing with Selective Layers
Not all layers benefit equally from activation checkpointing. Expensive layers (attention) benefit more than cheap layers (LayerNorm).
def should_checkpoint(layer):
# Only checkpoint attention layers
return isinstance(layer, (MultiHeadAttention, TransformerBlock))
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=should_checkpoint,
)
2. Dynamic Loss Scaling
Instead of fixed loss scaling (e.g., 1024), use dynamic scaling that adapts to gradient magnitudes.
scaler = GradScaler(
init_scale=2**16, # Start high
growth_factor=2.0, # Double if no overflow
backoff_factor=0.5, # Halve if overflow detected
growth_interval=2000, # Check every 2000 steps
)
3. Fused Optimizers
Standard optimizers (Adam, SGD) launch many small CUDA kernels. Fused optimizers combine these into a single kernel.
from apex.optimizers import FusedAdam # NVIDIA Apex library
optimizer = FusedAdam(model.parameters(), lr=1e-4)
Speedup: 5-10% faster than torch.optim.Adam.
4. CPU Offloading for Inactive Ranks
In Pipeline Parallelism, GPUs in later pipeline stages are idle during the first few micro-batches. Offload their inactive weights to CPU during this time.
Implementation: DeepSpeed’s ZeRO-Offload with PP-aware scheduling.
5. Overlapping Data Loading with Computation
Use torch.utils.data.DataLoader with:
num_workers > 0: Prefetch data on CPU.pin_memory=True: Use pinned memory for faster CPU-to-GPU transfer.prefetch_factor=2: Keep 2 batches ready.
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=8,
pin_memory=True,
prefetch_factor=2,
persistent_workers=True, # Keep workers alive between epochs
)
9.1.16. Summary: The Architect’s Decision Tree
When designing your training cluster, use this heuristic:
- Does the model fit in one GPU?
- Yes: Use DDP. Simple, standard.
- Limit: ~1.5B params (FP16) on 24GB VRAM.
- Does it almost fit (or fit with small batch size)?
- Yes: Use FSDP (ZeRO-3).
- Limit: ~20B params on A100 80GB (single node).
- Is it a massive model (70B+)?
- Single Node: Use FSDP with CPU Offloading (slow) or QLoRA (quantized).
- Multi-Node: Use 3D Parallelism.
- TP = 8 (fill the node).
- PP = Model Depth / Layers per Node.
- DP = Remaining scale.
The Golden Rule of Distributed Training: Communication is the killer. Always prioritize strategies that keep heavy communication (TP) inside the NVLink domain and lightweight communication (DP/PP) across the Ethernet domain.
Technical debt in distributed systems manifests as GPU Idle Time. If your nvidia-smi shows GPU utilization fluctuating between 100% and 0%, your parallelism strategy is misaligned with your network topology.