Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

35.1. Audio Feature Extraction: The Spectrogram Pipeline

Note

Waveforms vs. Spectrograms: A Neural Network cannot “hear” a raw wav file ($16,000$ samples/sec). It is too noisy. We must convert time-domain signals into frequency-domain images (Spectrograms) so standard CNNs can “see” the sound.

Audio machine learning requires careful feature engineering to bridge the gap between raw waveforms and the tensor representations that neural networks understand. This chapter covers the complete pipeline from audio capture to production-ready feature representations.


35.1.1. Understanding Audio Fundamentals

The Audio Signal

Raw audio is a 1D time-series signal representing air pressure changes over time:

PropertyTypical ValueNotes
Sample Rate16,000 Hz (ASR), 44,100 Hz (Music)Samples per second
Bit Depth16-bit, 32-bit floatDynamic range
Channels1 (Mono), 2 (Stereo)Spatial dimensions
FormatWAV, FLAC, MP3, OpusCompression type

The Nyquist Theorem

To capture frequency $f$, you need sample rate $\geq 2f$:

  • Human speech: ~8kHz max → 16kHz sample rate sufficient
  • Music: ~20kHz max → 44.1kHz sample rate required
import numpy as np
import librosa

def demonstrate_nyquist():
    """Demonstrate aliasing when Nyquist is violated."""
    
    # Generate 8kHz tone
    sr = 44100  # High sample rate
    duration = 1.0
    t = np.linspace(0, duration, int(sr * duration))
    tone_8k = np.sin(2 * np.pi * 8000 * t)
    
    # Downsample to 16kHz (Nyquist = 8kHz, just barely sufficient)
    tone_16k = librosa.resample(tone_8k, orig_sr=sr, target_sr=16000)
    
    # Downsample to 12kHz (Nyquist = 6kHz, aliasing occurs)
    tone_12k = librosa.resample(tone_8k, orig_sr=sr, target_sr=12000)
    
    return {
        'original': tone_8k,
        'resampled_ok': tone_16k,
        'aliased': tone_12k
    }

35.1.2. The Standard Feature Extraction Pipeline

graph LR
    A[Raw Audio] --> B[Pre-emphasis]
    B --> C[Framing]
    C --> D[Windowing]
    D --> E[STFT]
    E --> F[Power Spectrum]
    F --> G[Mel Filterbank]
    G --> H[Log Compression]
    H --> I[Delta Features]
    I --> J[Normalization]
    J --> K[Output Tensor]

Step-by-Step Breakdown

  1. Raw Audio: 1D Array (float32, [-1, 1]).
  2. Pre-emphasis: High-pass filter to boost high frequencies.
  3. Framing: Cutting into 25ms windows with 10ms overlap.
  4. Windowing: Applying Hamming window to reduce spectral leakage.
  5. STFT (Short-Time Fourier Transform): Power Spectrum.
  6. Mel Filterbank: Mapping linear Hz to human-perceived “Mel” scale.
  7. Log: Compressing dynamic range (decibels).
  8. Delta Features: First and second derivatives (optional).
  9. Normalization: Zero-mean, unit-variance normalization.

35.1.3. Complete Feature Extraction Implementation

Librosa: Research-Grade Implementation

import librosa
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple
import soundfile as sf


@dataclass
class AudioFeatureConfig:
    """Configuration for audio feature extraction."""
    
    sample_rate: int = 16000
    n_fft: int = 2048          # FFT window size
    hop_length: int = 512       # Hop between frames
    n_mels: int = 128           # Number of Mel bands
    fmin: float = 0.0           # Minimum frequency
    fmax: Optional[float] = 8000.0  # Maximum frequency
    pre_emphasis: float = 0.97  # Pre-emphasis coefficient
    normalize: bool = True      # Whether to normalize output
    add_deltas: bool = False    # Add delta features


