Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

Chapter 13: The GCP Compute Ecosystem

13.2. The TPU (Tensor Processing Unit) Deep Dive

“We are running out of computing capability. Moore’s Law is effectively dead… The solution is domain-specific architectures.” — John Hennessy, Turing Award Winner and Chairman of Alphabet

In the grand theater of cloud computing, the Graphics Processing Unit (GPU) is the charismatic rock star—versatile, powerful, and universally recognized. It was born for gaming, pivoted to crypto, and found its destiny in AI. However, inside Google’s data centers, there exists a different kind of entity. A silent, industrial-scale mathematician built for a singular purpose: matrix multiplication.

This is the Tensor Processing Unit (TPU).

For the Principal Engineer or Systems Architect, the TPU represents a fundamental divergence in philosophy. While AWS focuses on providing the best possible raw primitives (EC2 instances with NVIDIA cards attached via PCIe), Google Cloud offers a vertically integrated supercomputer.

Choosing the TPU is not just swapping one chip for another; it is adopting a different paradigm of parallelism, networking, and compilation. It is a choice that yields massive dividends in cost-performance and scalability, but demands a rigorous understanding of the underlying “Physics” of the hardware.

This section dissects the TPU from the silicon up to the pod level, contrasting it with the NVIDIA ecosystem, and laying out the architectural patterns required to tame this beast.


7.2.1. The Architecture of Efficacy: Systolic Arrays

To understand why a TPU is orders of magnitude more power-efficient than a general-purpose CPU or even a GPU for specific workloads, we must look at the von Neumann Bottleneck.

In a CPU or GPU, every operation typically involves:

  1. Fetching an instruction.
  2. Fetching data from memory (Registers/L1/L2/HBM) to the Arithmetic Logic Unit (ALU).
  3. Performing the calculation.
  4. Writing the result back to memory.

For a massive matrix multiplication (the beating heart of Deep Learning), this creates a traffic jam. The ALUs spend more time waiting for data to travel across the wires than they do calculating.

The Systolic Paradigm

The TPU abandons this “Fetch-Execute-Write” cycle for a Systolic Array architecture. The term “systolic” comes from biology (systole), referring to the rhythmic pumping of the heart.

Imagine a bucket brigade.

  • CPU/GPU Approach: The worker runs to the water source, fills a bucket, runs to the fire, throws it, and runs back.
  • TPU Approach: A line of workers stands still. They pass the full bucket to their left and the empty bucket to their right in a synchronized rhythm.

In the TPU’s Matrix Multiply Unit (MXU):

  1. Weight parameters are pre-loaded into the array and stay stationary.
  2. Data (activations) flows in from the left.
  3. Partial sums flow down from the top.
  4. In each clock cycle, a cell performs a multiply-accumulate (MAC) operation and passes the data to its neighbor.

$$ C_{ij} = \sum_{k} A_{ik} \times B_{kj} $$

The data flows through the chip like blood. No memory access is required for intermediate results. This allows the TPU to pack tens of thousands of multipliers into a tiny area with minimal heat generation, achieving a TOPS-per-watt ratio that traditional architectures cannot touch.

The Architect’s Constraint: Static Shapes and Padding

This physical reality imposes a strict software constraint: Uniformity.

The Systolic Array is a rigid physical grid (e.g., 128x128). It loves big, rectangular blocks of numbers. It hates irregularity.

  • The Scenario: You are processing sentences of variable length. One is 5 tokens, one is 100 tokens.
  • The CPU/GPU: Handles this via masking and dynamic control flow relatively well.
  • The TPU: The compiler must “pad” the 5-token sentence to a fixed size (e.g., 128) with zeros to fit the array rhythm.
  • The Debt: If you choose a bucket size of 128, and your average sentence length is 20, you are wasting 84% of your compute cycles multiplying zeros.

Architectural Mitigation:

  • Bucketing: Sort inputs by length and use multiple distinct compiled graphs for different length buckets (e.g., bucket_64, bucket_128, bucket_256).
  • Packing: Concatenate multiple short sequences into one long sequence to fill the buffer, using attention masking to prevent them from “seeing” each other.

7.2.2. The Generations: Choosing the Right Silicon

Unlike NVIDIA’s relatively linear progression (V100 → A100 → H100), Google’s TPU lineup branches into specialized roles. Understanding the difference between “v5e” and “v5p” is critical for your budget and performance.

TPU v4: The Optical Supercomputer

  • Era: The workhorse of 2023-2024.
  • Key Innovation: Optical Circuit Switching (OCS).
    • Traditional clusters use electrical packet switches (InfiniBand/Ethernet).
    • TPU v4 pods connect 4,096 chips using mirrors. Yes, MEMS mirrors.
    • The Benefit: Reconfigurability. You can dynamically slice a 4,096-chip pod into arbitrary topologies (cubes, meshes) without recabling.
  • Use Case: Large-scale training where you need a dedicated “slice” of topology.

