Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

45.2. The Rust ML Ecosystem

Note

Not Just Wrappers: A few years ago, “Rust ML” meant “calling libtorch C++ bindings from Rust.” Today, we have a native ecosystem. Burn and Candle are written in pure Rust. They don’t segfault when C++ throws an exception.

45.2.1. The Landscape: A Feature Matrix

Before diving into code, let’s map the Python ecosystem to Rust.

DomainPython StandardRust StandardMaturity (1-10)Notes
Deep LearningPyTorch / TensorFlowBurn8Dynamic graphs, multiple backends (WGPU, Torch, Ndarray).
LLM InferencevLLM / CTranslate2Candle / Mistral.rs9Hugging Face supported. Production ready.
Classical MLScikit-LearnLinfa / SmartCore7Good for KMeans/SVM, missing esoteric algos.
DataframesPandasPolars10Faster than Pandas. Industry standard.
TensorsPubMedndarray9Mature, BLAS-backed.
VisualizationMatplotlibPlotters7Verbal, but produces high-quality SVG/PNG.
AutoDiffAutograddfdx6Compile-time shape checking (Experimental).

45.2.2. Burn: The “PyTorch of Rust”

Burn is the most promising General Purpose Deep Learning framework.

  • Philosophy: “Dynamic Graphs, Static Performance.” It feels like PyTorch (eager execution) but always compiles to highly optimized kernels.
  • Backends:
    • wgpu: Runs on any GPU (Vulkan/Metal/DX12). No CUDA lock-in!
    • tch: Libtorch (if you really need CUDA).
    • ndarray: CPU.

1. Defining the Model (The Module Trait)

In PyTorch, you subclass nn.Module. In Burn, you derive Module.

#![allow(unused)]
fn main() {
use burn::{
    nn::{loss::CrossEntropyLossConfig, Linear, LinearConfig, Relu},
    prelude::*,
    tensor::backend::Backend,
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    linear1: Linear<B>,
    relu: Relu,
    linear2: Linear<B>,
}

impl<B: Backend> Model<B> {
    // Constructor (Note: Config driven)
    pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize, device: &B::Device) -> Self {
        let linear1 = LinearConfig::new(input_dim, hidden_dim).init(device);
        let linear2 = LinearConfig::new(hidden_dim, output_dim).init(device);
        
        Self {
            linear1,
            relu: Relu::new(),
            linear2,
        }
    }
    
    // The Forward Pass
    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.linear1.forward(input);
        let x = self.relu.forward(x);
        self.linear2.forward(x)
    }
}
}

Key Differences from PyTorch:

  1. Generics: <B: Backend>. This code compiles 3 times: once for CPU, once for WGPU, once for Torch.
  2. Explicit Device: You pass device to .init(). No more “Expected tensor on cuda:0 but got cpu”.

2. The Training Loop (Learner)

Burn uses a Learner struct (similar to PyTorch Lightning) to abstract the loop.

#![allow(unused)]
fn main() {
use burn::train::{LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition};
use burn::optim::AdamConfig;

pub fn train<B: Backend>(device: B::Device) {
    // 1. Config
    let config = TrainingConfig::new(ModelConfig::new(10), AdamConfig::new());
    
    // 2. DataLoaders
    let batcher = MnistBatcher::<B>::new(device.clone());
    let dataloader_train = DataLoaderBuilder::new(batcher.clone())
        .batch_size(64)
        .shuffle(42)
        .num_workers(4)
        .build(MnistDataset::train());
        
    let dataloader_test = DataLoaderBuilder::new(batcher.clone())
        .batch_size(64)
        .build(MnistDataset::test());

    // 3. Learner
    let learner = LearnerBuilder::new("/tmp/artifacts")
        .metric_train_numeric(AccuracyMetric::new())
        .metric_valid_numeric(AccuracyMetric::new())
        .with_file_checkpointer(1, Compact)
        .devices(vec![device.clone()])
        .num_epochs(10)
        .build(
            ModelConfig::new(10).init(&device),
            config.optimizer.init(),
            config.learning_rate,
        );

    // 4. Fit
    let model_trained = learner.fit(dataloader_train, dataloader_test);
}
}

