Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

44.3. Operationalizing Neural Architecture Search (NAS)

Neural Architecture Search (NAS) is the “Nuclear Option” of AutoML. Instead of tuning hyperparameters (learning rate, tree depth) of a fixed model, NAS searches for the model structure itself (number of layers, types of convolutions, attention heads, connection topology).

From an MLOps perspective, NAS is extremely dangerous. It converts compute into accuracy at a terrifying exchange rate. A naive NAS search (like the original RL-based NASNet) can easily cost 100x more than a standard training run (e.g., 2,000 GPU hours for a 1% gain). Operationalizing NAS means imposing strict constraints to treat it not as a research experiment, but as an engineering search problem.

44.3.1. The Cost of NAS: Efficiency is Mandatory

Early NAS methods trained thousands of models from scratch to convergence. In production, this is non-viable. We must use Efficient NAS (ENAS) techniques.

Comparison of NAS Strategies

StrategyArchitectureCostOps ComplexityBest For
Reinforcement LearningController RNN samples Architectures, trained by Reward (Accuracy).High (~2000 GPU Days)High (Async updates)Research only
Evolutionary (Genetic)Mutate best architectures. Kill weak ones.Medium (~100 GPU Days)Medium (Embarrassingly parallel)Black-box search
Differentiable (DARTS)Continuous relaxation. Optimize structure with SGD.Low (~1-4 GPU Days)High (Sensitivity to hyperparams)Standard Vision/NLP tasks
One-Shot (Weight Sharing)Train one Supernet. Sample subgraphs.Very Low (~1-2 GPU Days)High (Supernet design)Production Edge deployment

1. One-Shot NAS (Weight Sharing)

Instead of training 1,000 separate models, we train one “Supernet” that contains all possible sub-architectures as paths (Over-parameterized Graph).

  • The Supernet: A massive graph where edges represents operations (Conv3x3, SkipConn).
  • Sub-network Selection: A “Controller” selects a path through the Supernet.
  • Weight Inheritance: The sub-network inherits weights from the Supernet, avoiding retraining from scratch.
  • Ops Benefit: Training cost is ~1-2x a standard model, not 1,000x.
  • Ops Complexity: The Supernet is huge and hard to fit in GPU memory. Gradient synchronization is complex.

2. Differentiable NAS (DARTS)

Instead of using a discrete controller (RL), we relax the architecture search space to be continuous, allowing us to optimized architecture parameters with gradient descent.

  • Ops Benefit: Faster search.
  • Ops Risk: “Collapse” to simple operations (e.g., all Identity connections) if not regularized.

3. Zero-Cost Proxies

How do you estimate accuracy without training?

  • Synflow: Measure how well gradients flow through the network at initialization. It computes the sum of the absolute products of gradients and weights. $$ R_{synflow} = \sum_{\theta} |\theta \odot \frac{\partial \mathcal{L}}{\partial \theta}| $$ Ops Note: This can be computed in a “Forward-Backward” pass on a single batch of data.
  • Fisher: Uses the Fisher Information Matrix to estimate the sensitivity of the loss to parameters.
  • Ops Impact: Allows pruning 99% of architectures in milliseconds before submitting the 1% to the GPU cluster.

The killer app for NAS in production is not “1% better accuracy”; it is “100% faster inference”. Hardware-Aware NAS searches for the architecture that maximizes accuracy subject to a latency constraint on a specific target device (e.g., “Must run < 10ms on iPhone 12 NPU”).

The Latency Lookup Table (The “Proxy”)

To make this search efficient, we cannot run a real benchmark on an iPhone for every candidate architecture (network latency would kill the search speed). instead, we pre-build a Latency Table.

  1. Profiling: Isolate standard blocks (Conv3x3, MBConv, Attention) + Input Shapes.
  2. Benchmarking: Run these micro-benchmarks on the physical target device (Device Farm).
  3. Lookup: Store (op_type, input_shape, stride) -> latency_ms.
  4. Search: During the NAS loop, the agent queries the table (sum of operation latencies) instead of running the model. This is O(1).

Reference Latency Table (Sample)

OperationInput StrideChannelsiPhone 12 (NPU) msJetson Nano (GPU) msT4 (Server GPU) ms
Conv3x31320.0450.0820.005
Conv3x32640.0380.0700.005
MBConv6_3x31320.1200.2100.012
SelfAttention-1280.4500.8900.025
AvgPool21280.0100.0150.001

Python Code: Building the Lookup Table

This runs on the edge device to populate the DB.

import time
import torch
import torch.nn as nn
import json

def profile_block(block, input_shape, iterations=100):
    dummy_input = torch.randn(input_shape).cuda()
    block.cuda()
    
    # Warmup
    for _ in range(10):
        _ = block(dummy_input)
        
    torch.cuda.synchronize()
    start = time.time()
    
    for _ in range(iterations):
        _ = block(dummy_input)
        
    torch.cuda.synchronize()
    avg_latency = (time.time() - start) / iterations
    return avg_latency * 1000 # ms