TPU v5e: The Efficiency Specialist (“Lite”)

  • Philosophy: “Not everyone is training GPT-4.”
  • Design: optimized for cost-performance (FLOPS/$).
  • Specs: Roughly half the chip area of a v4, but higher density.
  • Interconnect: High-speed, but optimized for smaller topologies (up to 256 chips).
  • Target:
    • Inference (Serving Llama-2-70b).
    • Fine-tuning (LoRA).
    • Training mid-sized models (< 100B parameters).
  • The Trap: Do not try to train a 1 Trillion parameter model on v5e; the cross-chip communication overhead will kill you.

TPU v5p: The Performance Beast

  • Philosophy: “We need to beat the H100.”
  • Design: Massive High Bandwidth Memory (HBM) capacity and bandwidth.
  • Specs: 2x-3x faster than v4.
  • Interconnect: 600 GB/s inter-chip links. Scales to tens of thousands of chips in a single pod.
  • Target: Frontier model training. If you are burning $10M+ on a training run, this is your chip.

The Decision Matrix

ConstraintRecommended SiliconReason
Workload: Serving Llama-3-8BTPU v5eOverkill to use v5p. v5e offers best price/inference.
Workload: Training 7B-70B modelTPU v4 / v5eGood balance. v5e for cost, v4 if you need faster convergence.
Workload: Training > 100B modelTPU v5pYou need the HBM capacity and the OCS scale.
Budget: LimitedTPU v5eHighest FLOPS per dollar.
Codebase: PyTorch (Standard)GPU (A100/H100)While PyTorch/XLA exists, GPUs are still the path of least resistance for pure PyTorch.
Codebase: JAX / TensorFlowTPUNative compilation advantage.

7.2.3. Topology and Interconnects: The 3D Torus

In the NVIDIA world, we talk about NVLink within a server and InfiniBand/RoCE across servers. In the TPU world, these boundaries dissolve. The TPU interconnect (ICI) fuses the chips into a single logical mesh.

The 3D Torus

Imagine a 3D grid of chips (X, Y, Z axes).

  • Chip (0,0,0) is directly connected to (0,0,1), (0,1,0), and (1,0,0).
  • This allows extremely low-latency communication for “Neighbor” operations.

The Wrap-Around: In a Torus, the edge connects back to the beginning. Chip (N, 0, 0) connects to Chip (0, 0, 0). This reduces the maximum number of hops (diameter) across the network.

The Topology Awareness Trap

When you provision a TPU Pod Slice (e.g., v4-128), you are physically renting a sub-section of this 3D lattice.

  • The Default: You get a shape, say $4 \times 4 \times 8$.
  • The Code Impact: If your model parallelism strategy assumes a ring, but the hardware provides a cube, your gradients will take inefficient paths through the silicon.

Mitigation: Topology-Aware Placement In XLA and JAX, you can explicitly map your model’s dimensions to the hardware mesh dimensions.

# JAX Topology Definition
from jax.sharding import Mesh, PartitionSpec, NamedSharding

# We define the physical mesh provided by the TPU slice
# "x", "y", "z" map to the physical interconnect axes
device_mesh = mesh_utils.create_device_mesh((4, 4, 8))
mesh = Mesh(device_mesh, axis_names=('x', 'y', 'z'))

# We map Model Layers to the Mesh
# Here, we shard the 'batch' dimension across 'x' and 'y' (16-way data parallelism)
# And the 'embed' dimension across 'z' (8-way model parallelism)
sharding_spec = NamedSharding(mesh, PartitionSpec(('x', 'y'), 'z'))

By aligning the logical sharding with the physical wires, you can achieve near-linear scaling efficiency (90%+) where Ethernet-based clusters often drop to 60-70%.


7.2.4. The Software Stack: XLA and the “Graph”

Using a TPU effectively requires accepting a hard truth: You are not writing Python code; you are writing a meta-program that generates a computation graph.

The Compiler: XLA (Accelerated Linear Algebra)

When you run code on a CPU, the interpreter executes line-by-line. When you run code on a TPU via XLA:

  1. Tracing: Python runs. It records operations (Add, MatMul, Relu) into a symbolic graph. It does not execute them.
  2. Optimization: XLA analyzes the graph. It fuses operations.
    • Fusion Example: Relu(Add(MatMul(A, B), C)) becomes a single hardware kernel call. No writing intermediate memory.
  3. Compilation: The graph is lowered to machine code for the specific TPU version.
  4. Execution: The binary is uploaded to the TPU and run.

The Trap: Recompilation Hell

This compilation step takes time (seconds to minutes).

  • The Anti-Pattern: Passing a changing Python scalar or a varying tensor shape into a JIT-compiled function.
  • The Result: XLA sees a “new” function signature. It triggers a full recompilation. The system stalls for 30 seconds.
  • The Symptom: “My first batch takes 30 seconds, my second batch takes 30 seconds…” (It should take 10ms).

Code Example: The “Static Argument” Fix

import jax
import jax.numpy as jnp

# BAD: 'dropout_rate' is passed as a dynamic tracer, but acts as a constant
@jax.jit
def train_step_bad(params, inputs, dropout_rate):
    # Logic utilizing dropout_rate
    pass

