Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

Chapter 17: Model Compression & Compilation

17.1. Pruning & Distillation: Teacher-Student Architectures

“To attain knowledge, add things every day. To attain wisdom, remove things every day.” — Lao Tzu

In the previous chapters, we focused on scaling up—training massive models on distributed clusters of H100s and TPUs. We discussed the architecture of abundance. Now, we must pivot to the architecture of constraint.

The economic reality of AI is asymmetric: you train once, but you infer billions of times. A model that costs $1 million to train but is inefficient at inference can bankrupt a company if deployed at scale. If your Large Language Model (LLM) requires 4x A100 GPUs to serve a single request, your cost per query might be $0.10. For a search engine receiving 100 million queries a day, that is $10 million in daily infrastructure burn.

Model compression is not just an optimization; it is the difference between a research prototype and a viable product. It is the discipline of making models smaller, faster, and cheaper without significantly sacrificing intelligence.

This section covers two of the most powerful techniques in the compression arsenal: Pruning (making the model sparse) and Distillation (transferring knowledge from a large “Teacher” to a compact “Student”).


11.1.1. The Physics of Redundancy

Why do compression techniques work? Why can we remove 90% of a neural network’s weights and lose only 1% of its accuracy?

The answer lies in the Over-Parameterization Hypothesis. Modern deep learning models are vastly over-parameterized. The optimization landscape of high-dimensional non-convex functions is treacherous; to ensure Gradient Descent finds a global minimum (or a good local minimum), we need a massive search space. We need billions of parameters to find the solution, but we do not need billions of parameters to represent the solution.

Think of the training process as erecting a complex scaffolding to build an arch. Once the keystone is in place and the arch is self-supporting, the scaffolding—which constitutes the bulk of the material—can be removed. Pruning is the systematic removal of this scaffolding.


11.1.2. Pruning: The Art of Sparsity

Pruning is the process of setting specific weights in a neural network to zero, effectively severing the synaptic connections between neurons.

$$ \mathbf{W}_{pruned} = \mathbf{W} \odot \mathbf{M} $$

Where $\mathbf{W}$ is the weight matrix, $\mathbf{M} \in {0, 1}$ is a binary mask, and $\odot$ is the Hadamard (element-wise) product.

Unstructured vs. Structured Pruning

The primary architectural decision in pruning is the granularity of the mask.

1. Unstructured Pruning (Fine-Grained Sparsity)

  • Mechanism: We look at individual weights $w_{ij}$. If $|w_{ij}| < \text{threshold}$, we set it to zero.
  • Result: The weight matrix becomes a sparse matrix. It might look like Swiss cheese.
  • Pros: Can achieve extremely high compression rates (90-95%) with minimal accuracy loss because the algorithm can surgically remove the least important connections.
  • Cons: Standard hardware (GPUs/CPUs) hates random memory access. A dense matrix multiplication is highly optimized (BLAS, cuBLAS). A sparse matrix multiplication requires specialized indexing (CSR/CSC formats), which often adds overhead that negates the speedup unless sparsity is very high (>95%).
  • Hardware Note: NVIDIA Ampere (A100) and Hopper (H100) architectures introduced Sparse Tensor Cores, which provide a 2x speedup for “2:4 sparsity” (every block of 4 weights must have at least 2 zeros). This is the only mainstream hardware support for semi-unstructured pruning.

2. Structured Pruning (Coarse-Grained Sparsity)

  • Mechanism: We remove entire structural units—columns, filters, channels, or attention heads.
  • Result: The weight matrix shrinks. A $1024 \times 1024$ matrix becomes $512 \times 512$.
  • Pros: The resulting model is a standard dense model, just smaller. It runs faster on any hardware without specialized kernels.
  • Cons: More destructive. Removing an entire filter might kill a feature detector that was 80% useless but 20% vital. Accuracy drops faster than with unstructured pruning.

Magnitude-Based Pruning (The Baseline)

The simplest heuristic for importance is magnitude. “If a weight is close to zero, it doesn’t contribute much to the output.”

The Algorithm (Iterative Magnitude Pruning - IMP):

  1. Train the network to convergence.
  2. Prune the bottom $p%$ of weights by magnitude (globally or layer-wise).
  3. Fine-tune the pruned network to recover accuracy.
  4. Repeat steps 2-3 until target sparsity is reached.

The fine-tuning step is critical. Pruning is a shock to the system; the remaining weights need to adjust to compensate for the missing connections.

Implementation: PyTorch Pruning

PyTorch provides a robust pruning API in torch.nn.utils.prune.

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # ... standard forward pass ...
        return x

model = LeNet()