class AudioFeatureExtractor:
    """Production-ready audio feature extractor."""
    
    def __init__(self, config: AudioFeatureConfig = None):
        self.config = config or AudioFeatureConfig()
        self._mel_basis = None
        self._setup_mel_basis()
    
    def _setup_mel_basis(self):
        """Pre-compute Mel filterbank for efficiency."""
        self._mel_basis = librosa.filters.mel(
            sr=self.config.sample_rate,
            n_fft=self.config.n_fft,
            n_mels=self.config.n_mels,
            fmin=self.config.fmin,
            fmax=self.config.fmax
        )
    
    def load_audio(
        self, 
        path: str, 
        mono: bool = True
    ) -> Tuple[np.ndarray, int]:
        """Load audio file with consistent format."""
        
        y, sr = librosa.load(
            path, 
            sr=self.config.sample_rate,
            mono=mono
        )
        
        return y, sr
    
    def extract_features(self, audio: np.ndarray) -> np.ndarray:
        """Extract log-mel spectrogram features."""
        
        # Step 1: Pre-emphasis
        if self.config.pre_emphasis > 0:
            audio = np.append(
                audio[0], 
                audio[1:] - self.config.pre_emphasis * audio[:-1]
            )
        
        # Step 2: STFT
        stft = librosa.stft(
            audio,
            n_fft=self.config.n_fft,
            hop_length=self.config.hop_length,
            window='hann',
            center=True,
            pad_mode='reflect'
        )
        
        # Step 3: Power spectrum
        power_spec = np.abs(stft) ** 2
        
        # Step 4: Mel filterbank
        mel_spec = np.dot(self._mel_basis, power_spec)
        
        # Step 5: Log compression
        log_mel = librosa.power_to_db(
            mel_spec, 
            ref=np.max,
            top_db=80.0
        )
        
        # Step 6: Delta features (optional)
        if self.config.add_deltas:
            delta = librosa.feature.delta(log_mel, order=1)
            delta2 = librosa.feature.delta(log_mel, order=2)
            log_mel = np.concatenate([log_mel, delta, delta2], axis=0)
        
        # Step 7: Normalization
        if self.config.normalize:
            log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-8)
        
        return log_mel  # Shape: (n_mels * (1 + 2*add_deltas), time_steps)
    
    def extract_from_file(self, path: str) -> np.ndarray:
        """Convenience method for file-based extraction."""
        audio, _ = self.load_audio(path)
        return self.extract_features(audio)


# Example usage
extractor = AudioFeatureExtractor(AudioFeatureConfig(
    sample_rate=16000,
    n_mels=80,
    add_deltas=True,
    normalize=True
))

features = extractor.extract_from_file("speech.wav")
print(f"Feature shape: {features.shape}")  # (240, T) with deltas

Torchaudio: GPU-Accelerated Production

import torch
import torchaudio
from torchaudio import transforms as T
from typing import Tuple


class TorchAudioFeatureExtractor(torch.nn.Module):
    """
    GPU-accelerated feature extraction for training loops.
    
    Key Advantages:
    - Runs on GPU alongside model
    - Differentiable (for end-to-end training)
    - Batched processing
    """
    
    def __init__(
        self,
        sample_rate: int = 16000,
        n_mels: int = 80,
        n_fft: int = 1024,
        hop_length: int = 256,
        f_min: float = 0.0,
        f_max: float = 8000.0
    ):
        super().__init__()
        
        self.sample_rate = sample_rate
        
        # Pre-emphasis filter
        self.register_buffer(
            'pre_emphasis_filter',
            torch.FloatTensor([[-0.97, 1]])
        )
        
        # Mel spectrogram transform
        self.mel_spectrogram = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            f_min=f_min,
            f_max=f_max,
            power=2.0,
            normalized=False,
            mel_scale='htk'
        )
        
        # Amplitude to dB
        self.amplitude_to_db = T.AmplitudeToDB(
            stype='power',
            top_db=80.0
        )
    
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Extract features from waveform.
        
        Args:
            waveform: (batch, samples) or (batch, channels, samples)
            
        Returns:
            features: (batch, n_mels, time)
        """
        
        # Ensure 2D: (batch, samples)
        if waveform.dim() == 3:
            waveform = waveform.mean(dim=1)  # Mix to mono
        
        # Pre-emphasis
        waveform = torch.nn.functional.conv1d(
            waveform.unsqueeze(1),
            self.pre_emphasis_filter.unsqueeze(0),
            padding=1
        ).squeeze(1)[:, :-1]
        
        # Mel spectrogram
        mel_spec = self.mel_spectrogram(waveform)
        
        # Log scale
        log_mel = self.amplitude_to_db(mel_spec)
        
        # Instance normalization
        mean = log_mel.mean(dim=(1, 2), keepdim=True)
        std = log_mel.std(dim=(1, 2), keepdim=True)
        log_mel = (log_mel - mean) / (std + 1e-8)
        
        return log_mel
    
    @torch.no_grad()
    def extract_batch(
        self, 
        waveforms: list,
        max_length: int = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Extract features from variable-length batch with padding.
        
        Returns:
            features: (batch, n_mels, max_time)
            lengths: (batch,) actual lengths before padding
        """
        
        # Get device
        device = next(self.parameters()).device if list(self.parameters()) else 'cpu'
        
        # Extract features
        features_list = []
        lengths = []
        
        for wf in waveforms:
            if isinstance(wf, np.ndarray):
                wf = torch.from_numpy(wf)
            wf = wf.to(device)
            
            if wf.dim() == 1:
                wf = wf.unsqueeze(0)
            
            feat = self.forward(wf)
            features_list.append(feat.squeeze(0))
            lengths.append(feat.shape[-1])
        
        # Pad to max length
        max_len = max_length or max(lengths)
        batch_features = torch.zeros(
            len(features_list),
            features_list[0].shape[0],
            max_len,
            device=device
        )
        
        for i, feat in enumerate(features_list):
            batch_features[i, :, :feat.shape[-1]] = feat
        
        return batch_features, torch.tensor(lengths, device=device)