# GOOD: Tell JAX that 'dropout_rate' is a static configuration, not a tensor
@jax.jit(static_argnames=['dropout_rate'])
def train_step_good(params, inputs, dropout_rate):
    # Logic utilizing dropout_rate
    pass

7.2.5. Operationalizing TPUs: The Host-Device Relationship

Operations on TPU work differently than GPU instances on EC2.

The Split Brain: Worker vs. Accelerator

  • Single-Host (v2/v3/v5e small slices): One VM controls 1, 4, or 8 TPU chips. This feels like a standard GPU box.
  • Multi-Host (Pod Slices): This is where it gets weird.
    • You provision a v4-128.
    • GCP spins up 16 separate VMs (hosts).
    • Each VM controls 8 TPU chips.
    • Your Python code must run on all 16 VMs simultaneously.

The Orchestration Challenge

You cannot just ssh into one box and run python train.py. You need to launch the process on the entire fleet in sync.

Tooling Solution: Google Cloud TPU VM Architecture Historically, Google used “TPU Nodes” (where you couldn’t SSH into the host). Now, with TPU VMs, you have root access to the machines physically attached to the TPUs.

The Startup Script (GKE / JobSet) In Kubernetes (GKE), this is handled by the JobSet API or the TPU Operator. It creates a headless service to allow the workers to discover each other.

# Kubernetes JobSet snippet for TPU Multi-Host
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: llama-3-training
spec:
  replicatedJobs:
  - name: workers
    replicas: 1
    template:
      spec:
        parallelism: 4   # 4 VMs (e.g., v4-32 slice)
        completions: 4
        template:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-topology: 2x2x4  # Request specific topology
            containers:
            - name: train
              image: us-docker.pkg.dev/my-project/train:latest
              env:
              - name: JAX_COORDINATOR_ADDRESS
                value: "$(master-service-host):8471"

Fault Tolerance: The “Orbax” Checkpoint

In a system with 4,096 chips, the probability of a cosmic ray bit-flip or a hardware failure approaches 1.

  • Synchronous Failure: If one chip fails, the global barrier synchronization halts the entire pod.
  • The Mitigation: Frequent, asynchronous checkpointing.
  • Orbax: Google’s open-source library designed for checkpointing massive sharded arrays across distributed hosts without blocking the training loop for too long.

7.2.6. Benchmarking and Cost Economics: The MFU Metric

When comparing TPU v5p to H100, do not look at “Peak TFLOPS” in the spec sheet. That is a theoretical number assuming perfect spherical cows in a vacuum.

Look at MFU (Model FLOPs Utilization). $$ MFU = \frac{\text{Observed Througput (FLOPs/sec)}}{\text{Theoretical Peak (FLOPs/sec)}} $$

  • GPU Reality: On large clusters, GPUs often struggle to sustain >40-50% MFU due to PCIe bottlenecks and Ethernet latency.
  • TPU Reality: Due to the OCS and native mesh networking, well-tuned TPU workloads frequently hit 60-75% MFU.

The Economic Implications

If Chip A costs $2/hr and claims 100 TFLOPS (but delivers 40%), and Chip B costs $2/hr and claims 80 TFLOPS (but delivers 60%):

  • Chip A Effective: 40 TFLOPS
  • Chip B Effective: 48 TFLOPS

Chip B (the TPU, often) is 20% faster in reality, despite being “slower” on paper.

Cost Efficiency (v5e) The TPU v5e is aggressively priced. For workloads that fit within its memory/interconnect constraints, it often delivers 3x-4x better performance-per-dollar than A100s. It is the “Toyota Camry” of AI chips—reliable, efficient, and everywhere.


7.2.7. Architecture Patterns for Large Scale Training

Scaling to thousands of chips requires sophisticated parallelism strategies.

SPMD (Single Program, Multiple Data)

You write one program. It runs on every chip. The only difference is the slice of data each chip sees.

The Sharding Dimensions

To train a model larger than the memory of a single chip (e.g., 70B params > 16GB HBM), you must shard.

  1. Data Parallelism (DP): Copy model to all chips. Split batch across chips.
    • Limit: Model must fit in one chip.
  2. Fully Sharded Data Parallel (FSDP): Shard the model parameters, gradients, and optimizer state across chips. Gather them only when needed for computation.
  3. Tensor Parallelism (TP): Split individual matrix multiplications across chips.
    • Requires: Ultra-fast interconnect (ICI). This is the TPU’s home turf.
  4. Pipeline Parallelism (PP): Put Layer 1 on Chip A, Layer 2 on Chip B.
    • Problem: “The Bubble”. Chip B waits for Chip A.
    • TPU Context: Often unnecessary on TPU pods because Tensor Parallelism scales so well on the mesh.

GSPMD: The Generalizer

Google developed GSPMD, a compiler pass in XLA that handles sharding automatically based on simple annotations.

# The "Magic" of GSPMD in JAX
# We annotate the weight matrix "W"
# "mesh" is our 2D grid of chips
# P('x', 'y') means: Shard the first dimension on mesh axis x, second on y.

