Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

20.3 Model Sharding: Running Large Models on Multiple GPUs

The “iPhone Moment” of AI was ChatGPT. But under the hood, ChatGPT isn’t running on a GPU. It is running on thousands of GPUs. Even a mid-sized open model like Llama-3-70B cannot fit on a single A100 (80GB) if you want decent context length and batch size.

This chapter covers Distributed Inference: How to split a single neural network across multiple physical devices and make them act as one.


1. The Math of VRAM: Why Shard?

Let’s do the math for Llama-3-70B.

  • Parameters: 70 Billion.
  • Precision:
    • FP16 (2 bytes): $70B \times 2 = 140$ GB.
    • INT8 (1 byte): $70B \times 1 = 70$ GB.
    • INT4 (0.5 bytes): $70B \times 0.5 = 35$ GB.

The Hardware:

  • NVIDIA A100: 80 GB VRAM.
  • NVIDIA A10G: 24 GB VRAM.
  • NVIDIA T4: 16 GB VRAM.

The Problem: Even at INT4 (35GB), Llama-70B fits on an A100 technically, but you have no room for KV Cache (Context Memory). A 4k context window can take 1-2 GB per user. If you want batch size > 1, you OOM immediately. At FP16 (140GB), it fits on zero single cards.

The Solution: Sharding. Splitting the weights across card boundaries.


2. Parallelism Strategies

There are two main ways to cut the model.

2.1. Pipeline Parallelism (PP)

Vertical Slicing.

  • Concept: Put Layer 1-10 on GPU 0. Layer 11-20 on GPU 1.
  • Flow: Batch enters GPU 0 -> Compute -> Send to GPU 1 -> Compute -> …
  • Pros: Simple to implement. Low communication overhead (only passing activations between layers).
  • Cons: The Bubble. While GPU 1 is working, GPU 0 is idle. Utilization is low unless you use micro-batching. Latency is high (sequential processing).

2.2. Tensor Parallelism (TP)

Horizontal Slicing.

  • Concept: Split every single matrix multiplication across GPUs.
  • Flow:
    • Layer 1: $Y = W \cdot X$.
    • Split $W$ into $W_1, W_2$.
    • GPU 0 computes $W_1 \cdot X$. GPU 1 computes $W_2 \cdot X$.
    • All-Reduce: GPU 0 and 1 communicate to sum their results.
  • Pros: Low Latency. Both GPUs work simultaneously.
  • Cons: Massive Communication Overhead. Requires high-bandwidth interconnects (NVLink). If you do TP over Ethernet, it is slow.

Verdict for Inference: Use Tensor Parallelism. We care about Latency.


3. The Framework: Ray Serve

Ray is the industry standard for distributed Python. It allows us to define an “Actor” that conceptually spans multiple GPUs.

3.1. KubeRay Architecture

On Kubernetes, you deploy a RayCluster.

  • Head Node: Manages state.
  • Worker Groups: GPU nodes (e.g., g5.12xlarge which has 4x A10Gs).

3.2. Ray Serve Implementation

Serving Llama-70B across 4 GPUs using vLLM backend.

import ray
from ray import serve
from vllm import AsyncLLMEngine, EngineArgs, SamplingParams

@serve.deployment(ray_actor_options={"num_gpus": 4})
class VLLMPredictor:
    def __init__(self):
        # 1. Start Engine
        # This automatically detects 4 GPUs and initializes Tensor Parallelism
        args = EngineArgs(
            model="meta-llama/Llama-3-70b-hf",
            tensor_parallel_size=4,
            trust_remote_code=True
        )
        self.engine = AsyncLLMEngine.from_engine_args(args)

    async def __call__(self, request):
        # 2. Parse Request
        data = await request.json()
        prompt = data.get("prompt")
        
        # 3. Generate
        results_generator = self.engine.generate(
            prompt, 
            SamplingParams(temperature=0.7)
        )
        
        # 4. Stream Output
        final_text = ""
        async for request_output in results_generator:
            final_text = request_output.outputs[0].text
            
        return {"text": final_text}

# Deploy
app = VLLMPredictor.bind()

Ops Note: The @serve.deployment(num_gpus=4) decorator determines the scheduling. Ray will look for a node with 4 free GPUs. If you have 4 single-GPU nodes, this fails unless your TP supports multi-node (slow). Always try to pack TP groups onto a single physical instance (e.g., p4d or g5 metal).


4. Serving Engines: vLLM vs. TGI

You don’t write the CUDA kernels yourself. You use an Engine.

4.1. vLLM (Virtual LLM)

  • Feature: PagedAttention. Inspired by OS Virtual Memory. It fragments the KV Cache into blocks, allowing non-contiguous memory allocation.
  • Pros: 2-4x higher throughput than naive implementation. Near-zero memory waste.
  • Best For: High concurrency batch serving.