# GPU training loop example
def training_step(model, batch_waveforms, labels, feature_extractor):
    """Training step with GPU-accelerated feature extraction."""
    
    # Extract features on GPU
    features, lengths = feature_extractor.extract_batch(batch_waveforms)
    
    # Forward pass
    logits = model(features, lengths)
    
    # Loss computation
    loss = compute_ctc_loss(logits, labels, lengths)
    
    return loss

35.1.4. Advanced Feature Representations

MFCC (Mel-Frequency Cepstral Coefficients)

Traditional ASR feature that decorrelates Mel filterbank outputs:

class MFCCExtractor:
    """Extract MFCC features with optional deltas."""
    
    def __init__(
        self,
        sample_rate: int = 16000,
        n_mfcc: int = 13,
        n_mels: int = 40,
        n_fft: int = 2048,
        hop_length: int = 512
    ):
        self.sample_rate = sample_rate
        self.n_mfcc = n_mfcc
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
    
    def extract(
        self, 
        audio: np.ndarray,
        include_deltas: bool = True
    ) -> np.ndarray:
        """Extract MFCC with optional delta features."""
        
        # Base MFCCs
        mfcc = librosa.feature.mfcc(
            y=audio,
            sr=self.sample_rate,
            n_mfcc=self.n_mfcc,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length
        )
        
        if include_deltas:
            # Delta (velocity)
            delta = librosa.feature.delta(mfcc, order=1)
            # Delta-delta (acceleration)
            delta2 = librosa.feature.delta(mfcc, order=2)
            
            # Stack: (39, time) for n_mfcc=13
            mfcc = np.concatenate([mfcc, delta, delta2], axis=0)
        
        return mfcc
    
    def compare_with_mel(self, audio: np.ndarray):
        """Compare MFCC vs Mel spectrogram features."""
        
        # Mel spectrogram
        mel = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sample_rate,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            hop_length=self.hop_length
        )
        log_mel = librosa.power_to_db(mel)
        
        # MFCC
        mfcc = self.extract(audio, include_deltas=False)
        
        return {
            'mel_shape': log_mel.shape,      # (n_mels, time)
            'mfcc_shape': mfcc.shape,        # (n_mfcc, time)
            'mel_correlated': True,          # Adjacent bands correlated
            'mfcc_decorrelated': True        # DCT removes correlation
        }

Wav2Vec 2.0 Embeddings

Modern approach using self-supervised representations:

import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model