3. Custom Training Step (Under the Hood)

If you need a custom loop (e.g., GANs, RL), you implement TrainStep.

#![allow(unused)]
fn main() {
impl<B: Backend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
    fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
        let item = self.forward(batch.images);
        let loss = CrossEntropyLoss::new(None).forward(item.output.clone(), batch.targets.clone());
        
        // AutoDiff happens here
        let grads = loss.backward();
        
        TrainOutput::new(self, grads, ClassificationOutput::new(loss, item.output, batch.targets))
    }
}
}

45.2.3. Candle: Minimalist Inference (Hugging Face)

Candle is built by Hugging Face specifically for LLM Inference.

  • Goal: remove the massive 5GB torch dependency. Candle binaries are tiny (~10MB).
  • Features: Quantization (4-bit/8-bit), Flash Attention v2, SafeTensors support.

1. Minimal Llama 2 Inference

This is a complete, compilable example of loading Llama 2 and generating text.

use candle_core::{Tensor, Device, DType};
use candle_nn::{VarBuilder, Module};
use candle_transformers::models::llama::Llama;
use hf_hub::{api::sync::Api, Repo, RepoType};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 1. Select Device (CUDA -> Metal -> CPU)
    let device = Device::cuda_if_available(0)
        .unwrap_or(Device::new_metal(0).unwrap_or(Device::Cpu));
        
    println!("Running on: {:?}", device);

    // 2. Download Weights (Hugging Face Hub)
    let api = Api::new()?;
    let repo = api.repo(Repo::new("meta-llama/Llama-2-7b-chat-hf".to_string(), RepoType::Model));
    let model_path = repo.get("model.safetensors")?;
    let config_path = repo.get("config.json")?;

    // 3. Load Model (Zero Copy Mmap)
    let config = std::fs::read_to_string(config_path)?;
    let config: LlamaConfig = serde_json::from_str(&config)?;
    
    let vb = unsafe { 
        VarBuilder::from_mmaped_safetensors(&[model_path], DType::F16, &device)? 
    };
    
    let model = Llama::load(vb, &config)?;

    // 4. Tokenization (Using 'tokenizers' crate)
    let tokenizer = Tokenizer::from_file(repo.get("tokenizer.json")?)?;
    let tokens = tokenizer.encode("The capital of France is", true)?.get_ids().to_vec();
    let mut input = Tensor::new(tokens, &device)?.unsqueeze(0)?;

    // 5. Generation Loop
    for _ in 0..20 {
        let logits = model.forward(&input)?;
        // Sample next token (Argmax for greedy)
        let next_token_id = logits_processor.sample(&logits)?;
        
        print!("{}", tokenizer.decode(&[next_token_id], true)?);
        
        // Append to input (kv-cache handles history automatically in Candle)
        input = Tensor::new(&[next_token_id], &device)?.unsqueeze(0)?;
    }
    
    Ok(())
}

2. Custom Kernels (CUDA in Rust)

Candle allows you to write custom CUDA kernels. Unlike PyTorch (where you write C++), Candle uses cudarc to compile PTX at runtime or load pre-compiled cubins.

#![allow(unused)]
fn main() {
// Simplified Custom Op
struct MyCustomOp;

impl CustomOp1 for MyCustomOp {
    fn name(&self) -> &'static str { "my_custom_op" }
    
    fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Layout)> {
        // CPU fallback implementation
    }
    
    fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Layout)> {
        // Launch CUDA Kernel
        let function_name = "my_kernel_Function";
        let kernel = s.device().get_or_load_func(function_name, PTX_SOURCE)?;
        
        unsafe { kernel.launch(...) }
    }
}
}

45.2.4. Linfa: The “Scikit-Learn of Rust”

For classical ML (K-Means, PCA, SVM), Linfa is the standard. It uses ndarray for data representation.

1. K-Means Clustering Full Example

