Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

35.2. Streaming ASR Architectures: Real-Time Transcription

Important

The Latency Trap: Users expect subtitles to appear as they speak. If you wait for the sentence to finish (EOU - End of Utterance), the latency is 3-5 seconds. You must stream partial results.

Real-time Automatic Speech Recognition (ASR) is one of the most demanding MLOps challenges. Unlike batch transcription—where you can take minutes to process an hour-long podcast—streaming ASR requires sub-second latency while maintaining high accuracy. This chapter covers the complete architecture for building production-grade streaming ASR systems.


35.2.1. The Streaming Architecture

The streaming ASR pipeline consists of several critical components that must work in harmony:

graph LR
    A[Client/Browser] -->|WebSocket/gRPC| B[Load Balancer]
    B --> C[VAD Service]
    C -->|Speech Segments| D[ASR Engine]
    D -->|Partial Results| E[Post-Processor]
    E -->|Final Results| F[Client]
    
    subgraph "State Management"
        G[Session Store]
        D <--> G
    end

Component Breakdown

  1. Client Layer: Browser captures Mic blob (WebAudio API). Sends chunks via WebSocket.
  2. VAD (Voice Activity Detection): “Is this silence?” If yes, drop packet. If no, pass to queue.
  3. ASR Engine: Maintains state (RNN/Transformer Memory). Updates partial transcript.
  4. Post-Processor: Punctuation, capitalization, number formatting.
  5. Stabilization: “I think you said ‘Hello W…’ -> ‘Hello World’”. The text changes.

Latency Budget Breakdown

ComponentTarget LatencyNotes
Client Capture20-50msWebAudio buffer size
Network Transit10-50msDepends on geography
VAD Processing5-10msMust be ultra-fast
ASR Inference50-200msGPU-dependent
Post-Processing10-20msPunctuation/formatting
Total E2E100-350msTarget < 300ms for UX

35.2.2. Protocol: WebSocket vs. gRPC

The choice of streaming protocol significantly impacts architecture decisions.

WebSocket Architecture

  • Pros: Ubiquitous. Works in Browser JS natively. Good for B2C apps.
  • Cons: Text-based overhead, less efficient for binary data.

gRPC Streaming

  • Pros: Lower overhead (ProtoBuf). Better for backend-to-backend (e.g., Phone Switch -> ASR).
  • Cons: Not native in browsers (requires grpc-web proxy).

Comparison Matrix

FeatureWebSocketgRPC
Browser SupportNativeRequires Proxy
Binary EfficiencyModerateExcellent
BidirectionalYesYes
Load BalancingL7 (Complex)L4/L7
TLSWSSmTLS Native
MultiplexingPer-connectionHTTP/2 Streams

FastAPI WebSocket Server Implementation

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import asyncio
import numpy as np
from dataclasses import dataclass
from typing import Optional
import logging

logger = logging.getLogger(__name__)

app = FastAPI(title="Streaming ASR Service")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@dataclass
class TranscriptionResult:
    text: str
    is_final: bool
    confidence: float
    start_time: float
    end_time: float
    words: list

class ASRSession:
    """Manages state for a single ASR streaming session."""
    
    def __init__(self, model_name: str = "base"):
        self.model = self._load_model(model_name)
        self.buffer = np.array([], dtype=np.float32)
        self.context = None
        self.sample_rate = 16000
        self.chunk_duration = 0.5  # seconds
        self.total_audio_processed = 0.0
        
    def _load_model(self, model_name: str):
        # Initialize ASR model with streaming support
        from faster_whisper import WhisperModel
        return WhisperModel(
            model_name, 
            device="cuda", 
            compute_type="int8"
        )
    
    async def process_chunk(self, audio_bytes: bytes) -> TranscriptionResult:
        # Convert bytes to numpy array
        audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
        
        # Append to buffer
        self.buffer = np.concatenate([self.buffer, audio])
        
        # Check if we have enough audio to process
        min_samples = int(self.sample_rate * self.chunk_duration)
        if len(self.buffer) < min_samples:
            return TranscriptionResult(
                text="",
                is_final=False,
                confidence=0.0,
                start_time=self.total_audio_processed,
                end_time=self.total_audio_processed,
                words=[]
            )
        
        # Process the buffer
        segments, info = self.model.transcribe(
            self.buffer,
            beam_size=5,
            language="en",
            vad_filter=True,
            vad_parameters=dict(
                min_silence_duration_ms=500,
                speech_pad_ms=400
            )
        )
        
        # Collect results
        text_parts = []
        words = []
        for segment in segments:
            text_parts.append(segment.text)
            if hasattr(segment, 'words') and segment.words:
                words.extend([
                    {"word": w.word, "start": w.start, "end": w.end, "probability": w.probability}
                    for w in segment.words
                ])
        
        result_text = " ".join(text_parts).strip()
        
        # Update tracking
        buffer_duration = len(self.buffer) / self.sample_rate
        self.total_audio_processed += buffer_duration
        
        # Clear processed buffer (keep last 0.5s for context)
        overlap_samples = int(self.sample_rate * 0.5)
        self.buffer = self.buffer[-overlap_samples:]
        
        return TranscriptionResult(
            text=result_text,
            is_final=False,
            confidence=info.language_probability if info else 0.0,
            start_time=self.total_audio_processed - buffer_duration,
            end_time=self.total_audio_processed,
            words=words
        )
    
    def close(self):
        self.buffer = np.array([], dtype=np.float32)
        self.context = None