# 1. Unstructured Pruning (L1 Unstructured)
# Prune 30% of connections in conv1 based on L1 norm (magnitude)
prune.l1_unstructured(model.conv1, name="weight", amount=0.3)

# The weight is not actually deleted. 
# PyTorch creates 'weight_orig' and a buffer 'weight_mask'.
# 'weight' becomes a computed attribute: weight_orig * weight_mask.
print(list(model.conv1.named_parameters())) 

# 2. Structured Pruning (L2 Structured)
# Prune 20% of CHANNELS (dim=0) in conv2
prune.ln_structured(model.conv2, name="weight", amount=0.2, n=2, dim=0)

# 3. Global Pruning
# Often better to prune globally. Maybe layer 1 needs all its weights,
# but layer 10 is redundant.
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

# 4. Finalizing (Making it permanent)
# This removes the _orig and _mask, applying the mask permanently.
for module, param in parameters_to_prune:
    prune.remove(module, param)

The Lottery Ticket Hypothesis

In 2018, Frankle and Carbin published a seminal paper: “The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks”.

They discovered that within a large, randomly initialized dense network, there exist small subnetworks (“winning tickets”) that, when trained in isolation from the same initialization, reach the same accuracy as the original network in the same number of steps.

Implications for MLOps:

  • Retraining Stability: If you prune a model and want to retrain it from scratch (rather than fine-tuning), you must reset the remaining weights to their original initialization values, not random new ones. The specific initialization “geometry” matters.
  • Early-Bird Tickets: Recent research suggests these tickets emerge early in training. This led to Early Pruning techniques—pruning the model after just a few epochs to save compute on the rest of the training run.

11.1.3. Knowledge Distillation: The Teacher and The Student

Pruning tries to fix a bloated architecture. Knowledge Distillation (KD) accepts that we need two architectures: a massive one to learn, and a tiny one to run.

The Concept: We have a large, accurate Teacher model (e.g., BERT-Large, ResNet-152, GPT-4). We want to train a small Student model (e.g., DistilBERT, MobileNet, Llama-7B).

If we just train the Student on the original dataset (One-Hot Encoded labels), it struggles. The dataset labels are “hard targets”—they tell the model that an image is a “Dog” (1.0) and not a “Cat” (0.0). They contain zero information about the relationship between classes.

The Teacher, however, knows more. For a specific image of a Dog, the Teacher might output:

  • Dog: 0.90
  • Cat: 0.09
  • Car: 0.0001

The Teacher is telling the Student: “This is a Dog, but it looks a lot like a Cat. It looks nothing like a Car.” This “Dark Knowledge” (inter-class relationships) provides a richer signal for the Student to learn from.

The Mathematics of Distillation

The Student is trained to minimize a combined loss function:

$$ L_{total} = \alpha L_{task} + (1 - \alpha) L_{KD} $$

  1. Task Loss ($L_{task}$): Standard Cross-Entropy between Student predictions and Ground Truth labels.
  2. Distillation Loss ($L_{KD}$): Kullback-Leibler (KL) Divergence between the Student’s soft predictions and the Teacher’s soft predictions.

The Temperature Parameter ($T$): To expose the hidden details in the Teacher’s output distribution (which is often very sharp, e.g., 0.999 vs 0.001), we divide the logits by a temperature $T > 1$ before applying Softmax.

$$ p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} $$

As $T \to \infty$, the distribution becomes uniform. At moderate values (e.g., $T=3$ to $T=10$), the tiny probabilities of incorrect classes get magnified, making them learnable.

Implementation: A PyTorch Distillation Trainer

Below is a production-grade snippet for a Distillation Loop.

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationTrainer:
    def __init__(self, teacher, student, device, alpha=0.5, temperature=4.0):
        self.teacher = teacher.to(device)
        self.student = student.to(device)
        self.device = device
        self.alpha = alpha
        self.T = temperature
        
        # Teacher is usually frozen during distillation
        self.teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False
            
    def train_step(self, inputs, labels, optimizer):
        inputs, labels = inputs.to(self.device), labels.to(self.device)
        
        # 1. Forward pass of Student
        student_logits = self.student(inputs)
        
        # 2. Forward pass of Teacher (no grad)
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)
            
        # 3. Calculate Hard Target Loss (Standard Cross Entropy)
        loss_task = F.cross_entropy(student_logits, labels)
        
        # 4. Calculate Soft Target Loss (KL Divergence)
        # Note: F.log_softmax needs to be applied to Student
        # F.softmax needs to be applied to Teacher
        # We scale logits by T
        
        distillation_loss = F.kl_div(
            F.log_softmax(student_logits / self.T, dim=1),
            F.softmax(teacher_logits / self.T, dim=1),
            reduction='batchmean'
        ) * (self.T ** 2) 
        # We multiply by T^2 to keep gradients scaled correctly as T changes
        
        # 5. Combined Loss
        loss = self.alpha * loss_task + (1 - self.alpha) * distillation_loss
        
        # 6. Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()