use linfa::prelude::*;
use linfa_clustering::KMeans;
use linfa_datasets::iris;
use plotters::prelude::*; // Visualization

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 1. Load Data
    let dataset = iris();
    
    // 2. Train KMeans
    let model = KMeans::params(3)
        .max_n_iterations(200)
        .tolerance(1e-5)
        .fit(&dataset)
        .expect("KMeans failed");
        
    // 3. Predict Cluster Labels
    let labels = model.predict(&dataset);
    
    // 4. Visualization (Plotters)
    let root = BitMapBackend::new("clusters.png", (640, 480)).into_drawing_area();
    root.fill(&WHITE)?;
    
    let mut chart = ChartBuilder::on(&root)
        .caption("Iris K-Means in Rust", ("sans-serif", 50).into_font())
        .margin(5)
        .x_label_area_size(30)
        .y_label_area_size(30)
        .build_cartesian_2d(4.0f32..8.0f32, 2.0f32..4.5f32)?;
        
    chart.configure_mesh().draw()?;
    
    // Scatter Plot behavior in Rust
    chart.draw_series(
        dataset.records.outer_iter().zip(labels.iter()).map(|(point, &label)| {
            let x = point[0];
            let y = point[1];
            let color = match label {
                0 => RED,
                1 => GREEN,
                _ => BLUE,
            };
            Circle::new((x, y), 5, color.filled())
        })
    )?;
    
    println!("Chart saved to clusters.png");
    Ok(())
}

2. Supported Algorithms

AlgorithmStatusNotes
K-MeansStableFast, supports parallel init.
DBSCANStableGood for noise handling.
Logistic RegressionStableL1/L2 regularization.
SVMBetaSupports RBF Kernels.
PCAStableUses SVD under the hood.
Random ForestAlphaTrees are hard to optimize in Rust without unsafe pointers.

45.2.5. ndarray: The Tensor Foundation

If you know NumPy, you know ndarray. It provides the ArrayBase struct that underpins linfa and burn (CPU backend).

Slicing and Views

In Python, slicing a[:] creates a view. In Rust, you must be explicit.

use ndarray::{Array3, s};

fn main() {
    // Create 3D tensor (Depth, Height, Width)
    let mut image = Array3::<u8>::zeros((3, 224, 224));
    
    // Slice: Center Crop
    // s! macro simulates Python slicing syntax
    let crop = image.slice_mut(s![.., 50..150, 50..150]);
    
    // 'crop' is a View (ArrayViewMut3). No data copied.
    // Changing crop changes image.
    crop.fill(255); 
}

Broadcasting

#![allow(unused)]
fn main() {
let a = Array::from_elem((3, 4), 1.0);
let b = Array::from_elem((4,), 2.0);

// Python: a + b (Implicit broadcasting)
// Rust: &a + &b (Explicit borrowing)
let c = &a + &b; 

assert_eq!(c.shape(), &[3, 4]);
}

45.2.6. dfdx: Compile-Time Shape Checking

DFDX (Derivatives for Dummies) is an experimental library that prevents shape mismatch errors at compile time.

The Problem it Solves

In PyTorch, you define: self.layer = nn.Linear(10, 20) Then forward: self.layer(tensor_with_30_features) Runtime Error: “Size mismatch”. This happens 10 hours into training.

The DFDX Solution

In Rust, Generic Const Exprs allow us to encode dimensions in the type.

use dfdx::prelude::*;

// Define Network Architecture as a Type
type MLP = (
    Linear<10, 50>,
    ReLU,
    Linear<50, 20>, // Output must match next Input
    Tanh,
    Linear<20, 2>,
);

fn main() {
    let dev: Cpu = Default::default();
    let model: MLP = dev.build_module(Default::default(), Default::default());
    
    let x: Tensor<Rank1<10>> = dev.zeros(); // Shape is [10]
    let y = model.forward(x); // Works!
    
    // let z: Tensor<Rank1<30>> = dev.zeros();
    // let out = model.forward(z); 
    // ^^^ COMPILER ERROR: "Expected Rank1<10>, found Rank1<30>"
}

This guarantees that if your binary builds, your tensor shapes line up perfectly across the entire network.

