Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

19.3 Debugging: Visualizing Activation Maps & Gradients

Debugging software is hard. Debugging Machine Learning is harder.

In traditional software, a bug usually causes a Crash (Segmentation Fault) or an Error (Exception). In Machine Learning, a bug usually causes… nothing. The model trains. The loss goes down. It predicts “Dog” for everything. Or it gets 90% accuracy but fails in production. This is the Silent Failure of ML.

This chapter covers the tactical skills of debugging Deep Neural Networks: Visualizing what they see, monitoring their internal blood pressure (Gradients), and diagnosing their illnesses (Dead ReLUs, Collapse).


1. The Taxonomy of ML Bugs

Before we open the toolbox, let’s classify the enemy.

1.1. Implementation Bugs (Code)

  • Tensor Shape Mismatch: Broadcasting (B, C, H, W) + (B, C) implicitly might work but produce garbage.
  • Pre-processing Mismatch: Training on 0..255 but inferring on 0..1 floats. The model sees “white noise”.
  • Flip-Flop Labels: Class 0 is Cat in the dataloader, but Class 0 is Dog in the evaluation script.

1.2. Convergence Bugs (Math)

  • Vanishing Gradients: Network is too deep; signal dies before reaching the start.
  • Exploding Gradients: Learning rate too high; weights diverge to NaN.
  • Dead ReLUs: Neurons get stuck outputting 0 and never recover (since gradient of 0 is 0).

1.3. Logic Bugs (Data)

  • Leakage: Target variable contained in features (e.g., “Future Date”).
  • Clever Hans: Model learns background artifacts instead of the object.

2. Visualizing CNNs: Opening the Vision Black Box

Convolutional Neural Networks (CNNs) are spatial. We can visualize their internals.

2.1. Feature Map Visualization

The simplest debug step: “What does Layer 1 see?”

The Technique:

  1. Hook into the model.
  2. Pass an image.
  3. Plot the outputs of the Convolutional filters.

Implementation (PyTorch):

import torch
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt

# 1. Load Model
model = models.resnet18(pretrained=True)
model.eval()

# 2. Define Hook
# A list to store the activations
activations = []

def get_activation(name):
    def hook(model, input, output):
        activations.append(output.detach())
    return hook

# 3. Register Hook on First Layer
model.layer1[0].conv1.register_forward_hook(get_activation("layer1_conv1"))

# 4. Pass Data
input_image = torch.rand(1, 3, 224, 224) # Normalize this properly in real life
output = model(input_image)

# 5. Visualize
act = activations[0].squeeze()
# act shape is [64, 56, 56] (64 filters)

fig, axes = plt.subplots(8, 8, figsize=(12, 12))
for i in range(64):
    row = i // 8
    col = i % 8
    axes[row, col].imshow(act[i], cmap='viridis')
    axes[row, col].axis('off')

plt.show()

Interpretation:

  • Good: You see edges, textures, blobs. Some filters look like diagonal line detectors.
  • Bad: You see solid colors (dead filters) or white noise (random initialization). If Layer 1 looks like noise after training, the model learned nothing.

2.2. Grad-CAM from Scratch

We discussed Grad-CAM conceptually in 19.1. Now let’s implement the Backward Hook logic from scratch. This is essential for debugging models that typical libraries don’t support.

The Math: $$ w_k = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A^k_{ij}} $$ Weight $w_k$ is the global average of the gradients of class score $y^c$ with respect to feature map $A^k$.

import torch.nn.functional as F

class GradCAMExplainer:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_full_backward_hook(self.save_gradient)
        
    def save_activation(self, module, input, output):
        self.activations = output
        
    def save_gradient(self, module, grad_input, grad_output):
        # grad_output[0] corresponds to the gradient of the loss w.r.t the output of this layer
        self.gradients = grad_output[0]
        
    def generate_cam(self, input_tensor, target_class_idx):
        # 1. Forward Pass
        model_output = self.model(input_tensor)
        
        # 2. Zero Grads
        self.model.zero_grad()
        
        # 3. Backward Pass
        # We want gradient of the specific class score
        one_hot_output = torch.zeros_like(model_output)
        one_hot_output[0][target_class_idx] = 1
        
        model_output.backward(gradient=one_hot_output, retain_graph=True)
        
        # 4. Get captured values
        grads = self.gradients.detach().cpu().numpy()[0] # [C, H, W]
        fmaps = self.activations.detach().cpu().numpy()[0] # [C, H, W]
        
        # 5. Global Average Pooling of Gradients
        weights = np.mean(grads, axis=(1, 2)) # [C]
        
        # 6. Weighted Combination
        # cam = sum(weight * fmap)
        cam = np.zeros(fmaps.shape[1:], dtype=np.float32)
        for i, w in enumerate(weights):
             cam += w * fmaps[i, :, :]
             
        # 7. ReLU (Discard negative influence)
        cam = np.maximum(cam, 0)
        
        # 8. Normalize (0..1) for visualization
        cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
        
        # 9. Resize to input size
        # (This is usually done with cv2.resize)
        
        return cam