4.2. TGI (Text Generation Inference)

  • Feature: Continuous Batching. Instead of waiting for the whole batch to finish, it injects new requests as soon as old ones finish generation (because some sentences are shorter).
  • Pros: Hugging Face native. Great Docker support.
  • Best For: Production simplicity.

Configuration Example (TGI):

model=meta-llama/Llama-2-70b-chat-hf
num_shard=4

docker run --gpus all --shm-size 1g -p 8080:80 \
  -v $PWD/data:/data \
  ghcr.io/huggingface/text-generation-inference:1.1.0 \
  --model-id $model \
  --num-shard $num_shard \
  --quantize bitsandbytes-nf4
  • --num-shard 4: This flag triggers the Tensor Parallelism logic automatically.

5. Deployment Pattern: The “Sidecar Shard”

In Kubernetes, getting 4 GPUs to talk requires shared memory (/dev/shm). Standard Pods have limits.

5.1. The Shared Memory Hack

PyTorch Distributed uses shared memory for IPC. Default Docker shm is 64MB. This crashes distributed runs. Fix: Mount an emptyDir with Medium: Memory.

# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: llama-70b
spec:
  replicas: 1
  template:
    spec:
      containers:
        - name: inference-server
          image: my-ray-image
          resources:
            limits:
              nvidia.com/gpu: 4 # Request 4 GPUs
          volumeMounts:
            - mountPath: /dev/shm
              name: dshm
      volumes:
        - name: dshm
          emptyDir:
            medium: Memory # RAM-backed filesystem

5.2. Autoscaling Sharded Workloads

Autoscaling a 1-GPU pod is easy. Autoscaling a 4-GPU pod is hard.

  • Bin Packing: You need a node with exactly 4 contiguous GPUs free.
  • Karpenter: Use AWS Karpenter to provision new instances specifically for the pod.
    • Provisioner Config: instance-type: [g5.12xlarge].
    • Effect: When a new Pod request comes in, Karpenter spins up a fresh g5.12xlarge in 60s, binds the pod.

5.3. Fault Tolerance

If 1 GPU dies in a 4-GPU group, the entire group crashes. Tensor Parallelism is brittle.

  • Recovery: Ray Serve handles restart. It kills the actor and restarts it on a healthy node.
  • Health Check: Ensure your Liveness Probe queries the model, not just the server. A GPU can be stuck (ECC errors) while the HTTP server is up.

In the next section, we look at Data Engineering for Sharding.


6. Data Engineering: The 100GB Loading Problem

When your model is 140GB, model.load_state_dict() is your bottleneck. On a standard SSD (500MB/s), loading 140GB takes ~5 minutes. If you have autoscaling, a 5-minute cold start is unacceptable.

6.1. SafeTensors: The Savior

Pickle (PyTorch default) is slow and insecure. It requires unpickling (CPU work) and memory copying. SafeTensors is a zero-copy format.

  • Memory Mapping: It maps the file directly on disk to the memory address space.
  • Speed: Faster than torch.load().
  • Safety: No code execution.

Conversion Code:

from safetensors.torch import save_file, load_file
import torch

# Convert PyTorch bin to SafeTensors
weights = torch.load("pytorch_model.bin")
save_file(weights, "model.safetensors")

Note: Hugging Face now defaults to SafeTensors. Always verify your repo has .safetensors files before deploying.

6.2. Fast Loading Architecture

Optimizing the Cold Start:

  1. S3 Throughput: Standard S3 is ~100MB/s.
    • Fix: Use high-concurrency download (AWS CLI max_concurrent_requests).
  2. Container Image Baking:
    • Bad: Download weights in ENTRYPOINT script. (Slow every start).
    • Better: Mount an EFS/Filestore volume with weights pre-loaded. (Shared capability).
    • Best: Bake weights into the Docker Image (if < 10GB). For 140GB, this is hard.
  3. Instance Store (NVMe):
    • g5 instances come with local NVMe SSDs.
    • Startup Script: aws s3 cp s3://bucket/model /mnt/nvme/model (Use s5cmd for 10GB/s throughput).

The s5cmd Trick: Standard broadcast of 140GB takes forever. Go-based s5cmd saturates the 100Gbps network bandwidth.

# In your startup script
curl -L https://github.com/peak/s5cmd/releases/download/v2.0.0/s5cmd_2.0.0_Linux-64bit.tar.gz | tar xz
./s5cmd cp "s3://my-bucket/llama-70b/*" /data/model/

Result: 140GB download in < 60 seconds (on instances with 100Gbps networking).