45.2.7. The Ecosystem Map: “What is the X of Rust?”

If you are coming from Python, this map is your survival guide.

PythonRustMaturityNotes
NumPyndarrayHighJust as fast, but stricter broadcasting.
PandaspolarsHighFaster, lazy execution, Arrow-native.
Scikit-LearnlinfaMidGood coverage, API is similar.
PyTorchburnHighDynamic graphs, cross-platform.
TensorFlowtensorflow-rustMidJust bindings to C++ lib. Avoiding it is recommended.
RequestsreqwestHighAsync by default, extremely robust.
FastAPIaxumHighErgonomic, built on Tokio.
Flask/Djangoactix-webHighHighest performance web framework in the world.
JupyterevcxrMidRust kernel for Jupyter.
MatplotlibplottersMidGood for static charts, less interactive.
OpenCVopencv-rustMidBindings to C++. Heavy build time.
Pillow (PIL)imageHighPure Rust image decoding (JPEG/PNG). Safe.
LibrosasymphoniaHighPure Rust audio decoding (MP3/WAV/AAC).
TqdmindicatifHighBeautiful progress bars.
ClickclapHighBest-in-class CLI parser.

45.2.8. Domain Specifics: Vision and Audio

MLOps is rarely just “Vectors”. It involves decoding complex binary formats. In Python, this relies on libjpeg, ffmpeg, etc. (unsafe C libs). In Rust, we have safe alternatives.

1. Computer Vision (image crate)

The image crate is a pure Rust implementation of image decoders. No libpng vulnerability panic.

#![allow(unused)]
fn main() {
use image::{GenericImageView, imageops::FilterType};

fn process_image(path: &str) {
    // 1. Load (Detects format automatically)
    let img = image::open(path).expect("File not found");
    
    // 2. Metadata
    println!("Dimensions: {:?}", img.dimensions());
    println!("Color: {:?}", img.color());
    
    // 3. Resize (Lanczos3 is high quality)
    let resized = img.resize(224, 224, FilterType::Lanczos3);
    
    // 4. Convert to Tensor (ndarray)
    let raw_pixels = resized.to_rgb8().into_raw();
    // ... feed to Burn ...
}
}

2. Audio Processing (symphonia)

Decoding MP3s correctly is famously hard. symphonia is a generic media library used by Spotify-like services built in Rust.

#![allow(unused)]
fn main() {
use symphonia::core::probe::Probe;

fn decode_mp3() {
    let file = std::fs::File::open("music.mp3").unwrap();
    let mss = symphonia::default::get_probe()
        .format(&hint, MediaSourceStream::new(Box::new(file), Default::default()), &fmt_opts, &meta_opts)
        .expect("unsupported format");
        
    let mut format = mss.format;
    let track = format.default_track().unwrap();
    
    // Decode Loop
    loop {
        let packet = format.next_packet().unwrap();
        let decoded = decoder.decode(&packet).unwrap();
        // ... access PCM samples ...
    }
}
}

45.2.9. SmartCore: The Alternative to Linfa

SmartCore is another ML library. Unlike Linfa (which splits into many crates), SmartCore is a monolith. It puts emphasis on Linear Algebra traits.

use smartcore::linear::logistic_regression::LogisticRegression;
use smartcore::metrics::accuracy;

fn main() {
    // Load Iris
    let iris_data = smartcore::dataset::iris::load_dataset();
    let x = iris_data.data;
    let y = iris_data.target;
    
    // Train
    let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
    
    // Predict
    let y_hat = lr.predict(&x).unwrap();
    
    // Evaluate
    println!("Accuracy: {}", accuracy(&y, &y_hat));
}

Linfa vs SmartCore:

  • Use Linfa if you want modularity and ndarray first-class support.
  • Use SmartCore if you want a “Batteries Included” experience similar to R.

45.2.10. Rust Notebooks (Evcxr)

You don’t have to give up Jupyter. Evcxr is a collection of tools (REPL + Jupyter Kernel) that allows executing Rust incrementally.

Installation:

cargo install evcxr_jupyter
evcxr_jupyter --install