# Usage Scenario:
# teacher_model = ResNet50(pretrained=True)
# student_model = MobileNetV3()
# trainer = DistillationTrainer(teacher_model, student_model, device="cuda")

11.1.4. Advanced Distillation Patterns

Beyond the basic “Response-Based” distillation described above, modern architectures use more sophisticated alignment.

1. Feature-Based Distillation

Instead of just matching the final output, we force the Student’s intermediate layers to mimic the Teacher’s intermediate layers.

  • Challenge: The Teacher has 1024 channels, the Student has 128. You cannot compare them directly.
  • Solution: Learn a Linear Projection (1x1 Conv) that maps the Student’s 128 channels to the Teacher’s 1024, then minimize the MSE loss between the feature maps.

2. Relation-Based Distillation

We want the Student to understand how data points relate to each other.

  • If the Teacher thinks Image A and Image B are similar (embedding cosine similarity is high), the Student should also map them close together.
  • This preserves the structure of the embedding space.

3. Data-Free Distillation

What if you don’t have the original training data (privacy/GDPR)?

  • You can treat the Teacher model as a “Generator.”
  • Invert the network: start with a random image, and optimize the pixels until the Teacher outputs “Goldfish” with high confidence.
  • Use these “DeepDreamed” synthetic images to train the Student.

11.1.5. Distilling Large Language Models (LLMs)

In the GenAI era, distillation has taken a new form. We are no longer just matching logits; we are transferring reasoning capabilities.

Black-Box Distillation (Synthetic Data Generation)

When distilling a closed-source model (like GPT-4) into an open model (like Llama-3-8B), you often don’t have access to the teacher’s logits or weights. You only have the text output.

Methodology:

  1. Prompt Engineering: Ask GPT-4 to generate a high-quality dataset.
    • Input: “Write a Python function to compute Fibonacci numbers, with detailed comments.”
    • Output: (High-quality code).
  2. Step-by-Step Distillation: Use “Chain of Thought” (CoT) prompting.
    • Instead of just “Question -> Answer”, generate “Question -> Reasoning -> Answer”.
    • Train the Student on the Reasoning trace. This teaches the Student how to think, not just what to say.
  3. Fine-Tuning: Train the smaller model on this synthetic dataset (Standard SFT).

The “Alpaca” Paradigm: This was made famous by Stanford’s Alpaca model, which was Llama-7B fine-tuned on 52k instruction-following examples generated by text-davinci-003.

White-Box Distillation (Minitron / Sheared Llama)

If you own the Teacher model (e.g., you trained a 70B model and want a 7B version), you can be more aggressive.

NVIDIA’s Minitron Approach:

  1. Width Pruning: Prune attention heads and MLP intermediate dimensions based on importance scores.
  2. Depth Pruning: Remove entire Transformer blocks (layers). A common heuristic is to keep every $n$-th layer (e.g., layers 0, 2, 4, …).
  3. Retraining: Continue training the pruned model on a small percentage of the original tokens.
  4. Distillation Loss: Use the original 70B model to supervise the retraining, ensuring the 7B model’s logits match the 70B model’s logits on the training tokens.

11.1.6. Cloud Implementation: AWS vs. GCP

How do we execute these workflows in the cloud?

AWS Implementation: SageMaker & Neuron

1. AWS Model Optimizer (formerly Sagemaker Neo) AWS provides a managed compilation service that optimizes models for specific hardware targets.

  • It performs graph-level optimizations (operator fusion).
  • It can quantize models to FP16 or INT8.
  • Key Feature: It specifically compiles for AWS Inferentia (Inf1/Inf2).

2. Distillation on Trainium (Trn1) Distillation is compute-intensive. You are running forward passes on a massive Teacher and forward/backward on a Student.

  • Architecture:
    • Load the Teacher model onto AWS Trainium chips (Trn1.32xlarge). Trainium has huge memory bandwidth.
    • Since the Teacher is frozen, you can use Neuron Cast to run the Teacher in FP16 or BF16 for speed, while keeping the Student in FP32/BF16.
    • Use the high-speed EFA (Elastic Fabric Adapter) networking to synchronize gradients if training a large student across multiple nodes.

GCP Implementation: Vertex AI

1. Vertex AI Model Optimization GCP offers a suite of tools within Vertex AI for pruning and quantization.

  • Supports Quantization Aware Training (QAT) directly in the pipeline.
  • Integrates with TensorFlow Model Optimization Toolkit (TFMOT).