class Wav2VecFeatureExtractor:
    """
    Extract contextual representations from Wav2Vec 2.0.
    
    Advantages:
    - Pre-trained on 60k hours of unlabeled audio
    - Captures high-level phonetic information
    - State-of-the-art for low-resource ASR
    
    Disadvantages:
    - Computationally expensive
    - Large model size (~300MB)
    """
    
    def __init__(
        self,
        model_name: str = "facebook/wav2vec2-base-960h",
        layer: int = -1  # Which layer to extract from
    ):
        self.processor = Wav2Vec2Processor.from_pretrained(model_name)
        self.model = Wav2Vec2Model.from_pretrained(model_name)
        self.model.eval()
        self.layer = layer
    
    @torch.no_grad()
    def extract(self, audio: np.ndarray, sample_rate: int = 16000) -> np.ndarray:
        """Extract Wav2Vec representations."""
        
        # Ensure 16kHz
        if sample_rate != 16000:
            audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
        
        # Process input
        inputs = self.processor(
            audio,
            sampling_rate=16000,
            return_tensors="pt"
        )
        
        # Extract features
        outputs = self.model(
            inputs.input_values,
            output_hidden_states=True
        )
        
        # Get specified layer
        if self.layer == -1:
            features = outputs.last_hidden_state
        else:
            features = outputs.hidden_states[self.layer]
        
        return features.squeeze(0).numpy()  # (time, 768)
    
    def get_feature_dimensions(self) -> dict:
        """Get feature dimension information."""
        return {
            'hidden_size': self.model.config.hidden_size,  # 768
            'num_layers': self.model.config.num_hidden_layers,  # 12
            'output_rate': 50,  # 50 frames per second
        }

35.1.5. Data Augmentation for Audio

Audio models overfit easily. Augmentation is critical for generalization.

Time-Domain Augmentations

import torch
import torchaudio.transforms as T
import numpy as np
from typing import Tuple