class VADProcessor:
    """Voice Activity Detection to filter silence."""
    
    def __init__(self, threshold: float = 0.5):
        import torch
        self.model, self.utils = torch.hub.load(
            repo_or_dir='snakers4/silero-vad',
            model='silero_vad',
            force_reload=False
        )
        self.threshold = threshold
        self.sample_rate = 16000
        
    def is_speech(self, audio_bytes: bytes) -> bool:
        import torch
        audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
        audio_tensor = torch.from_numpy(audio)
        
        # Get speech probability
        speech_prob = self.model(audio_tensor, self.sample_rate).item()
        return speech_prob > self.threshold


# Global instances (in production, use dependency injection)
vad_processor = VADProcessor()


@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    
    # Initialize ASR session for this connection
    session = ASRSession(model_name="base")
    silence_count = 0
    max_silence_chunks = 10  # Close after ~5 seconds of silence
    
    try:
        while True:
            # 1. Receive audio chunk (bytes)
            audio_chunk = await websocket.receive_bytes()
            
            # 2. VAD Check - skip silence
            if not vad_processor.is_speech(audio_chunk):
                silence_count += 1
                if silence_count >= max_silence_chunks:
                    # Send end-of-stream signal
                    await websocket.send_json({
                        "text": "",
                        "is_final": True,
                        "event": "end_of_speech"
                    })
                continue
            
            silence_count = 0  # Reset on speech
            
            # 3. Process through ASR
            result = await session.process_chunk(audio_chunk)
            
            # 4. Send partial result back
            await websocket.send_json({
                "text": result.text,
                "is_final": result.is_final,
                "confidence": result.confidence,
                "start_time": result.start_time,
                "end_time": result.end_time,
                "words": result.words
            })
            
    except WebSocketDisconnect:
        logger.info("Client disconnected")
    except Exception as e:
        logger.error(f"ASR session error: {e}")
        await websocket.send_json({"error": str(e)})
    finally:
        session.close()


@app.get("/health")
async def health_check():
    return {"status": "healthy", "service": "streaming-asr"}

gRPC Server Implementation

// asr_service.proto
syntax = "proto3";

package asr;

service StreamingASR {
    rpc StreamingRecognize(stream AudioRequest) returns (stream TranscriptionResponse);
    rpc GetSupportedLanguages(Empty) returns (LanguageList);
}

message AudioRequest {
    oneof request {
        StreamingConfig config = 1;
        bytes audio_content = 2;
    }
}

message StreamingConfig {
    string language_code = 1;
    int32 sample_rate_hertz = 2;
    string encoding = 3;  // LINEAR16, FLAC, OGG_OPUS
    bool enable_word_timestamps = 4;
    bool enable_punctuation = 5;
    int32 max_alternatives = 6;
}

message TranscriptionResponse {
    repeated SpeechRecognitionResult results = 1;
    string error = 2;
}

message SpeechRecognitionResult {
    repeated SpeechRecognitionAlternative alternatives = 1;
    bool is_final = 2;
    float stability = 3;
}

message SpeechRecognitionAlternative {
    string transcript = 1;
    float confidence = 2;
    repeated WordInfo words = 3;
}

message WordInfo {
    string word = 1;
    float start_time = 2;
    float end_time = 3;
    float confidence = 4;
}

message Empty {}

message LanguageList {
    repeated string languages = 1;
}
# asr_grpc_server.py
import grpc
from concurrent import futures
import asyncio
from typing import Iterator
import asr_service_pb2 as pb2
import asr_service_pb2_grpc as pb2_grpc