W = jax.random.normal(key, (8192, 8192))
W_sharded = jax.device_put(W, NamedSharding(mesh, PartitionSpec('x', 'y')))

# Now, any operation on W_sharded is automatically distributed.
# A matmul: Y = X @ W
# XLA generates the necessary "All-Gather" and "Reduce-Scatter" collectives
# to move data across the ICI wires without the user writing communication code.

7.2.8. Common Pitfalls (The Anti-Pattern Zoo)

1. The Data Feed Starvation

The TPU is a Ferrari engine. If you feed it with a garden hose (standard Python DataLoader), it will stall.

  • Symptom: TPU utilization oscillates (0% -> 100% -> 0%).
  • Cause: The CPU host cannot unzip/parse images fast enough.
  • Fix: Use tf.data (TensorFlow Data) or grain (Google’s new JAX data loader) which are optimized for prefetching and C++ execution. Store data in ArrayRecord or TFRecord formats, not loose JPEGs.

2. The Floating Point Trap (BF16 vs FP32)

TPUs are designed for BFloat16 (Brain Floating Point).

  • BF16 has the same range as FP32 (8-bit exponent) but lower precision (7-bit mantissa).
  • The Trap: Using standard FP16 (IEEE). TPUs emulate FP16 slowly or cast it.
  • The Fix: Always use BF16 for training. It is numerically stable (unlike FP16) and runs at peak speed on MXUs.

3. The “Opaque Error”

When XLA crashes, it often emits a C++ stack trace from the compiler internals that looks like hieroglyphics.

  • Strategy:
    • Disable JIT (jax.disable_jit()) to debug logic errors in pure Python.
    • Use jax.debug.print() which injects print operations into the compiled graph (runtime printing).

7.2.9. Conclusion: The Strategic Bet

Adopting TPUs is a strategic bet on Vertical Integration.

  • On AWS: You are integrating components from Intel (CPU), NVIDIA (GPU), and AWS (Nitro/EFA). You are the integrator.
  • On GCP: You are entering a walled garden where the cooler, the chip, the network switch, the compiler, and the orchestration software were all designed by the same company to do one thing: Math.

For generic, explorative work or teams deeply entrenched in legacy CUDA kernels, the friction may be too high. But for organizations aiming to train foundation models or serve inference at global scale, the TPU offers an architectural purity and economic efficiency that is arguably the highest in the cloud.

In the next section, we will look at how to orchestrate these powerful compute resources using Kubernetes, and the specific quirks of managing EKS vs GKE for AI workloads.


7.2.10. Real-World Case Study: Foundation Model Training on TPU v5p

Company: LangTech AI (anonymized)

Challenge: Train a 52B parameter encoder-decoder model (T5-style) for multilingual translation with <$150k budget.

Initial GPU Baseline (A100):

# Configuration: 64× a2-ultragpu-1g (64× A100 80GB)
# Cost: ~$16/hr per instance
# Total: $16 × 64 = $1,024/hr
# Estimated training time: 18 days
# Total cost: $1,024 × 24 × 18 = $442,368 (WAY OVER BUDGET)

# Standard PyTorch FSDP
model = T5ForConditionalGeneration.from_pretrained("t5-11b")
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)

# Bottleneck: Cross-node communication for all-reduce
# Achieved MFU: ~42% (significant network overhead)

Migrated to TPU v5p Pod:

# Configuration: v5p-128 (128 TPU v5p chips)
# Cost: ~$8/hr per chip
# Total: $8 × 128 = $1,024/hr (same as GPU option)
# Actual training time: 10 days (45% faster!)
# Total cost: $1,024 × 24 × 10 = $245,760

import jax
import jax.numpy as jnp
from flax import linen as nn
import optax

# JAX model definition
class T5Model(nn.Module):
    vocab_size: int = 32128
    d_model: int = 1024
    num_layers: int = 24

    @nn.compact
    def __call__(self, input_ids, decoder_input_ids):
        # Encoder
        encoder_embed = nn.Embed(self.vocab_size, self.d_model)(input_ids)
        encoder_output = encoder_embed

        for _ in range(self.num_layers):
            encoder_output = TransformerEncoderLayer(self.d_model)(encoder_output)

        # Decoder
        decoder_embed = nn.Embed(self.vocab_size, self.d_model)(decoder_input_ids)
        decoder_output = decoder_embed

        for _ in range(self.num_layers):
            decoder_output = TransformerDecoderLayer(self.d_model)(
                decoder_output, encoder_output
            )

        logits = nn.Dense(self.vocab_size)(decoder_output)
        return logits

# Sharding specification for 128 TPUs (8×4×4 mesh)
from jax.sharding import Mesh, PartitionSpec, NamedSharding

devices = jax.devices()
device_mesh = np.array(devices).reshape(8, 4, 4)
mesh = Mesh(device_mesh, axis_names=('data', 'model', 'tensor'))

# Shard model across tensor dimension, data across data dimension
sharding = NamedSharding(mesh, PartitionSpec('data', 'tensor', None))