7. Networking: The Invisible Bottleneck

In Tensor Parallelism, GPUs talk to each other every single layer. Layer 1 Compute -> Sync -> Layer 2 Compute -> Sync. If “Sync” is slow, the GPUs spend 50% of time waiting.

  • PCIe Gen4: ~64 GB/s. (Standard slots).
  • NVLink: ~600 GB/s. (Bridge between GPUs).

Ops Implication: You cannot do Tensor Parallelism efficiently across two separate machines (e.g., two g4dn.xlarge instances) over TCP/IP Ethernet. The latency (milliseconds) is 1000x too slow compared to NVLink (microseconds). Rule: TP must happen inside a single chassis.

7.2. NCCL (NVIDIA Collective Communication Library)

NCCL is the protocol used for AllReduce. It automatically detects the best path (NVLink > PCIe > Socket).

Debugging NCCL: If distributed training hangs or is slow, use:

export NCCL_DEBUG=INFO
export NCCL_P2P_DISABLE=0

Watch the logs. If you see it falling back to NET/Socket inside a single machine, your NVLink topology is broken or virtualization is misconfigured.

7.3. EFA (Elastic Fabric Adapter)

For Multi-Node training (not inference), capabilities like AWS EFA bypass the OS kernel to provide low-latency networking. While less critical for Inference (since we keep TP local), it is mandatory for distributed Training (20.4).


8. Quantized Sharding: AWQ and GPTQ

If you can’t afford 4x A100s, you Quantize. Llama-3-70B can fit on 2x A100s or 4x A10Gs if compressed to 4-bit.

8.1. GPTQ (Post-Training Quantization)

Reduces weights to 4-bit by analyzing the Hessian (curvature) of the loss landscape, identifying which weights “don’t matter”.

  • Format: Pre-quantized .safetensors.
  • Serving: vLLM and TGI support loading GPTQ/AWQ models directly.

8.2. AWQ (Activation-aware Weight Quantization)

Newer standard. Better at preserving reasoning capabilities than GPTQ.

Serving Config:

# vLLM
engine = AsyncLLMEngine.from_engine_args(
    model="TheBloke/Llama-2-70B-Chat-AWQ",
    quantization="awq",
    tensor_parallel_size=2 # Fits on 2x A100s!
)

Cost Math:

  • FP16: 4x A100 ($12/hr).
  • AWQ 4-bit: 2x A100 ($6/hr).
  • Optimization: 50% cost reduction by changing one line of config.

9. Hands-On Lab: The “Poor Man’s” 70B Cluster

We will simulate a distributed environment using 2x T4 GPUs (cheap) to run a smaller sharded model (e.g., 13B) to prove the pipeline works, since requesting 4x A100s might hit quota limits.

9.1. Setup

  • Instance: g4dn.12xlarge (4x T4 GPUs). Cost: ~$3.9/hr.
  • Goal: Serve Llama-2-13B (26GB FP16) across 2 GPUs (16GB each).

9.2. Code

# serve.py
from vllm import LLM, SamplingParams

# 13B model needs ~26GB.
# T4 has 16GB.
# 2x T4 = 32GB. It fits with room for 6GB KV Cache.

llm = LLM(
    model="meta-llama/Llama-2-13b-chat-hf",
    tensor_parallel_size=2 # Utilization of 2 GPUs
)

output = llm.generate("Hello, how are you?")
print(output[0].outputs[0].text)

9.3. Observation

Run nvidia-smi in a separate terminal during generation.

  • You should see memory usage spike on GPU 0 AND GPU 1.
  • Compute utilization should rise synchronously.
  • If only GPU 0 moves, TP is not working.

10. Troubleshooting Model Sharding

Symptom: RuntimeError: CUDA out of memory.

  • Check: Are you counting the KV Cache?
  • Fix: Reduce max_model_len (Context size). Default is often 4096. Lowering to 2048 frees up GBs.
  • Fix: quantization (Load load_in_8bit=True).

Symptom: NCCL timeout or Hang.

  • Cause: Firewall/Security Group blocking internal ports.
  • Fix: Allow Inbound Trafic on All TCP Ports from Self (Security Group ID). NCCL uses random high ports.

Symptom: Throughput is low (2 tokens/sec).

  • Cause: You are CPU bound?
  • Check: top. If Python is 100%, data loading or post-processing is the bottleneck.
  • Cause: Broken NVLink. Running over PCIe.

11. Reference Architectures

How do you wire this up in AWS EKS?

11.1. The Single Node Pod

If model > Single GPU but < Single Node (8 GPUs).

  • Node Pool: p4d.24xlarge (8x A100).
  • Pod: Requests nvidia.com/gpu: 8.
  • Networking: Loopback (NVLink).