class StreamingASRServicer(pb2_grpc.StreamingASRServicer):
    
    def __init__(self):
        self.model = self._load_model()
        
    def _load_model(self):
        from faster_whisper import WhisperModel
        return WhisperModel("base", device="cuda", compute_type="int8")
    
    def StreamingRecognize(
        self, 
        request_iterator: Iterator[pb2.AudioRequest],
        context: grpc.ServicerContext
    ) -> Iterator[pb2.TranscriptionResponse]:
        
        config = None
        audio_buffer = b""
        
        for request in request_iterator:
            if request.HasField("config"):
                config = request.config
                continue
                
            audio_buffer += request.audio_content
            
            # Process when we have enough audio (e.g., 500ms)
            if len(audio_buffer) >= config.sample_rate_hertz:
                # Convert and transcribe
                import numpy as np
                audio = np.frombuffer(audio_buffer, dtype=np.int16).astype(np.float32) / 32768.0
                
                segments, _ = self.model.transcribe(
                    audio,
                    language=config.language_code[:2] if config.language_code else "en"
                )
                
                for segment in segments:
                    alternative = pb2.SpeechRecognitionAlternative(
                        transcript=segment.text,
                        confidence=segment.avg_logprob
                    )
                    
                    if config.enable_word_timestamps and segment.words:
                        for word in segment.words:
                            alternative.words.append(pb2.WordInfo(
                                word=word.word,
                                start_time=word.start,
                                end_time=word.end,
                                confidence=word.probability
                            ))
                    
                    result = pb2.SpeechRecognitionResult(
                        alternatives=[alternative],
                        is_final=False,
                        stability=0.9
                    )
                    
                    yield pb2.TranscriptionResponse(results=[result])
                
                # Keep overlap for context
                audio_buffer = audio_buffer[-config.sample_rate_hertz // 2:]
    
    def GetSupportedLanguages(self, request, context):
        return pb2.LanguageList(languages=[
            "en-US", "en-GB", "es-ES", "fr-FR", "de-DE", 
            "it-IT", "pt-BR", "ja-JP", "ko-KR", "zh-CN"
        ])


def serve():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    pb2_grpc.add_StreamingASRServicer_to_server(StreamingASRServicer(), server)
    server.add_insecure_port('[::]:50051')
    server.start()
    server.wait_for_termination()

35.2.3. Models: Whisper vs. Conformer vs. Kaldi

Model Comparison

ModelArchitectureLatencyAccuracy (WER)StreamingResource Usage
KaldiWFST + GMM/DNNUltra-lowModerateNativeLow (CPU)
WhisperTransformerHighExcellentAdaptedHigh (GPU)
ConformerConv + TransformerMediumExcellentNativeMedium-High
DeepSpeechRNN (LSTM/GRU)LowGoodNativeMedium
Wav2Vec2TransformerMediumExcellentAdaptedHigh

Kaldi: The Classic Approach

  • Architecture: Uses Weighted Finite State Transducers (WFST). Extremely fast, low CPU.
  • Pros: Battle-tested, CPU-only, microsecond latency.
  • Cons: Complex deployment, steep learning curve, harder to customize.
# Kaldi online2 decoder example
online2-wav-nnet3-latgen-faster \
    --online=true \
    --do-endpointing=true \
    --config=conf/online.conf \
    --max-active=7000 \
    --beam=15.0 \
    --lattice-beam=6.0 \
    --acoustic-scale=1.0 \
    final.mdl \
    graph/HCLG.fst \
    ark:spk2utt \
    scp:wav.scp \
    ark:/dev/null

Whisper: The Modern Standard

  • Architecture: Encoder-Decoder Transformer (680M params for large-v3).
  • Pros: State-of-the-art accuracy, multilingual, robust to noise.
  • Cons: Not natively streaming, high GPU requirements.

Streaming Whisper Implementations:

  1. faster-whisper: CTranslate2 backend with INT8 quantization
  2. whisper.cpp: C/C++ port for edge devices
  3. whisper-streaming: Buffered streaming with LocalAgreement
# faster-whisper streaming implementation
from faster_whisper import WhisperModel
import numpy as np

class StreamingWhisper:
    def __init__(self, model_size: str = "base"):
        self.model = WhisperModel(
            model_size,
            device="cuda",
            compute_type="int8",  # INT8 for 2x speedup
            cpu_threads=4
        )
        self.buffer = np.array([], dtype=np.float32)
        self.min_chunk_size = 16000  # 1 second at 16kHz
        self.overlap_size = 8000     # 0.5 second overlap
        
    def process_chunk(self, audio_chunk: np.ndarray) -> str:
        self.buffer = np.concatenate([self.buffer, audio_chunk])
        
        if len(self.buffer) < self.min_chunk_size:
            return ""
        
        # Transcribe current buffer
        segments, _ = self.model.transcribe(
            self.buffer,
            beam_size=5,
            best_of=5,
            language="en",
            condition_on_previous_text=True,
            vad_filter=True
        )
        
        text = " ".join([s.text for s in segments])
        
        # Keep overlap for context continuity
        self.buffer = self.buffer[-self.overlap_size:]
        
        return text.strip()

Conformer: Best of Both Worlds

The Conformer architecture combines convolutional layers (local patterns) with Transformer attention (global context).

# Using NVIDIA NeMo Conformer for streaming
import nemo.collections.asr as nemo_asr

class ConformerStreaming:
    def __init__(self):
        self.model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
            "nvidia/stt_en_conformer_transducer_large"
        )
        self.model.eval()
        self.model.cuda()
        
    def transcribe_stream(self, audio_chunks):
        """Process audio in streaming mode."""
        # Enable streaming mode
        self.model.encoder.set_streaming_cfg(
            chunk_size=160,  # 10ms chunks at 16kHz
            left_context=32,
            right_context=0   # Causal for streaming
        )
        
        for chunk in audio_chunks:
            # Process each chunk
            with torch.inference_mode():
                logits, logits_len = self.model.encoder(
                    audio_signal=chunk.cuda(),
                    length=torch.tensor([len(chunk)])
                )
                hypotheses = self.model.decoding.rnnt_decoder_predictions_tensor(
                    logits, logits_len
                )
                yield hypotheses[0]

35.2.4. Metrics: Word Error Rate (WER) and Beyond

Word Error Rate (WER)

The standard metric for ASR quality:

$$ WER = \frac{S + D + I}{N} \times 100% $$

Where:

  • S: Substitutions (“Cat” -> “Bat”)
  • D: Deletions (“The Cat” -> “Cat”)
  • I: Insertions (“Cat” -> “The Cat”)
  • N: Total words in reference
# WER calculation implementation
from jiwer import wer, cer
from dataclasses import dataclass
from typing import List

@dataclass
class ASRMetrics:
    wer: float
    cer: float
    substitutions: int
    deletions: int
    insertions: int
    reference_words: int

def calculate_asr_metrics(reference: str, hypothesis: str) -> ASRMetrics:
    """Calculate comprehensive ASR metrics."""
    from jiwer import compute_measures
    
    # Normalize text
    reference = reference.lower().strip()
    hypothesis = hypothesis.lower().strip()
    
    measures = compute_measures(reference, hypothesis)
    
    return ASRMetrics(
        wer=measures['wer'] * 100,
        cer=cer(reference, hypothesis) * 100,
        substitutions=measures['substitutions'],
        deletions=measures['deletions'],
        insertions=measures['insertions'],
        reference_words=len(reference.split())
    )