class WaveformAugmentor(torch.nn.Module):
    """
    Time-domain augmentations applied to raw waveform.
    
    Apply before feature extraction for maximum effectiveness.
    """
    
    def __init__(
        self,
        sample_rate: int = 16000,
        noise_snr_range: Tuple[float, float] = (5, 20),
        speed_range: Tuple[float, float] = (0.9, 1.1),
        pitch_shift_range: Tuple[int, int] = (-2, 2),
        enable_noise: bool = True,
        enable_speed: bool = True,
        enable_pitch: bool = True
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.noise_snr_range = noise_snr_range
        self.speed_range = speed_range
        self.pitch_shift_range = pitch_shift_range
        
        self.enable_noise = enable_noise
        self.enable_speed = enable_speed
        self.enable_pitch = enable_pitch
        
        # Pre-load noise samples for efficiency
        self.noise_samples = self._load_noise_samples()
    
    def _load_noise_samples(self):
        """Load background noise samples for mixing."""
        # In production, load from MUSAN or similar dataset
        return {
            'white': torch.randn(self.sample_rate * 10),
            'pink': self._generate_pink_noise(self.sample_rate * 10),
        }
    
    def _generate_pink_noise(self, samples: int) -> torch.Tensor:
        """Generate pink (1/f) noise."""
        white = torch.randn(samples)
        # Simple approximation via filtering
        pink = torch.nn.functional.conv1d(
            white.unsqueeze(0).unsqueeze(0),
            torch.ones(1, 1, 3) / 3,
            padding=1
        ).squeeze()
        return pink
    
    def add_noise(
        self, 
        waveform: torch.Tensor,
        snr_db: float = None
    ) -> torch.Tensor:
        """Add background noise at specified SNR."""
        
        if snr_db is None:
            snr_db = np.random.uniform(*self.noise_snr_range)
        
        # Select random noise type
        noise_type = np.random.choice(list(self.noise_samples.keys()))
        noise = self.noise_samples[noise_type]
        
        # Repeat/truncate noise to match length
        if len(noise) < len(waveform):
            repeats = len(waveform) // len(noise) + 1
            noise = noise.repeat(repeats)
        noise = noise[:len(waveform)]
        
        # Calculate scaling for target SNR
        signal_power = (waveform ** 2).mean()
        noise_power = (noise ** 2).mean()
        
        snr_linear = 10 ** (snr_db / 10)
        scale = torch.sqrt(signal_power / (snr_linear * noise_power))
        
        return waveform + scale * noise
    
    def change_speed(
        self, 
        waveform: torch.Tensor,
        factor: float = None
    ) -> torch.Tensor:
        """Change playback speed without pitch change."""
        
        if factor is None:
            factor = np.random.uniform(*self.speed_range)
        
        # Resample to change speed
        orig_freq = self.sample_rate
        new_freq = int(self.sample_rate * factor)
        
        resampler = T.Resample(orig_freq, new_freq)
        stretched = resampler(waveform)
        
        # Resample back to original rate
        restore = T.Resample(new_freq, orig_freq)
        return restore(stretched)
    
    def shift_pitch(
        self, 
        waveform: torch.Tensor,
        steps: int = None
    ) -> torch.Tensor:
        """Shift pitch by semitones."""
        
        if steps is None:
            steps = np.random.randint(*self.pitch_shift_range)
        
        # Convert to numpy for librosa processing
        audio_np = waveform.numpy()
        shifted = librosa.effects.pitch_shift(
            audio_np,
            sr=self.sample_rate,
            n_steps=steps
        )
        return torch.from_numpy(shifted)
    
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """Apply random augmentations."""
        
        # Apply each augmentation with 50% probability
        if self.enable_noise and np.random.random() < 0.5:
            waveform = self.add_noise(waveform)
        
        if self.enable_speed and np.random.random() < 0.5:
            waveform = self.change_speed(waveform)
        
        if self.enable_pitch and np.random.random() < 0.5:
            waveform = self.shift_pitch(waveform)
        
        return waveform

Spectrogram-Domain Augmentations (SpecAugment)

class SpecAugmentor(torch.nn.Module):
    """
    SpecAugment: A Simple Augmentation Method (Google Brain, 2019)
    
    Applied AFTER feature extraction on the spectrogram.
    
    Key insight: Masking forces the model to rely on context,
    improving robustness to missing information.
    """
    
    def __init__(
        self,
        freq_mask_param: int = 27,     # Maximum frequency mask width
        time_mask_param: int = 100,    # Maximum time mask width
        num_freq_masks: int = 2,       # Number of frequency masks
        num_time_masks: int = 2,       # Number of time masks
        replace_with_zero: bool = False  # False = mean value
    ):
        super().__init__()
        
        self.freq_masking = T.FrequencyMasking(freq_mask_param)
        self.time_masking = T.TimeMasking(time_mask_param)
        
        self.num_freq_masks = num_freq_masks
        self.num_time_masks = num_time_masks
        self.replace_with_zero = replace_with_zero
    
    def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
        """
        Apply SpecAugment to spectrogram.
        
        Args:
            spectrogram: (batch, freq, time) or (freq, time)
            
        Returns:
            Augmented spectrogram with same shape
        """
        
        # Calculate replacement value
        if not self.replace_with_zero:
            mask_value = spectrogram.mean()
        else:
            mask_value = 0.0
        
        # Apply frequency masks
        for _ in range(self.num_freq_masks):
            spectrogram = self.freq_masking(spectrogram, mask_value)
        
        # Apply time masks
        for _ in range(self.num_time_masks):
            spectrogram = self.time_masking(spectrogram, mask_value)
        
        return spectrogram


class AdvancedSpecAugment(torch.nn.Module):
    """
    Advanced SpecAugment with adaptive parameters.
    
    Implements SpecAugment++ with frequency warping.
    """
    
    def __init__(
        self,
        freq_mask_range: Tuple[int, int] = (0, 27),
        time_mask_range: Tuple[int, int] = (0, 100),
        num_masks_range: Tuple[int, int] = (1, 3),
        warp_window: int = 80
    ):
        super().__init__()
        self.freq_mask_range = freq_mask_range
        self.time_mask_range = time_mask_range
        self.num_masks_range = num_masks_range
        self.warp_window = warp_window
    
    def time_warp(self, spectrogram: torch.Tensor) -> torch.Tensor:
        """Apply time warping (non-linear time stretching)."""
        
        batch, freq, time = spectrogram.shape
        
        if time < self.warp_window * 2:
            return spectrogram
        
        # Random warp point
        center = time // 2
        warp_distance = np.random.randint(-self.warp_window, self.warp_window)
        
        # Create warped indices
        left_indices = torch.linspace(0, center + warp_distance, center).long()
        right_indices = torch.linspace(center + warp_distance, time - 1, time - center).long()
        indices = torch.cat([left_indices, right_indices])
        
        # Apply warping
        warped = spectrogram[:, :, indices.clamp(0, time - 1)]
        
        return warped
    
    def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
        """Apply advanced SpecAugment."""
        
        # Ensure batch dimension
        squeeze = False
        if spectrogram.dim() == 2:
            spectrogram = spectrogram.unsqueeze(0)
            squeeze = True
        
        batch, freq, time = spectrogram.shape
        
        # Time warping
        if np.random.random() < 0.5:
            spectrogram = self.time_warp(spectrogram)
        
        # Adaptive frequency masks
        num_freq_masks = np.random.randint(*self.num_masks_range)
        for _ in range(num_freq_masks):
            width = np.random.randint(*self.freq_mask_range)
            start = np.random.randint(0, max(1, freq - width))
            spectrogram[:, start:start + width, :] = spectrogram.mean()
        
        # Adaptive time masks
        num_time_masks = np.random.randint(*self.num_masks_range)
        for _ in range(num_time_masks):
            width = np.random.randint(*self.time_mask_range)
            width = min(width, time // 4)  # Limit to 25% of time
            start = np.random.randint(0, max(1, time - width))
            spectrogram[:, :, start:start + width] = spectrogram.mean()
        
        if squeeze:
            spectrogram = spectrogram.squeeze(0)
        
        return spectrogram

35.1.6. Handling Variable-Length Sequences

Audio samples have different durations. Batching requires careful handling.

Padding and Masking Strategy

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from typing import List, Dict


class AudioDataset(Dataset):
    """Audio dataset with variable-length handling."""
    
    def __init__(
        self,
        audio_paths: List[str],
        labels: List[str],
        feature_extractor: AudioFeatureExtractor,
        max_length_seconds: float = 30.0
    ):
        self.audio_paths = audio_paths
        self.labels = labels
        self.feature_extractor = feature_extractor
        self.max_samples = int(max_length_seconds * 16000)
    
    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        # Load audio
        audio, sr = self.feature_extractor.load_audio(self.audio_paths[idx])
        
        # Truncate if too long
        if len(audio) > self.max_samples:
            audio = audio[:self.max_samples]
        
        # Extract features
        features = self.feature_extractor.extract_features(audio)
        
        return {
            'features': torch.from_numpy(features).float(),
            'length': features.shape[1],
            'label': self.labels[idx]
        }


def audio_collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """
    Custom collate function for variable-length audio.
    
    Returns:
        features: (batch, freq, max_time)
        lengths: (batch,)
        labels: List of labels
    """
    
    features = [item['features'] for item in batch]
    lengths = torch.tensor([item['length'] for item in batch])
    labels = [item['label'] for item in batch]
    
    # Pad to max length in batch
    max_len = max(f.shape[1] for f in features)
    
    padded = torch.zeros(len(features), features[0].shape[0], max_len)
    
    for i, feat in enumerate(features):
        padded[i, :, :feat.shape[1]] = feat
    
    # Create attention mask
    mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1)
    
    return {
        'features': padded,
        'lengths': lengths,
        'attention_mask': mask,
        'labels': labels
    }


class LengthBucketingSampler:
    """
    Bucket samples by length for efficient batching.
    
    Minimizes padding by grouping similar-length samples.
    Training speedup: 20-30%
    """
    
    def __init__(
        self,
        lengths: List[int],
        batch_size: int,
        num_buckets: int = 10
    ):
        self.lengths = lengths
        self.batch_size = batch_size
        self.num_buckets = num_buckets
        
        # Create length-sorted indices
        self.sorted_indices = np.argsort(lengths)
        
        # Create buckets
        self.buckets = np.array_split(self.sorted_indices, num_buckets)
    
    def __iter__(self):
        # Shuffle within buckets
        for bucket in self.buckets:
            np.random.shuffle(bucket)
        
        # Yield batches
        all_indices = np.concatenate(self.buckets)
        
        for i in range(0, len(all_indices), self.batch_size):
            yield all_indices[i:i + self.batch_size].tolist()
    
    def __len__(self):
        return (len(self.lengths) + self.batch_size - 1) // self.batch_size

35.1.7. Storage and Data Loading at Scale

WebDataset for Large-Scale Training

import webdataset as wds
from pathlib import Path
import tarfile
import json


def create_audio_shards(
    audio_files: List[str],
    labels: List[str],
    output_dir: str,
    shard_size: int = 10000
):
    """
    Create WebDataset shards for efficient loading.
    
    Benefits:
    - Sequential reads instead of random IO
    - Works with cloud storage (S3, GCS)
    - Streaming without full download
    """
    
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    shard_idx = 0
    sample_idx = 0
    
    current_tar = None
    
    for audio_path, label in zip(audio_files, labels):
        # Start new shard if needed
        if sample_idx % shard_size == 0:
            if current_tar:
                current_tar.close()
            
            shard_name = f"shard-{shard_idx:06d}.tar"
            current_tar = tarfile.open(output_path / shard_name, "w")
            shard_idx += 1
        
        # Create sample key
        key = f"{sample_idx:08d}"
        
        # Add audio file
        with open(audio_path, 'rb') as f:
            audio_bytes = f.read()
        
        audio_info = tarfile.TarInfo(name=f"{key}.wav")
        audio_info.size = len(audio_bytes)
        current_tar.addfile(audio_info, fileobj=io.BytesIO(audio_bytes))
        
        # Add metadata
        metadata = json.dumps({'label': label, 'path': audio_path})
        meta_bytes = metadata.encode('utf-8')
        
        meta_info = tarfile.TarInfo(name=f"{key}.json")
        meta_info.size = len(meta_bytes)
        current_tar.addfile(meta_info, fileobj=io.BytesIO(meta_bytes))
        
        sample_idx += 1
    
    if current_tar:
        current_tar.close()
    
    print(f"Created {shard_idx} shards with {sample_idx} samples")


def create_webdataset_loader(
    shard_pattern: str,
    batch_size: int = 32,
    num_workers: int = 4,
    shuffle_buffer: int = 1000
):
    """Create streaming WebDataset loader."""
    
    def decode_audio(sample):
        """Decode audio from bytes."""
        audio_bytes = sample['wav']
        audio, sr = torchaudio.load(io.BytesIO(audio_bytes))
        
        # Resample if needed
        if sr != 16000:
            resampler = T.Resample(sr, 16000)
            audio = resampler(audio)
        
        metadata = json.loads(sample['json'])
        
        return {
            'audio': audio.squeeze(0),
            'label': metadata['label']
        }
    
    dataset = (
        wds.WebDataset(shard_pattern)
        .shuffle(shuffle_buffer)
        .map(decode_audio)
        .batched(batch_size, collation_fn=audio_collate_fn)
    )
    
    loader = wds.WebLoader(
        dataset,
        num_workers=num_workers,
        batch_size=None  # Batching done in dataset
    )
    
    return loader


# Usage
loader = create_webdataset_loader(
    "s3://my-bucket/audio-shards/shard-{000000..000100}.tar",
    batch_size=32
)

for batch in loader:
    features = batch['features']
    labels = batch['labels']
    # Training step...

35.1.8. Production Serving Pipeline

Triton Inference Server Configuration

# config.pbtxt for audio feature extraction ensemble

name: "audio_feature_ensemble"
platform: "ensemble"
max_batch_size: 32

input [
    {
        name: "AUDIO_BYTES"
        data_type: TYPE_UINT8
        dims: [ -1 ]
    }
]

output [
    {
        name: "FEATURES"
        data_type: TYPE_FP32
        dims: [ 80, -1 ]
    }
]

ensemble_scheduling {
    step [
        {
            model_name: "audio_decoder"
            model_version: 1
            input_map {
                key: "BYTES"
                value: "AUDIO_BYTES"
            }
            output_map {
                key: "WAVEFORM"
                value: "decoded_audio"
            }
        },
        {
            model_name: "feature_extractor"
            model_version: 1
            input_map {
                key: "AUDIO"
                value: "decoded_audio"
            }
            output_map {
                key: "MEL_SPECTROGRAM"
                value: "FEATURES"
            }
        }
    ]
}

ONNX Export for Feature Extraction

import torch
import onnx
import onnxruntime as ort


def export_feature_extractor_onnx(
    extractor: TorchAudioFeatureExtractor,
    output_path: str,
    max_audio_length: int = 160000  # 10 seconds at 16kHz
):
    """Export feature extractor to ONNX for production serving."""
    
    extractor.eval()
    
    # Create dummy input
    dummy_input = torch.randn(1, max_audio_length)
    
    # Export
    torch.onnx.export(
        extractor,
        dummy_input,
        output_path,
        input_names=['audio'],
        output_names=['features'],
        dynamic_axes={
            'audio': {0: 'batch', 1: 'samples'},
            'features': {0: 'batch', 2: 'time'}
        },
        opset_version=14
    )
    
    # Verify export
    onnx_model = onnx.load(output_path)
    onnx.checker.check_model(onnx_model)
    
    # Test inference
    session = ort.InferenceSession(output_path)
    test_input = np.random.randn(1, 16000).astype(np.float32)
    output = session.run(None, {'audio': test_input})
    
    print(f"Exported to {output_path}")
    print(f"Output shape: {output[0].shape}")
    
    return output_path


class ONNXFeatureExtractor:
    """Production ONNX feature extractor."""
    
    def __init__(self, model_path: str, use_gpu: bool = True):
        providers = ['CUDAExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
        self.session = ort.InferenceSession(model_path, providers=providers)
    
    def extract(self, audio: np.ndarray) -> np.ndarray:
        """Extract features using ONNX runtime."""
        
        if audio.ndim == 1:
            audio = audio.reshape(1, -1)
        
        audio = audio.astype(np.float32)
        
        features = self.session.run(
            None,
            {'audio': audio}
        )[0]
        
        return features

35.1.9. Cloud-Specific Implementations

AWS SageMaker Processing

# sagemaker_processor.py

from sagemaker.processing import ScriptProcessor
from sagemaker.pytorch import PyTorchProcessor


def create_feature_extraction_job(
    input_s3_uri: str,
    output_s3_uri: str,
    role: str
):
    """Create SageMaker Processing job for batch feature extraction."""
    
    processor = PyTorchProcessor(
        role=role,
        instance_count=4,
        instance_type="ml.g4dn.xlarge",
        framework_version="2.0",
        py_version="py310"
    )
    
    processor.run(
        code="extract_features.py",
        source_dir="./processing_scripts",
        inputs=[
            ProcessingInput(
                source=input_s3_uri,
                destination="/opt/ml/processing/input"
            )
        ],
        outputs=[
            ProcessingOutput(
                source="/opt/ml/processing/output",
                destination=output_s3_uri
            )
        ],
        arguments=[
            "--sample-rate", "16000",
            "--n-mels", "80",
            "--batch-size", "64"
        ]
    )

GCP Vertex AI Pipeline

from kfp.v2 import dsl
from kfp.v2.dsl import component


@component(
    packages_to_install=["librosa", "torch", "torchaudio"],
    base_image="python:3.10"
)
def extract_audio_features(
    input_gcs_path: str,
    output_gcs_path: str,
    sample_rate: int = 16000,
    n_mels: int = 80
):
    """Vertex AI component for audio feature extraction."""
    
    from google.cloud import storage
    import librosa
    import numpy as np
    import os
    
    # ... feature extraction logic
    pass


@dsl.pipeline(
    name="audio-feature-pipeline",
    description="Extract audio features at scale"
)
def audio_pipeline(
    input_path: str,
    output_path: str
):
    extract_task = extract_audio_features(
        input_gcs_path=input_path,
        output_gcs_path=output_path
    ).set_gpu_limit(1).set_memory_limit("16G")

35.1.10. Summary Checklist for Audio Feature Extraction

Data Preparation

  • Standardize sample rate (16kHz for ASR, 44.1kHz for music)
  • Convert to mono unless stereo is required
  • Handle variable lengths with padding/truncation
  • Create efficient storage format (WebDataset)

Feature Extraction

  • Choose appropriate feature type (Log-Mel vs MFCC vs Wav2Vec)
  • Configure FFT parameters for use case
  • Implement GPU acceleration for training
  • Export ONNX model for production

Augmentation

  • Apply time-domain augmentations (noise, speed, pitch)
  • Apply SpecAugment during training
  • Use length bucketing for efficient batching

Production

  • Ensure preprocessing consistency (train == inference)
  • Set up Triton or custom serving pipeline
  • Monitor feature distributions for drift

[End of Section 35.1]