2. TPU-based Distillation TPUs are exceptionally good at distillation because of their large high-bandwidth memory (HBM) and systolic array architecture.

  • TPU Strategy:
    • Place the Teacher and Student on the same TPU core if they fit (minimizes data transfer latency).
    • If not, use Model Parallelism to shard the Teacher across 4 TPU chips, and Data Parallelism for the Student.
    • Google’s JAX framework shines here, allowing you to define the distillation loss function and jit compile the entire teacher-student interaction into a single XLA executable graph.

11.1.7. Operationalizing Compression in CI/CD

Model compression should not be a one-off “science project.” It should be a stage in your MLOps pipeline.

The Compression Pipeline Pattern:

graph LR
    A[Model Training] --> B[Evaluation (Accuracy: 95%)]
    B --> C{Passes Threshold?}
    C -- Yes --> D[Compression Stage]
    C -- No --> A
    D --> E[Pruning / Distillation]
    E --> F[Fine-Tuning]
    F --> G[Evaluation (Accuracy: 94%?)]
    G --> H{Acceptable Drop?}
    H -- Yes --> I[Quantization (FP32 -> INT8)]
    H -- No --> D
    I --> J[Compile for Target (TensorRT/Neuron)]
    J --> K[Production Registry]

Automated Budget Checks: Your CI system should enforce constraints:

  • assert model_size_mb < 50
  • assert inference_latency_ms < 10
  • assert accuracy_drop < 0.01

If the compressed model fails these checks, the pipeline fails. This prevents “bloat creep” where models slowly get slower over months of development.


11.1.8. Advanced Structured Pruning: Channel and Filter Selection

Structured pruning is more practical for production deployment because the resulting model is a standard dense network. However, deciding which channels or filters to prune is non-trivial.

Importance Metrics for Structured Pruning

1. L1 Norm (Magnitude-Based): The sum of absolute values of all weights in a channel/filter. $$I_{L1}(F_i) = \sum_{j} |w_{ij}|$$

Rationale: If all weights in a filter are close to zero, that filter contributes little to the output.

2. Percentage of Zeros (Sparsity-Based): After unstructured pruning, some filters become very sparse. Remove filters that are >90% zero.

3. Geometric Median (Taylor Expansion Based): Approximate the change in loss if filter $F_i$ is removed using first-order Taylor expansion: $$\Delta L \approx \nabla_W L \cdot \delta W$$

Filters with minimum $|\Delta L|$ are candidates for pruning.

4. Activation-Based Importance (APoZ): Average Percentage of Zero activations. Run the model on a validation set and measure: $$\text{APoZ}(F_i) = \frac{1}{N \cdot M} \sum_{n=1}^N \sum_{m=1}^M \mathbb{1}(F_i(x_n)_m = 0)$$

Filters that frequently produce zero outputs (dead neurons) can be pruned.

Progressive Structured Pruning (Three-Phase Approach)

Instead of pruning all filters at once, use a gradual approach:

Phase 1: Coarse Pruning (50% reduction)

  • Prune 50% of filters based on L1 norm
  • Fine-tune for 5 epochs
  • Checkpoint

Phase 2: Fine-Grained Pruning (75% reduction)

  • Prune another 25% based on geometric median
  • Fine-tune for 10 epochs
  • Checkpoint

Phase 3: Final Polish (85% reduction)

  • Prune another 10% based on APoZ
  • Fine-tune for 15 epochs
  • Final model

Implementation:

import torch
import torch.nn as nn
import numpy as np

def compute_filter_importance_l1(conv_layer):
    """
    Compute L1 norm of each filter in a Conv2d layer.
    Returns: Tensor of shape [num_filters]
    """
    weights = conv_layer.weight  # Shape: [out_channels, in_channels, H, W]
    importance = torch.sum(torch.abs(weights), dim=(1, 2, 3))
    return importance

def prune_filters_by_threshold(model, layer_name, prune_ratio):
    """
    Prune filters in a specific convolutional layer.
    """
    for name, module in model.named_modules():
        if name == layer_name and isinstance(module, nn.Conv2d):
            importance = compute_filter_importance_l1(module)

            # Determine threshold (keep top (1-prune_ratio) filters)
            num_filters = len(importance)
            num_to_keep = int(num_filters * (1 - prune_ratio))
            threshold = torch.topk(importance, num_to_keep, largest=True).values[-1]

            # Create mask
            mask = importance >= threshold

            # Apply pruning (in practice, this requires restructuring the layer)
            pruned_weights = module.weight[mask]
            pruned_bias = module.bias[mask] if module.bias is not None else None

            # Create new layer with reduced channels
            new_conv = nn.Conv2d(
                in_channels=module.in_channels,
                out_channels=num_to_keep,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                bias=(module.bias is not None)
            )

            new_conv.weight.data = pruned_weights
            if pruned_bias is not None:
                new_conv.bias.data = pruned_bias

            # Replace in model (requires careful handling of connections)
            # This is simplified; production requires graph surgery
            return new_conv, mask

