Chapter 13: The GCP Compute Ecosystem
13.2. The TPU (Tensor Processing Unit) Deep Dive
“We are running out of computing capability. Moore’s Law is effectively dead… The solution is domain-specific architectures.” — John Hennessy, Turing Award Winner and Chairman of Alphabet
In the grand theater of cloud computing, the Graphics Processing Unit (GPU) is the charismatic rock star—versatile, powerful, and universally recognized. It was born for gaming, pivoted to crypto, and found its destiny in AI. However, inside Google’s data centers, there exists a different kind of entity. A silent, industrial-scale mathematician built for a singular purpose: matrix multiplication.
This is the Tensor Processing Unit (TPU).
For the Principal Engineer or Systems Architect, the TPU represents a fundamental divergence in philosophy. While AWS focuses on providing the best possible raw primitives (EC2 instances with NVIDIA cards attached via PCIe), Google Cloud offers a vertically integrated supercomputer.
Choosing the TPU is not just swapping one chip for another; it is adopting a different paradigm of parallelism, networking, and compilation. It is a choice that yields massive dividends in cost-performance and scalability, but demands a rigorous understanding of the underlying “Physics” of the hardware.
This section dissects the TPU from the silicon up to the pod level, contrasting it with the NVIDIA ecosystem, and laying out the architectural patterns required to tame this beast.
7.2.1. The Architecture of Efficacy: Systolic Arrays
To understand why a TPU is orders of magnitude more power-efficient than a general-purpose CPU or even a GPU for specific workloads, we must look at the von Neumann Bottleneck.
In a CPU or GPU, every operation typically involves:
- Fetching an instruction.
- Fetching data from memory (Registers/L1/L2/HBM) to the Arithmetic Logic Unit (ALU).
- Performing the calculation.
- Writing the result back to memory.
For a massive matrix multiplication (the beating heart of Deep Learning), this creates a traffic jam. The ALUs spend more time waiting for data to travel across the wires than they do calculating.
The Systolic Paradigm
The TPU abandons this “Fetch-Execute-Write” cycle for a Systolic Array architecture. The term “systolic” comes from biology (systole), referring to the rhythmic pumping of the heart.
Imagine a bucket brigade.
- CPU/GPU Approach: The worker runs to the water source, fills a bucket, runs to the fire, throws it, and runs back.
- TPU Approach: A line of workers stands still. They pass the full bucket to their left and the empty bucket to their right in a synchronized rhythm.
In the TPU’s Matrix Multiply Unit (MXU):
- Weight parameters are pre-loaded into the array and stay stationary.
- Data (activations) flows in from the left.
- Partial sums flow down from the top.
- In each clock cycle, a cell performs a multiply-accumulate (MAC) operation and passes the data to its neighbor.
$$ C_{ij} = \sum_{k} A_{ik} \times B_{kj} $$
The data flows through the chip like blood. No memory access is required for intermediate results. This allows the TPU to pack tens of thousands of multipliers into a tiny area with minimal heat generation, achieving a TOPS-per-watt ratio that traditional architectures cannot touch.
The Architect’s Constraint: Static Shapes and Padding
This physical reality imposes a strict software constraint: Uniformity.
The Systolic Array is a rigid physical grid (e.g., 128x128). It loves big, rectangular blocks of numbers. It hates irregularity.
- The Scenario: You are processing sentences of variable length. One is 5 tokens, one is 100 tokens.
- The CPU/GPU: Handles this via masking and dynamic control flow relatively well.
- The TPU: The compiler must “pad” the 5-token sentence to a fixed size (e.g., 128) with zeros to fit the array rhythm.
- The Debt: If you choose a bucket size of 128, and your average sentence length is 20, you are wasting 84% of your compute cycles multiplying zeros.
Architectural Mitigation:
- Bucketing: Sort inputs by length and use multiple distinct compiled graphs for different length buckets (e.g., bucket_64, bucket_128, bucket_256).
- Packing: Concatenate multiple short sequences into one long sequence to fill the buffer, using attention masking to prevent them from “seeing” each other.
7.2.2. The Generations: Choosing the Right Silicon
Unlike NVIDIA’s relatively linear progression (V100 → A100 → H100), Google’s TPU lineup branches into specialized roles. Understanding the difference between “v5e” and “v5p” is critical for your budget and performance.
TPU v4: The Optical Supercomputer
- Era: The workhorse of 2023-2024.
- Key Innovation: Optical Circuit Switching (OCS).
- Traditional clusters use electrical packet switches (InfiniBand/Ethernet).
- TPU v4 pods connect 4,096 chips using mirrors. Yes, MEMS mirrors.
- The Benefit: Reconfigurability. You can dynamically slice a 4,096-chip pod into arbitrary topologies (cubes, meshes) without recabling.
- Use Case: Large-scale training where you need a dedicated “slice” of topology.
TPU v5e: The Efficiency Specialist (“Lite”)
- Philosophy: “Not everyone is training GPT-4.”
- Design: optimized for cost-performance (FLOPS/$).
- Specs: Roughly half the chip area of a v4, but higher density.
- Interconnect: High-speed, but optimized for smaller topologies (up to 256 chips).
- Target:
- Inference (Serving Llama-2-70b).
- Fine-tuning (LoRA).
- Training mid-sized models (< 100B parameters).
- The Trap: Do not try to train a 1 Trillion parameter model on v5e; the cross-chip communication overhead will kill you.
TPU v5p: The Performance Beast
- Philosophy: “We need to beat the H100.”
- Design: Massive High Bandwidth Memory (HBM) capacity and bandwidth.
- Specs: 2x-3x faster than v4.
- Interconnect: 600 GB/s inter-chip links. Scales to tens of thousands of chips in a single pod.
- Target: Frontier model training. If you are burning $10M+ on a training run, this is your chip.
The Decision Matrix
| Constraint | Recommended Silicon | Reason |
|---|---|---|
| Workload: Serving Llama-3-8B | TPU v5e | Overkill to use v5p. v5e offers best price/inference. |
| Workload: Training 7B-70B model | TPU v4 / v5e | Good balance. v5e for cost, v4 if you need faster convergence. |
| Workload: Training > 100B model | TPU v5p | You need the HBM capacity and the OCS scale. |
| Budget: Limited | TPU v5e | Highest FLOPS per dollar. |
| Codebase: PyTorch (Standard) | GPU (A100/H100) | While PyTorch/XLA exists, GPUs are still the path of least resistance for pure PyTorch. |
| Codebase: JAX / TensorFlow | TPU | Native compilation advantage. |
7.2.3. Topology and Interconnects: The 3D Torus
In the NVIDIA world, we talk about NVLink within a server and InfiniBand/RoCE across servers. In the TPU world, these boundaries dissolve. The TPU interconnect (ICI) fuses the chips into a single logical mesh.
The 3D Torus
Imagine a 3D grid of chips (X, Y, Z axes).
- Chip (0,0,0) is directly connected to (0,0,1), (0,1,0), and (1,0,0).
- This allows extremely low-latency communication for “Neighbor” operations.
The Wrap-Around: In a Torus, the edge connects back to the beginning. Chip (N, 0, 0) connects to Chip (0, 0, 0). This reduces the maximum number of hops (diameter) across the network.
The Topology Awareness Trap
When you provision a TPU Pod Slice (e.g., v4-128), you are physically renting a sub-section of this 3D lattice.
- The Default: You get a shape, say $4 \times 4 \times 8$.
- The Code Impact: If your model parallelism strategy assumes a ring, but the hardware provides a cube, your gradients will take inefficient paths through the silicon.
Mitigation: Topology-Aware Placement In XLA and JAX, you can explicitly map your model’s dimensions to the hardware mesh dimensions.
# JAX Topology Definition
from jax.sharding import Mesh, PartitionSpec, NamedSharding
# We define the physical mesh provided by the TPU slice
# "x", "y", "z" map to the physical interconnect axes
device_mesh = mesh_utils.create_device_mesh((4, 4, 8))
mesh = Mesh(device_mesh, axis_names=('x', 'y', 'z'))
# We map Model Layers to the Mesh
# Here, we shard the 'batch' dimension across 'x' and 'y' (16-way data parallelism)
# And the 'embed' dimension across 'z' (8-way model parallelism)
sharding_spec = NamedSharding(mesh, PartitionSpec(('x', 'y'), 'z'))
By aligning the logical sharding with the physical wires, you can achieve near-linear scaling efficiency (90%+) where Ethernet-based clusters often drop to 60-70%.
7.2.4. The Software Stack: XLA and the “Graph”
Using a TPU effectively requires accepting a hard truth: You are not writing Python code; you are writing a meta-program that generates a computation graph.
The Compiler: XLA (Accelerated Linear Algebra)
When you run code on a CPU, the interpreter executes line-by-line. When you run code on a TPU via XLA:
- Tracing: Python runs. It records operations (Add, MatMul, Relu) into a symbolic graph. It does not execute them.
- Optimization: XLA analyzes the graph. It fuses operations.
- Fusion Example:
Relu(Add(MatMul(A, B), C))becomes a single hardware kernel call. No writing intermediate memory.
- Fusion Example:
- Compilation: The graph is lowered to machine code for the specific TPU version.
- Execution: The binary is uploaded to the TPU and run.
The Trap: Recompilation Hell
This compilation step takes time (seconds to minutes).
- The Anti-Pattern: Passing a changing Python scalar or a varying tensor shape into a JIT-compiled function.
- The Result: XLA sees a “new” function signature. It triggers a full recompilation. The system stalls for 30 seconds.
- The Symptom: “My first batch takes 30 seconds, my second batch takes 30 seconds…” (It should take 10ms).
Code Example: The “Static Argument” Fix
import jax
import jax.numpy as jnp
# BAD: 'dropout_rate' is passed as a dynamic tracer, but acts as a constant
@jax.jit
def train_step_bad(params, inputs, dropout_rate):
# Logic utilizing dropout_rate
pass
# GOOD: Tell JAX that 'dropout_rate' is a static configuration, not a tensor
@jax.jit(static_argnames=['dropout_rate'])
def train_step_good(params, inputs, dropout_rate):
# Logic utilizing dropout_rate
pass
7.2.5. Operationalizing TPUs: The Host-Device Relationship
Operations on TPU work differently than GPU instances on EC2.
The Split Brain: Worker vs. Accelerator
- Single-Host (v2/v3/v5e small slices): One VM controls 1, 4, or 8 TPU chips. This feels like a standard GPU box.
- Multi-Host (Pod Slices): This is where it gets weird.
- You provision a
v4-128. - GCP spins up 16 separate VMs (hosts).
- Each VM controls 8 TPU chips.
- Your Python code must run on all 16 VMs simultaneously.
- You provision a
The Orchestration Challenge
You cannot just ssh into one box and run python train.py. You need to launch the process on the entire fleet in sync.
Tooling Solution: Google Cloud TPU VM Architecture Historically, Google used “TPU Nodes” (where you couldn’t SSH into the host). Now, with TPU VMs, you have root access to the machines physically attached to the TPUs.
The Startup Script (GKE / JobSet) In Kubernetes (GKE), this is handled by the JobSet API or the TPU Operator. It creates a headless service to allow the workers to discover each other.
# Kubernetes JobSet snippet for TPU Multi-Host
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: llama-3-training
spec:
replicatedJobs:
- name: workers
replicas: 1
template:
spec:
parallelism: 4 # 4 VMs (e.g., v4-32 slice)
completions: 4
template:
spec:
nodeSelector:
cloud.google.com/gke-tpu-topology: 2x2x4 # Request specific topology
containers:
- name: train
image: us-docker.pkg.dev/my-project/train:latest
env:
- name: JAX_COORDINATOR_ADDRESS
value: "$(master-service-host):8471"
Fault Tolerance: The “Orbax” Checkpoint
In a system with 4,096 chips, the probability of a cosmic ray bit-flip or a hardware failure approaches 1.
- Synchronous Failure: If one chip fails, the global barrier synchronization halts the entire pod.
- The Mitigation: Frequent, asynchronous checkpointing.
- Orbax: Google’s open-source library designed for checkpointing massive sharded arrays across distributed hosts without blocking the training loop for too long.
7.2.6. Benchmarking and Cost Economics: The MFU Metric
When comparing TPU v5p to H100, do not look at “Peak TFLOPS” in the spec sheet. That is a theoretical number assuming perfect spherical cows in a vacuum.
Look at MFU (Model FLOPs Utilization). $$ MFU = \frac{\text{Observed Througput (FLOPs/sec)}}{\text{Theoretical Peak (FLOPs/sec)}} $$
- GPU Reality: On large clusters, GPUs often struggle to sustain >40-50% MFU due to PCIe bottlenecks and Ethernet latency.
- TPU Reality: Due to the OCS and native mesh networking, well-tuned TPU workloads frequently hit 60-75% MFU.
The Economic Implications
If Chip A costs $2/hr and claims 100 TFLOPS (but delivers 40%), and Chip B costs $2/hr and claims 80 TFLOPS (but delivers 60%):
- Chip A Effective: 40 TFLOPS
- Chip B Effective: 48 TFLOPS
Chip B (the TPU, often) is 20% faster in reality, despite being “slower” on paper.
Cost Efficiency (v5e) The TPU v5e is aggressively priced. For workloads that fit within its memory/interconnect constraints, it often delivers 3x-4x better performance-per-dollar than A100s. It is the “Toyota Camry” of AI chips—reliable, efficient, and everywhere.
7.2.7. Architecture Patterns for Large Scale Training
Scaling to thousands of chips requires sophisticated parallelism strategies.
SPMD (Single Program, Multiple Data)
You write one program. It runs on every chip. The only difference is the slice of data each chip sees.
The Sharding Dimensions
To train a model larger than the memory of a single chip (e.g., 70B params > 16GB HBM), you must shard.
- Data Parallelism (DP): Copy model to all chips. Split batch across chips.
- Limit: Model must fit in one chip.
- Fully Sharded Data Parallel (FSDP): Shard the model parameters, gradients, and optimizer state across chips. Gather them only when needed for computation.
- Tensor Parallelism (TP): Split individual matrix multiplications across chips.
- Requires: Ultra-fast interconnect (ICI). This is the TPU’s home turf.
- Pipeline Parallelism (PP): Put Layer 1 on Chip A, Layer 2 on Chip B.
- Problem: “The Bubble”. Chip B waits for Chip A.
- TPU Context: Often unnecessary on TPU pods because Tensor Parallelism scales so well on the mesh.
GSPMD: The Generalizer
Google developed GSPMD, a compiler pass in XLA that handles sharding automatically based on simple annotations.
# The "Magic" of GSPMD in JAX
# We annotate the weight matrix "W"
# "mesh" is our 2D grid of chips
# P('x', 'y') means: Shard the first dimension on mesh axis x, second on y.
W = jax.random.normal(key, (8192, 8192))
W_sharded = jax.device_put(W, NamedSharding(mesh, PartitionSpec('x', 'y')))
# Now, any operation on W_sharded is automatically distributed.
# A matmul: Y = X @ W
# XLA generates the necessary "All-Gather" and "Reduce-Scatter" collectives
# to move data across the ICI wires without the user writing communication code.
7.2.8. Common Pitfalls (The Anti-Pattern Zoo)
1. The Data Feed Starvation
The TPU is a Ferrari engine. If you feed it with a garden hose (standard Python DataLoader), it will stall.
- Symptom: TPU utilization oscillates (0% -> 100% -> 0%).
- Cause: The CPU host cannot unzip/parse images fast enough.
- Fix: Use
tf.data(TensorFlow Data) orgrain(Google’s new JAX data loader) which are optimized for prefetching and C++ execution. Store data in ArrayRecord or TFRecord formats, not loose JPEGs.
2. The Floating Point Trap (BF16 vs FP32)
TPUs are designed for BFloat16 (Brain Floating Point).
- BF16 has the same range as FP32 (8-bit exponent) but lower precision (7-bit mantissa).
- The Trap: Using standard FP16 (IEEE). TPUs emulate FP16 slowly or cast it.
- The Fix: Always use BF16 for training. It is numerically stable (unlike FP16) and runs at peak speed on MXUs.
3. The “Opaque Error”
When XLA crashes, it often emits a C++ stack trace from the compiler internals that looks like hieroglyphics.
- Strategy:
- Disable JIT (
jax.disable_jit()) to debug logic errors in pure Python. - Use
jax.debug.print()which injects print operations into the compiled graph (runtime printing).
- Disable JIT (
7.2.9. Conclusion: The Strategic Bet
Adopting TPUs is a strategic bet on Vertical Integration.
- On AWS: You are integrating components from Intel (CPU), NVIDIA (GPU), and AWS (Nitro/EFA). You are the integrator.
- On GCP: You are entering a walled garden where the cooler, the chip, the network switch, the compiler, and the orchestration software were all designed by the same company to do one thing: Math.
For generic, explorative work or teams deeply entrenched in legacy CUDA kernels, the friction may be too high. But for organizations aiming to train foundation models or serve inference at global scale, the TPU offers an architectural purity and economic efficiency that is arguably the highest in the cloud.
In the next section, we will look at how to orchestrate these powerful compute resources using Kubernetes, and the specific quirks of managing EKS vs GKE for AI workloads.
7.2.10. Real-World Case Study: Foundation Model Training on TPU v5p
Company: LangTech AI (anonymized)
Challenge: Train a 52B parameter encoder-decoder model (T5-style) for multilingual translation with <$150k budget.
Initial GPU Baseline (A100):
# Configuration: 64× a2-ultragpu-1g (64× A100 80GB)
# Cost: ~$16/hr per instance
# Total: $16 × 64 = $1,024/hr
# Estimated training time: 18 days
# Total cost: $1,024 × 24 × 18 = $442,368 (WAY OVER BUDGET)
# Standard PyTorch FSDP
model = T5ForConditionalGeneration.from_pretrained("t5-11b")
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)
# Bottleneck: Cross-node communication for all-reduce
# Achieved MFU: ~42% (significant network overhead)
Migrated to TPU v5p Pod:
# Configuration: v5p-128 (128 TPU v5p chips)
# Cost: ~$8/hr per chip
# Total: $8 × 128 = $1,024/hr (same as GPU option)
# Actual training time: 10 days (45% faster!)
# Total cost: $1,024 × 24 × 10 = $245,760
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
# JAX model definition
class T5Model(nn.Module):
vocab_size: int = 32128
d_model: int = 1024
num_layers: int = 24
@nn.compact
def __call__(self, input_ids, decoder_input_ids):
# Encoder
encoder_embed = nn.Embed(self.vocab_size, self.d_model)(input_ids)
encoder_output = encoder_embed
for _ in range(self.num_layers):
encoder_output = TransformerEncoderLayer(self.d_model)(encoder_output)
# Decoder
decoder_embed = nn.Embed(self.vocab_size, self.d_model)(decoder_input_ids)
decoder_output = decoder_embed
for _ in range(self.num_layers):
decoder_output = TransformerDecoderLayer(self.d_model)(
decoder_output, encoder_output
)
logits = nn.Dense(self.vocab_size)(decoder_output)
return logits
# Sharding specification for 128 TPUs (8×4×4 mesh)
from jax.sharding import Mesh, PartitionSpec, NamedSharding
devices = jax.devices()
device_mesh = np.array(devices).reshape(8, 4, 4)
mesh = Mesh(device_mesh, axis_names=('data', 'model', 'tensor'))
# Shard model across tensor dimension, data across data dimension
sharding = NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
# Training loop with automatic GSPMD
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = state.apply_fn({'params': params}, batch['input_ids'], batch['decoder_input_ids'])
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['labels'])
return jnp.mean(loss)
loss, grads = jax.value_and_grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
# Results:
# - Throughput: 125k tokens/sec (vs 78k on GPU)
# - MFU: 67% (vs 42% on GPU) - 60% better efficiency!
# - Total cost: $246k (vs $442k GPU, 44% savings)
# - Training time: 10 days (vs 18 days GPU, 45% faster)
Migration Challenges:
-
Challenge: PyTorch codebase conversion to JAX
- Solution: 3-week engineer effort, ~2,500 lines rewritten
- Tools: Used
jax2torchconverter for reference, manual fixes
-
Challenge: Dynamic sequence lengths causing recompilation
- Solution: Implemented bucketing strategy (128, 256, 512, 1024, 2048)
- Result: 95% of samples fit into 3 buckets, <5% padding waste
-
Challenge: Debugging compilation errors
- Solution: Disabled JIT initially, debugged in Python, then re-enabled
- Tools:
JAX_DISABLE_JIT=1 python train.pyfor debugging
Key Learnings:
- TPU v5p’s optical circuit switching eliminated GPU’s network bottleneck
- MFU improvement (42% → 67%) was the critical cost driver
- JAX migration ROI: 3 weeks investment saved $196k (1 training run)
- Bucketing strategy essential for variable-length sequences
7.2.11. Advanced Optimization Techniques
Technique 1: Efficient Data Loading with grain
# Google's grain library for efficient TPU data loading
import grain.python as grain
class TranslationDataset:
"""Custom dataset for TPU-optimized loading"""
def __init__(self, data_dir, split='train'):
# Use ArrayRecord format (optimized for TPU)
self.arrayrecord_path = f"{data_dir}/{split}.arrayrecord"
self.data_source = grain.ArrayRecordDataSource(self.arrayrecord_path)
def __len__(self):
return len(self.data_source)
def __getitem__(self, idx):
record = self.data_source[idx]
# Parse serialized example
example = parse_example(record)
return {
'input_ids': example['source'],
'decoder_input_ids': example['target'][:-1],
'labels': example['target'][1:]
}
# Create optimized dataloader
def create_tpu_dataloader(dataset, batch_size=128, num_epochs=None):
"""Create dataloader with TPU-specific optimizations"""
# Shuffle with large buffer
sampler = grain.IndexSampler(
len(dataset),
shuffle=True,
seed=42,
num_epochs=num_epochs
)
# Batch with padding to fixed shapes
operations = [
grain.Batch(batch_size=batch_size, drop_remainder=True),
grain.PadToMaxLength(
max_length={'input_ids': 512, 'decoder_input_ids': 512, 'labels': 512},
pad_value=0
)
]
loader = grain.DataLoader(
data_source=dataset,
sampler=sampler,
operations=operations,
worker_count=32, # Parallel workers
worker_buffer_size=2 # Prefetch depth
)
return loader
# Result: Eliminates data loading bottleneck
# TPU utilization: 95%+ (vs 70% with naive loading)
Technique 2: Topology-Aware Model Sharding
import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
def create_optimal_mesh(num_chips):
"""Create 3D mesh matching physical TPU topology"""
# For v5p-128: 8×4×4 topology
# For v5p-256: 8×8×4 topology
# For v5p-512: 8×8×8 topology
if num_chips == 128:
mesh_shape = (8, 4, 4)
elif num_chips == 256:
mesh_shape = (8, 8, 4)
elif num_chips == 512:
mesh_shape = (8, 8, 8)
else:
raise ValueError(f"Unsupported chip count: {num_chips}")
devices = mesh_utils.create_device_mesh(mesh_shape)
mesh = Mesh(devices, axis_names=('data', 'fsdp', 'tensor'))
return mesh
def shard_params_optimally(params, mesh):
"""Shard model parameters across mesh dimensions"""
# Embedding tables: shard vocab dimension across 'tensor'
embedding_sharding = NamedSharding(mesh, PartitionSpec(None, 'tensor'))
# Attention weights: shard across 'fsdp' and 'tensor'
attention_sharding = NamedSharding(mesh, PartitionSpec('fsdp', 'tensor'))
# FFN weights: shard across 'tensor' only
ffn_sharding = NamedSharding(mesh, PartitionSpec(None, 'tensor'))
# Apply sharding spec
sharded_params = {
'embeddings': jax.device_put(params['embeddings'], embedding_sharding),
'attention': jax.device_put(params['attention'], attention_sharding),
'ffn': jax.device_put(params['ffn'], ffn_sharding)
}
return sharded_params
# Usage
mesh = create_optimal_mesh(num_chips=128)
sharded_params = shard_params_optimally(model_params, mesh)
# Result: Near-linear scaling efficiency
# 128 chips: 67% MFU
# 256 chips: 65% MFU (only 3% drop when doubling scale!)
Technique 3: Gradient Accumulation for Large Batch Training
import jax
import jax.numpy as jnp
def create_accumulation_step(train_step_fn, accumulation_steps=4):
"""Implement gradient accumulation for effective large batches"""
def accumulate_gradients(state, batches):
"""Accumulate gradients over multiple micro-batches"""
accumulated_grads = jax.tree_map(jnp.zeros_like, state.params)
total_loss = 0.0
for micro_batch in batches:
# Compute gradients for micro-batch
def loss_fn(params):
logits = state.apply_fn({'params': params}, **micro_batch)
loss = compute_loss(logits, micro_batch['labels'])
return loss / accumulation_steps # Scale loss
loss, grads = jax.value_and_grad(loss_fn)(state.params)
# Accumulate
accumulated_grads = jax.tree_map(
lambda acc, g: acc + g,
accumulated_grads,
grads
)
total_loss += loss
# Apply accumulated gradients
state = state.apply_gradients(grads=accumulated_grads)
return state, total_loss
return accumulate_gradients
# Usage: Effective batch size = micro_batch × accumulation_steps × num_chips
# Example: 32 × 4 × 128 = 16,384 effective batch size
# Fits in memory while achieving large-batch training benefits
7.2.12. Cost Optimization Strategies
Strategy 1: Preemptible TPU Pods
# Create preemptible TPU pod for 60-70% savings
from google.cloud import tpu_v2
def create_preemptible_tpu_pod(
project_id,
zone,
tpu_name,
accelerator_type="v5litepod-16",
runtime_version="tpu-vm-tf-2.14.0"
):
"""Create preemptible TPU pod with automatic checkpointing"""
client = tpu_v2.TpuClient()
tpu = tpu_v2.Node(
name=f"projects/{project_id}/locations/{zone}/nodes/{tpu_name}",
accelerator_type=accelerator_type,
runtime_version=runtime_version,
network_config=tpu_v2.NetworkConfig(
enable_external_ips=True
),
scheduling_config=tpu_v2.SchedulingConfig(
preemptible=True # 60-70% discount
),
metadata={
# Startup script for automatic checkpoint restoration
"startup-script": """#!/bin/bash
gsutil cp gs://my-bucket/checkpoint-latest/* /tmp/checkpoint/
python3 /home/user/train.py --restore_from=/tmp/checkpoint
"""
}
)
operation = client.create_node(
parent=f"projects/{project_id}/locations/{zone}",
node_id=tpu_name,
node=tpu
)
print(f"Creating TPU pod: {tpu_name}")
result = operation.result() # Wait for completion
return result
# Savings example:
# v5p-128 on-demand: $1,024/hr
# v5p-128 preemptible: $307/hr (70% savings!)
# 10-day training: $245k → $74k
Strategy 2: Reserved Capacity for Long Training Runs
# Reserved TPU capacity for predictable costs
def calculate_tpu_reservation_savings(
monthly_chip_hours,
chip_type="v5p",
on_demand_rate=8.00, # $/chip-hr
commitment_months=12
):
"""Calculate savings from TPU reserved capacity"""
# Reservation discounts (approximate)
reservation_discounts = {
1: 0.15, # 15% for 1-month
3: 0.25, # 25% for 3-month
12: 0.40 # 40% for 1-year
}
discount = reservation_discounts[commitment_months]
reserved_rate = on_demand_rate * (1 - discount)
monthly_cost_on_demand = monthly_chip_hours * on_demand_rate
monthly_cost_reserved = monthly_chip_hours * reserved_rate
total_savings = (monthly_cost_on_demand - monthly_cost_reserved) * commitment_months
print(f"Chip type: {chip_type}")
print(f"Monthly chip-hours: {monthly_chip_hours}")
print(f"On-demand: ${monthly_cost_on_demand:,.2f}/month")
print(f"Reserved ({commitment_months}mo): ${monthly_cost_reserved:,.2f}/month")
print(f"Total savings over {commitment_months} months: ${total_savings:,.2f}")
return total_savings
# Example: v5p-128 running 50% of the time
savings = calculate_tpu_reservation_savings(
monthly_chip_hours=128 * 24 * 30 * 0.5, # 50% utilization
commitment_months=12
)
# Output: Total savings: $196,608 over 12 months
Strategy 3: TPU v5e for Cost-Optimized Training
# Use TPU v5e for models <100B parameters
# Cost comparison (approximate):
# v5p: $8/chip-hr, 128 chips = $1,024/hr
# v5e: $2/chip-hr, 256 chips = $512/hr (50% cheaper!)
# Performance comparison:
# v5p: 67% MFU, 125k tokens/sec
# v5e: 58% MFU, 89k tokens/sec (71% of v5p throughput)
# Cost per token:
# v5p: $1,024/hr / 125k tokens/sec = $0.0082 per 1M tokens
# v5e: $512/hr / 89k tokens/sec = $0.0057 per 1M tokens (30% cheaper!)
# Decision framework:
# - Model <70B: Use v5e (best cost/token)
# - Model 70-200B: Use v5p if budget allows, v5e otherwise
# - Model >200B: Use v5p (v5e lacks HBM capacity)
7.2.13. Monitoring and Debugging
TPU Profiling:
import jax
from jax import profiler
def profile_training_step(train_step_fn, state, batch):
"""Profile TPU execution to identify bottlenecks"""
# Start profiling server
profiler.start_server(port=9999)
# Run training step with profiling
with profiler.trace("/tmp/tensorboard"):
for step in range(100): # Profile 100 steps
state, loss = train_step_fn(state, batch)
# Add custom annotations
profiler.annotate_function(
train_step_fn,
name=f"train_step_{step}"
)
print("Profiling complete. View in TensorBoard:")
print("tensorboard --logdir=/tmp/tensorboard --port=6006")
# Key metrics to analyze:
# 1. Device compute time (should be >90% of total)
# 2. Host-to-device transfer time (should be <5%)
# 3. Compilation time (only on first step)
# 4. Idle time (should be <2%)
# Common issues:
# - High transfer time → Data loading bottleneck
# - High idle time → Unbalanced sharding
# - Frequent compilation → Dynamic shapes (need bucketing)
Cloud Monitoring Integration:
from google.cloud import monitoring_v3
import jax
def publish_tpu_metrics(project_id):
"""Publish custom TPU training metrics"""
client = monitoring_v3.MetricServiceClient()
project_name = f"projects/{project_id}"
# Get TPU device info
devices = jax.devices()
num_devices = len(devices)
# Metrics to track
metrics_data = {
'tpu/mfu': 0.67, # Model FLOPs Utilization
'tpu/tokens_per_second': 125000,
'tpu/cost_per_million_tokens': 0.0082,
'tpu/training_loss': 2.45,
'tpu/num_active_devices': num_devices
}
for metric_name, value in metrics_data.items():
series = monitoring_v3.TimeSeries()
series.metric.type = f"custom.googleapis.com/{metric_name}"
series.resource.type = "gce_instance"
point = monitoring_v3.Point()
point.value.double_value = value
series.points = [point]
client.create_time_series(name=project_name, time_series=[series])
print(f"Published {len(metrics_data)} metrics to Cloud Monitoring")
# Create alert for low MFU
def create_mfu_alert(project_id, threshold=0.50):
"""Alert when MFU drops below threshold"""
alert_client = monitoring_v3.AlertPolicyServiceClient()
alert_policy = monitoring_v3.AlertPolicy(
display_name=f"Low TPU MFU (<{threshold*100}%)",
conditions=[{
"display_name": "MFU threshold",
"condition_threshold": {
"filter": 'metric.type="custom.googleapis.com/tpu/mfu"',
"comparison": "COMPARISON_LT",
"threshold_value": threshold,
"duration": {"seconds": 600}
}
}]
)
policy = alert_client.create_alert_policy(
name=f"projects/{project_id}",
alert_policy=alert_policy
)
print(f"Created MFU alert: {policy.name}")
7.2.14. Troubleshooting Guide
| Issue | Symptoms | Diagnosis | Solution |
|---|---|---|---|
| Compilation taking forever | First step >30min | Complex graph, dynamic shapes | Enable bucketing, simplify model, use static shapes |
| Low MFU (<40%) | Slow training, TPU idle | Data loading bottleneck | Use ArrayRecord format, increase prefetch, optimize data pipeline |
| OOM during compilation | Compilation fails with OOM | Graph too large for compiler | Reduce model size, enable rematerialization, split into sub-graphs |
| NaN losses | Training diverges early | Numerical instability | Use BF16 instead of FP16, reduce learning rate, enable gradient clipping |
| Slow cross-pod communication | Doesn’t scale beyond 128 chips | Network bottleneck | Verify ICI topology, increase tensor parallelism, reduce pipeline parallelism |
| JAX XLA errors | Cryptic C++ stack traces | Unsupported operation | Disable JIT (JAX_DISABLE_JIT=1), debug in Python, rewrite operation |
Debug Commands:
# Check TPU status
gcloud compute tpus tpu-vm list --zone=us-central2-b
# SSH into TPU VM
gcloud compute tpus tpu-vm ssh my-tpu --zone=us-central2-b
# Check TPU chip status
python3 -c "import jax; print(jax.devices())"
# Monitor TPU utilization
python3 -c "
import jax
from jax.experimental import profiler
profiler.start_server(9999)
"
# Then open tensorboard
# Test ICI bandwidth
python3 -c "
import jax
import jax.numpy as jnp
# Create large array and all-reduce
x = jnp.ones((1000, 1000))
result = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
print('ICI test passed')
"
# Check for compilation cache
ls -lh ~/.cache/jax_cache/
7.2.15. Best Practices
- Always Use Static Shapes: Pad sequences to fixed lengths, avoid dynamic control flow
- Implement Bucketing: Group inputs by length to minimize padding waste
- Use BF16 for Training: Native hardware support, no loss scaling needed
- Profile Early: Use JAX profiler to identify bottlenecks before scaling
- Optimize Data Pipeline: Use ArrayRecord format, prefetch aggressively
- Start Small: Debug on v5e-8 before scaling to v5p-512
- Monitor MFU: Target >60%, investigate if <50%
- Use Topology-Aware Sharding: Align model parallelism with physical mesh
- Enable Preemptible for Dev: Save 70% on experimental training runs
- Checkpoint Frequently: Every 500-1000 steps for resilience
7.2.16. Comparison: TPU vs GPU Deep Dive
| Aspect | TPU v5p | NVIDIA H100 |
|---|---|---|
| Architecture | Systolic array, OCS | SIMT GPU, NVLink |
| Peak Performance | ~460 TFLOPS (BF16) | ~1,000 TFLOPS (FP8) |
| MFU (Typical) | 60-70% | 40-50% |
| Effective Performance | ~300 TFLOPS | ~450 TFLOPS |
| Memory per Chip | 95 GB HBM | 80 GB HBM3 |
| Interconnect | 600 GB/s ICI (optical) | 900 GB/s NVLink |
| Cluster Scale | 10,000+ chips (native) | Limited by InfiniBand |
| Cost per Chip-Hour | ~$8 | ~$12-15 |
| Ecosystem | JAX/TensorFlow (narrow) | PyTorch/All frameworks |
| Programming Model | XLA (compilation required) | CUDA (imperative) |
| Best For | Large-scale training, JAX/TF | Research, PyTorch, flexibility |
When to Choose TPU:
- Training models >50B parameters at scale
- Using JAX or TensorFlow framework
- Cost is primary concern (>$100k training budget)
- Can invest in XLA/JAX ecosystem learning
- Google Cloud committed strategy
When to Choose GPU:
- Research with rapidly changing architectures
- PyTorch-first organization
- Need maximum ecosystem flexibility
- Small scale experiments (<64 accelerators)
- Multi-cloud portability required
7.2.17. Exercises
Exercise 1: JAX Migration Assessment For your PyTorch model:
- Identify dynamic shapes and control flow
- Estimate rewrite effort (% of code)
- Calculate potential MFU improvement (GPU baseline vs TPU target)
- Determine TPU ROI break-even point
Exercise 2: Bucketing Strategy Design Analyze your dataset:
- Plot sequence length distribution
- Design bucket sizes to minimize padding (<10% waste)
- Implement bucketing logic
- Measure throughput improvement
Exercise 3: TPU Profiling Profile a training step:
- Run JAX profiler for 100 steps
- Identify top 3 bottlenecks
- Calculate time breakdown (compute/transfer/idle)
- Optimize bottlenecks and re-profile
Exercise 4: MFU Calculation Measure actual MFU:
- Count model FLOPs per forward+backward pass
- Measure wall-clock time per step
- Calculate observed TFLOPS
- Compare to theoretical peak
- Identify gap causes
Exercise 5: Cost Optimization Compare strategies for your workload:
- On-demand TPU v5p
- Preemptible TPU v5p (with interruption handling)
- Reserved TPU v5p (1-year)
- TPU v5e alternative
- Calculate total cost and risk for each
7.2.18. Summary
The TPU represents Google’s vertical integration vision for AI compute: custom silicon, networking, compilers, and frameworks co-designed for maximum efficiency at planet scale.
Key Takeaways:
- Systolic Arrays for Efficiency: 60-70% MFU vs 40-50% for GPUs
- Optical Circuit Switching: Enables 10,000+ chip supercomputers
- XLA Compilation: Required paradigm shift from imperative to declarative
- Static Shapes Essential: Dynamic shapes destroy performance
- Cost Advantage: 30-50% cheaper per effective TFLOP
- Ecosystem Trade-off: JAX/TensorFlow required, PyTorch immature
- Scaling Efficiency: Near-linear scaling to thousands of chips
- MFU is King: Focus on utilization, not peak specs
Decision Framework:
- Foundation model training (JAX/TF): TPU v5p strongly recommended
- Mid-size models (<100B): TPU v5e for best cost/performance
- Research (PyTorch): GPU ecosystem more mature
- Cost-constrained: TPU delivers 30-50% savings at scale
- Multi-cloud: GPU for portability, TPU for GCP-only
ROI Calculation:
- JAX migration: 2-4 engineer-weeks (~$30k)
- Training cost savings: 30-50% (~$150k on $300k job)
- Break-even: 1-2 large training runs
- Long-term: Compounds with every training iteration
TPUs are not universally better than GPUs, but for organizations training large models repeatedly on Google Cloud with JAX/TensorFlow, they offer compelling economics and technical advantages that justify the ecosystem investment.
The choice between TPU and GPU is ultimately a choice between vertical integration (efficiency, scale, cost) and horizontal compatibility (flexibility, ecosystem, portability). Choose wisely based on your organization’s strategic priorities.