def batch_evaluate(references: List[str], hypotheses: List[str]) -> dict:
    """Evaluate a batch of transcriptions."""
    total_wer = wer(references, hypotheses)
    
    # Per-sample analysis
    metrics = [
        calculate_asr_metrics(ref, hyp) 
        for ref, hyp in zip(references, hypotheses)
    ]
    
    return {
        "overall_wer": total_wer * 100,
        "mean_wer": sum(m.wer for m in metrics) / len(metrics),
        "median_wer": sorted(m.wer for m in metrics)[len(metrics) // 2],
        "samples_above_10_wer": sum(1 for m in metrics if m.wer > 10),
        "substitution_rate": sum(m.substitutions for m in metrics) / sum(m.reference_words for m in metrics) * 100,
        "deletion_rate": sum(m.deletions for m in metrics) / sum(m.reference_words for m in metrics) * 100,
        "insertion_rate": sum(m.insertions for m in metrics) / sum(m.reference_words for m in metrics) * 100
    }

Real-Time Factor (RTF)

Measures processing speed relative to audio duration:

$$ RTF = \frac{Processing Time}{Audio Duration} $$

  • RTF < 1: Real-time capable
  • RTF < 0.5: Good for streaming (leaves headroom)
  • RTF < 0.1: Excellent, supports batching
import time
import numpy as np

def benchmark_rtf(model, audio_samples: List[np.ndarray], sample_rate: int = 16000) -> dict:
    """Benchmark Real-Time Factor for ASR model."""
    total_audio_duration = 0
    total_processing_time = 0
    
    for audio in audio_samples:
        audio_duration = len(audio) / sample_rate
        total_audio_duration += audio_duration
        
        start_time = time.perf_counter()
        _ = model.transcribe(audio)
        end_time = time.perf_counter()
        
        total_processing_time += (end_time - start_time)
    
    rtf = total_processing_time / total_audio_duration
    
    return {
        "rtf": rtf,
        "is_realtime": rtf < 1.0,
        "throughput_factor": 1.0 / rtf,
        "total_audio_hours": total_audio_duration / 3600,
        "processing_time_hours": total_processing_time / 3600
    }

Streaming-Specific Metrics

MetricDescriptionTarget
First Byte LatencyTime to first partial result< 200ms
Partial WERWER of unstable partials< 30%
Final WERWER of finalized text< 10%
Word Stabilization TimeTime for word to become final< 2s
Endpoint Detection LatencyTime to detect end of utterance< 500ms

35.2.5. Handling Whisper Hallucinations

Whisper is notorious for hallucinating when fed silence or low-quality audio.

Common Hallucination Patterns

  1. “Thank you for watching” - YouTube training data artifact
  2. Repeated phrases - Getting stuck in loops
  3. Language switching - Random multilingual outputs
  4. Phantom speakers - Inventing conversation partners

Mitigation Strategies

from dataclasses import dataclass
from typing import Optional, List
import numpy as np

@dataclass
class HallucinationDetector:
    """Detect and filter ASR hallucinations."""
    
    # Known hallucination phrases
    KNOWN_HALLUCINATIONS = [
        "thank you for watching",
        "thanks for watching",
        "please subscribe",
        "like and subscribe",
        "see you in the next video",
        "don't forget to subscribe",
        "if you enjoyed this video",
    ]
    
    # Repetition detection
    MAX_REPEAT_RATIO = 0.7
    
    # Silence detection threshold
    SILENCE_THRESHOLD = 0.01
    
    def is_hallucination(
        self, 
        text: str, 
        audio: np.ndarray,
        previous_texts: List[str] = None
    ) -> tuple[bool, str]:
        """
        Check if transcription is likely a hallucination.
        Returns (is_hallucination, reason)
        """
        text_lower = text.lower().strip()
        
        # Check 1: Known phrases
        for phrase in self.KNOWN_HALLUCINATIONS:
            if phrase in text_lower:
                return True, f"known_hallucination:{phrase}"
        
        # Check 2: Audio is silence
        if self._is_silence(audio):
            return True, "silence_detected"
        
        # Check 3: Excessive repetition
        if self._has_repetition(text):
            return True, "excessive_repetition"
        
        # Check 4: Exact repeat of previous output
        if previous_texts and text_lower in [t.lower() for t in previous_texts[-3:]]:
            return True, "exact_repeat"
        
        return False, "valid"
    
    def _is_silence(self, audio: np.ndarray) -> bool:
        """Check if audio is effectively silence."""
        rms = np.sqrt(np.mean(audio ** 2))
        return rms < self.SILENCE_THRESHOLD
    
    def _has_repetition(self, text: str) -> bool:
        """Detect word-level repetition."""
        words = text.lower().split()
        if len(words) < 4:
            return False
        
        # Check for bigram repetition
        bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
        unique_bigrams = set(bigrams)
        
        repeat_ratio = 1 - (len(unique_bigrams) / len(bigrams))
        return repeat_ratio > self.MAX_REPEAT_RATIO


class RobustASR:
    """ASR with hallucination filtering."""
    
    def __init__(self, model):
        self.model = model
        self.detector = HallucinationDetector()
        self.history = []
    
    def transcribe(self, audio: np.ndarray) -> Optional[str]:
        # Step 1: Pre-filter with aggressive VAD
        if self._is_low_energy(audio):
            return None
        
        # Step 2: Transcribe
        segments, _ = self.model.transcribe(
            audio,
            no_speech_threshold=0.6,  # More aggressive filtering
            logprob_threshold=-1.0,
            compression_ratio_threshold=2.4,  # Detect repetition
            condition_on_previous_text=False  # Reduce hallucination propagation
        )
        
        text = " ".join(s.text for s in segments).strip()
        
        # Step 3: Post-filter hallucinations
        is_hallucination, reason = self.detector.is_hallucination(
            text, audio, self.history
        )
        
        if is_hallucination:
            return None
        
        # Update history
        self.history.append(text)
        if len(self.history) > 10:
            self.history.pop(0)
        
        return text
    
    def _is_low_energy(self, audio: np.ndarray, threshold: float = 0.005) -> bool:
        return np.sqrt(np.mean(audio ** 2)) < threshold

35.2.6. Voice Activity Detection (VAD) Deep Dive

VAD is the first line of defense against wasted compute and hallucinations.

Silero VAD: Production Standard

import torch
import numpy as np
from typing import List, Tuple

class SileroVAD:
    """Production-ready VAD using Silero."""
    
    def __init__(
        self, 
        threshold: float = 0.5,
        min_speech_duration_ms: int = 250,
        min_silence_duration_ms: int = 100,
        sample_rate: int = 16000
    ):
        self.model, self.utils = torch.hub.load(
            'snakers4/silero-vad',
            'silero_vad',
            trust_repo=True
        )
        self.threshold = threshold
        self.min_speech_samples = int(sample_rate * min_speech_duration_ms / 1000)
        self.min_silence_samples = int(sample_rate * min_silence_duration_ms / 1000)
        self.sample_rate = sample_rate
        
        # State for streaming
        self.reset()
    
    def reset(self):
        """Reset internal state for new stream."""
        self.model.reset_states()
        self._in_speech = False
        self._speech_start = 0
        self._current_position = 0
    
    def process_chunk(self, audio: np.ndarray) -> List[Tuple[int, int]]:
        """
        Process audio chunk and return speech segments.
        Returns list of (start_sample, end_sample) tuples.
        """
        audio_tensor = torch.from_numpy(audio).float()
        
        # Process in 30ms windows
        window_size = int(self.sample_rate * 0.030)
        segments = []
        
        for i in range(0, len(audio), window_size):
            window = audio_tensor[i:i + window_size]
            if len(window) < window_size:
                # Pad final window
                window = torch.nn.functional.pad(window, (0, window_size - len(window)))
            
            speech_prob = self.model(window, self.sample_rate).item()
            
            if speech_prob >= self.threshold:
                if not self._in_speech:
                    self._in_speech = True
                    self._speech_start = self._current_position + i
            else:
                if self._in_speech:
                    self._in_speech = False
                    speech_end = self._current_position + i
                    duration = speech_end - self._speech_start
                    
                    if duration >= self.min_speech_samples:
                        segments.append((self._speech_start, speech_end))
        
        self._current_position += len(audio)
        return segments
    
    def get_speech_timestamps(
        self, 
        audio: np.ndarray
    ) -> List[dict]:
        """Get speech timestamps for entire audio."""
        self.reset()
        segments = self.process_chunk(audio)
        
        return [
            {
                "start": start / self.sample_rate,
                "end": end / self.sample_rate,
                "duration": (end - start) / self.sample_rate
            }
            for start, end in segments
        ]


class EnhancedVAD:
    """VAD with additional features for production."""
    
    def __init__(self):
        self.vad = SileroVAD()
        self.energy_threshold = 0.01
        
    def is_speech_segment(self, audio: np.ndarray) -> dict:
        """Comprehensive speech detection."""
        # Energy check (fast, first filter)
        energy = np.sqrt(np.mean(audio ** 2))
        if energy < self.energy_threshold:
            return {"is_speech": False, "reason": "low_energy", "confidence": 0.0}
        
        # Zero-crossing rate (detect static noise)
        zcr = np.mean(np.abs(np.diff(np.sign(audio))))
        if zcr > 0.5:  # High ZCR often indicates noise
            return {"is_speech": False, "reason": "high_zcr", "confidence": 0.3}
        
        # Neural VAD
        segments = self.vad.get_speech_timestamps(audio)
        
        if not segments:
            return {"is_speech": False, "reason": "vad_reject", "confidence": 0.0}
        
        # Calculate speech ratio
        total_speech = sum(s["duration"] for s in segments)
        audio_duration = len(audio) / 16000
        speech_ratio = total_speech / audio_duration
        
        return {
            "is_speech": True,
            "reason": "speech_detected",
            "confidence": min(speech_ratio * 1.5, 1.0),
            "segments": segments,
            "speech_ratio": speech_ratio
        }

35.2.7. Production Infrastructure: Kubernetes Deployment

AWS EKS Architecture

# asr-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: streaming-asr
  namespace: ml-inference
spec:
  replicas: 3
  selector:
    matchLabels:
      app: streaming-asr
  template:
    metadata:
      labels:
        app: streaming-asr
    spec:
      nodeSelector:
        node.kubernetes.io/instance-type: g5.xlarge
      tolerations:
        - key: "nvidia.com/gpu"
          operator: "Exists"
          effect: "NoSchedule"
      containers:
        - name: asr-server
          image: 123456789.dkr.ecr.us-east-1.amazonaws.com/streaming-asr:v1.2.0
          ports:
            - containerPort: 8000
              name: websocket
            - containerPort: 50051
              name: grpc
          resources:
            requests:
              memory: "8Gi"
              cpu: "2"
              nvidia.com/gpu: "1"
            limits:
              memory: "16Gi"
              cpu: "4"
              nvidia.com/gpu: "1"
          env:
            - name: MODEL_SIZE
              value: "large-v3"
            - name: COMPUTE_TYPE
              value: "int8"
            - name: MAX_CONCURRENT_STREAMS
              value: "50"
          livenessProbe:
            httpGet:
              path: /health
              port: 8000
            initialDelaySeconds: 60
            periodSeconds: 10
          readinessProbe:
            httpGet:
              path: /health
              port: 8000
            initialDelaySeconds: 30
            periodSeconds: 5
          volumeMounts:
            - name: model-cache
              mountPath: /root/.cache/huggingface
      volumes:
        - name: model-cache
          persistentVolumeClaim:
            claimName: model-cache-pvc
---
apiVersion: v1
kind: Service
metadata:
  name: streaming-asr
  namespace: ml-inference
spec:
  selector:
    app: streaming-asr
  ports:
    - name: websocket
      port: 80
      targetPort: 8000
    - name: grpc
      port: 50051
      targetPort: 50051
  type: ClusterIP
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
  name: streaming-asr-ingress
  namespace: ml-inference
  annotations:
    kubernetes.io/ingress.class: "alb"
    alb.ingress.kubernetes.io/scheme: "internet-facing"
    alb.ingress.kubernetes.io/target-type: "ip"
    alb.ingress.kubernetes.io/healthcheck-path: "/health"
    alb.ingress.kubernetes.io/backend-protocol: "HTTP"
    # WebSocket support
    alb.ingress.kubernetes.io/load-balancer-attributes: "idle_timeout.timeout_seconds=3600"
spec:
  rules:
    - host: asr.example.com
      http:
        paths:
          - path: /
            pathType: Prefix
            backend:
              service:
                name: streaming-asr
                port:
                  number: 80

GCP GKE with TPU

# gke-asr-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: streaming-asr-tpu
  namespace: ml-inference
spec:
  replicas: 2
  selector:
    matchLabels:
      app: streaming-asr-tpu
  template:
    metadata:
      labels:
        app: streaming-asr-tpu
    spec:
      nodeSelector:
        cloud.google.com/gke-accelerator: nvidia-l4
      containers:
        - name: asr-server
          image: gcr.io/my-project/streaming-asr:v1.2.0
          ports:
            - containerPort: 8000
          resources:
            requests:
              memory: "8Gi"
              cpu: "4"
              nvidia.com/gpu: "1"
            limits:
              memory: "16Gi"
              nvidia.com/gpu: "1"
          env:
            - name: GOOGLE_CLOUD_PROJECT
              value: "my-project"
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: streaming-asr-hpa
  namespace: ml-inference
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: streaming-asr-tpu
  minReplicas: 2
  maxReplicas: 20
  metrics:
    - type: Resource
      resource:
        name: cpu
        target:
          type: Utilization
          averageUtilization: 70
    - type: Pods
      pods:
        metric:
          name: active_websocket_connections
        target:
          type: AverageValue
          averageValue: "100"

Terraform Infrastructure

# asr_infrastructure.tf

# AWS EKS Node Group for ASR
resource "aws_eks_node_group" "asr_gpu" {
  cluster_name    = aws_eks_cluster.main.name
  node_group_name = "asr-gpu-nodes"
  node_role_arn   = aws_iam_role.eks_node.arn
  subnet_ids      = var.private_subnet_ids

  scaling_config {
    desired_size = 3
    max_size     = 10
    min_size     = 2
  }

  instance_types = ["g5.xlarge"]
  ami_type       = "AL2_x86_64_GPU"
  capacity_type  = "ON_DEMAND"

  labels = {
    workload = "asr-inference"
    gpu      = "true"
  }

  taint {
    key    = "nvidia.com/gpu"
    value  = "true"
    effect = "NO_SCHEDULE"
  }

  tags = {
    Environment = var.environment
    Service     = "streaming-asr"
  }
}

# ElastiCache for session state
resource "aws_elasticache_cluster" "asr_sessions" {
  cluster_id           = "asr-sessions"
  engine               = "redis"
  node_type            = "cache.r6g.large"
  num_cache_nodes      = 2
  parameter_group_name = "default.redis7"
  port                 = 6379
  
  subnet_group_name    = aws_elasticache_subnet_group.main.name
  security_group_ids   = [aws_security_group.redis.id]

  tags = {
    Service = "streaming-asr"
  }
}

# CloudWatch Dashboard for ASR metrics
resource "aws_cloudwatch_dashboard" "asr" {
  dashboard_name = "streaming-asr-metrics"

  dashboard_body = jsonencode({
    widgets = [
      {
        type   = "metric"
        x      = 0
        y      = 0
        width  = 12
        height = 6
        properties = {
          metrics = [
            ["ASR", "ActiveConnections", "Service", "streaming-asr"],
            [".", "TranscriptionsPerSecond", ".", "."],
            [".", "P99Latency", ".", "."]
          ]
          title = "ASR Performance"
          region = var.aws_region
        }
      },
      {
        type   = "metric"
        x      = 12
        y      = 0
        width  = 12
        height = 6
        properties = {
          metrics = [
            ["ASR", "GPU_Utilization", "Service", "streaming-asr"],
            [".", "GPU_Memory", ".", "."]
          ]
          title = "GPU Metrics"
          region = var.aws_region
        }
      }
    ]
  })
}

35.2.8. Speaker Diarization: “Who Said What?”

Transcription is useless for meetings without speaker attribution.

The Diarization Pipeline

graph LR
    A[Audio Input] --> B[VAD]
    B --> C[Speaker Embedding Extraction]
    C --> D[Clustering]
    D --> E[Speaker Assignment]
    E --> F[Merge with ASR]
    F --> G[Labeled Transcript]

pyannote.audio Implementation

from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from dataclasses import dataclass
from typing import List, Dict
import torch

@dataclass
class SpeakerSegment:
    speaker: str
    start: float
    end: float
    text: str = ""

class SpeakerDiarization:
    """Speaker diarization using pyannote.audio."""
    
    def __init__(self, hf_token: str):
        self.pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1",
            use_auth_token=hf_token
        )
        
        # Move to GPU if available
        if torch.cuda.is_available():
            self.pipeline.to(torch.device("cuda"))
    
    def diarize(
        self, 
        audio_path: str,
        num_speakers: int = None,
        min_speakers: int = 1,
        max_speakers: int = 10
    ) -> List[SpeakerSegment]:
        """Run speaker diarization on audio file."""
        
        # Configure speaker count
        if num_speakers:
            diarization = self.pipeline(
                audio_path,
                num_speakers=num_speakers
            )
        else:
            diarization = self.pipeline(
                audio_path,
                min_speakers=min_speakers,
                max_speakers=max_speakers
            )
        
        segments = []
        for turn, _, speaker in diarization.itertracks(yield_label=True):
            segments.append(SpeakerSegment(
                speaker=speaker,
                start=turn.start,
                end=turn.end
            ))
        
        return segments
    
    def merge_with_transcription(
        self,
        diarization_segments: List[SpeakerSegment],
        asr_words: List[Dict]
    ) -> List[SpeakerSegment]:
        """Merge ASR word timestamps with speaker labels."""
        
        result = []
        current_segment = None
        
        for word_info in asr_words:
            word_mid = (word_info["start"] + word_info["end"]) / 2
            
            # Find speaker for this word
            speaker = self._find_speaker_at_time(
                diarization_segments, word_mid
            )
            
            if current_segment is None or current_segment.speaker != speaker:
                if current_segment:
                    result.append(current_segment)
                current_segment = SpeakerSegment(
                    speaker=speaker,
                    start=word_info["start"],
                    end=word_info["end"],
                    text=word_info["word"]
                )
            else:
                current_segment.end = word_info["end"]
                current_segment.text += " " + word_info["word"]
        
        if current_segment:
            result.append(current_segment)
        
        return result
    
    def _find_speaker_at_time(
        self,
        segments: List[SpeakerSegment],
        time: float
    ) -> str:
        """Find which speaker was talking at given time."""
        for segment in segments:
            if segment.start <= time <= segment.end:
                return segment.speaker
        return "UNKNOWN"