# Training loop with automatic GSPMD
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['input_ids'], batch['decoder_input_ids'])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['labels'])
        return jnp.mean(loss)

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Results:
# - Throughput: 125k tokens/sec (vs 78k on GPU)
# - MFU: 67% (vs 42% on GPU) - 60% better efficiency!
# - Total cost: $246k (vs $442k GPU, 44% savings)
# - Training time: 10 days (vs 18 days GPU, 45% faster)

Migration Challenges:

  1. Challenge: PyTorch codebase conversion to JAX

    • Solution: 3-week engineer effort, ~2,500 lines rewritten
    • Tools: Used jax2torch converter for reference, manual fixes
  2. Challenge: Dynamic sequence lengths causing recompilation

    • Solution: Implemented bucketing strategy (128, 256, 512, 1024, 2048)
    • Result: 95% of samples fit into 3 buckets, <5% padding waste
  3. Challenge: Debugging compilation errors

    • Solution: Disabled JIT initially, debugged in Python, then re-enabled
    • Tools: JAX_DISABLE_JIT=1 python train.py for debugging

Key Learnings:

  • TPU v5p’s optical circuit switching eliminated GPU’s network bottleneck
  • MFU improvement (42% → 67%) was the critical cost driver
  • JAX migration ROI: 3 weeks investment saved $196k (1 training run)
  • Bucketing strategy essential for variable-length sequences

7.2.11. Advanced Optimization Techniques

Technique 1: Efficient Data Loading with grain

# Google's grain library for efficient TPU data loading
import grain.python as grain

class TranslationDataset:
    """Custom dataset for TPU-optimized loading"""

    def __init__(self, data_dir, split='train'):
        # Use ArrayRecord format (optimized for TPU)
        self.arrayrecord_path = f"{data_dir}/{split}.arrayrecord"
        self.data_source = grain.ArrayRecordDataSource(self.arrayrecord_path)

    def __len__(self):
        return len(self.data_source)

    def __getitem__(self, idx):
        record = self.data_source[idx]
        # Parse serialized example
        example = parse_example(record)
        return {
            'input_ids': example['source'],
            'decoder_input_ids': example['target'][:-1],
            'labels': example['target'][1:]
        }

# Create optimized dataloader
def create_tpu_dataloader(dataset, batch_size=128, num_epochs=None):
    """Create dataloader with TPU-specific optimizations"""

    # Shuffle with large buffer
    sampler = grain.IndexSampler(
        len(dataset),
        shuffle=True,
        seed=42,
        num_epochs=num_epochs
    )

    # Batch with padding to fixed shapes
    operations = [
        grain.Batch(batch_size=batch_size, drop_remainder=True),
        grain.PadToMaxLength(
            max_length={'input_ids': 512, 'decoder_input_ids': 512, 'labels': 512},
            pad_value=0
        )
    ]

    loader = grain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=operations,
        worker_count=32,  # Parallel workers
        worker_buffer_size=2  # Prefetch depth
    )

    return loader

# Result: Eliminates data loading bottleneck
# TPU utilization: 95%+ (vs 70% with naive loading)

Technique 2: Topology-Aware Model Sharding

import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

def create_optimal_mesh(num_chips):
    """Create 3D mesh matching physical TPU topology"""

    # For v5p-128: 8×4×4 topology
    # For v5p-256: 8×8×4 topology
    # For v5p-512: 8×8×8 topology

    if num_chips == 128:
        mesh_shape = (8, 4, 4)
    elif num_chips == 256:
        mesh_shape = (8, 8, 4)
    elif num_chips == 512:
        mesh_shape = (8, 8, 8)
    else:
        raise ValueError(f"Unsupported chip count: {num_chips}")

    devices = mesh_utils.create_device_mesh(mesh_shape)
    mesh = Mesh(devices, axis_names=('data', 'fsdp', 'tensor'))

    return mesh

def shard_params_optimally(params, mesh):
    """Shard model parameters across mesh dimensions"""

    # Embedding tables: shard vocab dimension across 'tensor'
    embedding_sharding = NamedSharding(mesh, PartitionSpec(None, 'tensor'))

    # Attention weights: shard across 'fsdp' and 'tensor'
    attention_sharding = NamedSharding(mesh, PartitionSpec('fsdp', 'tensor'))

    # FFN weights: shard across 'tensor' only
    ffn_sharding = NamedSharding(mesh, PartitionSpec(None, 'tensor'))

    # Apply sharding spec
    sharded_params = {
        'embeddings': jax.device_put(params['embeddings'], embedding_sharding),
        'attention': jax.device_put(params['attention'], attention_sharding),
        'ffn': jax.device_put(params['ffn'], ffn_sharding)
    }

    return sharded_params

# Usage
mesh = create_optimal_mesh(num_chips=128)
sharded_params = shard_params_optimally(model_params, mesh)

# Result: Near-linear scaling efficiency
# 128 chips: 67% MFU
# 256 chips: 65% MFU (only 3% drop when doubling scale!)

Technique 3: Gradient Accumulation for Large Batch Training