ops = {
    "Conv3x3_32": nn.Conv2d(32, 32, 3, padding=1),
    "Conv1x1_32": nn.Conv2d(32, 32, 1),
    "MaxPool": nn.MaxPool2d(2),
    "MBConv3_3x3_32": nn.Sequential(
        nn.Conv2d(32, 32*3, 1), # Expand
        nn.Conv2d(32*3, 32*3, 3, groups=32*3, padding=1), # Depthwise
        nn.Conv2d(32*3, 32, 1) # Project
    )
}

results = {}
for name, layer in ops.items():
    lat = profile_block(layer, (1, 32, 224, 224))
    results[name] = lat
    print(f"{name}: {lat:.4f} ms")

with open("latency_table_nvidia_t4.json", "w") as f:
    json.dump(results, f)

44.3.3. Rust Implementation: A Search Space Pruner

Below is a Rust snippet for a high-performance “Pruner” that rejects invalid architectures before they hit the training queue. This is crucial because Python-based graph traversal can be a bottleneck when evaluating millions of candidates in a Genetic Algorithm.

use std::collections::HashMap;
use serde::{Deserialize, Serialize};

// A simple representation of a Neural Network Layer
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
enum LayerType {
    Conv3x3,
    Conv5x5,
    Identity,
    MaxPool,
    AvgPool,
    DepthwiseConv3x3,
    MBConv3,
    MBConv6,
}

#[derive(Debug, Deserialize)]
struct Architecture {
    layers: Vec<LayerType>,
    input_resolution: u32,
    channels: Vec<u32>, // Width search
}

#[derive(Debug, Deserialize)]
struct Constraint {
    max_layers: usize,
    max_flops: u64,
    max_params: u64,
    max_conv5x5: usize,
    estimated_latency_budget_ms: f32,
}

impl Architecture {
    // Fast estimation of latency using a lookup table
    // In production, this allows interpolation for resolutions
    fn estimate_latency(&self, lookup: &HashMap<LayerType, f32>) -> f32 {
        self.layers.iter().map(|l| lookup.get(l).unwrap_or(&0.1)).sum()
    }

    // Estimate FLOPs (simplified)
    fn estimate_flops(&self) -> u64 {
        let mut flops = 0;
        for (i, layer) in self.layers.iter().enumerate() {
            let ch = self.channels.get(i).unwrap_or(&32);
            let res = self.input_resolution; // Assume no downsampling for simplicity
            
            let ops = match layer {
                LayerType::Conv3x3 => 3 * 3 * res.pow(2) * ch.pow(2) as u64,
                LayerType::Conv5x5 => 5 * 5 * res.pow(2) * ch.pow(2) as u64,
                LayerType::MBConv6 => 6 * res.pow(2) * ch.pow(2) as u64, // simplified
                _ => 0, 
            };
            flops += ops;
        }
        flops
    }

    // The Gatekeeper function
    // Returns Option<String> where None = Valid, Some(Reason) = Invalid
    fn check_validity(&self, constraints: &Constraint, lookup: &HashMap<LayerType, f32>) -> Option<String> {
        if self.layers.len() > constraints.max_layers {
            return Some(format!("Too many layers: {}", self.layers.len()));
        }

        let conv5_count = self.layers.iter()
            .filter(|&l| *l == LayerType::Conv5x5)
            .count();
        
        if conv5_count > constraints.max_conv5x5 {
            return Some(format!("Too many expensive Conv5x5: {}", conv5_count));
        }

        let latency = self.estimate_latency(lookup);
        if latency > constraints.estimated_latency_budget_ms {
            return Some(format!("Latency budget exceeded: {:.2} > {:.2}", latency, constraints.estimated_latency_budget_ms));
        }
        
        let flops = self.estimate_flops();
        if flops > constraints.max_flops {
            return Some(format!("FLOPs budget exceeded: {} > {}", flops, constraints.max_flops));
        }

        None
    }
}

fn load_latency_table() -> HashMap<LayerType, f32> {
    let mut map = HashMap::new();
    map.insert(LayerType::Conv3x3, 1.5);
    map.insert(LayerType::Conv5x5, 4.2);
    map.insert(LayerType::MaxPool, 0.5);
    map.insert(LayerType::Identity, 0.05);
    map.insert(LayerType::AvgPool, 0.6);
    map.insert(LayerType::DepthwiseConv3x3, 0.8);
    map.insert(LayerType::MBConv3, 2.1);
    map.insert(LayerType::MBConv6, 3.5);
    map
}

