35.1. Audio Feature Extraction: The Spectrogram Pipeline
Note
Waveforms vs. Spectrograms: A Neural Network cannot “hear” a raw wav file ($16,000$ samples/sec). It is too noisy. We must convert time-domain signals into frequency-domain images (Spectrograms) so standard CNNs can “see” the sound.
Audio machine learning requires careful feature engineering to bridge the gap between raw waveforms and the tensor representations that neural networks understand. This chapter covers the complete pipeline from audio capture to production-ready feature representations.
35.1.1. Understanding Audio Fundamentals
The Audio Signal
Raw audio is a 1D time-series signal representing air pressure changes over time:
| Property | Typical Value | Notes |
|---|---|---|
| Sample Rate | 16,000 Hz (ASR), 44,100 Hz (Music) | Samples per second |
| Bit Depth | 16-bit, 32-bit float | Dynamic range |
| Channels | 1 (Mono), 2 (Stereo) | Spatial dimensions |
| Format | WAV, FLAC, MP3, Opus | Compression type |
The Nyquist Theorem
To capture frequency $f$, you need sample rate $\geq 2f$:
- Human speech: ~8kHz max → 16kHz sample rate sufficient
- Music: ~20kHz max → 44.1kHz sample rate required
import numpy as np
import librosa
def demonstrate_nyquist():
"""Demonstrate aliasing when Nyquist is violated."""
# Generate 8kHz tone
sr = 44100 # High sample rate
duration = 1.0
t = np.linspace(0, duration, int(sr * duration))
tone_8k = np.sin(2 * np.pi * 8000 * t)
# Downsample to 16kHz (Nyquist = 8kHz, just barely sufficient)
tone_16k = librosa.resample(tone_8k, orig_sr=sr, target_sr=16000)
# Downsample to 12kHz (Nyquist = 6kHz, aliasing occurs)
tone_12k = librosa.resample(tone_8k, orig_sr=sr, target_sr=12000)
return {
'original': tone_8k,
'resampled_ok': tone_16k,
'aliased': tone_12k
}
35.1.2. The Standard Feature Extraction Pipeline
graph LR
A[Raw Audio] --> B[Pre-emphasis]
B --> C[Framing]
C --> D[Windowing]
D --> E[STFT]
E --> F[Power Spectrum]
F --> G[Mel Filterbank]
G --> H[Log Compression]
H --> I[Delta Features]
I --> J[Normalization]
J --> K[Output Tensor]
Step-by-Step Breakdown
- Raw Audio: 1D Array (
float32, [-1, 1]). - Pre-emphasis: High-pass filter to boost high frequencies.
- Framing: Cutting into 25ms windows with 10ms overlap.
- Windowing: Applying Hamming window to reduce spectral leakage.
- STFT (Short-Time Fourier Transform): Power Spectrum.
- Mel Filterbank: Mapping linear Hz to human-perceived “Mel” scale.
- Log: Compressing dynamic range (decibels).
- Delta Features: First and second derivatives (optional).
- Normalization: Zero-mean, unit-variance normalization.
35.1.3. Complete Feature Extraction Implementation
Librosa: Research-Grade Implementation
import librosa
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple
import soundfile as sf
@dataclass
class AudioFeatureConfig:
"""Configuration for audio feature extraction."""
sample_rate: int = 16000
n_fft: int = 2048 # FFT window size
hop_length: int = 512 # Hop between frames
n_mels: int = 128 # Number of Mel bands
fmin: float = 0.0 # Minimum frequency
fmax: Optional[float] = 8000.0 # Maximum frequency
pre_emphasis: float = 0.97 # Pre-emphasis coefficient
normalize: bool = True # Whether to normalize output
add_deltas: bool = False # Add delta features
class AudioFeatureExtractor:
"""Production-ready audio feature extractor."""
def __init__(self, config: AudioFeatureConfig = None):
self.config = config or AudioFeatureConfig()
self._mel_basis = None
self._setup_mel_basis()
def _setup_mel_basis(self):
"""Pre-compute Mel filterbank for efficiency."""
self._mel_basis = librosa.filters.mel(
sr=self.config.sample_rate,
n_fft=self.config.n_fft,
n_mels=self.config.n_mels,
fmin=self.config.fmin,
fmax=self.config.fmax
)
def load_audio(
self,
path: str,
mono: bool = True
) -> Tuple[np.ndarray, int]:
"""Load audio file with consistent format."""
y, sr = librosa.load(
path,
sr=self.config.sample_rate,
mono=mono
)
return y, sr
def extract_features(self, audio: np.ndarray) -> np.ndarray:
"""Extract log-mel spectrogram features."""
# Step 1: Pre-emphasis
if self.config.pre_emphasis > 0:
audio = np.append(
audio[0],
audio[1:] - self.config.pre_emphasis * audio[:-1]
)
# Step 2: STFT
stft = librosa.stft(
audio,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
window='hann',
center=True,
pad_mode='reflect'
)
# Step 3: Power spectrum
power_spec = np.abs(stft) ** 2
# Step 4: Mel filterbank
mel_spec = np.dot(self._mel_basis, power_spec)
# Step 5: Log compression
log_mel = librosa.power_to_db(
mel_spec,
ref=np.max,
top_db=80.0
)
# Step 6: Delta features (optional)
if self.config.add_deltas:
delta = librosa.feature.delta(log_mel, order=1)
delta2 = librosa.feature.delta(log_mel, order=2)
log_mel = np.concatenate([log_mel, delta, delta2], axis=0)
# Step 7: Normalization
if self.config.normalize:
log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-8)
return log_mel # Shape: (n_mels * (1 + 2*add_deltas), time_steps)
def extract_from_file(self, path: str) -> np.ndarray:
"""Convenience method for file-based extraction."""
audio, _ = self.load_audio(path)
return self.extract_features(audio)
# Example usage
extractor = AudioFeatureExtractor(AudioFeatureConfig(
sample_rate=16000,
n_mels=80,
add_deltas=True,
normalize=True
))
features = extractor.extract_from_file("speech.wav")
print(f"Feature shape: {features.shape}") # (240, T) with deltas
Torchaudio: GPU-Accelerated Production
import torch
import torchaudio
from torchaudio import transforms as T
from typing import Tuple
class TorchAudioFeatureExtractor(torch.nn.Module):
"""
GPU-accelerated feature extraction for training loops.
Key Advantages:
- Runs on GPU alongside model
- Differentiable (for end-to-end training)
- Batched processing
"""
def __init__(
self,
sample_rate: int = 16000,
n_mels: int = 80,
n_fft: int = 1024,
hop_length: int = 256,
f_min: float = 0.0,
f_max: float = 8000.0
):
super().__init__()
self.sample_rate = sample_rate
# Pre-emphasis filter
self.register_buffer(
'pre_emphasis_filter',
torch.FloatTensor([[-0.97, 1]])
)
# Mel spectrogram transform
self.mel_spectrogram = T.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
power=2.0,
normalized=False,
mel_scale='htk'
)
# Amplitude to dB
self.amplitude_to_db = T.AmplitudeToDB(
stype='power',
top_db=80.0
)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Extract features from waveform.
Args:
waveform: (batch, samples) or (batch, channels, samples)
Returns:
features: (batch, n_mels, time)
"""
# Ensure 2D: (batch, samples)
if waveform.dim() == 3:
waveform = waveform.mean(dim=1) # Mix to mono
# Pre-emphasis
waveform = torch.nn.functional.conv1d(
waveform.unsqueeze(1),
self.pre_emphasis_filter.unsqueeze(0),
padding=1
).squeeze(1)[:, :-1]
# Mel spectrogram
mel_spec = self.mel_spectrogram(waveform)
# Log scale
log_mel = self.amplitude_to_db(mel_spec)
# Instance normalization
mean = log_mel.mean(dim=(1, 2), keepdim=True)
std = log_mel.std(dim=(1, 2), keepdim=True)
log_mel = (log_mel - mean) / (std + 1e-8)
return log_mel
@torch.no_grad()
def extract_batch(
self,
waveforms: list,
max_length: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Extract features from variable-length batch with padding.
Returns:
features: (batch, n_mels, max_time)
lengths: (batch,) actual lengths before padding
"""
# Get device
device = next(self.parameters()).device if list(self.parameters()) else 'cpu'
# Extract features
features_list = []
lengths = []
for wf in waveforms:
if isinstance(wf, np.ndarray):
wf = torch.from_numpy(wf)
wf = wf.to(device)
if wf.dim() == 1:
wf = wf.unsqueeze(0)
feat = self.forward(wf)
features_list.append(feat.squeeze(0))
lengths.append(feat.shape[-1])
# Pad to max length
max_len = max_length or max(lengths)
batch_features = torch.zeros(
len(features_list),
features_list[0].shape[0],
max_len,
device=device
)
for i, feat in enumerate(features_list):
batch_features[i, :, :feat.shape[-1]] = feat
return batch_features, torch.tensor(lengths, device=device)
# GPU training loop example
def training_step(model, batch_waveforms, labels, feature_extractor):
"""Training step with GPU-accelerated feature extraction."""
# Extract features on GPU
features, lengths = feature_extractor.extract_batch(batch_waveforms)
# Forward pass
logits = model(features, lengths)
# Loss computation
loss = compute_ctc_loss(logits, labels, lengths)
return loss
35.1.4. Advanced Feature Representations
MFCC (Mel-Frequency Cepstral Coefficients)
Traditional ASR feature that decorrelates Mel filterbank outputs:
class MFCCExtractor:
"""Extract MFCC features with optional deltas."""
def __init__(
self,
sample_rate: int = 16000,
n_mfcc: int = 13,
n_mels: int = 40,
n_fft: int = 2048,
hop_length: int = 512
):
self.sample_rate = sample_rate
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
def extract(
self,
audio: np.ndarray,
include_deltas: bool = True
) -> np.ndarray:
"""Extract MFCC with optional delta features."""
# Base MFCCs
mfcc = librosa.feature.mfcc(
y=audio,
sr=self.sample_rate,
n_mfcc=self.n_mfcc,
n_mels=self.n_mels,
n_fft=self.n_fft,
hop_length=self.hop_length
)
if include_deltas:
# Delta (velocity)
delta = librosa.feature.delta(mfcc, order=1)
# Delta-delta (acceleration)
delta2 = librosa.feature.delta(mfcc, order=2)
# Stack: (39, time) for n_mfcc=13
mfcc = np.concatenate([mfcc, delta, delta2], axis=0)
return mfcc
def compare_with_mel(self, audio: np.ndarray):
"""Compare MFCC vs Mel spectrogram features."""
# Mel spectrogram
mel = librosa.feature.melspectrogram(
y=audio,
sr=self.sample_rate,
n_mels=self.n_mels,
n_fft=self.n_fft,
hop_length=self.hop_length
)
log_mel = librosa.power_to_db(mel)
# MFCC
mfcc = self.extract(audio, include_deltas=False)
return {
'mel_shape': log_mel.shape, # (n_mels, time)
'mfcc_shape': mfcc.shape, # (n_mfcc, time)
'mel_correlated': True, # Adjacent bands correlated
'mfcc_decorrelated': True # DCT removes correlation
}
Wav2Vec 2.0 Embeddings
Modern approach using self-supervised representations:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
class Wav2VecFeatureExtractor:
"""
Extract contextual representations from Wav2Vec 2.0.
Advantages:
- Pre-trained on 60k hours of unlabeled audio
- Captures high-level phonetic information
- State-of-the-art for low-resource ASR
Disadvantages:
- Computationally expensive
- Large model size (~300MB)
"""
def __init__(
self,
model_name: str = "facebook/wav2vec2-base-960h",
layer: int = -1 # Which layer to extract from
):
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = Wav2Vec2Model.from_pretrained(model_name)
self.model.eval()
self.layer = layer
@torch.no_grad()
def extract(self, audio: np.ndarray, sample_rate: int = 16000) -> np.ndarray:
"""Extract Wav2Vec representations."""
# Ensure 16kHz
if sample_rate != 16000:
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
# Process input
inputs = self.processor(
audio,
sampling_rate=16000,
return_tensors="pt"
)
# Extract features
outputs = self.model(
inputs.input_values,
output_hidden_states=True
)
# Get specified layer
if self.layer == -1:
features = outputs.last_hidden_state
else:
features = outputs.hidden_states[self.layer]
return features.squeeze(0).numpy() # (time, 768)
def get_feature_dimensions(self) -> dict:
"""Get feature dimension information."""
return {
'hidden_size': self.model.config.hidden_size, # 768
'num_layers': self.model.config.num_hidden_layers, # 12
'output_rate': 50, # 50 frames per second
}
35.1.5. Data Augmentation for Audio
Audio models overfit easily. Augmentation is critical for generalization.
Time-Domain Augmentations
import torch
import torchaudio.transforms as T
import numpy as np
from typing import Tuple
class WaveformAugmentor(torch.nn.Module):
"""
Time-domain augmentations applied to raw waveform.
Apply before feature extraction for maximum effectiveness.
"""
def __init__(
self,
sample_rate: int = 16000,
noise_snr_range: Tuple[float, float] = (5, 20),
speed_range: Tuple[float, float] = (0.9, 1.1),
pitch_shift_range: Tuple[int, int] = (-2, 2),
enable_noise: bool = True,
enable_speed: bool = True,
enable_pitch: bool = True
):
super().__init__()
self.sample_rate = sample_rate
self.noise_snr_range = noise_snr_range
self.speed_range = speed_range
self.pitch_shift_range = pitch_shift_range
self.enable_noise = enable_noise
self.enable_speed = enable_speed
self.enable_pitch = enable_pitch
# Pre-load noise samples for efficiency
self.noise_samples = self._load_noise_samples()
def _load_noise_samples(self):
"""Load background noise samples for mixing."""
# In production, load from MUSAN or similar dataset
return {
'white': torch.randn(self.sample_rate * 10),
'pink': self._generate_pink_noise(self.sample_rate * 10),
}
def _generate_pink_noise(self, samples: int) -> torch.Tensor:
"""Generate pink (1/f) noise."""
white = torch.randn(samples)
# Simple approximation via filtering
pink = torch.nn.functional.conv1d(
white.unsqueeze(0).unsqueeze(0),
torch.ones(1, 1, 3) / 3,
padding=1
).squeeze()
return pink
def add_noise(
self,
waveform: torch.Tensor,
snr_db: float = None
) -> torch.Tensor:
"""Add background noise at specified SNR."""
if snr_db is None:
snr_db = np.random.uniform(*self.noise_snr_range)
# Select random noise type
noise_type = np.random.choice(list(self.noise_samples.keys()))
noise = self.noise_samples[noise_type]
# Repeat/truncate noise to match length
if len(noise) < len(waveform):
repeats = len(waveform) // len(noise) + 1
noise = noise.repeat(repeats)
noise = noise[:len(waveform)]
# Calculate scaling for target SNR
signal_power = (waveform ** 2).mean()
noise_power = (noise ** 2).mean()
snr_linear = 10 ** (snr_db / 10)
scale = torch.sqrt(signal_power / (snr_linear * noise_power))
return waveform + scale * noise
def change_speed(
self,
waveform: torch.Tensor,
factor: float = None
) -> torch.Tensor:
"""Change playback speed without pitch change."""
if factor is None:
factor = np.random.uniform(*self.speed_range)
# Resample to change speed
orig_freq = self.sample_rate
new_freq = int(self.sample_rate * factor)
resampler = T.Resample(orig_freq, new_freq)
stretched = resampler(waveform)
# Resample back to original rate
restore = T.Resample(new_freq, orig_freq)
return restore(stretched)
def shift_pitch(
self,
waveform: torch.Tensor,
steps: int = None
) -> torch.Tensor:
"""Shift pitch by semitones."""
if steps is None:
steps = np.random.randint(*self.pitch_shift_range)
# Convert to numpy for librosa processing
audio_np = waveform.numpy()
shifted = librosa.effects.pitch_shift(
audio_np,
sr=self.sample_rate,
n_steps=steps
)
return torch.from_numpy(shifted)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
"""Apply random augmentations."""
# Apply each augmentation with 50% probability
if self.enable_noise and np.random.random() < 0.5:
waveform = self.add_noise(waveform)
if self.enable_speed and np.random.random() < 0.5:
waveform = self.change_speed(waveform)
if self.enable_pitch and np.random.random() < 0.5:
waveform = self.shift_pitch(waveform)
return waveform
Spectrogram-Domain Augmentations (SpecAugment)
class SpecAugmentor(torch.nn.Module):
"""
SpecAugment: A Simple Augmentation Method (Google Brain, 2019)
Applied AFTER feature extraction on the spectrogram.
Key insight: Masking forces the model to rely on context,
improving robustness to missing information.
"""
def __init__(
self,
freq_mask_param: int = 27, # Maximum frequency mask width
time_mask_param: int = 100, # Maximum time mask width
num_freq_masks: int = 2, # Number of frequency masks
num_time_masks: int = 2, # Number of time masks
replace_with_zero: bool = False # False = mean value
):
super().__init__()
self.freq_masking = T.FrequencyMasking(freq_mask_param)
self.time_masking = T.TimeMasking(time_mask_param)
self.num_freq_masks = num_freq_masks
self.num_time_masks = num_time_masks
self.replace_with_zero = replace_with_zero
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
"""
Apply SpecAugment to spectrogram.
Args:
spectrogram: (batch, freq, time) or (freq, time)
Returns:
Augmented spectrogram with same shape
"""
# Calculate replacement value
if not self.replace_with_zero:
mask_value = spectrogram.mean()
else:
mask_value = 0.0
# Apply frequency masks
for _ in range(self.num_freq_masks):
spectrogram = self.freq_masking(spectrogram, mask_value)
# Apply time masks
for _ in range(self.num_time_masks):
spectrogram = self.time_masking(spectrogram, mask_value)
return spectrogram
class AdvancedSpecAugment(torch.nn.Module):
"""
Advanced SpecAugment with adaptive parameters.
Implements SpecAugment++ with frequency warping.
"""
def __init__(
self,
freq_mask_range: Tuple[int, int] = (0, 27),
time_mask_range: Tuple[int, int] = (0, 100),
num_masks_range: Tuple[int, int] = (1, 3),
warp_window: int = 80
):
super().__init__()
self.freq_mask_range = freq_mask_range
self.time_mask_range = time_mask_range
self.num_masks_range = num_masks_range
self.warp_window = warp_window
def time_warp(self, spectrogram: torch.Tensor) -> torch.Tensor:
"""Apply time warping (non-linear time stretching)."""
batch, freq, time = spectrogram.shape
if time < self.warp_window * 2:
return spectrogram
# Random warp point
center = time // 2
warp_distance = np.random.randint(-self.warp_window, self.warp_window)
# Create warped indices
left_indices = torch.linspace(0, center + warp_distance, center).long()
right_indices = torch.linspace(center + warp_distance, time - 1, time - center).long()
indices = torch.cat([left_indices, right_indices])
# Apply warping
warped = spectrogram[:, :, indices.clamp(0, time - 1)]
return warped
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
"""Apply advanced SpecAugment."""
# Ensure batch dimension
squeeze = False
if spectrogram.dim() == 2:
spectrogram = spectrogram.unsqueeze(0)
squeeze = True
batch, freq, time = spectrogram.shape
# Time warping
if np.random.random() < 0.5:
spectrogram = self.time_warp(spectrogram)
# Adaptive frequency masks
num_freq_masks = np.random.randint(*self.num_masks_range)
for _ in range(num_freq_masks):
width = np.random.randint(*self.freq_mask_range)
start = np.random.randint(0, max(1, freq - width))
spectrogram[:, start:start + width, :] = spectrogram.mean()
# Adaptive time masks
num_time_masks = np.random.randint(*self.num_masks_range)
for _ in range(num_time_masks):
width = np.random.randint(*self.time_mask_range)
width = min(width, time // 4) # Limit to 25% of time
start = np.random.randint(0, max(1, time - width))
spectrogram[:, :, start:start + width] = spectrogram.mean()
if squeeze:
spectrogram = spectrogram.squeeze(0)
return spectrogram
35.1.6. Handling Variable-Length Sequences
Audio samples have different durations. Batching requires careful handling.
Padding and Masking Strategy
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from typing import List, Dict
class AudioDataset(Dataset):
"""Audio dataset with variable-length handling."""
def __init__(
self,
audio_paths: List[str],
labels: List[str],
feature_extractor: AudioFeatureExtractor,
max_length_seconds: float = 30.0
):
self.audio_paths = audio_paths
self.labels = labels
self.feature_extractor = feature_extractor
self.max_samples = int(max_length_seconds * 16000)
def __len__(self):
return len(self.audio_paths)
def __getitem__(self, idx):
# Load audio
audio, sr = self.feature_extractor.load_audio(self.audio_paths[idx])
# Truncate if too long
if len(audio) > self.max_samples:
audio = audio[:self.max_samples]
# Extract features
features = self.feature_extractor.extract_features(audio)
return {
'features': torch.from_numpy(features).float(),
'length': features.shape[1],
'label': self.labels[idx]
}
def audio_collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""
Custom collate function for variable-length audio.
Returns:
features: (batch, freq, max_time)
lengths: (batch,)
labels: List of labels
"""
features = [item['features'] for item in batch]
lengths = torch.tensor([item['length'] for item in batch])
labels = [item['label'] for item in batch]
# Pad to max length in batch
max_len = max(f.shape[1] for f in features)
padded = torch.zeros(len(features), features[0].shape[0], max_len)
for i, feat in enumerate(features):
padded[i, :, :feat.shape[1]] = feat
# Create attention mask
mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1)
return {
'features': padded,
'lengths': lengths,
'attention_mask': mask,
'labels': labels
}
class LengthBucketingSampler:
"""
Bucket samples by length for efficient batching.
Minimizes padding by grouping similar-length samples.
Training speedup: 20-30%
"""
def __init__(
self,
lengths: List[int],
batch_size: int,
num_buckets: int = 10
):
self.lengths = lengths
self.batch_size = batch_size
self.num_buckets = num_buckets
# Create length-sorted indices
self.sorted_indices = np.argsort(lengths)
# Create buckets
self.buckets = np.array_split(self.sorted_indices, num_buckets)
def __iter__(self):
# Shuffle within buckets
for bucket in self.buckets:
np.random.shuffle(bucket)
# Yield batches
all_indices = np.concatenate(self.buckets)
for i in range(0, len(all_indices), self.batch_size):
yield all_indices[i:i + self.batch_size].tolist()
def __len__(self):
return (len(self.lengths) + self.batch_size - 1) // self.batch_size
35.1.7. Storage and Data Loading at Scale
WebDataset for Large-Scale Training
import webdataset as wds
from pathlib import Path
import tarfile
import json
def create_audio_shards(
audio_files: List[str],
labels: List[str],
output_dir: str,
shard_size: int = 10000
):
"""
Create WebDataset shards for efficient loading.
Benefits:
- Sequential reads instead of random IO
- Works with cloud storage (S3, GCS)
- Streaming without full download
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
shard_idx = 0
sample_idx = 0
current_tar = None
for audio_path, label in zip(audio_files, labels):
# Start new shard if needed
if sample_idx % shard_size == 0:
if current_tar:
current_tar.close()
shard_name = f"shard-{shard_idx:06d}.tar"
current_tar = tarfile.open(output_path / shard_name, "w")
shard_idx += 1
# Create sample key
key = f"{sample_idx:08d}"
# Add audio file
with open(audio_path, 'rb') as f:
audio_bytes = f.read()
audio_info = tarfile.TarInfo(name=f"{key}.wav")
audio_info.size = len(audio_bytes)
current_tar.addfile(audio_info, fileobj=io.BytesIO(audio_bytes))
# Add metadata
metadata = json.dumps({'label': label, 'path': audio_path})
meta_bytes = metadata.encode('utf-8')
meta_info = tarfile.TarInfo(name=f"{key}.json")
meta_info.size = len(meta_bytes)
current_tar.addfile(meta_info, fileobj=io.BytesIO(meta_bytes))
sample_idx += 1
if current_tar:
current_tar.close()
print(f"Created {shard_idx} shards with {sample_idx} samples")
def create_webdataset_loader(
shard_pattern: str,
batch_size: int = 32,
num_workers: int = 4,
shuffle_buffer: int = 1000
):
"""Create streaming WebDataset loader."""
def decode_audio(sample):
"""Decode audio from bytes."""
audio_bytes = sample['wav']
audio, sr = torchaudio.load(io.BytesIO(audio_bytes))
# Resample if needed
if sr != 16000:
resampler = T.Resample(sr, 16000)
audio = resampler(audio)
metadata = json.loads(sample['json'])
return {
'audio': audio.squeeze(0),
'label': metadata['label']
}
dataset = (
wds.WebDataset(shard_pattern)
.shuffle(shuffle_buffer)
.map(decode_audio)
.batched(batch_size, collation_fn=audio_collate_fn)
)
loader = wds.WebLoader(
dataset,
num_workers=num_workers,
batch_size=None # Batching done in dataset
)
return loader
# Usage
loader = create_webdataset_loader(
"s3://my-bucket/audio-shards/shard-{000000..000100}.tar",
batch_size=32
)
for batch in loader:
features = batch['features']
labels = batch['labels']
# Training step...
35.1.8. Production Serving Pipeline
Triton Inference Server Configuration
# config.pbtxt for audio feature extraction ensemble
name: "audio_feature_ensemble"
platform: "ensemble"
max_batch_size: 32
input [
{
name: "AUDIO_BYTES"
data_type: TYPE_UINT8
dims: [ -1 ]
}
]
output [
{
name: "FEATURES"
data_type: TYPE_FP32
dims: [ 80, -1 ]
}
]
ensemble_scheduling {
step [
{
model_name: "audio_decoder"
model_version: 1
input_map {
key: "BYTES"
value: "AUDIO_BYTES"
}
output_map {
key: "WAVEFORM"
value: "decoded_audio"
}
},
{
model_name: "feature_extractor"
model_version: 1
input_map {
key: "AUDIO"
value: "decoded_audio"
}
output_map {
key: "MEL_SPECTROGRAM"
value: "FEATURES"
}
}
]
}
ONNX Export for Feature Extraction
import torch
import onnx
import onnxruntime as ort
def export_feature_extractor_onnx(
extractor: TorchAudioFeatureExtractor,
output_path: str,
max_audio_length: int = 160000 # 10 seconds at 16kHz
):
"""Export feature extractor to ONNX for production serving."""
extractor.eval()
# Create dummy input
dummy_input = torch.randn(1, max_audio_length)
# Export
torch.onnx.export(
extractor,
dummy_input,
output_path,
input_names=['audio'],
output_names=['features'],
dynamic_axes={
'audio': {0: 'batch', 1: 'samples'},
'features': {0: 'batch', 2: 'time'}
},
opset_version=14
)
# Verify export
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
# Test inference
session = ort.InferenceSession(output_path)
test_input = np.random.randn(1, 16000).astype(np.float32)
output = session.run(None, {'audio': test_input})
print(f"Exported to {output_path}")
print(f"Output shape: {output[0].shape}")
return output_path
class ONNXFeatureExtractor:
"""Production ONNX feature extractor."""
def __init__(self, model_path: str, use_gpu: bool = True):
providers = ['CUDAExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
self.session = ort.InferenceSession(model_path, providers=providers)
def extract(self, audio: np.ndarray) -> np.ndarray:
"""Extract features using ONNX runtime."""
if audio.ndim == 1:
audio = audio.reshape(1, -1)
audio = audio.astype(np.float32)
features = self.session.run(
None,
{'audio': audio}
)[0]
return features
35.1.9. Cloud-Specific Implementations
AWS SageMaker Processing
# sagemaker_processor.py
from sagemaker.processing import ScriptProcessor
from sagemaker.pytorch import PyTorchProcessor
def create_feature_extraction_job(
input_s3_uri: str,
output_s3_uri: str,
role: str
):
"""Create SageMaker Processing job for batch feature extraction."""
processor = PyTorchProcessor(
role=role,
instance_count=4,
instance_type="ml.g4dn.xlarge",
framework_version="2.0",
py_version="py310"
)
processor.run(
code="extract_features.py",
source_dir="./processing_scripts",
inputs=[
ProcessingInput(
source=input_s3_uri,
destination="/opt/ml/processing/input"
)
],
outputs=[
ProcessingOutput(
source="/opt/ml/processing/output",
destination=output_s3_uri
)
],
arguments=[
"--sample-rate", "16000",
"--n-mels", "80",
"--batch-size", "64"
]
)
GCP Vertex AI Pipeline
from kfp.v2 import dsl
from kfp.v2.dsl import component
@component(
packages_to_install=["librosa", "torch", "torchaudio"],
base_image="python:3.10"
)
def extract_audio_features(
input_gcs_path: str,
output_gcs_path: str,
sample_rate: int = 16000,
n_mels: int = 80
):
"""Vertex AI component for audio feature extraction."""
from google.cloud import storage
import librosa
import numpy as np
import os
# ... feature extraction logic
pass
@dsl.pipeline(
name="audio-feature-pipeline",
description="Extract audio features at scale"
)
def audio_pipeline(
input_path: str,
output_path: str
):
extract_task = extract_audio_features(
input_gcs_path=input_path,
output_gcs_path=output_path
).set_gpu_limit(1).set_memory_limit("16G")
35.1.10. Summary Checklist for Audio Feature Extraction
Data Preparation
- Standardize sample rate (16kHz for ASR, 44.1kHz for music)
- Convert to mono unless stereo is required
- Handle variable lengths with padding/truncation
- Create efficient storage format (WebDataset)
Feature Extraction
- Choose appropriate feature type (Log-Mel vs MFCC vs Wav2Vec)
- Configure FFT parameters for use case
- Implement GPU acceleration for training
- Export ONNX model for production
Augmentation
- Apply time-domain augmentations (noise, speed, pitch)
- Apply SpecAugment during training
- Use length bucketing for efficient batching
Production
- Ensure preprocessing consistency (train == inference)
- Set up Triton or custom serving pipeline
- Monitor feature distributions for drift
[End of Section 35.1]