Debugging Use Case: You try to classify “Stethoscope”.

  • The model predicts “Medical”. OK.
  • You look at Grad-CAM. It is highlighting the Doctor’s Face, not the Stethoscope.
  • Diagnosis: The model has learned “Face + White Coat = Medical”. It doesn’t know what a stethoscope is.

3. Debugging Transformers: Attention Viz

Transformers don’t have “feature maps” in the same way. They have Attention Weights. Matrices of shape (Batch, Heads, SeqLen, SeqLen).

3.1. Attention Collapse

A common bug in Transformer training is “Attention Collapse”.

  • Pattern: All attention heads focus on the [CLS] token or the . (separator) token.
  • Meaning: The model has failed to find relationships between words. It is basically becoming a bag-of-words model.

3.2. Visualizing with BertViz

bertviz is a Jupyter-optimized inspection tool.

from transformers import AutoTokenizer, AutoModel
from bertviz import head_view

# Load
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased", output_attentions=True)

# Run
inputs = tokenizer.encode("The quick brown fox jumps over the dog", return_tensors='pt')
outputs = model(inputs)

# Attention is a list of tensors (one per layer)
attention = outputs.attention 

# Viz
head_view(attention, inputs, tokenizer.convert_ids_to_tokens(inputs[0]))

What to look for:

  1. Diagonal Patterns: Looking at previous/next word (local context). Common in early layers.
  2. Vertical Stripes: Looking at the same word (e.g., [SEP]) for everything. Too much of this = Collapse.
  3. Syntactic Patterns: Nouns looking at Adjectives.

4. Monitoring Training Dynamics

If the visualizations look fine but the model isn’t learning (Loss is flat), we need to look at the Gradients.

4.1. The Gradient Norm

The L2 norm of all gradients in the network.

  • High and Spiky: Exploding Gradients. Learning Rate is too high.
  • Near Zero: Vanishing Gradients. Network too deep or initialization failed.
  • Steady: Good.

4.2. Implementing a Gradient Monitor (PyTorch Lightning)

Don’t write training loops manually. Use Callbacks.

import pytorch_lightning as pl
import numpy as np

class GradientMonitor(pl.Callback):
    def on_after_backward(self, trainer, pl_module):
        # Called after loss.backward() but before optimizer.step()
        
        grad_norms = []
        for name, param in pl_module.named_parameters():
             if param.grad is not None:
                 grad_norms.append(param.grad.norm().item())
        
        # Log to TensorBoard
        avg_grad = np.mean(grad_norms)
        max_grad = np.max(grad_norms)
        
        pl_module.log("grad/avg", avg_grad)
        pl_module.log("grad/max", max_grad)
        
        # Alerting logic
        if avg_grad < 1e-6:
             print(f"WARNING: Vanishing Gradient detected at step {trainer.global_step}")
        if max_grad > 10.0:
             print(f"WARNING: Exploding Gradient! Consider Gradient Clipping.")

# Usage
trainer = pl.Trainer(callbacks=[GradientMonitor()])

4.3. The Dead ReLU Detector

ReLU units output max(0, x). If a neuron’s weights shift such that it always receives negative input, it always outputs 0. Its gradient becomes 0. It never updates again. It is dead.

Top-tier MLOps pipelines monitor Activation Sparsity.

def check_dead_neurons(model, dataloader):
    dead_counts = {}
    
    for inputs, _ in dataloader:
        # Pass data
        activations = get_activations_all_layers(model, inputs)
        
        for name, act in activations.items():
            # Check % of zeros
            sparsity = (act == 0).float().mean()
            if sparsity > 0.99:
                 dead_counts[name] = dead_counts.get(name, 0) + 1
                 
    return dead_counts