# Progressive Pruning Scheduler
class ProgressivePruningScheduler:
    def __init__(self, model, target_sparsity=0.85, phases=3):
        self.model = model
        self.target_sparsity = target_sparsity
        self.phases = phases
        self.current_phase = 0

    def get_phase_sparsity(self):
        """Calculate sparsity target for current phase"""
        phase_targets = [0.5, 0.75, 0.85]  # Predefined schedule
        return phase_targets[min(self.current_phase, len(phase_targets)-1)]

    def step_phase(self):
        """Move to next pruning phase"""
        self.current_phase += 1
        sparsity = self.get_phase_sparsity()
        print(f"Entering Phase {self.current_phase}: Target sparsity {sparsity}")
        return sparsity

11.1.9. Domain-Specific Distillation Strategies

Distillation is not one-size-fits-all. Different modalities (vision, language, speech) require different alignment strategies.

Computer Vision: Attention Transfer

In CNNs, attention maps (spatial activations) are critical. The student should not just match final logits; it should “look at” the same regions of the image as the teacher.

Attention Transfer Loss: $$L_{AT} = \sum_l \left| \frac{A_l^S}{|A_l^S|_2} - \frac{A_l^T}{|A_l^T|_2} \right|_2^2$$

Where $A_l$ is the attention map at layer $l$, computed as the sum of squared activations across channels: $$A_l(x) = \sum_c F_{l,c}(x)^2$$

Implementation:

def attention_map(feature_map):
    """
    Compute attention map from feature tensor.
    feature_map: [B, C, H, W]
    Returns: [B, H, W]
    """
    return torch.sum(feature_map ** 2, dim=1)

def attention_transfer_loss(student_features, teacher_features):
    """
    Compute attention transfer loss between student and teacher.
    """
    total_loss = 0

    for s_feat, t_feat in zip(student_features, teacher_features):
        # Compute attention maps
        s_attn = attention_map(s_feat)
        t_attn = attention_map(t_feat)

        # Normalize
        s_attn_norm = s_attn / (torch.norm(s_attn, p=2, dim=(1,2), keepdim=True) + 1e-8)
        t_attn_norm = t_attn / (torch.norm(t_attn, p=2, dim=(1,2), keepdim=True) + 1e-8)

        # L2 distance
        loss = torch.mean((s_attn_norm - t_attn_norm) ** 2)
        total_loss += loss

    return total_loss

Natural Language Processing: Logit Matching with Token-Level Alignment

For LLMs, we care about the distribution over the entire vocabulary for each token position.

Standard Approach: KL divergence on the final logits.

Advanced Approach: Layer-wise Distillation (Used in DistilBERT).

  • Match the hidden states at each Transformer layer
  • Use a linear projection to map student hidden dim to teacher hidden dim if they differ

Implementation:

class BERTDistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super().__init__()
        self.alpha = alpha
        self.T = temperature
        self.cosine_loss = nn.CosineEmbeddingLoss()

    def forward(self, student_logits, teacher_logits,
                student_hidden, teacher_hidden, labels):
        """
        student_logits: [B, seq_len, vocab_size]
        teacher_logits: [B, seq_len, vocab_size]
        student_hidden: List of [B, seq_len, hidden_dim]
        teacher_hidden: List of [B, seq_len, hidden_dim]
        """
        # 1. Soft Target Loss (KL Divergence)
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / self.T, dim=-1),
            F.softmax(teacher_logits / self.T, dim=-1),
            reduction='batchmean'
        ) * (self.T ** 2)

        # 2. Hard Target Loss (Cross Entropy with labels)
        hard_loss = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1)
        )

        # 3. Hidden State Alignment
        hidden_loss = 0
        for s_h, t_h in zip(student_hidden, teacher_hidden):
            # Cosine similarity loss
            # Target = 1 (maximize similarity)
            target = torch.ones(s_h.size(0) * s_h.size(1)).to(s_h.device)
            hidden_loss += self.cosine_loss(
                s_h.view(-1, s_h.size(-1)),
                t_h.view(-1, t_h.size(-1)),
                target
            )

        # Combine losses
        total_loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss + 0.1 * hidden_loss
        return total_loss

Speech Recognition: Sequence-Level Distillation

In ASR (Automatic Speech Recognition), distillation must account for variable-length sequences.

Challenge: Teacher outputs a sequence of phonemes/characters. Student must learn the temporal alignment, not just the final transcript.