import jax
import jax.numpy as jnp

def create_accumulation_step(train_step_fn, accumulation_steps=4):
    """Implement gradient accumulation for effective large batches"""

    def accumulate_gradients(state, batches):
        """Accumulate gradients over multiple micro-batches"""

        accumulated_grads = jax.tree_map(jnp.zeros_like, state.params)
        total_loss = 0.0

        for micro_batch in batches:
            # Compute gradients for micro-batch
            def loss_fn(params):
                logits = state.apply_fn({'params': params}, **micro_batch)
                loss = compute_loss(logits, micro_batch['labels'])
                return loss / accumulation_steps  # Scale loss

            loss, grads = jax.value_and_grad(loss_fn)(state.params)

            # Accumulate
            accumulated_grads = jax.tree_map(
                lambda acc, g: acc + g,
                accumulated_grads,
                grads
            )
            total_loss += loss

        # Apply accumulated gradients
        state = state.apply_gradients(grads=accumulated_grads)

        return state, total_loss

    return accumulate_gradients

# Usage: Effective batch size = micro_batch × accumulation_steps × num_chips
# Example: 32 × 4 × 128 = 16,384 effective batch size
# Fits in memory while achieving large-batch training benefits

7.2.12. Cost Optimization Strategies

Strategy 1: Preemptible TPU Pods

# Create preemptible TPU pod for 60-70% savings
from google.cloud import tpu_v2

def create_preemptible_tpu_pod(
    project_id,
    zone,
    tpu_name,
    accelerator_type="v5litepod-16",
    runtime_version="tpu-vm-tf-2.14.0"
):
    """Create preemptible TPU pod with automatic checkpointing"""

    client = tpu_v2.TpuClient()

    tpu = tpu_v2.Node(
        name=f"projects/{project_id}/locations/{zone}/nodes/{tpu_name}",
        accelerator_type=accelerator_type,
        runtime_version=runtime_version,
        network_config=tpu_v2.NetworkConfig(
            enable_external_ips=True
        ),
        scheduling_config=tpu_v2.SchedulingConfig(
            preemptible=True  # 60-70% discount
        ),
        metadata={
            # Startup script for automatic checkpoint restoration
            "startup-script": """#!/bin/bash
            gsutil cp gs://my-bucket/checkpoint-latest/* /tmp/checkpoint/
            python3 /home/user/train.py --restore_from=/tmp/checkpoint
            """
        }
    )

    operation = client.create_node(
        parent=f"projects/{project_id}/locations/{zone}",
        node_id=tpu_name,
        node=tpu
    )

    print(f"Creating TPU pod: {tpu_name}")
    result = operation.result()  # Wait for completion
    return result

# Savings example:
# v5p-128 on-demand: $1,024/hr
# v5p-128 preemptible: $307/hr (70% savings!)
# 10-day training: $245k → $74k

Strategy 2: Reserved Capacity for Long Training Runs

# Reserved TPU capacity for predictable costs
def calculate_tpu_reservation_savings(
    monthly_chip_hours,
    chip_type="v5p",
    on_demand_rate=8.00,  # $/chip-hr
    commitment_months=12
):
    """Calculate savings from TPU reserved capacity"""

    # Reservation discounts (approximate)
    reservation_discounts = {
        1: 0.15,   # 15% for 1-month
        3: 0.25,   # 25% for 3-month
        12: 0.40   # 40% for 1-year
    }

    discount = reservation_discounts[commitment_months]
    reserved_rate = on_demand_rate * (1 - discount)

    monthly_cost_on_demand = monthly_chip_hours * on_demand_rate
    monthly_cost_reserved = monthly_chip_hours * reserved_rate

    total_savings = (monthly_cost_on_demand - monthly_cost_reserved) * commitment_months

    print(f"Chip type: {chip_type}")
    print(f"Monthly chip-hours: {monthly_chip_hours}")
    print(f"On-demand: ${monthly_cost_on_demand:,.2f}/month")
    print(f"Reserved ({commitment_months}mo): ${monthly_cost_reserved:,.2f}/month")
    print(f"Total savings over {commitment_months} months: ${total_savings:,.2f}")

    return total_savings

# Example: v5p-128 running 50% of the time
savings = calculate_tpu_reservation_savings(
    monthly_chip_hours=128 * 24 * 30 * 0.5,  # 50% utilization
    commitment_months=12
)
# Output: Total savings: $196,608 over 12 months

Strategy 3: TPU v5e for Cost-Optimized Training

# Use TPU v5e for models <100B parameters

# Cost comparison (approximate):
# v5p: $8/chip-hr, 128 chips = $1,024/hr
# v5e: $2/chip-hr, 256 chips = $512/hr (50% cheaper!)

# Performance comparison:
# v5p: 67% MFU, 125k tokens/sec
# v5e: 58% MFU, 89k tokens/sec (71% of v5p throughput)

# Cost per token:
# v5p: $1,024/hr / 125k tokens/sec = $0.0082 per 1M tokens
# v5e: $512/hr / 89k tokens/sec = $0.0057 per 1M tokens (30% cheaper!)