11.2. The Multi-Node Cluster (Training)

If model > 8 GPUs (e.g., Training Llama-3-400B).

  • Interconnect: EFA (Elastic Fabric Adapter).
  • Deployment: Ray Cluster (Head + Workers).
  • Worker: Each Worker manages 8 GPUs. They talk via EFA.

KubeRay Manifest Example:

apiVersion: ray.io/v1
kind: RayService
metadata:
  name: llama-serving
spec:
  serveConfigV2: |
    applications:
      - name: llama_app
        import_path: serving.app
        runtime_env:
          pip: ["vllm", "ray[serve]"]
  rayClusterConfig:
    rayVersion: '2.9.0'
    headGroupSpec:
      rayStartParams:
        dashboard-host: '0.0.0.0'
      template:
        spec:
          containers:
          - name: ray-head
            image: rayproject/ray:2.9.0-gpu
            resources:
              limits:
                cpu: 2
                memory: 8Gi
    workerGroupSpecs:
    - groupName: gpu-group
      replicas: 1
      minReplicas: 1
      maxReplicas: 5
      rayStartParams: {}
      template:
        spec:
          containers:
          - name: ray-worker
            image: rayproject/ray:2.9.0-gpu
            resources:
              limits:
                nvidia.com/gpu: 4 # THE CRITICAL LINE
                memory: 200Gi
          nodeSelector:
            instance-type: g5.12xlarge # Maps to physical hardware

Sharding solves Capacity (Fitting the model). It does not solve Latency (Autoregressive is slow).

12.1. The Theory

LLMs are memory-bandwidth bound. It takes the same time to process 1 token as 5 tokens. Idea: What if we had a tiny “Draft Model” (Llama-7B) guess the next 5 tokens, and the “Oracle Model” (Llama-70B) verifies them in parallel?

  • Draft: “The cat sat on the” (Fast).
  • Oracle: Check [“The”, “cat”, “sat”, “on”, “the”].
    • If all correct: Acceptance! We generated 5 tokens in 1 step.
    • If wrong: Reject and re-generate.

12.2. vLLM Support

vLLM supports this out of the box.

engine_args = EngineArgs(
    model="meta-llama/Llama-3-70b-hf",
    speculative_model="meta-llama/Llama-3-8b-hf", # The Drafter
    num_speculative_tokens=5
)

Ops Impact: You need VRAM for both models. But the draft model is usually small. Result: 2x-3x speedup in tokens/sec.


13. FAQ

Q: Can I run 70B on CPU? A: Yes, with llama.cpp (GGUF format). It will run at 2-3 tokens/second. Good for debugging. Unusable for production chat (users expect 20-50 t/s).

Q: Do I need InfiniBand? A: For Inference of < 100B models: No. NVLink inside the node is enough. For Training: Yes.

Q: How does this impact Cost? A: Inference cost is linear with model size.

  • 7B: $1/hr.
  • 70B: $10/hr. Your business case must justify the 10x cost. Does 70B provide 10x better answers? (Often: Yes, for coding/reasoning. No, for summarization).

14. Glossary of Distributed Terms

  • All-Reduce: The MPI operation where every node shares its data with every other node, and they all end up with the Sum/Mean.
  • NVLink: Proprietary NVIDIA cable for high-speed GPU-to-GPU talk.
  • Pipieline Parallelism (PP): Assigning layers to different GPUs.
  • Tensor Parallelism (TP): Splitting tensors within a layer across GPUs.
  • Sharding: The general act of partitioning data/weights.
  • vLLM: The leading open-source inference engine optimized for throughput.
  • Weights vs. Activations:
    • Weights: Static parameters (Fixed size).
    • Activations: Dynamic data flowing through net (Depends on Batch Size).
    • KV Cache: Saved activations for Attention ( Grows with Context Length).

15. References

1. “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM”

  • Narayanan et al. (NVIDIA) (2021): The paper that defined Tensor Parallelism.

2. “Ray: A Distributed Framework for Emerging AI Applications”

  • Moritz et al. (Berkeley) (2018): The foundation of modern distributed AI orchestration.

3. “vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention”

  • Kwon et al. (Berkeley) (2023): Revolutionized inference memory management.

16. Final Checklist: Deployment Day

  1. Hardware: Do you have a g5.12xlarge or p4d quota aproved?
  2. Format: Is the model in SafeTensors?
  3. Quantization: Did you benchmark AWQ vs FP16?
  4. Engine: Are you using vLLM (Throughput) or TGI (Simplicity)?
  5. Health: Is the Liveness Probe configured to check the Engine loop?

In the next section, we move from Serving models to Teaching them human preferences: RLHF Operations (20.4).