If Layer 3 has 99% sparsity, your initialization scheme (He/Xavier) might be wrong, or your Learning Rate is too high (causing weights to jump into the negative regime).


5. Tooling: TensorBoard vs Weights & Biases

5.1. TensorBoard

The original. Runs locally. Good for privacy.

  • Embedding Projector: Visualize PCA/t-SNE of your embeddings. This is critical for debugging retrieval models. If your “Dogs” and “Cats” embeddings are intermingled, your encoder is broken.

5.2. Weights & Biases (W&B)

The modern standard. Cloud-hosted.

  • Gradients: Automatically logs gradient histograms (wandb.watch(model)).
  • System Metrics: Correlates GPU memory usage with Loss spikes (OOM debugging).
  • Comparisons: Overlays Loss curves from experiment A vs B.

Pro Tip: Always log your Configuration (Hyperparams) and Git Commit Hash. “Model 12 worked, Model 13 failed.” “What changed?” “I don’t know.” -> Instant firing offense in MLOps.


6. Interactive Debugging Patterns

6.1. The “Overfit One Batch” Test

Before training on 1TB of data, try to train on 1 Batch of 32 images.

  • Goal: Drive Loss to 0.00000. Accuracy to 100%.
  • Why: A Neural Network is a universal function approximator. It should be able to memorize 32 images easily.
    • If it CANNOT memorize 1 batch: You have a Code Bug (Forward pass broken, Labels flipped, Gradient not connected).
    • If it CAN memorize: Your model architecture works. Now you can try generalization.

6.2. Using ipdb / pdb

You can insert breakpoints in your forward() pass.