# Decision framework:
# - Model <70B: Use v5e (best cost/token)
# - Model 70-200B: Use v5p if budget allows, v5e otherwise
# - Model >200B: Use v5p (v5e lacks HBM capacity)

7.2.13. Monitoring and Debugging

TPU Profiling:

import jax
from jax import profiler

def profile_training_step(train_step_fn, state, batch):
    """Profile TPU execution to identify bottlenecks"""

    # Start profiling server
    profiler.start_server(port=9999)

    # Run training step with profiling
    with profiler.trace("/tmp/tensorboard"):
        for step in range(100):  # Profile 100 steps
            state, loss = train_step_fn(state, batch)

            # Add custom annotations
            profiler.annotate_function(
                train_step_fn,
                name=f"train_step_{step}"
            )

    print("Profiling complete. View in TensorBoard:")
    print("tensorboard --logdir=/tmp/tensorboard --port=6006")

# Key metrics to analyze:
# 1. Device compute time (should be >90% of total)
# 2. Host-to-device transfer time (should be <5%)
# 3. Compilation time (only on first step)
# 4. Idle time (should be <2%)

# Common issues:
# - High transfer time → Data loading bottleneck
# - High idle time → Unbalanced sharding
# - Frequent compilation → Dynamic shapes (need bucketing)

Cloud Monitoring Integration:

from google.cloud import monitoring_v3
import jax

def publish_tpu_metrics(project_id):
    """Publish custom TPU training metrics"""

    client = monitoring_v3.MetricServiceClient()
    project_name = f"projects/{project_id}"

    # Get TPU device info
    devices = jax.devices()
    num_devices = len(devices)

    # Metrics to track
    metrics_data = {
        'tpu/mfu': 0.67,  # Model FLOPs Utilization
        'tpu/tokens_per_second': 125000,
        'tpu/cost_per_million_tokens': 0.0082,
        'tpu/training_loss': 2.45,
        'tpu/num_active_devices': num_devices
    }

    for metric_name, value in metrics_data.items():
        series = monitoring_v3.TimeSeries()
        series.metric.type = f"custom.googleapis.com/{metric_name}"
        series.resource.type = "gce_instance"

        point = monitoring_v3.Point()
        point.value.double_value = value

        series.points = [point]
        client.create_time_series(name=project_name, time_series=[series])

    print(f"Published {len(metrics_data)} metrics to Cloud Monitoring")

# Create alert for low MFU
def create_mfu_alert(project_id, threshold=0.50):
    """Alert when MFU drops below threshold"""

    alert_client = monitoring_v3.AlertPolicyServiceClient()

    alert_policy = monitoring_v3.AlertPolicy(
        display_name=f"Low TPU MFU (<{threshold*100}%)",
        conditions=[{
            "display_name": "MFU threshold",
            "condition_threshold": {
                "filter": 'metric.type="custom.googleapis.com/tpu/mfu"',
                "comparison": "COMPARISON_LT",
                "threshold_value": threshold,
                "duration": {"seconds": 600}
            }
        }]
    )

    policy = alert_client.create_alert_policy(
        name=f"projects/{project_id}",
        alert_policy=alert_policy
    )

    print(f"Created MFU alert: {policy.name}")

7.2.14. Troubleshooting Guide

IssueSymptomsDiagnosisSolution
Compilation taking foreverFirst step >30minComplex graph, dynamic shapesEnable bucketing, simplify model, use static shapes
Low MFU (<40%)Slow training, TPU idleData loading bottleneckUse ArrayRecord format, increase prefetch, optimize data pipeline
OOM during compilationCompilation fails with OOMGraph too large for compilerReduce model size, enable rematerialization, split into sub-graphs
NaN lossesTraining diverges earlyNumerical instabilityUse BF16 instead of FP16, reduce learning rate, enable gradient clipping
Slow cross-pod communicationDoesn’t scale beyond 128 chipsNetwork bottleneckVerify ICI topology, increase tensor parallelism, reduce pipeline parallelism
JAX XLA errorsCryptic C++ stack tracesUnsupported operationDisable JIT (JAX_DISABLE_JIT=1), debug in Python, rewrite operation

Debug Commands:

# Check TPU status
gcloud compute tpus tpu-vm list --zone=us-central2-b

# SSH into TPU VM
gcloud compute tpus tpu-vm ssh my-tpu --zone=us-central2-b

# Check TPU chip status
python3 -c "import jax; print(jax.devices())"

# Monitor TPU utilization
python3 -c "
import jax
from jax.experimental import profiler
profiler.start_server(9999)
"
# Then open tensorboard

# Test ICI bandwidth
python3 -c "
import jax
import jax.numpy as jnp

# Create large array and all-reduce
x = jnp.ones((1000, 1000))
result = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
print('ICI test passed')
"

# Check for compilation cache
ls -lh ~/.cache/jax_cache/