# Usage example
async def transcribe_meeting(audio_path: str) -> str:
    # Step 1: Diarization
    diarizer = SpeakerDiarization(hf_token="hf_xxx")
    speaker_segments = diarizer.diarize(audio_path, max_speakers=4)
    
    # Step 2: ASR with word timestamps
    from faster_whisper import WhisperModel
    model = WhisperModel("large-v3", device="cuda")
    segments, _ = model.transcribe(
        audio_path,
        word_timestamps=True
    )
    
    # Collect words
    words = []
    for segment in segments:
        if segment.words:
            words.extend([
                {"word": w.word, "start": w.start, "end": w.end}
                for w in segment.words
            ])
    
    # Step 3: Merge
    labeled_segments = diarizer.merge_with_transcription(speaker_segments, words)
    
    # Format output
    output = []
    for seg in labeled_segments:
        output.append(f"[{seg.speaker}] ({seg.start:.1f}s): {seg.text.strip()}")
    
    return "\n".join(output)

35.2.9. Load Balancing for Stateful Streams

WebSockets are stateful. Traditional round-robin doesn’t work.

Session Affinity Architecture

graph TB
    subgraph "Client Layer"
        C1[Client 1]
        C2[Client 2]
        C3[Client 3]
    end
    
    subgraph "Load Balancer"
        LB[ALB/Nginx]
        SS[(Session Store)]
        LB <--> SS
    end
    
    subgraph "ASR Pool"
        S1[Server 1]
        S2[Server 2]
        S3[Server 3]
    end
    
    C1 --> LB
    C2 --> LB
    C3 --> LB
    
    LB -->|"Session Affinity"| S1
    LB -->|"Session Affinity"| S2
    LB -->|"Session Affinity"| S3