Cell 1:

#![allow(unused)]
fn main() {
:dep ndarray = "0.15"
:dep plotters = "0.3"

use ndarray::Array;
let x = Array::linspace(0., 10., 100);
let y = x.map(|v| v.sin());
}

Cell 2:

#![allow(unused)]
fn main() {
// Plotting inline in Jupyter!
use plotters::prelude::*;
let root = BitMapBackend::new("output.png", (600, 400)).into_drawing_area();
// ... drawing code ...
// Evcxr automatically displays the PNG.
}

45.2.11. Final Exam: Choosing your Stack

  1. “I need to train a Transformer from scratch.”

    • Burn. Use WGPU backend for Mac execution, or Torch backend for Cluster execution.
  2. “I need to deploy Llama-3 to a Raspberry Pi.”

    • Candle or Mistral.rs. Use 4-bit Quantization.
  3. “I need to cluster 1 Million customer vectors.”

    • Linfa (K-Means). Compile with --release. It will scream past Scikit-Learn.
  4. “I need to analyze 1TB of CSV logs.”

    • Polars. Do not use Pandas. Do not use Spark (unless it’s >10TB). Use Polars Streaming.

45.2.12. Deep Dive: GPGPU with WGPU

CUDA is vendor lock-in. WGPU is the portable future. It runs on Vulkan (Linux), Metal (Mac), DX12 (Windows), and WebGPU (Browser). Burn uses WGPU by default. But you can write raw shaders.

The Compute Shader (WGSL)

// shader.wgsl
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    // ReLU Activation Kernel
    output[index] = max(0.0, input[index]);
}

The Rust Host Code

#![allow(unused)]
fn main() {
use wgpu::util::DeviceExt;

async fn run_compute() {
    let instance = wgpu::Instance::default();
    let adapter = instance.request_adapter(&Default::default()).await.unwrap();
    let (device, queue) = adapter.request_device(&Default::default(), None).await.unwrap();
    
    // 1. Load Shader
    let cs_module = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
    
    // 2. Create Buffers (Input/Output)
    let input_data: Vec<f32> = vec![1.0, -1.0, 2.0, -2.0];
    let input_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
        label: Some("Input"),
        contents: bytemuck::cast_slice(&input_data),
        usage: wgpu::BufferUsages::STORAGE,
    });
    
    // 3. Dispatch
    let mut encoder = device.create_command_encoder(&Default::default());
    {
        let mut cpass = encoder.begin_compute_pass(&Default::default());
        cpass.set_pipeline(&compute_pipeline);
        cpass.set_bind_group(0, &bind_group, &[]);
        cpass.dispatch_workgroups(input_data.len() as u32 / 64 + 1, 1, 1);
    }
    queue.submit(Some(encoder.finish()));
    
    // 4. Readback (Async)
    // ... map_async ...
}
}

This runs on your MacBook (Metal) and your NVIDIA Server (Vulkan) without changing a line of code.

45.2.13. Reinforcement Learning: gym-rs

Python has OpenAI Gym. Rust has gym-rs. It connects to the same environments but allows agents to be written in Rust.

use gym_rs::{Action, Env, GymClient};

fn main() {
    let client = GymClient::default();
    let env = client.make("CartPole-v1");
    let mut observation = env.reset();
    
    let mut total_reward = 0.0;
    
    for _ in 0..1000 {
        // Random Agent
        let action = if rand::random() { 1 } else { 0 };
        
        let step = env.step(action);
        observation = step.observation;
        total_reward += step.reward;
        
        if step.done {
            println!("Episode finished inside Rust! Reward: {}", total_reward);
            break;
        }
    }
}

45.2.14. Graph Neural Networks: petgraph + Burn

Graph theory is where Rust shines due to strict ownership of nodes/edges. petgraph is the standard graph library.