#[tokio::main]
async fn main() {
    // 1. Setup
    let latency_table = load_latency_table();
    
    // 2. Define Production Constraints
    let constraints = Constraint {
        max_layers: 50,
        max_flops: 1_000_000_000,
        max_params: 5_000_000,
        max_conv5x5: 5, // Strictly limit expensive ops
        estimated_latency_budget_ms: 25.0, 
    };

    // 3. Batch Process Candidates (e.g., from Kafka or a file)
    let candidate = Architecture {
        layers: vec![
            LayerType::Conv3x3,
            LayerType::Identity,
            LayerType::MBConv6,
            LayerType::MaxPool,
            LayerType::Conv5x5,
        ],
        input_resolution: 224,
        channels: vec![32, 32, 64, 64, 128],
    };

    // 4. MLOps Gatekeeping
    match candidate.check_validity(&constraints, &latency_table) {
        None => println!("Candidate ACCEPTED for finetuning."),
        Some(reason) => println!("Candidate REJECTED: {}", reason),
    }
}

44.3.4. Managing the Search Space Cache

NAS is often wasteful because it re-discovers the same architectures (Isomorphic Graphs). An “Architecture Database” is a critical MLOps component for NAS teams.

Schema for an Architecture DB (Postgres/DynamoDB)

  • Arch Hash: Unique SHA signature of the graph topology (Canonicalized to handle isomorphism).
  • Metrics: Accuracy, Latency (Mobile), Latency (Server), FLOPs, Params.
  • Training State: Untrained, OneShot, FineTuned.
  • Artifacts: Weights URL (S3).
CREATE TABLE latency_lookup (
    hardware_id VARCHAR(50), -- e.g. "iphone12_npu"
    op_type VARCHAR(50),     -- e.g. "Conv3x3"
    input_h INT,
    input_w INT,
    channels_in INT,
    channels_out INT,
    stride INT,
    latency_micros FLOAT,    -- The golden number
    energy_mj FLOAT,         -- Power consumption
    PRIMARY KEY (hardware_id, op_type, input_h, input_w, channels_in, channels_out, stride)
);

Search Space Configuration (YAML)

Define your priors in a config file, not code.

# nas_search_config_v1.yaml
search_space:
  backbone:
    type: "MobileNetV3"
    width_mult: [0.5, 0.75, 1.0]
    depth_mult: [1.0, 1.2]
  head:
    type: "FPN"
    channels: [64, 128]

constraints:
  latency:
    target_device: "pixel6_tpu"
    max_ms: 15.0
  size:
    max_params_m: 3.5

strategy:
  algorithm: "DNA (Block-Wisely)"
  supernet_epochs: 50
  finetune_epochs: 100
  population_size: 50

44.3.5. Troubleshooting Common NAS Issues

1. The “Identity Collapse”

  • Symptom: DARTS converges to a network of all “Skip Connections”. Accuracy is terrible, but loss was low during search.
  • Why: Skip connections are “easy” for gradient flow. The optimizer took the path of least resistance.
  • Fix: Add “Topology Regularization” or force a minimum number of FLOPs.

2. The “Supernet Gap”

  • Symptom: The best architecture found on the Supernet performs poorly when trained from scratch.
  • Why: Weight sharing correlation is low. The weights in the Supernet were fighting each other (interference).
  • Fix: Use “One-Shot NAS with Fine-Tuning” or “Few-Shot NAS”. Measure the Kendall-Tau correlation between Supernet accuracy and Standalone accuracy.

3. Latency Mismatch

  • Symptom: NAS predicts 10ms, Real device is 20ms.
  • Why: The Latency Lookup Table ignored memory access costs (MACs) or cache misses.
  • Fix: Incorporate “fragmentation penalty” in the lookup table.

44.3.6. FAQ

Q: Should I use NAS for tabular data? A: No. Use Gradient Boosting (AutoGluon/XGBoost). NAS is useful for perceptual tasks (Vision, Audio) where inductive biases matter (e.g., finding the right receptive field size).

Q: Do I need a GPU cluster for NAS? A: For One-Shot NAS, a single 8-GPU node is sufficient. For standard Evolution NAS, you need massive scale (hundreds of GPUs).

Q: What is the difference between HPO and NAS? A: HPO tunes scalar values (learning rate, layers). NAS tunes the graph topology (connections, operations). HPO is a subset of NAS.

44.3.7. Glossary

  • DARTS (Differentiable Architecture Search): A continuous relaxation of the architecture representation, allowing gradient descent to find architectures.
  • Supernet: A mega-network containing all possible operations. Subgraphs are sampled from this during search.
  • Zero-Cost Proxy: A metric (like Synflow) that evaluates an untrained network’s potential in milliseconds.
  • Hardware-Aware: Incorporating physical device latency into the loss function of the search.
  • Kendall-Tau: A rank correlation coefficient used to measure if the Supernet ranking matches the true standalone capability ranking.
  • Macro-Search: Searching for the connection between blocks.
  • Micro-Search: Searching for the operations inside a block (e.g., cell search).

44.3.8. Summary

NAS is powerful but expensive. To operationalize it:

  1. Use Weight Sharing to reduce training costs from N * Cost to 1.5 * Cost.
  2. Optimize for Hardware Latency using Lookup Tables, not just accuracy.
  3. Use Architecture Caching to avoid redundant work.
  4. Implement fast Pruning Gates to filter candidates cheaply before they consume GPU cycles.