NGINX Configuration for WebSocket Sticky Sessions

# nginx.conf for ASR WebSocket load balancing

upstream asr_backend {
    # IP Hash for session affinity
    ip_hash;
    
    server asr-1.internal:8000 weight=1;
    server asr-2.internal:8000 weight=1;
    server asr-3.internal:8000 weight=1;
    
    # Health checks
    keepalive 32;
}

map $http_upgrade $connection_upgrade {
    default upgrade;
    '' close;
}

server {
    listen 443 ssl http2;
    server_name asr.example.com;
    
    ssl_certificate /etc/nginx/ssl/cert.pem;
    ssl_certificate_key /etc/nginx/ssl/key.pem;
    
    # WebSocket timeout (1 hour for long calls)
    proxy_read_timeout 3600s;
    proxy_send_timeout 3600s;
    
    location /ws/ {
        proxy_pass http://asr_backend;
        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection $connection_upgrade;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        
        # Buffer settings for streaming
        proxy_buffering off;
        proxy_cache off;
    }
    
    location /health {
        proxy_pass http://asr_backend/health;
    }
}

Graceful Shutdown for Scaling

import signal
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
import logging

logger = logging.getLogger(__name__)

class GracefulShutdown:
    """Manage graceful shutdown for WebSocket server."""
    
    def __init__(self):
        self.is_shutting_down = False
        self.active_connections = set()
        self.shutdown_event = asyncio.Event()
        
    def register_connection(self, connection_id: str):
        self.active_connections.add(connection_id)
        
    def unregister_connection(self, connection_id: str):
        self.active_connections.discard(connection_id)
        if self.is_shutting_down and not self.active_connections:
            self.shutdown_event.set()
    
    async def initiate_shutdown(self, timeout: int = 3600):
        """Start graceful shutdown process."""
        logger.info("Initiating graceful shutdown")
        self.is_shutting_down = True
        
        if not self.active_connections:
            return
        
        logger.info(f"Waiting for {len(self.active_connections)} connections to close")
        
        try:
            await asyncio.wait_for(
                self.shutdown_event.wait(),
                timeout=timeout
            )
            logger.info("All connections closed gracefully")
        except asyncio.TimeoutError:
            logger.warning(f"Shutdown timeout, {len(self.active_connections)} connections remaining")