#![allow(unused)]
fn main() {
use petgraph::graph::{Graph, NodeIndex};
use burn::tensor::Tensor;

struct GNNNode {
    features: Vec<f32>,
}

fn build_and_traverse() {
    let mut g = Graph::<GNNNode, ()>::new();
    
    let n1 = g.add_node(GNNNode { features: vec![0.1, 0.2] });
    let n2 = g.add_node(GNNNode { features: vec![0.5, 0.9] });
    g.add_edge(n1, n2, ());
    
    // Message Passing Step
    for node in g.node_indices() {
        let neighbors = g.neighbors(node);
        // Aggregate neighbor features...
    }
}
}

45.2.15. Rust vs Julia: The Systems Verdict

Julia is fantastic for Math (Multiple Dispatch is great). But Julia has a Heavy Runtime (LLVM JIT) and Garbage Collection. It suffers from the “Time to First Plot” problem.

  • Latency: Julia JIT compilation causes unpredictable latency spikes on first request. Not suitable for Lambda / Microservices.
  • Deployment: Julia images are large. Rust binaries are tiny.
  • Correctness: Julia is dynamic. Rust is static.

Verdict:

  • Use Julia for Research / Scientific Computing (replacing MATLAB).
  • Use Rust for MLOps / Production Engineering (replacing C++).

45.2.16. Advanced ndarray: Memory Layouts

Row-Major (C) vs Column-Major (Fortran). NumPy defaults to C. Linear Algebra libraries (BLAS) often prefer Fortran.

#![allow(unused)]
fn main() {
use ndarray::{Array2, ShapeBuilder};

fn memory_layouts() {
    // Standard (C-Contiguous)
    let a = Array2::<f32>::zeros((100, 100));
    
    // Fortran-Contiguous (f())
    let b = Array2::<f32>::zeros((100, 100).f());
    
    // Iteration Performance
    // Iterating 'a' by rows is fast.
    // Iterating 'b' by cols is fast.
}
}

Rust makes these layouts explicit types, preventing cache-thrashing bugs that plague Python/NumPy users.

45.2.17. Serialization: serde is King

The superpower of the Rust ecosystem is serde (Serializer/Deserializer). Every ML crate (ndarray, burn, candle) implements Serialize and Deserialize.

This means you can dump your entire Model Config, Dataset, or Tensor to JSON/Bincode/MessagePack effortlessly.

#![allow(unused)]
fn main() {
use serde::{Serialize, Deserialize};

#[derive(Serialize, Deserialize)]
struct ExperimentLog {
    epoch: usize,
    loss: f32,
    hyperparams: HyperParams, // Nested struct
}

fn save_log(log: &ExperimentLog) {
    let json = serde_json::to_string(log).unwrap();
    std::fs::write("log.json", json).unwrap();
    
    // Or binary for speed
    let bin = bincode::serialize(log).unwrap();
    std::fs::write("log.bin", bin).unwrap();
}
}

45.2.18. Crate of the Day: rkyv (Archive)

serde is fast, but rkyv is Zero-Copy Deserialization. It guarantees the same in-memory representation on disk as in RAM. Loading a 10GB Checkpoint takes 0 seconds (mmap).

#![allow(unused)]
fn main() {
use rkyv::{Archive, Serialize, Deserialize};

#[derive(Archive, Serialize, Deserialize)]
struct Checkpoint {
    weights: Vec<f32>,
}

// Accessing fields without parsing
fn read_checkpoint() {
    let bytes = std::fs::read("ckpt.rkyv").unwrap();
    let archived = unsafe { rkyv::archived_root::<Checkpoint>(&bytes) };
    
    // Instant access!
    println!("{}", archived.weights[0]);
}
}

45.2.19. Final Ecosystem Checklist

If you are building an ML Platform in Rust, verify you have these crates:

  1. Core: tokio, anyhow, thiserror, serde, clap.
  2. Data: polars, ndarray, sqlx.
  3. ML: burn or candle.
  4. Observability: tracing, tracing-subscriber, metrics.
  5. Utils: itertools, rayon, dashmap (Concurrent HashMap).

With this stack, you are unstoppable.

[End of Section 45.2]

45.2.20. Accelerated Computing: cuDNN and Friends

For CUDA-accelerated training, Rust has bindings to NVIDIA’s libraries.

cudarc: Safe CUDA Bindings

