20.3 Model Sharding: Running Large Models on Multiple GPUs
The “iPhone Moment” of AI was ChatGPT. But under the hood, ChatGPT isn’t running on a GPU. It is running on thousands of GPUs. Even a mid-sized open model like Llama-3-70B cannot fit on a single A100 (80GB) if you want decent context length and batch size.
This chapter covers Distributed Inference: How to split a single neural network across multiple physical devices and make them act as one.
1. The Math of VRAM: Why Shard?
Let’s do the math for Llama-3-70B.
- Parameters: 70 Billion.
- Precision:
- FP16 (2 bytes): $70B \times 2 = 140$ GB.
- INT8 (1 byte): $70B \times 1 = 70$ GB.
- INT4 (0.5 bytes): $70B \times 0.5 = 35$ GB.
The Hardware:
- NVIDIA A100: 80 GB VRAM.
- NVIDIA A10G: 24 GB VRAM.
- NVIDIA T4: 16 GB VRAM.
The Problem: Even at INT4 (35GB), Llama-70B fits on an A100 technically, but you have no room for KV Cache (Context Memory). A 4k context window can take 1-2 GB per user. If you want batch size > 1, you OOM immediately. At FP16 (140GB), it fits on zero single cards.
The Solution: Sharding. Splitting the weights across card boundaries.
2. Parallelism Strategies
There are two main ways to cut the model.
2.1. Pipeline Parallelism (PP)
Vertical Slicing.
- Concept: Put Layer 1-10 on GPU 0. Layer 11-20 on GPU 1.
- Flow: Batch enters GPU 0 -> Compute -> Send to GPU 1 -> Compute -> …
- Pros: Simple to implement. Low communication overhead (only passing activations between layers).
- Cons: The Bubble. While GPU 1 is working, GPU 0 is idle. Utilization is low unless you use micro-batching. Latency is high (sequential processing).
2.2. Tensor Parallelism (TP)
Horizontal Slicing.
- Concept: Split every single matrix multiplication across GPUs.
- Flow:
- Layer 1: $Y = W \cdot X$.
- Split $W$ into $W_1, W_2$.
- GPU 0 computes $W_1 \cdot X$. GPU 1 computes $W_2 \cdot X$.
- All-Reduce: GPU 0 and 1 communicate to sum their results.
- Pros: Low Latency. Both GPUs work simultaneously.
- Cons: Massive Communication Overhead. Requires high-bandwidth interconnects (NVLink). If you do TP over Ethernet, it is slow.
Verdict for Inference: Use Tensor Parallelism. We care about Latency.
3. The Framework: Ray Serve
Ray is the industry standard for distributed Python. It allows us to define an “Actor” that conceptually spans multiple GPUs.
3.1. KubeRay Architecture
On Kubernetes, you deploy a RayCluster.
- Head Node: Manages state.
- Worker Groups: GPU nodes (e.g.,
g5.12xlargewhich has 4x A10Gs).
3.2. Ray Serve Implementation
Serving Llama-70B across 4 GPUs using vLLM backend.
import ray
from ray import serve
from vllm import AsyncLLMEngine, EngineArgs, SamplingParams
@serve.deployment(ray_actor_options={"num_gpus": 4})
class VLLMPredictor:
def __init__(self):
# 1. Start Engine
# This automatically detects 4 GPUs and initializes Tensor Parallelism
args = EngineArgs(
model="meta-llama/Llama-3-70b-hf",
tensor_parallel_size=4,
trust_remote_code=True
)
self.engine = AsyncLLMEngine.from_engine_args(args)
async def __call__(self, request):
# 2. Parse Request
data = await request.json()
prompt = data.get("prompt")
# 3. Generate
results_generator = self.engine.generate(
prompt,
SamplingParams(temperature=0.7)
)
# 4. Stream Output
final_text = ""
async for request_output in results_generator:
final_text = request_output.outputs[0].text
return {"text": final_text}
# Deploy
app = VLLMPredictor.bind()
Ops Note: The @serve.deployment(num_gpus=4) decorator determines the scheduling. Ray will look for a node with 4 free GPUs. If you have 4 single-GPU nodes, this fails unless your TP supports multi-node (slow). Always try to pack TP groups onto a single physical instance (e.g., p4d or g5 metal).
4. Serving Engines: vLLM vs. TGI
You don’t write the CUDA kernels yourself. You use an Engine.
4.1. vLLM (Virtual LLM)
- Feature: PagedAttention. Inspired by OS Virtual Memory. It fragments the KV Cache into blocks, allowing non-contiguous memory allocation.
- Pros: 2-4x higher throughput than naive implementation. Near-zero memory waste.
- Best For: High concurrency batch serving.
4.2. TGI (Text Generation Inference)
- Feature: Continuous Batching. Instead of waiting for the whole batch to finish, it injects new requests as soon as old ones finish generation (because some sentences are shorter).
- Pros: Hugging Face native. Great Docker support.
- Best For: Production simplicity.
Configuration Example (TGI):
model=meta-llama/Llama-2-70b-chat-hf
num_shard=4
docker run --gpus all --shm-size 1g -p 8080:80 \
-v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:1.1.0 \
--model-id $model \
--num-shard $num_shard \
--quantize bitsandbytes-nf4
--num-shard 4: This flag triggers the Tensor Parallelism logic automatically.
5. Deployment Pattern: The “Sidecar Shard”
In Kubernetes, getting 4 GPUs to talk requires shared memory (/dev/shm).
Standard Pods have limits.
5.1. The Shared Memory Hack
PyTorch Distributed uses shared memory for IPC.
Default Docker shm is 64MB. This crashes distributed runs.
Fix: Mount an emptyDir with Medium: Memory.
# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: llama-70b
spec:
replicas: 1
template:
spec:
containers:
- name: inference-server
image: my-ray-image
resources:
limits:
nvidia.com/gpu: 4 # Request 4 GPUs
volumeMounts:
- mountPath: /dev/shm
name: dshm
volumes:
- name: dshm
emptyDir:
medium: Memory # RAM-backed filesystem
5.2. Autoscaling Sharded Workloads
Autoscaling a 1-GPU pod is easy. Autoscaling a 4-GPU pod is hard.
- Bin Packing: You need a node with exactly 4 contiguous GPUs free.
- Karpenter: Use AWS Karpenter to provision new instances specifically for the pod.
- Provisioner Config:
instance-type: [g5.12xlarge]. - Effect: When a new Pod request comes in, Karpenter spins up a fresh
g5.12xlargein 60s, binds the pod.
- Provisioner Config:
5.3. Fault Tolerance
If 1 GPU dies in a 4-GPU group, the entire group crashes. Tensor Parallelism is brittle.
- Recovery: Ray Serve handles restart. It kills the actor and restarts it on a healthy node.
- Health Check: Ensure your Liveness Probe queries the model, not just the server. A GPU can be stuck (ECC errors) while the HTTP server is up.
In the next section, we look at Data Engineering for Sharding.
6. Data Engineering: The 100GB Loading Problem
When your model is 140GB, model.load_state_dict() is your bottleneck.
On a standard SSD (500MB/s), loading 140GB takes ~5 minutes.
If you have autoscaling, a 5-minute cold start is unacceptable.
6.1. SafeTensors: The Savior
Pickle (PyTorch default) is slow and insecure. It requires unpickling (CPU work) and memory copying. SafeTensors is a zero-copy format.
- Memory Mapping: It maps the file directly on disk to the memory address space.
- Speed: Faster than
torch.load(). - Safety: No code execution.
Conversion Code:
from safetensors.torch import save_file, load_file
import torch
# Convert PyTorch bin to SafeTensors
weights = torch.load("pytorch_model.bin")
save_file(weights, "model.safetensors")
Note: Hugging Face now defaults to SafeTensors. Always verify your repo has .safetensors files before deploying.
6.2. Fast Loading Architecture
Optimizing the Cold Start:
- S3 Throughput: Standard S3 is ~100MB/s.
- Fix: Use high-concurrency download (AWS CLI
max_concurrent_requests).
- Fix: Use high-concurrency download (AWS CLI
- Container Image Baking:
- Bad: Download weights in
ENTRYPOINTscript. (Slow every start). - Better: Mount an EFS/Filestore volume with weights pre-loaded. (Shared capability).
- Best: Bake weights into the Docker Image (if < 10GB). For 140GB, this is hard.
- Bad: Download weights in
- Instance Store (NVMe):
g5instances come with local NVMe SSDs.- Startup Script:
aws s3 cp s3://bucket/model /mnt/nvme/model(Uses5cmdfor 10GB/s throughput).
The s5cmd Trick:
Standard broadcast of 140GB takes forever.
Go-based s5cmd saturates the 100Gbps network bandwidth.
# In your startup script
curl -L https://github.com/peak/s5cmd/releases/download/v2.0.0/s5cmd_2.0.0_Linux-64bit.tar.gz | tar xz
./s5cmd cp "s3://my-bucket/llama-70b/*" /data/model/
Result: 140GB download in < 60 seconds (on instances with 100Gbps networking).
7. Networking: The Invisible Bottleneck
In Tensor Parallelism, GPUs talk to each other every single layer. Layer 1 Compute -> Sync -> Layer 2 Compute -> Sync. If “Sync” is slow, the GPUs spend 50% of time waiting.
7.1. NVLink vs. PCIe
- PCIe Gen4: ~64 GB/s. (Standard slots).
- NVLink: ~600 GB/s. (Bridge between GPUs).
Ops Implication:
You cannot do Tensor Parallelism efficiently across two separate machines (e.g., two g4dn.xlarge instances) over TCP/IP Ethernet. The latency (milliseconds) is 1000x too slow compared to NVLink (microseconds).
Rule: TP must happen inside a single chassis.
7.2. NCCL (NVIDIA Collective Communication Library)
NCCL is the protocol used for AllReduce.
It automatically detects the best path (NVLink > PCIe > Socket).
Debugging NCCL: If distributed training hangs or is slow, use:
export NCCL_DEBUG=INFO
export NCCL_P2P_DISABLE=0
Watch the logs. If you see it falling back to NET/Socket inside a single machine, your NVLink topology is broken or virtualization is misconfigured.
7.3. EFA (Elastic Fabric Adapter)
For Multi-Node training (not inference), capabilities like AWS EFA bypass the OS kernel to provide low-latency networking. While less critical for Inference (since we keep TP local), it is mandatory for distributed Training (20.4).
8. Quantized Sharding: AWQ and GPTQ
If you can’t afford 4x A100s, you Quantize. Llama-3-70B can fit on 2x A100s or 4x A10Gs if compressed to 4-bit.
8.1. GPTQ (Post-Training Quantization)
Reduces weights to 4-bit by analyzing the Hessian (curvature) of the loss landscape, identifying which weights “don’t matter”.
- Format: Pre-quantized
.safetensors. - Serving:
vLLMandTGIsupport loading GPTQ/AWQ models directly.
8.2. AWQ (Activation-aware Weight Quantization)
Newer standard. Better at preserving reasoning capabilities than GPTQ.
Serving Config:
# vLLM
engine = AsyncLLMEngine.from_engine_args(
model="TheBloke/Llama-2-70B-Chat-AWQ",
quantization="awq",
tensor_parallel_size=2 # Fits on 2x A100s!
)
Cost Math:
- FP16: 4x A100 ($12/hr).
- AWQ 4-bit: 2x A100 ($6/hr).
- Optimization: 50% cost reduction by changing one line of config.
9. Hands-On Lab: The “Poor Man’s” 70B Cluster
We will simulate a distributed environment using 2x T4 GPUs (cheap) to run a smaller sharded model (e.g., 13B) to prove the pipeline works, since requesting 4x A100s might hit quota limits.
9.1. Setup
- Instance:
g4dn.12xlarge(4x T4 GPUs). Cost: ~$3.9/hr. - Goal: Serve Llama-2-13B (26GB FP16) across 2 GPUs (16GB each).
9.2. Code
# serve.py
from vllm import LLM, SamplingParams
# 13B model needs ~26GB.
# T4 has 16GB.
# 2x T4 = 32GB. It fits with room for 6GB KV Cache.
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
tensor_parallel_size=2 # Utilization of 2 GPUs
)
output = llm.generate("Hello, how are you?")
print(output[0].outputs[0].text)
9.3. Observation
Run nvidia-smi in a separate terminal during generation.
- You should see memory usage spike on
GPU 0ANDGPU 1. - Compute utilization should rise synchronously.
- If only GPU 0 moves, TP is not working.
10. Troubleshooting Model Sharding
Symptom: RuntimeError: CUDA out of memory.
- Check: Are you counting the KV Cache?
- Fix: Reduce
max_model_len(Context size). Default is often 4096. Lowering to 2048 frees up GBs. - Fix: quantization (Load
load_in_8bit=True).
Symptom: NCCL timeout or Hang.
- Cause: Firewall/Security Group blocking internal ports.
- Fix: Allow Inbound Trafic on All TCP Ports from
Self(Security Group ID). NCCL uses random high ports.
Symptom: Throughput is low (2 tokens/sec).
- Cause: You are CPU bound?
- Check:
top. If Python is 100%, data loading or post-processing is the bottleneck. - Cause: Broken NVLink. Running over PCIe.
11. Reference Architectures
How do you wire this up in AWS EKS?
11.1. The Single Node Pod
If model > Single GPU but < Single Node (8 GPUs).
- Node Pool:
p4d.24xlarge(8x A100). - Pod: Requests
nvidia.com/gpu: 8. - Networking: Loopback (NVLink).
11.2. The Multi-Node Cluster (Training)
If model > 8 GPUs (e.g., Training Llama-3-400B).
- Interconnect: EFA (Elastic Fabric Adapter).
- Deployment: Ray Cluster (Head + Workers).
- Worker: Each Worker manages 8 GPUs. They talk via EFA.
KubeRay Manifest Example:
apiVersion: ray.io/v1
kind: RayService
metadata:
name: llama-serving
spec:
serveConfigV2: |
applications:
- name: llama_app
import_path: serving.app
runtime_env:
pip: ["vllm", "ray[serve]"]
rayClusterConfig:
rayVersion: '2.9.0'
headGroupSpec:
rayStartParams:
dashboard-host: '0.0.0.0'
template:
spec:
containers:
- name: ray-head
image: rayproject/ray:2.9.0-gpu
resources:
limits:
cpu: 2
memory: 8Gi
workerGroupSpecs:
- groupName: gpu-group
replicas: 1
minReplicas: 1
maxReplicas: 5
rayStartParams: {}
template:
spec:
containers:
- name: ray-worker
image: rayproject/ray:2.9.0-gpu
resources:
limits:
nvidia.com/gpu: 4 # THE CRITICAL LINE
memory: 200Gi
nodeSelector:
instance-type: g5.12xlarge # Maps to physical hardware
12. Future Trends: Speculative Decoding
Sharding solves Capacity (Fitting the model). It does not solve Latency (Autoregressive is slow).
12.1. The Theory
LLMs are memory-bandwidth bound. It takes the same time to process 1 token as 5 tokens. Idea: What if we had a tiny “Draft Model” (Llama-7B) guess the next 5 tokens, and the “Oracle Model” (Llama-70B) verifies them in parallel?
- Draft: “The cat sat on the” (Fast).
- Oracle: Check [“The”, “cat”, “sat”, “on”, “the”].
- If all correct: Acceptance! We generated 5 tokens in 1 step.
- If wrong: Reject and re-generate.
12.2. vLLM Support
vLLM supports this out of the box.
engine_args = EngineArgs(
model="meta-llama/Llama-3-70b-hf",
speculative_model="meta-llama/Llama-3-8b-hf", # The Drafter
num_speculative_tokens=5
)
Ops Impact: You need VRAM for both models. But the draft model is usually small. Result: 2x-3x speedup in tokens/sec.
13. FAQ
Q: Can I run 70B on CPU?
A: Yes, with llama.cpp (GGUF format). It will run at 2-3 tokens/second. Good for debugging. Unusable for production chat (users expect 20-50 t/s).
Q: Do I need InfiniBand? A: For Inference of < 100B models: No. NVLink inside the node is enough. For Training: Yes.
Q: How does this impact Cost? A: Inference cost is linear with model size.
- 7B: $1/hr.
- 70B: $10/hr. Your business case must justify the 10x cost. Does 70B provide 10x better answers? (Often: Yes, for coding/reasoning. No, for summarization).
14. Glossary of Distributed Terms
- All-Reduce: The MPI operation where every node shares its data with every other node, and they all end up with the Sum/Mean.
- NVLink: Proprietary NVIDIA cable for high-speed GPU-to-GPU talk.
- Pipieline Parallelism (PP): Assigning layers to different GPUs.
- Tensor Parallelism (TP): Splitting tensors within a layer across GPUs.
- Sharding: The general act of partitioning data/weights.
- vLLM: The leading open-source inference engine optimized for throughput.
- Weights vs. Activations:
- Weights: Static parameters (Fixed size).
- Activations: Dynamic data flowing through net (Depends on Batch Size).
- KV Cache: Saved activations for Attention ( Grows with Context Length).
15. References
1. “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM”
- Narayanan et al. (NVIDIA) (2021): The paper that defined Tensor Parallelism.
2. “Ray: A Distributed Framework for Emerging AI Applications”
- Moritz et al. (Berkeley) (2018): The foundation of modern distributed AI orchestration.
3. “vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention”
- Kwon et al. (Berkeley) (2023): Revolutionized inference memory management.
16. Final Checklist: Deployment Day
- Hardware: Do you have a
g5.12xlargeorp4dquota aproved? - Format: Is the model in SafeTensors?
- Quantization: Did you benchmark AWQ vs FP16?
- Engine: Are you using vLLM (Throughput) or TGI (Simplicity)?
- Health: Is the Liveness Probe configured to check the Engine loop?
In the next section, we move from Serving models to Teaching them human preferences: RLHF Operations (20.4).