shutdown_manager = GracefulShutdown()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    logger.info("Starting ASR server")
    yield
    # Shutdown
    await shutdown_manager.initiate_shutdown()

app = FastAPI(lifespan=lifespan)

@app.get("/health")
async def health():
    if shutdown_manager.is_shutting_down:
        # Return 503 to stop new connections
        return {"status": "draining", "active_connections": len(shutdown_manager.active_connections)}
    return {"status": "healthy"}

35.2.10. Observability and Monitoring

Prometheus Metrics

from prometheus_client import Counter, Histogram, Gauge, generate_latest
from fastapi import FastAPI, Response

# Metrics definitions
TRANSCRIPTION_REQUESTS = Counter(
    'asr_transcription_requests_total',
    'Total transcription requests',
    ['status', 'language']
)

TRANSCRIPTION_LATENCY = Histogram(
    'asr_transcription_latency_seconds',
    'Transcription latency in seconds',
    ['model_size'],
    buckets=[0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0]
)

ACTIVE_CONNECTIONS = Gauge(
    'asr_active_websocket_connections',
    'Number of active WebSocket connections'
)

AUDIO_PROCESSED = Counter(
    'asr_audio_processed_seconds_total',
    'Total seconds of audio processed',
    ['language']
)

WER_SCORE = Histogram(
    'asr_wer_score',
    'Word Error Rate distribution',
    buckets=[0.01, 0.05, 0.1, 0.15, 0.2, 0.3, 0.5]
)