#![allow(unused)]
fn main() {
use cudarc::driver::*;
use cudarc::cublas::CudaBlas;

fn gpu_matrix_multiply() -> Result<(), DriverError> {
    let dev = CudaDevice::new(0)?;
    
    let m = 1024;
    let n = 1024;
    let k = 1024;
    
    // Allocate GPU memory
    let a = dev.alloc_zeros::<f32>(m * k)?;
    let b = dev.alloc_zeros::<f32>(k * n)?;
    let c = dev.alloc_zeros::<f32>(m * n)?;
    
    // Use cuBLAS for GEMM
    let blas = CudaBlas::new(dev.clone())?;
    
    unsafe {
        blas.sgemm(
            false, false,
            m as i32, n as i32, k as i32,
            1.0, // alpha
            &a, m as i32,
            &b, k as i32,
            0.0, // beta
            &mut c, m as i32,
        )?;
    }
    
    Ok(())
}
}

Flash Attention in Rust

Flash Attention is critical for efficient LLM inference. Candle implements it directly.

#![allow(unused)]
fn main() {
use candle_transformers::models::with_tracing::flash_attn;

fn scaled_dot_product_attention(
    query: &Tensor,
    key: &Tensor,
    value: &Tensor,
    scale: f64,
) -> Result<Tensor, Error> {
    // Use Flash Attention when available
    if cfg!(feature = "flash-attn") {
        flash_attn(query, key, value, scale as f32, false)
    } else {
        // Fallback to standard attention
        let attn_weights = (query.matmul(&key.transpose(-2, -1)?)? * scale)?;
        let attn_weights = candle_nn::ops::softmax(&attn_weights, -1)?;
        attn_weights.matmul(value)
    }
}
}

45.2.21. Model Compilation: Optimization at Compile Time

Tract NNEF/ONNX Optimization

#![allow(unused)]
fn main() {
use tract_onnx::prelude::*;

fn optimize_model(model_path: &str) -> TractResult<()> {
    // Load ONNX model
    let model = tract_onnx::onnx()
        .model_for_path(model_path)?
        .with_input_fact(0, f32::fact([1, 3, 224, 224]))?;
    
    // Optimize for inference
    let optimized = model
        .into_optimized()?
        .into_runnable()?;
    
    // Benchmark
    let input = tract_ndarray::Array4::<f32>::zeros((1, 3, 224, 224));
    let input = input.into_tensor();
    
    let start = std::time::Instant::now();
    for _ in 0..100 {
        let _ = optimized.run(tvec![input.clone().into()])?;
    }
    let elapsed = start.elapsed();
    
    println!("Average: {:.2}ms", elapsed.as_millis() as f64 / 100.0);
    
    Ok(())
}
}

Static Shapes for Performance

#![allow(unused)]
fn main() {
// Dynamic shapes (slow)
let model = model.with_input_fact(0, f32::fact(vec![dim_of(), dim_of(), dim_of()]))?;

// Static shapes (fast)  
let model = model.with_input_fact(0, f32::fact([1, 512]))?;

// The difference: 
// - Dynamic: Runtime shape inference + memory allocation per batch
// - Static: Compile-time shape propagation + pre-allocated buffers
}

45.2.22. Distributed Training

While Python dominates training, Rust can orchestrate distributed systems.

Gradient Aggregation with NCCL

#![allow(unused)]
fn main() {
use nccl_rs::{Comm, ReduceOp};

fn distributed_step(
    comm: &Comm,
    local_gradients: &mut [f32],
    world_size: usize,
) -> Result<(), Error> {
    // All-reduce gradients across GPUs
    comm.all_reduce(
        local_gradients,
        ReduceOp::Sum,
    )?;
    
    // Average
    for grad in local_gradients.iter_mut() {
        *grad /= world_size as f32;
    }
    
    Ok(())
}
}

Parameter Server Pattern