def forward(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    
    # Debugging shape mismatch
    import ipdb; ipdb.set_trace()
    
    x = self.fc(x) # Error happens here usually
    return x

Check shapes: x.shape. Check stats: x.mean(). If NaN, you know the previous layer blew up.


7. The Checklist: Analyzing a Broken Model

When a model fails, follow this procedure:

  1. Check Data:
    • Visualize inputs directly before they hit the model (fix normalization bugs).
    • Check statistics of Labels (is it all Class 0?).
  2. Check Initialization:
    • Is loss starting at ln(NumClasses)? (e.g., 2.3 for 10 classes). If it starts at 50, your init is garbage.
  3. Check Overfit:
    • Does “Overfit One Batch” work?
  4. Check Dynamics:
    • Are Gradients non-zero?
    • Is Loss oscillating? (Lower LR).
  5. Check Activation:
    • Are ReLUs dead?
    • Does Grad-CAM look at the object?

In the next chapter, we move from the Development phase to the Operations phase: Deployment and MLOps Infrastructure.


8. Captum: The PyTorch Standard

Writing hooks manually (as we did in Section 2.2) is educational, but in production, you use Captum. Developed by Facebook, it provides a unified API for model interpretability.

8.1. Installation & Setup

pip install captum

Captum algorithms are divided into:

  • Attribution: What pixels/features matter? (IG, Saliency).
  • Occlusion: What happens if I remove this region?
  • Concept: What high-level concept (e.g., “Stripes”) matters?

8.2. Integrated Gradients with Captum

Let’s replace our manual code with the robust version.

from captum.attr import IntegratedGradients
from captum.attr import visualization as viz

# 1. Init Algorithm
ig = IntegratedGradients(model)

# 2. Compute Attribution
# input_img: (1, 3, 224, 224)
# target: Class Index (e.g., 208 for Labrador)
attributions, delta = ig.attribute(
    input_img, 
    target=208, 
    return_convergence_delta=True
)

# 3. Visualize
# Captum provides helper functions to overlay heatmaps
viz.visualize_image_attr(
    np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
    np.transpose(input_img.squeeze().cpu().detach().numpy(), (1,2,0)),
    method="blended_heat_map",
    sign="all",
    show_colorbar=True,
    title="Integrated Gradients"
)

8.3. Occlusion Analysis

Saliency relies on Gradients. But sometimes gradients are misleading (e.g., in discrete architectures or when functions are flat). Occlusion is a perturbation method: “Slide a gray box over the image and see when the probability drops.”

Algorithm:

  1. Define a sliding window (e.g., 15x15 pixels).
  2. Slide it over the image with stride 5.
  3. Mask the window area (set to 0).
  4. Measure drop in target class probability.
from captum.attr import Occlusion

occlusion = Occlusion(model)

attributions_occ = occlusion.attribute(
    input_img,
    strides = (3, 8, 8), # (Channels, H, W)
    target=208,
    sliding_window_shapes=(3, 15, 15),
    baselines=0
)

# The result gives a coarse heatmap showing "Critical Regions"

Debug Insight: If Occlusion highlights the background (e.g., the grass behind the dog) while Integrated Gradients highlights the dog, your model might be relying on Context Correlations rather than the object features.


9. Profiling: Debugging Performance Bugs

Sometimes the bug isn’t “Wrong Accuracy,” it’s “Too Slow.” “My GPU usage is 20%. Why?”

This is a Data Loading Bottleneck or a kernel mismatch. We use the PyTorch Profiler.

9.1. Using the Context Manager

import torch.profiler

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profiler'),
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    
    for step, batch in enumerate(dataloader):
        train_step(batch)
        prof.step()

9.2. Analyzing the Trace

Open TensorBoard and go to the “PyTorch Profiler” tab.

  1. Overview: Look at “GPU Utilization”. If it looks like a comb (Spikes of activity separated by silence), your CPU is too slow feeding the GPU.
    • Fix: Increase num_workers in DataLoader. Use pin_memory=True. Prefetch data.
  2. Kernel View: Which operations take time?
    • Finding: You might see aten::copy_ taking 40% of time.
    • Meaning: You are moving tensors between CPU and GPU constantly inside the loop.
    • Fix: Move everything to GPU once at the start.
  3. Memory View:
    • Finding: Memory usage spikes linearly then crashes.
    • Meaning: You are appending tensors to a list (e.g., losses.append(loss)) without .detach(). You are keeping the entire Computation Graph in RAM.
    • Fix: losses.append(loss.item()).

10. Advanced TensorBoard: Beyond Scalars

Most people only log Loss. You should be logging everything.

10.1. Logging Images with Predictions

Don’t just inspect metrics. Inspect qualitative results during training.

from torch.utils.tensorboard import SummaryWriter
import torchvision

writer = SummaryWriter('runs/experiment_1')

# In your validation loop
images, labels = next(iter(val_loader))
preds = model(images).argmax(dim=1)

# Create a grid of images
img_grid = torchvision.utils.make_grid(images)

# Log
writer.add_image('Validation Images', img_grid, global_step)

# Advanced: Add Text Labels
# TensorBoard doesn't natively support overlay text on images well, 
# so we usually modify the image tensor using OpenCV or PIL before logging.

10.2. Logging Embeddings (The Projector)

If you are doing Metric Learning (Siamese Networks, Contrastive Learning), you MUST verify your latent space topology.

# 1. Collect a batch of features
features = model.encoder(images) # [B, 512]
class_labels = labels # [B]

# 2. Add to Embedding Projector
writer.add_embedding(
    features,
    metadata=class_labels,
    label_img=images, # Shows the tiny image sprite in 3D space!
    global_step=global_step
)

Debug Value:

  • Spin the 3D visualization.
  • Do you see distinct clusters for each class?
  • Do you see a “collapsed sphere” (everything mapped to same point)?
  • This catches bugs that “Accuracy” metrics hide (e.g., the model works but the margin is tiny).

10.3. Logging Histograms (Weight Health)

Are your weights dying?

for name, param in model.named_parameters():
    writer.add_histogram(f'weights/{name}', param, global_step)
    if param.grad is not None:
        writer.add_histogram(f'grads/{name}', param.grad, global_step)

Interpretation:

  • Bell Curve: Healthy.
  • Uniform: Random (Hasn’t learned).
  • Spike at 0: Dead / Sparsity.
  • Gradients at 0: Vanishing Gradient.

11. Debugging “Silent” Data Bugs

11.1. The “Off-by-One” Normalization

Common bug:

  • Pre-trained Model (ImageNet) expects: Mean=[0.485, 0.456, 0.406], Std=[0.229, 0.224, 0.225].
  • You provide: Mean=[0.5, 0.5, 0.5].
  • Result: Accuracy drops from 78% to 74%. It doesn’t fail, it’s just suboptimal. This is HARD to find.

The Fix: Always use a Data Sanity Check script that runs before training.

  1. Iterate the dataloader.
  2. Reverse the normalization.
  3. Save the images to disk.
  4. Look at them with your eyes. Do the colors look weird? Is Red swapped with Blue (BGR vs RGB)?

11.2. The Dataloader Shuffle Bug

  • Bug: DataLoader(train_set, shuffle=False).
  • Symptom: Model refuses to learn, or learns very slowly.
  • Reason: Batches contain only “Class A”, then only “Class B”. The optimizer oscillates wildly (Catastrophic Forgetting) instead of averaging the gradient direction.
  • Fix: Always verify shuffle=True for Train, shuffle=False for Val.

12. Conclusion: Principles of ML Debugging

  1. Visualize First, optimize later: Don’t tune hyperparameters if you haven’t looked at the input images and the output heatmaps.
  2. Start Small: Overfit one batch. If you can’t allow the model to cheat, it won’t learn the truth.
  3. Monitor Dynamics: Watch the gradient norms. Loss is a lagging indicator; Gradients are a leading indicator.
  4. Use Frameworks: Don’t write your own loops if you can help it. Use Lightning/Captum/W&B. They have solved these edge cases.

In the next chapter, we move to Generative AI Operations, specific tooling for LLMs.


13. Advanced: Debugging LLMs with the Logit Lens

Debugging Large Language Models requires new techniques. The model is too deep (80 layers) to just look at “Layer 1”. A powerful technique is the Logit Lens (nostalgebraist, 2020).

13.1. The Concept

In a Transformer, the hidden state at layer $L$ ($h_L$) has the same dimension as the final embedding. Hypothesis: We can apply the Final Unembedding Matrix (Linear Head) to intermediate hidden states to see “what the model thinks the next token is” at Layer 10 vs Layer 80.

13.2. Implementation

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
input_ids = tokenizer("The capital of France is", return_tensors="pt").input_ids

# Hook to capture hidden states
hidden_states = {}
def get_activation(name):
    def hook(model, input, output):
        # Transfomer output[0] is hidden state
        hidden_states[name] = output[0].detach()
    return hook

# Register hooks on all layers
for i, layer in enumerate(model.transformer.h):
    layer.register_forward_hook(get_activation(f"layer_{i}"))

# Forward Pass
out = model(input_ids)

# The Decoding Matrix (Unembedding)
# Normally: logits = hidden @ wte.T
wte = model.transformer.wte.weight # Word Token Embeddings

print("Logit Lens Analysis:")
print("Input: 'The capital of France is'")
print("-" * 30)

for i in range(len(model.transformer.h)):
    # Get hidden state for the LAST token position
    h = hidden_states[f"layer_{i}"][0, -1, :] 
    
    # Decode
    logits = torch.matmul(h, wte.t())
    probs = torch.nn.functional.softmax(logits, dim=-1)
    
    # Top prediction
    top_token_id = torch.argmax(probs).item()
    top_token = tokenizer.decode(top_token_id)
    
    print(f"Layer {i}: '{top_token}'")

# Expected Output Trail:
# Layer 0: 'the' (Random/Grammar)
# Layer 6: 'a'
# Layer 10: 'Paris' (Recall)
# Layer 11: 'Paris' (Refinement)

Reasoning: This tells you where knowledge is located. If “Paris” appears at Layer 10, but the final output is wrong, you know the corruption happens in Layers 11-12.


14. Debugging Mixed Precision (AMP)

Training in FP16/BF16 is standard. It introduces a new bug: NaN Overflow. FP16 max value is 65,504. Gradients often exceed this.

14.1. The Symptoms

  • Loss suddenly becomes NaN.
  • Gradient Scale in the scaler drops to 0.

14.2. Debugging with PyTorch Hooks

PyTorch provides tools to detect where NaNs originate.

import torch.autograd

# Enable Anomaly Detection
# WARNING: dramatically slows down training. Use only for debugging.
torch.autograd.set_detect_anomaly(True)

# Training Loop
optimizer.zero_grad()
with torch.cuda.amp.autocast():
    loss = model(inputs)

scaler.scale(loss).backward()

# Custom NaN Inspector
for name, param in model.named_parameters():
    if param.grad is not None:
        if torch.isnan(param.grad).any():
            print(f"NaN gradient detected in {name}")
            break

scaler.step(optimizer)
scaler.update()

15. Hands-on Lab: The Case of the Frozen ResNet

Scenario: You are training a ResNet-50 on a custom dataset of Car Parts. The Bug: Epoch 1 Accuracy: 1.5%. Epoch 10 Accuracy: 1.5%. The model predicts “Tire” for everything. Loss: Constant at 4.6.

Let’s debug this step-by-step.

Step 1: Overfit One Batch

  • Action: Take 1 batch (32 images). Run 100 epochs.
  • Result: Loss drops to 0.01. Accuracy 100%.
  • Conclusion: Code is functional. Layers are connected. Backprop works.

Step 2: Check Labels

  • Action: Inspect y_train.
  • Code: print(y_train[:10]) -> [0, 0, 0, 0, 0...]
  • Finding: The Dataloader is faulty! It is biased or shuffling is broken.
  • Fix: DataLoader(..., shuffle=True).

Step 3: Re-Train (Still Failed)

  • Result: Accuracy 5%. Loss fluctuates wildly.
  • Action: Monitor Gradient Norms via TensorBoard.
  • Finding: Gradients are 1e4 (Huge).
  • Hypothesis: Learning Rate 1e-3 is too high for a Finetuning task (destroying pre-trained weights).
  • Fix: Lower LR to 1e-5. Freeze early layers.

Step 4: Re-Train (Success)

  • Result: Accuracy climbs to 80%.

Lesson: Systematic debugging beats “Staring at the code” every time.


16. Appendix: PyTorch Hook Reference

A cheatsheet for the register_hook ecosystem.

Hook TypeMethodSignatureUse Case
Forward.register_forward_hook()fn(module, input, output)Saving activations, modifying outputs.
Forward Pre.register_forward_pre_hook()fn(module, input)Modifying inputs before they hit layer.
Backward.register_full_backward_hook()fn(module, grad_in, grad_out)visualizing gradients, clipping.
Tensortensor.register_hook()fn(grad)Debugging specific tensor flows.

Example: Clipping Gradients locally

def clip_hook(grad):
    return torch.clamp(grad, -1, 1)

# Register on specific weight
model.fc.weight.register_hook(clip_hook)

17. Final Summary

In this section (Part VIII - Observability & Control), we have journeyed from detecting Drift (Ch 18) to understanding Why (Ch 19).

  • Ch 19.1: Explaining the What. (SHAP/LIME).
  • Ch 19.2: Operationalizing Explanations at Scale. (AWS/GCP).
  • Ch 19.3: Debugging the Why. (Hooks, Gradients, Profiling).

You now possess the complete toolkit to own the full lifecycle of the model, not just the .fit() call.


18. Advanced Debugging: Distributed & Guided

18.1. Guided Backpropagation

Vanilla Saliency maps are noisy. Guided Backprop modifies the backward pass of ReLU to only propagate positive gradients (neurons that want to be active). It produces much sharper images.

# Minimal hook implementation for Guided Backprop
class GuidedBackprop:
    def __init__(self, model):
        self.model = model
        self.hooks = []
        self._register_hooks()
        
    def _register_hooks(self):
        def relu_backward_hook_function(module, grad_in, grad_out):
            # Cut off negative gradients
            if isinstance(module, torch.nn.ReLU):
                return (torch.clamp(grad_in[0], min=0.0),)
        
        for module in self.model.modules():
            if isinstance(module, torch.nn.ReLU):
                self.hooks.append(module.register_backward_hook(relu_backward_hook_function))
                
    def generate_gradients(self, input_image, target_class):
        output = self.model(input_image)
        self.model.zero_grad()
        
        one_hot = torch.zeros_like(output)
        one_hot[0][target_class] = 1
        
        output.backward(gradient=one_hot)
        
        return input_image.grad.cpu().data.numpy()[0]

18.2. Debugging DDP (Distributed Data Parallel)

Debugging single-GPU is hard. Multi-GPU is exponentially harder.

Common Bug: The “Hanged” Process

  • Symptom: Training starts, prints “Epoch 0”, and freezes forever. No GPU usage.
  • Cause: One rank crashed (OOM?), but others are waiting for a .barrier() synchronization.
  • Fix: Set NCCL_DEBUG=INFO env var to see which rank died.

Common Bug: Unused Parameters

  • Symptom: RuntimeError: Expected to mark a variable ready, but it was not marked.
  • Cause: You have a layer in your model self.fc2 that you defined but didn’t use in forward(). DDP breaks because it expects gradients for everything.
  • Fix: DistributedDataParallel(model, find_unused_parameters=True). (Warning: Performance cost).

19. Tooling: MLFlow Integration

While W&B is popular, MLFlow is often the enterprise standard for on-premise tracking.

19.1. Logging Artifacts (Debugging Outputs)

Don’t just log metrics. Log the debug artifacts (Grad-CAM images) associated with the run.

import mlflow
import matplotlib.pyplot as plt

mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("resnet-debugging")

with mlflow.start_run():
    # 1. Log Hyperparams
    mlflow.log_param("lr", 0.001)
    
    # 2. Log Metrics
    mlflow.log_metric("loss", 0.45)
    
    # 3. Log Debugging Artifacts
    # Generate GradCAM
    cam_img = generate_cam(model, input_img)
    
    # Save locally first
    plt.imsave("gradcam.png", cam_img)
    
    # Upload to MLFlow Artifact Store (S3/GCS)
    mlflow.log_artifact("gradcam.png")

Now, in the MLFlow UI, you can click on Run ID a1b2c3 and view the exact heatmaps produced by that specific version of the model code.


20. Glossary of Debugging Terms

  • Hook: A function callback in pytorch that executes automatically during the forward or backward pass.
  • Activation: The output of a neuron (or layer) after the non-linearity (ReLU).
  • Logit: The raw, unnormalized output of the last linear layer, before Softmax.
  • Saliency: The gradient of the Class Score with respect to the Input Image. Represents “Sensitivity”.
  • Vanishing Gradient: When gradients become so small ($<1e-7$) that weights stop updating in early layers.
  • Exploding Gradient: When gradients become so large that weights become NaN or Infinity.
  • Dead ReLU: A neuron that always outputs 0 for all inputs in the dataset.
  • Mode Collapse: (GANs) When the generator produces the exact same image regardless of noise input.
  • Attention Collapse: (Transformers) When all heads focus on the same token (usually padding or separator).

21. Annotation Bibliography

1. “Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps”

  • Simonyan et al. (2013): The paper that introduced Saliency Maps (backprop to pixels). Simple but fundamental.

2. “Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization”

  • Selvaraju et al. (2017): The paper defining Grad-CAM. It solved the interpretability problem for CNNs without requiring architectural changes (unlike CAM).

3. “A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks”

  • Hendrycks & Gimpel (2016): Showed that Max Logits (Confidence) is a decent baseline for detecting errors, but often overconfident.

4. “Interpreting GPT: The Logit Lens”

  • nostalgebraist (2020): A blog post, not a paper, but seminal in the field of Mechanistic Interpretability for Transformers.

22. Final Checklist: The “5 Whys” of a Bad Model

  1. Is it code? (Overfit one batch).
  2. Is it data? (Visualize inputs, check label distribution).
  3. Is it math? (Check gradient norms, check for NaNs).
  4. Is it architecture? (Check for Dead ReLUs, Attention Collapse).
  5. Is it the world? (Maybe the features simply don’t contain the signal).

If you pass 1-4, only then can you blame the data. most people blame the data at Step 0. Don’t be “most people.”


23. Special Topic: Mechanistic Interpretability

Traditional XAI (SHAP) tells you which input features mattered. Mechanistic Interpretability asks: How did the weights implement the algorithm?

This is the cutting edge of AI safety research (Anthropic, OpenAI). It treats NNs as “compiled programs” that we are trying to reverse engineer.

23.1. Key Concepts

  1. Circuits: Subgraphs of the network that perform a specific task (e.g., “Curve Detector” -> “Ear Detector” -> “Dog Detector”).
  2. Induction Heads: A specific attention mechanism discovered in Transformers. Theoretically, it explains “In-Context Learning”. It looks for the previous occurrence of the current token [A] and copies the token that followed it [B]. Algorithm: “If I see A, predict B”.
  3. Polysemantic Neurons: The “Superposition” problem. One single neuron might fire for “Cats” AND “Biblical Verses”. Why? Because high-dimensional space allows packing more concepts than there are neurons (Johnson-Lindenstrauss lemma).

23.2. Tooling: TransformerLens

The standard library for this research is TransformerLens (created by Neel Nanda). It allows you to hook into every meaningful intermediate state (Attention Patterns, Value Vectors, Residual Stream) easily.

pip install transformer_lens

23.3. Exploratory Analysis Code

Let’s analyze the Residual Stream.

import torch
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer

# 1. Load a model (designed for interpretability)
model = HookedTransformer.from_pretrained("gpt2-small")

# 2. Run with Cache
text = "When Mary and John went to the store, John gave a drink to"
# We expect next token: "Mary" (Indirect Object Identification task)

logits, cache = model.run_with_cache(text)

# 3. Inspect Attention Patterns
# cache is a dictionary mapping hook_names to tensors
layer0_attn = cache["blocks.0.attn.hook_pattern"]
print(layer0_attn.shape) # [Batch, Heads, SeqLen, SeqLen]

# 4. Intervention (Patching)
# We can modify the internal state during inference!
def patch_residual_stream(resid, hook):
    # Set the residual stream to zero at pos 5
    resid[:, 5, :] = 0 
    return resid

model.run_with_hooks(
    text, 
    fwd_hooks=[("blocks.5.hook_resid_pre", patch_residual_stream)]
)

Why this matters: Debugging “Why did the model generate hate speech?” might eventually move from “The prompt was bad” (Input level) to “The Hate Circuit in Layer 5 fired” (Mechanism level). This allows for Model Editing—manually turning off specific bad behaviors by clamping weights.


24. Final Words

Debugging ML models is a journey from the External (Loss Curves, Metrics) to the Internal (Gradients, Activations) to the Mechanistic (Circuits, Weights).

The best ML Engineers are not the ones who know the most architectures. They are the ones who can look at a flat loss curve and know exactly which three lines of Python code to check first.


25. Appendix: PyTorch Error Dictionary

Error MessageTranslationLikely CauseFix
RuntimeError: shape '[...]' is invalid for input of size X“You tried to .view() or .reshape() a tensor but the number of elements doesn’t match.”Applying a Conv2d math output size calculation incorrectly before a Linear layer.Check x.shape before the reshape. Use nn.Flatten().
RuntimeError: Expected object of scalar type Long but got Float“You passed Floats to a function that needs Integers.”Passing 0.0 instead of 0 to CrossEntropyLoss targets.Use .long() on your targets.
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same“Your Data is on CPU but your Model is on GPU.”Forgot .to(device) on the input batch.inputs = inputs.to(device)
CUDA out of memory“Your GPU VRAM is full.”Batch size too large.Reduce batch size. Use torch.utils.checkpoint. Use fp16.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation“You did x += 1 inside the graph.”In-place operations (+=, x[0]=1) break autograd history.Use out-of-place (x = x + 1) or .clone().

If you are setting up a team, standardizing tools prevents “Debugging Hell”.

  1. Logging: Weights & Biases (Cloud) or MLFlow (On-Prem). Mandatory.
  2. Profiler: PyTorch Profiler (TensorBoard plugin). For Optimization.
  3. Visualization:
    • Images: Grad-CAM (Custom hook or Captum).
    • Tabular: SHAP (TreeSHAP).
    • NLP: BertViz.
  4. Anomaly Detection: torch.autograd.detect_anomaly(True). Use sparingly.
  5. Interactive: ipdb.

Happy Debugging!


27. Code Snippet: The Ultimate Debugging Hook

Sometimes you just need to see the “health” of every layer at once.

import torch

class ModelThermometer:
    """
    Attaches to every leaf layer and prints stats.
    Useful for finding where the signal dies (Vanishing) or explodes (NaN).
    """
    def __init__(self, model):
        self.hooks = []
        # Recursively register on all leaf modules
        for name, module in model.named_modules():
            # If module has no children, it's a leaf (like Conv2d, ReLU)
            if len(list(module.children())) == 0: 
                 self.hooks.append(
                     module.register_forward_hook(self.make_hook(name))
                 )

    def make_hook(self, name):
        def hook(module, input, output):
            # Input is a tuple
            if isinstance(input[0], torch.Tensor):
                in_mean = input[0].mean().item()
                in_std = input[0].std().item()
            else:
                in_mean, in_std = 0.0, 0.0
            
            # Output is usually a tensor
            if isinstance(output, torch.Tensor):
                out_mean = output.mean().item()
                out_std = output.std().item()
            else:
                out_mean, out_std = 0.0, 0.0
            
            print(f"[{name}] In: {in_mean:.3f}+/-{in_std:.3f} | Out: {out_mean:.3f}+/-{out_std:.3f}")
        return hook

    def remove(self):
        for h in self.hooks:
            h.remove()

Usage:

thermometer = ModelThermometer(model)
output = model(input)
# Prints stats for every layer. 
# Look for:
# 1. Output Mean = 0.0 (Dead Layer)
# 2. Output Std = Nan (Explosion)
thermometer.remove()