GPU_MEMORY_USED = Gauge(
    'asr_gpu_memory_used_bytes',
    'GPU memory usage in bytes',
    ['gpu_id']
)


class MetricsCollector:
    """Collect and expose ASR metrics."""
    
    @staticmethod
    def record_transcription(
        language: str,
        latency: float,
        audio_duration: float,
        wer: float = None,
        success: bool = True
    ):
        status = "success" if success else "error"
        TRANSCRIPTION_REQUESTS.labels(status=status, language=language).inc()
        TRANSCRIPTION_LATENCY.labels(model_size="large-v3").observe(latency)
        AUDIO_PROCESSED.labels(language=language).inc(audio_duration)
        
        if wer is not None:
            WER_SCORE.observe(wer)
    
    @staticmethod
    def update_gpu_metrics():
        import pynvml
        pynvml.nvmlInit()
        device_count = pynvml.nvmlDeviceGetCount()
        
        for i in range(device_count):
            handle = pynvml.nvmlDeviceGetHandleByIndex(i)
            mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            GPU_MEMORY_USED.labels(gpu_id=str(i)).set(mem_info.used)


@app.get("/metrics")
async def metrics():
    MetricsCollector.update_gpu_metrics()
    return Response(
        generate_latest(),
        media_type="text/plain"
    )

Grafana Dashboard JSON

{
  "dashboard": {
    "title": "Streaming ASR Monitoring",
    "panels": [
      {
        "title": "Active Connections",
        "type": "stat",
        "targets": [
          {
            "expr": "sum(asr_active_websocket_connections)"
          }
        ]
      },
      {
        "title": "Transcription Latency P99",
        "type": "graph",
        "targets": [
          {
            "expr": "histogram_quantile(0.99, rate(asr_transcription_latency_seconds_bucket[5m]))"
          }
        ]
      },
      {
        "title": "Audio Processed (hours/min)",
        "type": "graph",
        "targets": [
          {
            "expr": "rate(asr_audio_processed_seconds_total[5m]) * 60 / 3600"
          }
        ]
      },
      {
        "title": "GPU Memory Usage",
        "type": "graph",
        "targets": [
          {
            "expr": "asr_gpu_memory_used_bytes / 1024 / 1024 / 1024",
            "legendFormat": "GPU {{gpu_id}}"
          }
        ]
      }
    ]
  }
}

35.2.11. Cost Optimization Strategies

Cloud Cost Comparison

ProviderServiceCost (per hour audio)RTFNotes
AWSTranscribe Streaming$0.024N/AFully managed
GCPSpeech-to-Text$0.024N/AFully managed
AzureSpeech Services$0.016N/ACheaper tier
Self-hosted (g5.xlarge)Whisper Large~$0.0080.3At scale
Self-hosted (g4dn.xlarge)Whisper Base~$0.0020.5Budget option

Spot Instance Strategy

# Spot instance handler for ASR workloads
import boto3
import time

class SpotInterruptionHandler:
    """Handle EC2 Spot interruption for ASR servers."""
    
    def __init__(self):
        self.metadata_url = "http://169.254.169.254/latest/meta-data"
        
    def check_for_interruption(self) -> bool:
        """Check if Spot interruption notice has been issued."""
        import requests
        try:
            response = requests.get(
                f"{self.metadata_url}/spot/instance-action",
                timeout=1
            )
            if response.status_code == 200:
                return True
        except:
            pass
        return False
    
    async def handle_interruption(self, shutdown_manager):
        """Handle Spot interruption gracefully."""
        # 2-minute warning before termination
        await shutdown_manager.initiate_shutdown(timeout=90)
        
        # Persist any necessary state
        # Drain connections to other instances

35.2.12. Summary Checklist for Streaming ASR Operations

Architecture

  • WebSocket/gRPC protocol based on client requirements
  • Session affinity configured in load balancer
  • Graceful shutdown for scaling events

Models

  • VAD (Silero) for silence filtering
  • Streaming-capable ASR (faster-whisper, Conformer)
  • Hallucination detection and filtering

Infrastructure

  • GPU nodes with appropriate instance types
  • Horizontal Pod Autoscaler on connection count
  • Redis for session state (if distributed)

Observability

  • Prometheus metrics for latency, throughput, errors
  • GPU memory and utilization monitoring
  • WER tracking on sampled data

Cost

  • Spot instances for non-critical traffic
  • Model quantization (INT8) for efficiency
  • Aggressive VAD to reduce GPU load

[End of Section 35.2]