#![allow(unused)]
fn main() {
use tokio::sync::mpsc;

pub struct ParameterServer {
    parameters: Arc<RwLock<HashMap<String, Tensor>>>,
    rx: mpsc::Receiver<WorkerMessage>,
}

pub enum WorkerMessage {
    GetParameters { layer: String, reply: oneshot::Sender<Tensor> },
    PushGradients { layer: String, gradients: Tensor },
}

impl ParameterServer {
    pub async fn run(&mut self) {
        while let Some(msg) = self.rx.recv().await {
            match msg {
                WorkerMessage::GetParameters { layer, reply } => {
                    let params = self.parameters.read().await;
                    let tensor = params.get(&layer).cloned().unwrap();
                    let _ = reply.send(tensor);
                }
                WorkerMessage::PushGradients { layer, gradients } => {
                    let mut params = self.parameters.write().await;
                    if let Some(p) = params.get_mut(&layer) {
                        // SGD update
                        *p = p.sub(&gradients.mul_scalar(0.01));
                    }
                }
            }
        }
    }
}
}

45.2.23. SIMD-Accelerated Operations

Rust exposes CPU SIMD directly via std::simd (nightly) or portable-simd crates.

#![allow(unused)]
fn main() {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

#[target_feature(enable = "avx2")]
unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
    assert_eq!(a.len(), b.len());
    assert!(a.len() % 8 == 0);
    
    let mut sum = _mm256_setzero_ps();
    
    for i in (0..a.len()).step_by(8) {
        let va = _mm256_loadu_ps(a.as_ptr().add(i));
        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
        sum = _mm256_fmadd_ps(va, vb, sum);
    }
    
    // Horizontal sum
    let low = _mm256_extractf128_ps::<0>(sum);
    let high = _mm256_extractf128_ps::<1>(sum);
    let sum128 = _mm_add_ps(low, high);
    
    let mut result = [0.0f32; 4];
    _mm_storeu_ps(result.as_mut_ptr(), sum128);
    result.iter().sum()
}
}

Portable SIMD

#![allow(unused)]
fn main() {
use wide::*;

fn relu_simd(data: &mut [f32]) {
    let zero = f32x8::ZERO;
    
    for chunk in data.chunks_exact_mut(8) {
        let v = f32x8::from(chunk);
        let result = v.max(zero);
        chunk.copy_from_slice(&result.to_array());
    }
    
    // Handle remainder
    for x in data.chunks_exact_mut(8).into_remainder() {
        *x = x.max(0.0);
    }
}
}

45.2.24. The Future: Rust 2024 and Beyond

GATs (Generic Associated Types)

GATs enable more expressive tensor types:

#![allow(unused)]
fn main() {
trait TensorBackend {
    type Tensor<const N: usize>: Clone;
    
    fn zeros<const N: usize>(shape: [usize; N]) -> Self::Tensor<N>;
    fn add<const N: usize>(a: Self::Tensor<N>, b: Self::Tensor<N>) -> Self::Tensor<N>;
}

// Now we can write generic code that works for any rank!
fn normalize<B: TensorBackend, const N: usize>(t: B::Tensor<N>) -> B::Tensor<N> {
    // ...
}
}

Const Generics for Dimension Safety

#![allow(unused)]
fn main() {
struct Tensor<T, const R: usize, const D: [usize; R]> {
    data: Vec<T>,
}

// This ONLY compiles if dimensions match at compile time
fn matmul<T: Num, const M: usize, const K: usize, const N: usize>(
    a: Tensor<T, 2, [M, K]>,
    b: Tensor<T, 2, [K, N]>,
) -> Tensor<T, 2, [M, N]> {
    // ...
}
}

45.2.25. Final Ecosystem Assessment

CrateProduction ReadinessRecommended For
Burn⭐⭐⭐⭐Training + Inference
Candle⭐⭐⭐⭐⭐LLM Inference
Polars⭐⭐⭐⭐⭐Data Engineering
ndarray⭐⭐⭐⭐⭐Numerical Computing
Linfa⭐⭐⭐⭐Classical ML
tract⭐⭐⭐⭐Edge Inference
dfdx⭐⭐⭐Research/Experiments

The Rust ML ecosystem is no longer experimental—it’s production-ready.

[End of Section 45.2]