CTC (Connectionist Temporal Classification) Distillation:

  • Use the teacher’s CTC alignment probabilities as soft targets
  • This teaches the student not just “what” to predict, but “when” to predict it

Encoder-Decoder Distillation:

  • For attention-based models (Transformer ASR), distill:
    1. Encoder outputs (acoustic features)
    2. Attention weights (where the model “listens”)
    3. Decoder outputs (predicted tokens)

11.1.10. Self-Distillation and Born-Again Networks

What if you don’t have a larger teacher? Can a model distill into itself?

Born-Again Networks (BANs)

Procedure:

  1. Train a network $N_1$ to convergence.
  2. Create an identical architecture $N_2$ (same size).
  3. Train $N_2$ to mimic the soft targets of $N_1$ (distillation).
  4. $N_2$ often achieves higher accuracy than $N_1$, despite being the same size!

Why This Works:

  • The soft targets from $N_1$ provide a “smoothed” version of the label space.
  • $N_2$ doesn’t have to discover these patterns from scratch; it learns from the refined knowledge.

Multi-Generation Distillation:

  • Train $N_3$ from $N_2$, $N_4$ from $N_3$, etc.
  • Research shows accuracy improvements for 2-3 generations, then plateaus.

Production Use Case:

  • After deploying a model for 6 months and collecting user feedback (corrections), retrain a “Born-Again” version using the old model’s outputs as soft targets. This preserves the good behaviors while adapting to new data.

Online Distillation (Co-Training)

Instead of the teacher-student being sequential (train teacher first, then student), train them simultaneously.

DML (Deep Mutual Learning):

  • Train 2 models (can be different architectures) in parallel.
  • At each step, each model acts as a “teacher” for the other.
  • Loss for Model A: $$L_A = L_{CE}(y_A, y_{true}) + \lambda \cdot L_{KL}(y_A, y_B)$$

Benefit: Both models improve by teaching each other. No need for a pre-trained large teacher.


11.1.11. Pruning for Edge Deployment: The MobileNet Philosophy

When deploying to mobile (iOS/Android) or embedded devices (Raspberry Pi, Jetson Nano), the constraints are different:

  • Limited DRAM: 1-4GB total system memory
  • No GPU: Or weak GPU (Mali, Adreno)
  • Battery Life: Power consumption matters

Depthwise Separable Convolutions (MobileNet)

Standard convolution is expensive. For an input of size $H \times W \times C_{in}$ and a kernel of size $K \times K$ with $C_{out}$ output channels:

FLOPs: $H \times W \times C_{in} \times C_{out} \times K^2$

MobileNet’s Innovation:

  1. Depthwise Convolution: Apply a $K \times K$ kernel to each input channel separately.
    • FLOPs: $H \times W \times C_{in} \times K^2$
  2. Pointwise Convolution: Use a $1 \times 1$ kernel to mix channels.
    • FLOPs: $H \times W \times C_{in} \times C_{out}$

Total FLOPs: $H \times W \times C_{in} \times (K^2 + C_{out})$

Speedup: For typical values ($K=3, C_{out}=256$), this is 8-9x fewer FLOPs.

Channel Pruning for MobileNet

Even MobileNets can be pruned further. Use AutoML for Channel Search (NAS) to find optimal channel counts per layer.

Google’s MNasNet Approach:

  • Search space: For each layer, channel count can be $[0.5x, 0.75x, 1.0x, 1.25x]$ of baseline.
  • Objective: Maximize accuracy subject to latency constraint (e.g., <50ms on Pixel 3).
  • Search algorithm: Reinforcement Learning with measured latency as reward.

Practical Approximation: Use a greedy search:

  1. Start with full MobileNetV3.
  2. For each layer, try reducing channels by 25%. Measure accuracy drop.
  3. Keep reductions where accuracy drop is <0.5%.
  4. Iterate.

11.1.12. Distillation for Multimodal Models (CLIP, Flamingo)

Multimodal models (vision + language) present unique distillation challenges.

CLIP Distillation

CLIP learns a shared embedding space for images and text.

Distillation Strategy:

  • Dual Encoders: Distill both the Image Encoder (Vision Transformer) and the Text Encoder (BERT) separately.
  • Contrastive Loss Alignment: The student must preserve the teacher’s alignment in the embedding space.

$$L_{CLIP} = -\log \frac{\exp(\text{sim}(\mathbf{i}_s, \mathbf{t}s) / \tau)}{\sum{j} \exp(\text{sim}(\mathbf{i}_s, \mathbf{t}_j) / \tau)}$$

Where the similarity function must match between teacher and student.