7.2.15. Best Practices

  1. Always Use Static Shapes: Pad sequences to fixed lengths, avoid dynamic control flow
  2. Implement Bucketing: Group inputs by length to minimize padding waste
  3. Use BF16 for Training: Native hardware support, no loss scaling needed
  4. Profile Early: Use JAX profiler to identify bottlenecks before scaling
  5. Optimize Data Pipeline: Use ArrayRecord format, prefetch aggressively
  6. Start Small: Debug on v5e-8 before scaling to v5p-512
  7. Monitor MFU: Target >60%, investigate if <50%
  8. Use Topology-Aware Sharding: Align model parallelism with physical mesh
  9. Enable Preemptible for Dev: Save 70% on experimental training runs
  10. Checkpoint Frequently: Every 500-1000 steps for resilience

7.2.16. Comparison: TPU vs GPU Deep Dive

AspectTPU v5pNVIDIA H100
ArchitectureSystolic array, OCSSIMT GPU, NVLink
Peak Performance~460 TFLOPS (BF16)~1,000 TFLOPS (FP8)
MFU (Typical)60-70%40-50%
Effective Performance~300 TFLOPS~450 TFLOPS
Memory per Chip95 GB HBM80 GB HBM3
Interconnect600 GB/s ICI (optical)900 GB/s NVLink
Cluster Scale10,000+ chips (native)Limited by InfiniBand
Cost per Chip-Hour~$8~$12-15
EcosystemJAX/TensorFlow (narrow)PyTorch/All frameworks
Programming ModelXLA (compilation required)CUDA (imperative)
Best ForLarge-scale training, JAX/TFResearch, PyTorch, flexibility

When to Choose TPU:

  • Training models >50B parameters at scale
  • Using JAX or TensorFlow framework
  • Cost is primary concern (>$100k training budget)
  • Can invest in XLA/JAX ecosystem learning
  • Google Cloud committed strategy

When to Choose GPU:

  • Research with rapidly changing architectures
  • PyTorch-first organization
  • Need maximum ecosystem flexibility
  • Small scale experiments (<64 accelerators)
  • Multi-cloud portability required

7.2.17. Exercises

Exercise 1: JAX Migration Assessment For your PyTorch model:

  • Identify dynamic shapes and control flow
  • Estimate rewrite effort (% of code)
  • Calculate potential MFU improvement (GPU baseline vs TPU target)
  • Determine TPU ROI break-even point

Exercise 2: Bucketing Strategy Design Analyze your dataset:

  • Plot sequence length distribution
  • Design bucket sizes to minimize padding (<10% waste)
  • Implement bucketing logic
  • Measure throughput improvement

Exercise 3: TPU Profiling Profile a training step:

  • Run JAX profiler for 100 steps
  • Identify top 3 bottlenecks
  • Calculate time breakdown (compute/transfer/idle)
  • Optimize bottlenecks and re-profile

Exercise 4: MFU Calculation Measure actual MFU:

  • Count model FLOPs per forward+backward pass
  • Measure wall-clock time per step
  • Calculate observed TFLOPS
  • Compare to theoretical peak
  • Identify gap causes

Exercise 5: Cost Optimization Compare strategies for your workload:

  • On-demand TPU v5p
  • Preemptible TPU v5p (with interruption handling)
  • Reserved TPU v5p (1-year)
  • TPU v5e alternative
  • Calculate total cost and risk for each

7.2.18. Summary

The TPU represents Google’s vertical integration vision for AI compute: custom silicon, networking, compilers, and frameworks co-designed for maximum efficiency at planet scale.

Key Takeaways:

  1. Systolic Arrays for Efficiency: 60-70% MFU vs 40-50% for GPUs
  2. Optical Circuit Switching: Enables 10,000+ chip supercomputers
  3. XLA Compilation: Required paradigm shift from imperative to declarative
  4. Static Shapes Essential: Dynamic shapes destroy performance
  5. Cost Advantage: 30-50% cheaper per effective TFLOP
  6. Ecosystem Trade-off: JAX/TensorFlow required, PyTorch immature
  7. Scaling Efficiency: Near-linear scaling to thousands of chips
  8. MFU is King: Focus on utilization, not peak specs

Decision Framework:

  • Foundation model training (JAX/TF): TPU v5p strongly recommended
  • Mid-size models (<100B): TPU v5e for best cost/performance
  • Research (PyTorch): GPU ecosystem more mature
  • Cost-constrained: TPU delivers 30-50% savings at scale
  • Multi-cloud: GPU for portability, TPU for GCP-only

ROI Calculation:

  • JAX migration: 2-4 engineer-weeks (~$30k)
  • Training cost savings: 30-50% (~$150k on $300k job)
  • Break-even: 1-2 large training runs
  • Long-term: Compounds with every training iteration

TPUs are not universally better than GPUs, but for organizations training large models repeatedly on Google Cloud with JAX/TensorFlow, they offer compelling economics and technical advantages that justify the ecosystem investment.

The choice between TPU and GPU is ultimately a choice between vertical integration (efficiency, scale, cost) and horizontal compatibility (flexibility, ecosystem, portability). Choose wisely based on your organization’s strategic priorities.