35.2. Streaming ASR Architectures: Real-Time Transcription
Important
The Latency Trap: Users expect subtitles to appear as they speak. If you wait for the sentence to finish (EOU - End of Utterance), the latency is 3-5 seconds. You must stream partial results.
Real-time Automatic Speech Recognition (ASR) is one of the most demanding MLOps challenges. Unlike batch transcription—where you can take minutes to process an hour-long podcast—streaming ASR requires sub-second latency while maintaining high accuracy. This chapter covers the complete architecture for building production-grade streaming ASR systems.
35.2.1. The Streaming Architecture
The streaming ASR pipeline consists of several critical components that must work in harmony:
graph LR
A[Client/Browser] -->|WebSocket/gRPC| B[Load Balancer]
B --> C[VAD Service]
C -->|Speech Segments| D[ASR Engine]
D -->|Partial Results| E[Post-Processor]
E -->|Final Results| F[Client]
subgraph "State Management"
G[Session Store]
D <--> G
end
Component Breakdown
- Client Layer: Browser captures Mic blob (WebAudio API). Sends chunks via WebSocket.
- VAD (Voice Activity Detection): “Is this silence?” If yes, drop packet. If no, pass to queue.
- ASR Engine: Maintains state (RNN/Transformer Memory). Updates partial transcript.
- Post-Processor: Punctuation, capitalization, number formatting.
- Stabilization: “I think you said ‘Hello W…’ -> ‘Hello World’”. The text changes.
Latency Budget Breakdown
| Component | Target Latency | Notes |
|---|---|---|
| Client Capture | 20-50ms | WebAudio buffer size |
| Network Transit | 10-50ms | Depends on geography |
| VAD Processing | 5-10ms | Must be ultra-fast |
| ASR Inference | 50-200ms | GPU-dependent |
| Post-Processing | 10-20ms | Punctuation/formatting |
| Total E2E | 100-350ms | Target < 300ms for UX |
35.2.2. Protocol: WebSocket vs. gRPC
The choice of streaming protocol significantly impacts architecture decisions.
WebSocket Architecture
- Pros: Ubiquitous. Works in Browser JS natively. Good for B2C apps.
- Cons: Text-based overhead, less efficient for binary data.
gRPC Streaming
- Pros: Lower overhead (ProtoBuf). Better for backend-to-backend (e.g., Phone Switch -> ASR).
- Cons: Not native in browsers (requires grpc-web proxy).
Comparison Matrix
| Feature | WebSocket | gRPC |
|---|---|---|
| Browser Support | Native | Requires Proxy |
| Binary Efficiency | Moderate | Excellent |
| Bidirectional | Yes | Yes |
| Load Balancing | L7 (Complex) | L4/L7 |
| TLS | WSS | mTLS Native |
| Multiplexing | Per-connection | HTTP/2 Streams |
FastAPI WebSocket Server Implementation
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import asyncio
import numpy as np
from dataclasses import dataclass
from typing import Optional
import logging
logger = logging.getLogger(__name__)
app = FastAPI(title="Streaming ASR Service")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@dataclass
class TranscriptionResult:
text: str
is_final: bool
confidence: float
start_time: float
end_time: float
words: list
class ASRSession:
"""Manages state for a single ASR streaming session."""
def __init__(self, model_name: str = "base"):
self.model = self._load_model(model_name)
self.buffer = np.array([], dtype=np.float32)
self.context = None
self.sample_rate = 16000
self.chunk_duration = 0.5 # seconds
self.total_audio_processed = 0.0
def _load_model(self, model_name: str):
# Initialize ASR model with streaming support
from faster_whisper import WhisperModel
return WhisperModel(
model_name,
device="cuda",
compute_type="int8"
)
async def process_chunk(self, audio_bytes: bytes) -> TranscriptionResult:
# Convert bytes to numpy array
audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
# Append to buffer
self.buffer = np.concatenate([self.buffer, audio])
# Check if we have enough audio to process
min_samples = int(self.sample_rate * self.chunk_duration)
if len(self.buffer) < min_samples:
return TranscriptionResult(
text="",
is_final=False,
confidence=0.0,
start_time=self.total_audio_processed,
end_time=self.total_audio_processed,
words=[]
)
# Process the buffer
segments, info = self.model.transcribe(
self.buffer,
beam_size=5,
language="en",
vad_filter=True,
vad_parameters=dict(
min_silence_duration_ms=500,
speech_pad_ms=400
)
)
# Collect results
text_parts = []
words = []
for segment in segments:
text_parts.append(segment.text)
if hasattr(segment, 'words') and segment.words:
words.extend([
{"word": w.word, "start": w.start, "end": w.end, "probability": w.probability}
for w in segment.words
])
result_text = " ".join(text_parts).strip()
# Update tracking
buffer_duration = len(self.buffer) / self.sample_rate
self.total_audio_processed += buffer_duration
# Clear processed buffer (keep last 0.5s for context)
overlap_samples = int(self.sample_rate * 0.5)
self.buffer = self.buffer[-overlap_samples:]
return TranscriptionResult(
text=result_text,
is_final=False,
confidence=info.language_probability if info else 0.0,
start_time=self.total_audio_processed - buffer_duration,
end_time=self.total_audio_processed,
words=words
)
def close(self):
self.buffer = np.array([], dtype=np.float32)
self.context = None
class VADProcessor:
"""Voice Activity Detection to filter silence."""
def __init__(self, threshold: float = 0.5):
import torch
self.model, self.utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False
)
self.threshold = threshold
self.sample_rate = 16000
def is_speech(self, audio_bytes: bytes) -> bool:
import torch
audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
audio_tensor = torch.from_numpy(audio)
# Get speech probability
speech_prob = self.model(audio_tensor, self.sample_rate).item()
return speech_prob > self.threshold
# Global instances (in production, use dependency injection)
vad_processor = VADProcessor()
@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
# Initialize ASR session for this connection
session = ASRSession(model_name="base")
silence_count = 0
max_silence_chunks = 10 # Close after ~5 seconds of silence
try:
while True:
# 1. Receive audio chunk (bytes)
audio_chunk = await websocket.receive_bytes()
# 2. VAD Check - skip silence
if not vad_processor.is_speech(audio_chunk):
silence_count += 1
if silence_count >= max_silence_chunks:
# Send end-of-stream signal
await websocket.send_json({
"text": "",
"is_final": True,
"event": "end_of_speech"
})
continue
silence_count = 0 # Reset on speech
# 3. Process through ASR
result = await session.process_chunk(audio_chunk)
# 4. Send partial result back
await websocket.send_json({
"text": result.text,
"is_final": result.is_final,
"confidence": result.confidence,
"start_time": result.start_time,
"end_time": result.end_time,
"words": result.words
})
except WebSocketDisconnect:
logger.info("Client disconnected")
except Exception as e:
logger.error(f"ASR session error: {e}")
await websocket.send_json({"error": str(e)})
finally:
session.close()
@app.get("/health")
async def health_check():
return {"status": "healthy", "service": "streaming-asr"}
gRPC Server Implementation
// asr_service.proto
syntax = "proto3";
package asr;
service StreamingASR {
rpc StreamingRecognize(stream AudioRequest) returns (stream TranscriptionResponse);
rpc GetSupportedLanguages(Empty) returns (LanguageList);
}
message AudioRequest {
oneof request {
StreamingConfig config = 1;
bytes audio_content = 2;
}
}
message StreamingConfig {
string language_code = 1;
int32 sample_rate_hertz = 2;
string encoding = 3; // LINEAR16, FLAC, OGG_OPUS
bool enable_word_timestamps = 4;
bool enable_punctuation = 5;
int32 max_alternatives = 6;
}
message TranscriptionResponse {
repeated SpeechRecognitionResult results = 1;
string error = 2;
}
message SpeechRecognitionResult {
repeated SpeechRecognitionAlternative alternatives = 1;
bool is_final = 2;
float stability = 3;
}
message SpeechRecognitionAlternative {
string transcript = 1;
float confidence = 2;
repeated WordInfo words = 3;
}
message WordInfo {
string word = 1;
float start_time = 2;
float end_time = 3;
float confidence = 4;
}
message Empty {}
message LanguageList {
repeated string languages = 1;
}
# asr_grpc_server.py
import grpc
from concurrent import futures
import asyncio
from typing import Iterator
import asr_service_pb2 as pb2
import asr_service_pb2_grpc as pb2_grpc
class StreamingASRServicer(pb2_grpc.StreamingASRServicer):
def __init__(self):
self.model = self._load_model()
def _load_model(self):
from faster_whisper import WhisperModel
return WhisperModel("base", device="cuda", compute_type="int8")
def StreamingRecognize(
self,
request_iterator: Iterator[pb2.AudioRequest],
context: grpc.ServicerContext
) -> Iterator[pb2.TranscriptionResponse]:
config = None
audio_buffer = b""
for request in request_iterator:
if request.HasField("config"):
config = request.config
continue
audio_buffer += request.audio_content
# Process when we have enough audio (e.g., 500ms)
if len(audio_buffer) >= config.sample_rate_hertz:
# Convert and transcribe
import numpy as np
audio = np.frombuffer(audio_buffer, dtype=np.int16).astype(np.float32) / 32768.0
segments, _ = self.model.transcribe(
audio,
language=config.language_code[:2] if config.language_code else "en"
)
for segment in segments:
alternative = pb2.SpeechRecognitionAlternative(
transcript=segment.text,
confidence=segment.avg_logprob
)
if config.enable_word_timestamps and segment.words:
for word in segment.words:
alternative.words.append(pb2.WordInfo(
word=word.word,
start_time=word.start,
end_time=word.end,
confidence=word.probability
))
result = pb2.SpeechRecognitionResult(
alternatives=[alternative],
is_final=False,
stability=0.9
)
yield pb2.TranscriptionResponse(results=[result])
# Keep overlap for context
audio_buffer = audio_buffer[-config.sample_rate_hertz // 2:]
def GetSupportedLanguages(self, request, context):
return pb2.LanguageList(languages=[
"en-US", "en-GB", "es-ES", "fr-FR", "de-DE",
"it-IT", "pt-BR", "ja-JP", "ko-KR", "zh-CN"
])
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
pb2_grpc.add_StreamingASRServicer_to_server(StreamingASRServicer(), server)
server.add_insecure_port('[::]:50051')
server.start()
server.wait_for_termination()
35.2.3. Models: Whisper vs. Conformer vs. Kaldi
Model Comparison
| Model | Architecture | Latency | Accuracy (WER) | Streaming | Resource Usage |
|---|---|---|---|---|---|
| Kaldi | WFST + GMM/DNN | Ultra-low | Moderate | Native | Low (CPU) |
| Whisper | Transformer | High | Excellent | Adapted | High (GPU) |
| Conformer | Conv + Transformer | Medium | Excellent | Native | Medium-High |
| DeepSpeech | RNN (LSTM/GRU) | Low | Good | Native | Medium |
| Wav2Vec2 | Transformer | Medium | Excellent | Adapted | High |
Kaldi: The Classic Approach
- Architecture: Uses Weighted Finite State Transducers (WFST). Extremely fast, low CPU.
- Pros: Battle-tested, CPU-only, microsecond latency.
- Cons: Complex deployment, steep learning curve, harder to customize.
# Kaldi online2 decoder example
online2-wav-nnet3-latgen-faster \
--online=true \
--do-endpointing=true \
--config=conf/online.conf \
--max-active=7000 \
--beam=15.0 \
--lattice-beam=6.0 \
--acoustic-scale=1.0 \
final.mdl \
graph/HCLG.fst \
ark:spk2utt \
scp:wav.scp \
ark:/dev/null
Whisper: The Modern Standard
- Architecture: Encoder-Decoder Transformer (680M params for large-v3).
- Pros: State-of-the-art accuracy, multilingual, robust to noise.
- Cons: Not natively streaming, high GPU requirements.
Streaming Whisper Implementations:
- faster-whisper: CTranslate2 backend with INT8 quantization
- whisper.cpp: C/C++ port for edge devices
- whisper-streaming: Buffered streaming with LocalAgreement
# faster-whisper streaming implementation
from faster_whisper import WhisperModel
import numpy as np
class StreamingWhisper:
def __init__(self, model_size: str = "base"):
self.model = WhisperModel(
model_size,
device="cuda",
compute_type="int8", # INT8 for 2x speedup
cpu_threads=4
)
self.buffer = np.array([], dtype=np.float32)
self.min_chunk_size = 16000 # 1 second at 16kHz
self.overlap_size = 8000 # 0.5 second overlap
def process_chunk(self, audio_chunk: np.ndarray) -> str:
self.buffer = np.concatenate([self.buffer, audio_chunk])
if len(self.buffer) < self.min_chunk_size:
return ""
# Transcribe current buffer
segments, _ = self.model.transcribe(
self.buffer,
beam_size=5,
best_of=5,
language="en",
condition_on_previous_text=True,
vad_filter=True
)
text = " ".join([s.text for s in segments])
# Keep overlap for context continuity
self.buffer = self.buffer[-self.overlap_size:]
return text.strip()
Conformer: Best of Both Worlds
The Conformer architecture combines convolutional layers (local patterns) with Transformer attention (global context).
# Using NVIDIA NeMo Conformer for streaming
import nemo.collections.asr as nemo_asr
class ConformerStreaming:
def __init__(self):
self.model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
"nvidia/stt_en_conformer_transducer_large"
)
self.model.eval()
self.model.cuda()
def transcribe_stream(self, audio_chunks):
"""Process audio in streaming mode."""
# Enable streaming mode
self.model.encoder.set_streaming_cfg(
chunk_size=160, # 10ms chunks at 16kHz
left_context=32,
right_context=0 # Causal for streaming
)
for chunk in audio_chunks:
# Process each chunk
with torch.inference_mode():
logits, logits_len = self.model.encoder(
audio_signal=chunk.cuda(),
length=torch.tensor([len(chunk)])
)
hypotheses = self.model.decoding.rnnt_decoder_predictions_tensor(
logits, logits_len
)
yield hypotheses[0]
35.2.4. Metrics: Word Error Rate (WER) and Beyond
Word Error Rate (WER)
The standard metric for ASR quality:
$$ WER = \frac{S + D + I}{N} \times 100% $$
Where:
- S: Substitutions (“Cat” -> “Bat”)
- D: Deletions (“The Cat” -> “Cat”)
- I: Insertions (“Cat” -> “The Cat”)
- N: Total words in reference
# WER calculation implementation
from jiwer import wer, cer
from dataclasses import dataclass
from typing import List
@dataclass
class ASRMetrics:
wer: float
cer: float
substitutions: int
deletions: int
insertions: int
reference_words: int
def calculate_asr_metrics(reference: str, hypothesis: str) -> ASRMetrics:
"""Calculate comprehensive ASR metrics."""
from jiwer import compute_measures
# Normalize text
reference = reference.lower().strip()
hypothesis = hypothesis.lower().strip()
measures = compute_measures(reference, hypothesis)
return ASRMetrics(
wer=measures['wer'] * 100,
cer=cer(reference, hypothesis) * 100,
substitutions=measures['substitutions'],
deletions=measures['deletions'],
insertions=measures['insertions'],
reference_words=len(reference.split())
)
def batch_evaluate(references: List[str], hypotheses: List[str]) -> dict:
"""Evaluate a batch of transcriptions."""
total_wer = wer(references, hypotheses)
# Per-sample analysis
metrics = [
calculate_asr_metrics(ref, hyp)
for ref, hyp in zip(references, hypotheses)
]
return {
"overall_wer": total_wer * 100,
"mean_wer": sum(m.wer for m in metrics) / len(metrics),
"median_wer": sorted(m.wer for m in metrics)[len(metrics) // 2],
"samples_above_10_wer": sum(1 for m in metrics if m.wer > 10),
"substitution_rate": sum(m.substitutions for m in metrics) / sum(m.reference_words for m in metrics) * 100,
"deletion_rate": sum(m.deletions for m in metrics) / sum(m.reference_words for m in metrics) * 100,
"insertion_rate": sum(m.insertions for m in metrics) / sum(m.reference_words for m in metrics) * 100
}
Real-Time Factor (RTF)
Measures processing speed relative to audio duration:
$$ RTF = \frac{Processing Time}{Audio Duration} $$
- RTF < 1: Real-time capable
- RTF < 0.5: Good for streaming (leaves headroom)
- RTF < 0.1: Excellent, supports batching
import time
import numpy as np
def benchmark_rtf(model, audio_samples: List[np.ndarray], sample_rate: int = 16000) -> dict:
"""Benchmark Real-Time Factor for ASR model."""
total_audio_duration = 0
total_processing_time = 0
for audio in audio_samples:
audio_duration = len(audio) / sample_rate
total_audio_duration += audio_duration
start_time = time.perf_counter()
_ = model.transcribe(audio)
end_time = time.perf_counter()
total_processing_time += (end_time - start_time)
rtf = total_processing_time / total_audio_duration
return {
"rtf": rtf,
"is_realtime": rtf < 1.0,
"throughput_factor": 1.0 / rtf,
"total_audio_hours": total_audio_duration / 3600,
"processing_time_hours": total_processing_time / 3600
}
Streaming-Specific Metrics
| Metric | Description | Target |
|---|---|---|
| First Byte Latency | Time to first partial result | < 200ms |
| Partial WER | WER of unstable partials | < 30% |
| Final WER | WER of finalized text | < 10% |
| Word Stabilization Time | Time for word to become final | < 2s |
| Endpoint Detection Latency | Time to detect end of utterance | < 500ms |
35.2.5. Handling Whisper Hallucinations
Whisper is notorious for hallucinating when fed silence or low-quality audio.
Common Hallucination Patterns
- “Thank you for watching” - YouTube training data artifact
- Repeated phrases - Getting stuck in loops
- Language switching - Random multilingual outputs
- Phantom speakers - Inventing conversation partners
Mitigation Strategies
from dataclasses import dataclass
from typing import Optional, List
import numpy as np
@dataclass
class HallucinationDetector:
"""Detect and filter ASR hallucinations."""
# Known hallucination phrases
KNOWN_HALLUCINATIONS = [
"thank you for watching",
"thanks for watching",
"please subscribe",
"like and subscribe",
"see you in the next video",
"don't forget to subscribe",
"if you enjoyed this video",
]
# Repetition detection
MAX_REPEAT_RATIO = 0.7
# Silence detection threshold
SILENCE_THRESHOLD = 0.01
def is_hallucination(
self,
text: str,
audio: np.ndarray,
previous_texts: List[str] = None
) -> tuple[bool, str]:
"""
Check if transcription is likely a hallucination.
Returns (is_hallucination, reason)
"""
text_lower = text.lower().strip()
# Check 1: Known phrases
for phrase in self.KNOWN_HALLUCINATIONS:
if phrase in text_lower:
return True, f"known_hallucination:{phrase}"
# Check 2: Audio is silence
if self._is_silence(audio):
return True, "silence_detected"
# Check 3: Excessive repetition
if self._has_repetition(text):
return True, "excessive_repetition"
# Check 4: Exact repeat of previous output
if previous_texts and text_lower in [t.lower() for t in previous_texts[-3:]]:
return True, "exact_repeat"
return False, "valid"
def _is_silence(self, audio: np.ndarray) -> bool:
"""Check if audio is effectively silence."""
rms = np.sqrt(np.mean(audio ** 2))
return rms < self.SILENCE_THRESHOLD
def _has_repetition(self, text: str) -> bool:
"""Detect word-level repetition."""
words = text.lower().split()
if len(words) < 4:
return False
# Check for bigram repetition
bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
unique_bigrams = set(bigrams)
repeat_ratio = 1 - (len(unique_bigrams) / len(bigrams))
return repeat_ratio > self.MAX_REPEAT_RATIO
class RobustASR:
"""ASR with hallucination filtering."""
def __init__(self, model):
self.model = model
self.detector = HallucinationDetector()
self.history = []
def transcribe(self, audio: np.ndarray) -> Optional[str]:
# Step 1: Pre-filter with aggressive VAD
if self._is_low_energy(audio):
return None
# Step 2: Transcribe
segments, _ = self.model.transcribe(
audio,
no_speech_threshold=0.6, # More aggressive filtering
logprob_threshold=-1.0,
compression_ratio_threshold=2.4, # Detect repetition
condition_on_previous_text=False # Reduce hallucination propagation
)
text = " ".join(s.text for s in segments).strip()
# Step 3: Post-filter hallucinations
is_hallucination, reason = self.detector.is_hallucination(
text, audio, self.history
)
if is_hallucination:
return None
# Update history
self.history.append(text)
if len(self.history) > 10:
self.history.pop(0)
return text
def _is_low_energy(self, audio: np.ndarray, threshold: float = 0.005) -> bool:
return np.sqrt(np.mean(audio ** 2)) < threshold
35.2.6. Voice Activity Detection (VAD) Deep Dive
VAD is the first line of defense against wasted compute and hallucinations.
Silero VAD: Production Standard
import torch
import numpy as np
from typing import List, Tuple
class SileroVAD:
"""Production-ready VAD using Silero."""
def __init__(
self,
threshold: float = 0.5,
min_speech_duration_ms: int = 250,
min_silence_duration_ms: int = 100,
sample_rate: int = 16000
):
self.model, self.utils = torch.hub.load(
'snakers4/silero-vad',
'silero_vad',
trust_repo=True
)
self.threshold = threshold
self.min_speech_samples = int(sample_rate * min_speech_duration_ms / 1000)
self.min_silence_samples = int(sample_rate * min_silence_duration_ms / 1000)
self.sample_rate = sample_rate
# State for streaming
self.reset()
def reset(self):
"""Reset internal state for new stream."""
self.model.reset_states()
self._in_speech = False
self._speech_start = 0
self._current_position = 0
def process_chunk(self, audio: np.ndarray) -> List[Tuple[int, int]]:
"""
Process audio chunk and return speech segments.
Returns list of (start_sample, end_sample) tuples.
"""
audio_tensor = torch.from_numpy(audio).float()
# Process in 30ms windows
window_size = int(self.sample_rate * 0.030)
segments = []
for i in range(0, len(audio), window_size):
window = audio_tensor[i:i + window_size]
if len(window) < window_size:
# Pad final window
window = torch.nn.functional.pad(window, (0, window_size - len(window)))
speech_prob = self.model(window, self.sample_rate).item()
if speech_prob >= self.threshold:
if not self._in_speech:
self._in_speech = True
self._speech_start = self._current_position + i
else:
if self._in_speech:
self._in_speech = False
speech_end = self._current_position + i
duration = speech_end - self._speech_start
if duration >= self.min_speech_samples:
segments.append((self._speech_start, speech_end))
self._current_position += len(audio)
return segments
def get_speech_timestamps(
self,
audio: np.ndarray
) -> List[dict]:
"""Get speech timestamps for entire audio."""
self.reset()
segments = self.process_chunk(audio)
return [
{
"start": start / self.sample_rate,
"end": end / self.sample_rate,
"duration": (end - start) / self.sample_rate
}
for start, end in segments
]
class EnhancedVAD:
"""VAD with additional features for production."""
def __init__(self):
self.vad = SileroVAD()
self.energy_threshold = 0.01
def is_speech_segment(self, audio: np.ndarray) -> dict:
"""Comprehensive speech detection."""
# Energy check (fast, first filter)
energy = np.sqrt(np.mean(audio ** 2))
if energy < self.energy_threshold:
return {"is_speech": False, "reason": "low_energy", "confidence": 0.0}
# Zero-crossing rate (detect static noise)
zcr = np.mean(np.abs(np.diff(np.sign(audio))))
if zcr > 0.5: # High ZCR often indicates noise
return {"is_speech": False, "reason": "high_zcr", "confidence": 0.3}
# Neural VAD
segments = self.vad.get_speech_timestamps(audio)
if not segments:
return {"is_speech": False, "reason": "vad_reject", "confidence": 0.0}
# Calculate speech ratio
total_speech = sum(s["duration"] for s in segments)
audio_duration = len(audio) / 16000
speech_ratio = total_speech / audio_duration
return {
"is_speech": True,
"reason": "speech_detected",
"confidence": min(speech_ratio * 1.5, 1.0),
"segments": segments,
"speech_ratio": speech_ratio
}
35.2.7. Production Infrastructure: Kubernetes Deployment
AWS EKS Architecture
# asr-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: streaming-asr
namespace: ml-inference
spec:
replicas: 3
selector:
matchLabels:
app: streaming-asr
template:
metadata:
labels:
app: streaming-asr
spec:
nodeSelector:
node.kubernetes.io/instance-type: g5.xlarge
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
containers:
- name: asr-server
image: 123456789.dkr.ecr.us-east-1.amazonaws.com/streaming-asr:v1.2.0
ports:
- containerPort: 8000
name: websocket
- containerPort: 50051
name: grpc
resources:
requests:
memory: "8Gi"
cpu: "2"
nvidia.com/gpu: "1"
limits:
memory: "16Gi"
cpu: "4"
nvidia.com/gpu: "1"
env:
- name: MODEL_SIZE
value: "large-v3"
- name: COMPUTE_TYPE
value: "int8"
- name: MAX_CONCURRENT_STREAMS
value: "50"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 60
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 5
volumeMounts:
- name: model-cache
mountPath: /root/.cache/huggingface
volumes:
- name: model-cache
persistentVolumeClaim:
claimName: model-cache-pvc
---
apiVersion: v1
kind: Service
metadata:
name: streaming-asr
namespace: ml-inference
spec:
selector:
app: streaming-asr
ports:
- name: websocket
port: 80
targetPort: 8000
- name: grpc
port: 50051
targetPort: 50051
type: ClusterIP
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: streaming-asr-ingress
namespace: ml-inference
annotations:
kubernetes.io/ingress.class: "alb"
alb.ingress.kubernetes.io/scheme: "internet-facing"
alb.ingress.kubernetes.io/target-type: "ip"
alb.ingress.kubernetes.io/healthcheck-path: "/health"
alb.ingress.kubernetes.io/backend-protocol: "HTTP"
# WebSocket support
alb.ingress.kubernetes.io/load-balancer-attributes: "idle_timeout.timeout_seconds=3600"
spec:
rules:
- host: asr.example.com
http:
paths:
- path: /
pathType: Prefix
backend:
service:
name: streaming-asr
port:
number: 80
GCP GKE with TPU
# gke-asr-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: streaming-asr-tpu
namespace: ml-inference
spec:
replicas: 2
selector:
matchLabels:
app: streaming-asr-tpu
template:
metadata:
labels:
app: streaming-asr-tpu
spec:
nodeSelector:
cloud.google.com/gke-accelerator: nvidia-l4
containers:
- name: asr-server
image: gcr.io/my-project/streaming-asr:v1.2.0
ports:
- containerPort: 8000
resources:
requests:
memory: "8Gi"
cpu: "4"
nvidia.com/gpu: "1"
limits:
memory: "16Gi"
nvidia.com/gpu: "1"
env:
- name: GOOGLE_CLOUD_PROJECT
value: "my-project"
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: streaming-asr-hpa
namespace: ml-inference
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: streaming-asr-tpu
minReplicas: 2
maxReplicas: 20
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Pods
pods:
metric:
name: active_websocket_connections
target:
type: AverageValue
averageValue: "100"
Terraform Infrastructure
# asr_infrastructure.tf
# AWS EKS Node Group for ASR
resource "aws_eks_node_group" "asr_gpu" {
cluster_name = aws_eks_cluster.main.name
node_group_name = "asr-gpu-nodes"
node_role_arn = aws_iam_role.eks_node.arn
subnet_ids = var.private_subnet_ids
scaling_config {
desired_size = 3
max_size = 10
min_size = 2
}
instance_types = ["g5.xlarge"]
ami_type = "AL2_x86_64_GPU"
capacity_type = "ON_DEMAND"
labels = {
workload = "asr-inference"
gpu = "true"
}
taint {
key = "nvidia.com/gpu"
value = "true"
effect = "NO_SCHEDULE"
}
tags = {
Environment = var.environment
Service = "streaming-asr"
}
}
# ElastiCache for session state
resource "aws_elasticache_cluster" "asr_sessions" {
cluster_id = "asr-sessions"
engine = "redis"
node_type = "cache.r6g.large"
num_cache_nodes = 2
parameter_group_name = "default.redis7"
port = 6379
subnet_group_name = aws_elasticache_subnet_group.main.name
security_group_ids = [aws_security_group.redis.id]
tags = {
Service = "streaming-asr"
}
}
# CloudWatch Dashboard for ASR metrics
resource "aws_cloudwatch_dashboard" "asr" {
dashboard_name = "streaming-asr-metrics"
dashboard_body = jsonencode({
widgets = [
{
type = "metric"
x = 0
y = 0
width = 12
height = 6
properties = {
metrics = [
["ASR", "ActiveConnections", "Service", "streaming-asr"],
[".", "TranscriptionsPerSecond", ".", "."],
[".", "P99Latency", ".", "."]
]
title = "ASR Performance"
region = var.aws_region
}
},
{
type = "metric"
x = 12
y = 0
width = 12
height = 6
properties = {
metrics = [
["ASR", "GPU_Utilization", "Service", "streaming-asr"],
[".", "GPU_Memory", ".", "."]
]
title = "GPU Metrics"
region = var.aws_region
}
}
]
})
}
35.2.8. Speaker Diarization: “Who Said What?”
Transcription is useless for meetings without speaker attribution.
The Diarization Pipeline
graph LR
A[Audio Input] --> B[VAD]
B --> C[Speaker Embedding Extraction]
C --> D[Clustering]
D --> E[Speaker Assignment]
E --> F[Merge with ASR]
F --> G[Labeled Transcript]
pyannote.audio Implementation
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from dataclasses import dataclass
from typing import List, Dict
import torch
@dataclass
class SpeakerSegment:
speaker: str
start: float
end: float
text: str = ""
class SpeakerDiarization:
"""Speaker diarization using pyannote.audio."""
def __init__(self, hf_token: str):
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=hf_token
)
# Move to GPU if available
if torch.cuda.is_available():
self.pipeline.to(torch.device("cuda"))
def diarize(
self,
audio_path: str,
num_speakers: int = None,
min_speakers: int = 1,
max_speakers: int = 10
) -> List[SpeakerSegment]:
"""Run speaker diarization on audio file."""
# Configure speaker count
if num_speakers:
diarization = self.pipeline(
audio_path,
num_speakers=num_speakers
)
else:
diarization = self.pipeline(
audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers
)
segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append(SpeakerSegment(
speaker=speaker,
start=turn.start,
end=turn.end
))
return segments
def merge_with_transcription(
self,
diarization_segments: List[SpeakerSegment],
asr_words: List[Dict]
) -> List[SpeakerSegment]:
"""Merge ASR word timestamps with speaker labels."""
result = []
current_segment = None
for word_info in asr_words:
word_mid = (word_info["start"] + word_info["end"]) / 2
# Find speaker for this word
speaker = self._find_speaker_at_time(
diarization_segments, word_mid
)
if current_segment is None or current_segment.speaker != speaker:
if current_segment:
result.append(current_segment)
current_segment = SpeakerSegment(
speaker=speaker,
start=word_info["start"],
end=word_info["end"],
text=word_info["word"]
)
else:
current_segment.end = word_info["end"]
current_segment.text += " " + word_info["word"]
if current_segment:
result.append(current_segment)
return result
def _find_speaker_at_time(
self,
segments: List[SpeakerSegment],
time: float
) -> str:
"""Find which speaker was talking at given time."""
for segment in segments:
if segment.start <= time <= segment.end:
return segment.speaker
return "UNKNOWN"
# Usage example
async def transcribe_meeting(audio_path: str) -> str:
# Step 1: Diarization
diarizer = SpeakerDiarization(hf_token="hf_xxx")
speaker_segments = diarizer.diarize(audio_path, max_speakers=4)
# Step 2: ASR with word timestamps
from faster_whisper import WhisperModel
model = WhisperModel("large-v3", device="cuda")
segments, _ = model.transcribe(
audio_path,
word_timestamps=True
)
# Collect words
words = []
for segment in segments:
if segment.words:
words.extend([
{"word": w.word, "start": w.start, "end": w.end}
for w in segment.words
])
# Step 3: Merge
labeled_segments = diarizer.merge_with_transcription(speaker_segments, words)
# Format output
output = []
for seg in labeled_segments:
output.append(f"[{seg.speaker}] ({seg.start:.1f}s): {seg.text.strip()}")
return "\n".join(output)
35.2.9. Load Balancing for Stateful Streams
WebSockets are stateful. Traditional round-robin doesn’t work.
Session Affinity Architecture
graph TB
subgraph "Client Layer"
C1[Client 1]
C2[Client 2]
C3[Client 3]
end
subgraph "Load Balancer"
LB[ALB/Nginx]
SS[(Session Store)]
LB <--> SS
end
subgraph "ASR Pool"
S1[Server 1]
S2[Server 2]
S3[Server 3]
end
C1 --> LB
C2 --> LB
C3 --> LB
LB -->|"Session Affinity"| S1
LB -->|"Session Affinity"| S2
LB -->|"Session Affinity"| S3
NGINX Configuration for WebSocket Sticky Sessions
# nginx.conf for ASR WebSocket load balancing
upstream asr_backend {
# IP Hash for session affinity
ip_hash;
server asr-1.internal:8000 weight=1;
server asr-2.internal:8000 weight=1;
server asr-3.internal:8000 weight=1;
# Health checks
keepalive 32;
}
map $http_upgrade $connection_upgrade {
default upgrade;
'' close;
}
server {
listen 443 ssl http2;
server_name asr.example.com;
ssl_certificate /etc/nginx/ssl/cert.pem;
ssl_certificate_key /etc/nginx/ssl/key.pem;
# WebSocket timeout (1 hour for long calls)
proxy_read_timeout 3600s;
proxy_send_timeout 3600s;
location /ws/ {
proxy_pass http://asr_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# Buffer settings for streaming
proxy_buffering off;
proxy_cache off;
}
location /health {
proxy_pass http://asr_backend/health;
}
}
Graceful Shutdown for Scaling
import signal
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
import logging
logger = logging.getLogger(__name__)
class GracefulShutdown:
"""Manage graceful shutdown for WebSocket server."""
def __init__(self):
self.is_shutting_down = False
self.active_connections = set()
self.shutdown_event = asyncio.Event()
def register_connection(self, connection_id: str):
self.active_connections.add(connection_id)
def unregister_connection(self, connection_id: str):
self.active_connections.discard(connection_id)
if self.is_shutting_down and not self.active_connections:
self.shutdown_event.set()
async def initiate_shutdown(self, timeout: int = 3600):
"""Start graceful shutdown process."""
logger.info("Initiating graceful shutdown")
self.is_shutting_down = True
if not self.active_connections:
return
logger.info(f"Waiting for {len(self.active_connections)} connections to close")
try:
await asyncio.wait_for(
self.shutdown_event.wait(),
timeout=timeout
)
logger.info("All connections closed gracefully")
except asyncio.TimeoutError:
logger.warning(f"Shutdown timeout, {len(self.active_connections)} connections remaining")
shutdown_manager = GracefulShutdown()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Starting ASR server")
yield
# Shutdown
await shutdown_manager.initiate_shutdown()
app = FastAPI(lifespan=lifespan)
@app.get("/health")
async def health():
if shutdown_manager.is_shutting_down:
# Return 503 to stop new connections
return {"status": "draining", "active_connections": len(shutdown_manager.active_connections)}
return {"status": "healthy"}
35.2.10. Observability and Monitoring
Prometheus Metrics
from prometheus_client import Counter, Histogram, Gauge, generate_latest
from fastapi import FastAPI, Response
# Metrics definitions
TRANSCRIPTION_REQUESTS = Counter(
'asr_transcription_requests_total',
'Total transcription requests',
['status', 'language']
)
TRANSCRIPTION_LATENCY = Histogram(
'asr_transcription_latency_seconds',
'Transcription latency in seconds',
['model_size'],
buckets=[0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0]
)
ACTIVE_CONNECTIONS = Gauge(
'asr_active_websocket_connections',
'Number of active WebSocket connections'
)
AUDIO_PROCESSED = Counter(
'asr_audio_processed_seconds_total',
'Total seconds of audio processed',
['language']
)
WER_SCORE = Histogram(
'asr_wer_score',
'Word Error Rate distribution',
buckets=[0.01, 0.05, 0.1, 0.15, 0.2, 0.3, 0.5]
)
GPU_MEMORY_USED = Gauge(
'asr_gpu_memory_used_bytes',
'GPU memory usage in bytes',
['gpu_id']
)
class MetricsCollector:
"""Collect and expose ASR metrics."""
@staticmethod
def record_transcription(
language: str,
latency: float,
audio_duration: float,
wer: float = None,
success: bool = True
):
status = "success" if success else "error"
TRANSCRIPTION_REQUESTS.labels(status=status, language=language).inc()
TRANSCRIPTION_LATENCY.labels(model_size="large-v3").observe(latency)
AUDIO_PROCESSED.labels(language=language).inc(audio_duration)
if wer is not None:
WER_SCORE.observe(wer)
@staticmethod
def update_gpu_metrics():
import pynvml
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
GPU_MEMORY_USED.labels(gpu_id=str(i)).set(mem_info.used)
@app.get("/metrics")
async def metrics():
MetricsCollector.update_gpu_metrics()
return Response(
generate_latest(),
media_type="text/plain"
)
Grafana Dashboard JSON
{
"dashboard": {
"title": "Streaming ASR Monitoring",
"panels": [
{
"title": "Active Connections",
"type": "stat",
"targets": [
{
"expr": "sum(asr_active_websocket_connections)"
}
]
},
{
"title": "Transcription Latency P99",
"type": "graph",
"targets": [
{
"expr": "histogram_quantile(0.99, rate(asr_transcription_latency_seconds_bucket[5m]))"
}
]
},
{
"title": "Audio Processed (hours/min)",
"type": "graph",
"targets": [
{
"expr": "rate(asr_audio_processed_seconds_total[5m]) * 60 / 3600"
}
]
},
{
"title": "GPU Memory Usage",
"type": "graph",
"targets": [
{
"expr": "asr_gpu_memory_used_bytes / 1024 / 1024 / 1024",
"legendFormat": "GPU {{gpu_id}}"
}
]
}
]
}
}
35.2.11. Cost Optimization Strategies
Cloud Cost Comparison
| Provider | Service | Cost (per hour audio) | RTF | Notes |
|---|---|---|---|---|
| AWS | Transcribe Streaming | $0.024 | N/A | Fully managed |
| GCP | Speech-to-Text | $0.024 | N/A | Fully managed |
| Azure | Speech Services | $0.016 | N/A | Cheaper tier |
| Self-hosted (g5.xlarge) | Whisper Large | ~$0.008 | 0.3 | At scale |
| Self-hosted (g4dn.xlarge) | Whisper Base | ~$0.002 | 0.5 | Budget option |
Spot Instance Strategy
# Spot instance handler for ASR workloads
import boto3
import time
class SpotInterruptionHandler:
"""Handle EC2 Spot interruption for ASR servers."""
def __init__(self):
self.metadata_url = "http://169.254.169.254/latest/meta-data"
def check_for_interruption(self) -> bool:
"""Check if Spot interruption notice has been issued."""
import requests
try:
response = requests.get(
f"{self.metadata_url}/spot/instance-action",
timeout=1
)
if response.status_code == 200:
return True
except:
pass
return False
async def handle_interruption(self, shutdown_manager):
"""Handle Spot interruption gracefully."""
# 2-minute warning before termination
await shutdown_manager.initiate_shutdown(timeout=90)
# Persist any necessary state
# Drain connections to other instances
35.2.12. Summary Checklist for Streaming ASR Operations
Architecture
- WebSocket/gRPC protocol based on client requirements
- Session affinity configured in load balancer
- Graceful shutdown for scaling events
Models
- VAD (Silero) for silence filtering
- Streaming-capable ASR (faster-whisper, Conformer)
- Hallucination detection and filtering
Infrastructure
- GPU nodes with appropriate instance types
- Horizontal Pod Autoscaler on connection count
- Redis for session state (if distributed)
Observability
- Prometheus metrics for latency, throughput, errors
- GPU memory and utilization monitoring
- WER tracking on sampled data
Cost
- Spot instances for non-critical traffic
- Model quantization (INT8) for efficiency
- Aggressive VAD to reduce GPU load
[End of Section 35.2]