Smaller CLIP Models:

  • DistilCLIP: Distill OpenAI CLIP-ViT-L/14 into a ResNet-50 backbone.
  • Use case: Running CLIP on edge devices for real-time image-text matching (e.g., accessibility tools).

Vision-Language Model (VLM) Distillation

For models like Flamingo or GPT-4V that generate text from images:

Challenge: The teacher might hallucinate or have inconsistent behaviors.

Solution: Selective Distillation:

  1. Run teacher on 1M image-caption pairs.
  2. Filter outputs: Keep only samples where BLEU score vs ground truth >0.7.
  3. Distill student on this “high-quality subset.”

This prevents the student from learning the teacher’s errors.


11.1.13. Quantization-Aware Pruning (QAP)

Pruning and Quantization are often applied sequentially: Prune → Fine-tune → Quantize → Fine-tune.

But compounding errors occur. A weight that survives pruning might become problematic after quantization.

Solution: Joint Optimization.

The QAP Loss Function

$$L_{QAP} = L_{task} + \lambda_1 R_{prune} + \lambda_2 R_{quant}$$

Where:

  • $R_{prune} = \sum |w|$ (L1 regularization to encourage sparsity)
  • $R_{quant} = \sum (w - \text{Quantize}(w))^2$ (minimizes quantization error)

Training Procedure:

  1. Start with dense FP32 model.
  2. Apply gradual pruning (increase $\lambda_1$ over epochs).
  3. Simultaneously apply Fake Quantization (simulates INT8).
  4. The model learns to find weights that are:
    • Sparse (many near-zero)
    • Quantization-friendly (cluster around quantization levels)

Result: A model that is both pruned (90% sparse) and quantized (INT8) with minimal accuracy loss.


11.1.14. Production Deployment Patterns

Compression is not just a research experiment. It must integrate into your MLOps pipeline.

Pattern 1: The Compression Pipeline

# .github/workflows/model-compression.yml
name: Model Compression Pipeline

on:
  workflow_dispatch:
    inputs:
      model_id:
        description: 'S3 path to base model'
        required: true
      target_compression:
        description: 'Target size reduction (%)'
        default: '75'

jobs:
  compress:
    runs-on: [self-hosted, gpu]
    steps:
      - name: Download base model
        run: aws s3 cp ${{ github.event.inputs.model_id }} ./model.pt

      - name: Apply pruning
        run: |
          python scripts/prune_model.py \
            --input model.pt \
            --output model_pruned.pt \
            --sparsity 0.9

      - name: Fine-tune pruned model
        run: |
          python scripts/finetune.py \
            --model model_pruned.pt \
            --epochs 10 \
            --lr 1e-5

      - name: Distill into student
        run: |
          python scripts/distill.py \
            --teacher model_pruned.pt \
            --student mobilenet_v3 \
            --output model_distilled.pt \
            --temperature 4.0

      - name: Quantize
        run: |
          python scripts/quantize.py \
            --model model_distilled.pt \
            --output model_quantized.pt \
            --precision int8

      - name: Validate accuracy
        run: |
          python scripts/validate.py \
            --model model_quantized.pt \
            --dataset val_set \
            --baseline-accuracy 0.95

      - name: Upload to model registry
        if: success()
        run: |
          aws s3 cp model_quantized.pt \
            s3://ml-models/compressed/$(date +%Y%m%d)_model.pt

Pattern 2: A/B Testing Compressed Models

Before rolling out a compressed model, run A/B tests.

Setup:

  • Control Group: 50% of traffic → Original FP32 model
  • Treatment Group: 50% of traffic → Compressed INT8 model

Metrics to Track:

  • Accuracy/F1 (should be within 1% of baseline)
  • P99 Latency (should decrease by 2x+)
  • Cost per 1M inferences (should decrease by 60%+)
  • User Engagement Metrics (e.g., click-through rate)

Decision Rule:

  • If accuracy drop >1% OR user engagement drops >3%: Rollback.
  • Else: Promote compressed model to 100%.

Pattern 3: Model Versioning and Lineage Tracking

Compressed models should maintain lineage to their parent.

MLflow Example:

import mlflow

with mlflow.start_run(run_name="compression_pipeline"):
    # Log parent model ID
    mlflow.set_tag("parent_model_id", "resnet50_v1_fp32")
    mlflow.set_tag("compression_method", "pruning+distillation")

    # Log compression config
    mlflow.log_params({
        "pruning_ratio": 0.9,
        "distillation_temperature": 4.0,
        "student_architecture": "mobilenet_v3_small",
        "quantization": "int8"
    })

    # Train compressed model
    compressed_model = apply_compression(base_model)

    # Log metrics
    mlflow.log_metrics({
        "accuracy": 0.94,
        "model_size_mb": 12,
        "inference_latency_ms": 8,
        "compression_ratio": 0.85
    })

    # Log model artifact
    mlflow.pytorch.log_model(compressed_model, "compressed_model")

11.1.15. Cost-Benefit Analysis: When Compression Pays Off

Compression introduces engineering complexity. When is it worth it?

Break-Even Calculation

Scenario: Deploying a recommendation model for 100M daily inferences.

Option A: No Compression (Baseline)

  • Model: BERT-Large (330M params, FP32)
  • Instance: AWS g5.xlarge (1x A10G, $1.006/hr)
  • Throughput: 100 inferences/sec
  • Hours needed: 100M / (100 * 3600) = 278 hours
  • Cost: 278 * $1.006 = $279.67/day

Option B: Compression (Pruned + Quantized)

  • Model: Pruned BERT (50M params, INT8)
  • Instance: AWS g5.xlarge (same)
  • Throughput: 400 inferences/sec (4x faster due to compression)
  • Hours needed: 100M / (400 * 3600) = 69 hours
  • Cost: 69 * $1.006 = $69.41/day

Savings: $210/day = $76,650/year

Engineering Cost:

  • Compression pipeline development: 2 engineer-weeks = $10,000
  • Validation and testing: 1 engineer-week = $5,000
  • Total: $15,000

Payback Period: 15,000 / 210 = 71 days

Conclusion: Compression pays for itself in 2.5 months.

When Compression is NOT Worth It

  • Low-scale inference: <1M inferences/month. The engineering cost exceeds savings.
  • Rapidly changing models: If you retrain weekly, the compression pipeline becomes a bottleneck.
  • Extreme accuracy requirements: Medical imaging, autonomous driving. 1% accuracy drop is unacceptable.

11.1.16. Summary: The Efficiency Mindset

Pruning and Distillation are mechanisms to pay down the “Compute Debt” incurred during training.

  1. Use Pruning when you need to reduce the model size and FLOPs, but want to keep the same architecture. It is most effective when you have specialized hardware (Sparse Tensor Cores) or are doing structured pruning.
  2. Use Distillation when you want to change the architecture entirely (e.g., replacing a Transformer with a CNN, or a Deep network with a Shallow one). It is the most robust way to train small models.
  3. Combine Them: The state-of-the-art approach is often:
    • Train a large Teacher.
    • Prune the Teacher to create a “Sparse Teacher”.
    • Distill the Sparse Teacher into a Student.
    • Quantize the Student.
  4. Domain Specialization: Adapt distillation strategies to your modality (CV: attention transfer, NLP: hidden state matching, Speech: temporal alignment).
  5. Production Integration: Build compression into CI/CD pipelines with automated validation gates.
  6. Economics: Always perform break-even analysis. Compression is an investment that typically pays back in 2-3 months for high-scale deployments.
  7. Progressive Approach: Don’t compress everything at once. Use gradual pruning with checkpoints to find the optimal sparsity-accuracy trade-off.
  8. Validation is Critical: Compressed models must undergo rigorous testing—unit tests for accuracy, latency tests, A/B tests in production. Never deploy without validation.

Future Directions

The field of model compression is rapidly evolving:

Neural Architecture Search (NAS) for Compression: Instead of manually designing student architectures, use NAS to discover optimal compressed architectures automatically. EfficientNet is an example of this approach.

Hardware-Aware Compression: Optimize models specifically for target hardware (e.g., prune to match Sparse Tensor Core patterns, or quantize to align with INT8 SIMD instructions).

Dynamic Compression: Models that can adjust their size/precision at runtime based on available resources. For example, serving a 7B model on GPU but falling back to a 1B distilled version on CPU.

Compound Scaling: Simultaneously optimize depth, width, and resolution (as in EfficientNet) rather than compressing one dimension at a time.

The Architect’s Checklist for Compression

Before deploying a compressed model to production:

  • Baseline Metrics: Record FP32 baseline accuracy, latency, memory usage
  • Compression Method Selected: Document whether using pruning, distillation, or both
  • Target Metrics Defined: Set acceptable accuracy drop threshold (e.g., <1%)
  • Validation Dataset: Use production-representative data for calibration/validation
  • Lineage Tracking: Maintain links between compressed model and parent model
  • Performance Testing: Benchmark latency/throughput on target hardware
  • A/B Test Plan: Design experiment to validate in production before full rollout
  • Rollback Strategy: Plan for reverting if compressed model underperforms
  • Monitoring: Set up alerts for accuracy degradation, latency SLA violations
  • Cost Analysis: Calculate ROI and payback period
  • Documentation: Record compression configuration, metrics, and decisions

In the next section, we will delve deeper into Quantization, exploring how moving from 32-bit floats to 8-bit integers can quadruple your throughput.