Keyboard shortcuts

Press or to navigate between chapters

Press ? to show this help

Press Esc to hide this help

16.3 Caching: Semantic Caching for LLMs and Beyond

16.3.1 Introduction: The Economics of Inference Caching

The fastest inference is the one you don’t have to run. The cheapest GPU is the one you don’t have to provision. Caching is the ultimate optimization—it turns $0.01 inference costs into $0.0001 cache lookups, a 100x reduction.

The ROI of Caching

Consider a customer support chatbot serving 1 million queries per month:

Without Caching:

  • Model: GPT-4 class (via API or self-hosted)
  • Cost: $0.03 per 1k tokens (input) + $0.06 per 1k tokens (output)
  • Average query: 100 input tokens, 200 output tokens
  • Monthly cost: 1M × ($0.003 + $0.012) = $15,000

With 60% Cache Hit Rate:

  • Cache hits: 600K × $0.00001 (Redis lookup) = $6
  • Cache misses: 400K × $0.015 = $6,000
  • Total: $6,006 (60% reduction)

With 80% Cache Hit Rate (achievable for FAQs):

  • Cache hits: 800K × $0.00001 = $8
  • Cache misses: 200K × $0.015 = $3,000
  • Total: $3,008 (80% reduction)

The ROI is astronomical, especially for conversational AI where users ask variations of the same questions.


16.3.2 Caching Paradigms: Exact vs. Semantic

Traditional web caching is exact match: cache GET /api/users/123 and serve identical responses for identical URLs. ML inference requires a paradigm shift.

The Problem with Exact Match for NLP

Query 1: "How do I reset my password?"
Query 2: "How can I reset my password?"
Query 3: "password reset instructions"

These are semantically identical but lexically different. Exact match caching treats them as three separate queries, wasting 2 LLM calls.

Semantic Caching

Approach: Embed the query into a vector space, then search for semantically similar cached queries.

graph LR
    Query[User Query:<br/>"How to reset password?"]
    Embed[Embedding Model<br/>all-MiniLM-L6-v2]
    Vector[Vector: [0.12, -0.45, ...]]
    VectorDB[(Vector DB<br/>Redis/Qdrant)]
    
    Query-->Embed
    Embed-->Vector
    Vector-->|Similarity Search|VectorDB
    VectorDB-->|Sim > 0.95?|Decision{Hit?}
    Decision-->|Yes|CachedResponse[Return Cached]
    Decision-->|No|LLM[Call LLM]
    LLM-->|Store|VectorDB

Algorithm:

  1. Embed the incoming query using a fast local model (e.g., sentence-transformers/all-MiniLM-L6-v2).
  2. Search the vector database for the top-k most similar cached queries.
  3. Threshold: If max_similarity > 0.95, return the cached response.
  4. Miss: Call the LLM, store the (query_embedding, response) pair.

16.3.3 Implementation: GPTCache

GPTCache is the industry-standard library for semantic caching.

Installation

pip install gptcache
pip install gptcache[onnx]  # For local embedding
pip install redis

Basic Setup

from gptcache import cache
from gptcache.adapter import openai
from gptcache.embedding import Onnx
from gptcache.manager import CacheBase, VectorBase, get_data_manager
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation

# 1. Configure embedding model (runs locally)
onnx_embedding = Onnx()  # Uses all-MiniLM-L6-v2 by default

# 2. Configure vector store (Redis)
redis_vector = VectorBase(
    "redis",
    host="localhost",
    port=6379,
    dimension=onnx_embedding.dimension,  # 384 for MiniLM
    collection="llm_cache"
)

# 3. Configure metadata store (SQLite for development, Postgres for production)
data_manager = get_data_manager(
    data_base=CacheBase("sqlite"),  # Stores query text and response text
    vector_base=redis_vector
)

# 4. Initialize cache
cache.init(
    pre_embedding_func=onnx_embedding.to_embeddings,
    embedding_func=onnx_embedding.to_embeddings,
    data_manager=data_manager,
    similarity_evaluation=SearchDistanceEvaluation(),
    
    # Tuning parameters
    similarity_threshold=0.95  # Require 95% similarity for a hit
)

# 5. Use OpenAI adapter (caching layer)
response = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[
        {"role": "user", "content": "How do I reset my password?"}
    ]
)

print(response.choices[0].message.content)
print(f"Cache hit: {response.get('gptcache', False)}")

On the second call with a similar query:

response = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[
        {"role": "user", "content": "password reset instructions"}
    ]
)

# This will be a cache hit if similarity > 0.95
print(f"Cache hit: {response.get('gptcache', True)}")  # Likely True

Advanced Configuration

Production Setup (PostgreSQL + Redis):

from gptcache.manager import get_data_manager, CacheBase, VectorBase
import os

data_manager = get_data_manager(
    data_base=CacheBase(
        "postgresql",
        sql_url=os.environ['DATABASE_URL']  # postgres://user:pass@host:5432/dbname
    ),
    vector_base=VectorBase(
        "redis",
        host=os.environ['REDIS_HOST'],
        port=6379,
        password=os.environ.get('REDIS_PASSWORD'),
        dimension=384,
        collection="llm_cache_prod"
    )
)

cache.init(
    pre_embedding_func=onnx_embedding.to_embeddings,
    embedding_func=onnx_embedding.to_embeddings,
    data_manager=data_manager,
    similarity_evaluation=SearchDistanceEvaluation(),
    similarity_threshold=0.95,
    
    # Performance tuning
    top_k=5,  # Consider top 5 similar queries
    max_size=1000000,  # Max cache entries
    eviction="LRU"  # Least Recently Used eviction
)

16.3.4 Architecture: Multi-Tier Caching

For high-scale systems (millions of users), a single Redis instance isn’t enough. Implement a tiered strategy.

The L1/L2/L3 Pattern

graph TD
    Request[User Request]
    L1[L1: In-Memory LRU<br/>Python functools.lru_cache<br/>Latency: 0.001ms]
    L2[L2: Redis Cluster<br/>Distributed Cache<br/>Latency: 5-20ms]
    L3[L3: S3/GCS<br/>Large Artifacts<br/>Latency: 200ms]
    LLM[LLM API<br/>Latency: 2000ms]
    
    Request-->|Check|L1
    L1-->|Miss|L2
    L2-->|Miss|L3
    L3-->|Miss|LLM
    
    LLM-->|Store|L3
    L3-->|Store|L2
    L2-->|Store|L1

Implementation:

from functools import lru_cache
import redis
import pickle
import hashlib

# L1: In-process cache (per container/pod)
@lru_cache(maxsize=1000)
def l1_cache(query_hash):
    return None  # Will be populated

# L2: Redis
redis_client = redis.Redis(host='redis-cluster', port=6379)

def get_cached_response(query: str, embedding_func) -> str:
    # Compute query hash
    query_hash = hashlib.sha256(query.encode()).hexdigest()
    
    # L1 check
    result = l1_cache(query_hash)
    if result:
        print("L1 HIT")
        return result
    
    # L2 check (vector search)
    embedding = embedding_func(query)
    similar_queries = vector_db.search(embedding, top_k=1)
    
    if similar_queries and similar_queries[0]['score'] > 0.95:
        print("L2 HIT")
        cached_response = similar_queries[0]['response']
        
        # Populate L1
        l1_cache.__wrapped__(query_hash, cached_response)
        
        return cached_response
    
    # L3 check (for large responses, e.g., generated images)
    s3_key = f"responses/{query_hash}"
    try:
        obj = s3_client.get_object(Bucket='llm-cache', Key=s3_key)
        response = obj['Body'].read().decode()
        print("L3 HIT")
        return response
    except:
        pass
    
    # Cache miss: call LLM
    print("CACHE MISS")
    response = call_llm(query)
    
    # Store in all tiers
    vector_db.insert(embedding, response)
    s3_client.put_object(Bucket='llm-cache', Key=s3_key, Body=response)
    
    return response

16.3.5 Exact Match Caching for Deterministic Workloads

For non-LLM workloads (image generation, video processing), exact match caching is sufficient and simpler.

Use Case: Stable Diffusion Image Generation

If a user requests:

prompt="A sunset on Mars"
seed=42
steps=50
guidance_scale=7.5

The output is deterministic (given the same hardware/drivers). Re-generating it is wasteful.

Implementation with Redis:

import hashlib
import json
import redis

redis_client = redis.Redis(host='localhost', port=6379)

def cache_key(prompt: str, seed: int, steps: int, guidance_scale: float) -> str:
    """
    Generate a deterministic cache key.
    """
    payload = {
        "prompt": prompt,
        "seed": seed,
        "steps": steps,
        "guidance_scale": guidance_scale
    }
    # Sort keys to ensure {"a":1,"b":2} == {"b":2,"a":1}
    canonical_json = json.dumps(payload, sort_keys=True)
    return hashlib.sha256(canonical_json.encode()).hexdigest()

def generate_image(prompt: str, seed: int = 42, steps: int = 50, guidance_scale: float = 7.5):
    key = cache_key(prompt, seed, steps, guidance_scale)
    
    # Check cache
    cached_image = redis_client.get(key)
    if cached_image:
        print("CACHE HIT")
        return cached_image  # Returns bytes (PNG)
    
    # Cache miss: generate
    print("CACHE MISS - Generating...")
    image_bytes = run_stable_diffusion(prompt, seed, steps, guidance_scale)
    
    # Store with 7-day TTL
    redis_client.setex(key, 604800, image_bytes)
    
    return image_bytes

# Usage
image1 = generate_image("A sunset on Mars", seed=42)  # MISS
image2 = generate_image("A sunset on Mars", seed=42)  # HIT (instant)

Cache Eviction Policies

Redis supports multiple eviction policies:

  1. noeviction: Return error when max memory is reached (not recommended)
  2. allkeys-lru: Evict least recently used keys (most common)
  3. volatile-lru: Evict least recently used keys with TTL set
  4. allkeys-lfu: Evict least frequently used keys (better for hot/cold data)

Configuration (redis.conf):

maxmemory 10gb
maxmemory-policy allkeys-lru

16.3.6 Cache Invalidation: The Hard Problem

“There are only two hard things in Computer Science: cache invalidation and naming things.” – Phil Karlton

Problem 1: Model Updates

You deploy model-v2 which generates different responses. Cached responses from model-v1 are now stale.

Solution: Version Namespacing

def cache_key(query: str, model_version: str) -> str:
    payload = {
        "query": query,
        "model_version": model_version
    }
    return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()

# When calling the model
response = get_cached_response(query, model_version="v2.3.1")

When you deploy v2.3.2, the cache key changes, so old responses aren’t served.

Trade-off: You lose the cache on every deployment. For frequently updated models, this defeats the purpose.

Alternative: Dual Write

During a migration period:

  1. Read from both v1 and v2 caches.
  2. Write to v2 cache only.
  3. After 7 days (typical cache TTL), all v1 entries expire naturally.

Problem 2: Fact Freshness (RAG Systems)

A RAG (Retrieval-Augmented Generation) system answers questions based on a knowledge base.

Scenario:

  • User asks: “What is our Q3 revenue?”
  • Document financial-report-q3.pdf is indexed.
  • LLM response is cached.
  • Document is updated (revised earnings).
  • Cached response is now stale.

Solution 1: TTL (Time To Live)

Set a short TTL on cache entries for time-sensitive topics.

redis_client.setex(
    key,
    ttl=86400,  # 24 hours
    value=response
)

Solution 2: Document-Based Invalidation

Tag cache entries with the document IDs they reference.

# When caching
cache_entry = {
    "query": "What is our Q3 revenue?",
    "response": "Our Q3 revenue was $100M",
    "document_ids": ["financial-report-q3.pdf"]
}

redis_client.hset(f"cache:{query_hash}", mapping=cache_entry)
redis_client.sadd(f"doc_index:financial-report-q3.pdf", query_hash)

# When document is updated
def invalidate_document(document_id: str):
    # Find all cache entries referencing this document
    query_hashes = redis_client.smembers(f"doc_index:{document_id}")
    
    # Delete them
    for qh in query_hashes:
        redis_client.delete(f"cache:{qh.decode()}")
    
    # Clear the index
    redis_client.delete(f"doc_index:{document_id}")

16.3.7 Monitoring Cache Performance

Key Metrics

  1. Hit Rate:

    hit_rate = cache_hits / (cache_hits + cache_misses)
    

    Target: > 60% for general chatbots, > 80% for FAQ bots.

  2. Latency Reduction:

    avg_latency_with_cache = (hit_rate × cache_latency) + ((1 - hit_rate) × llm_latency)
    

    Example:

    • Cache latency: 10ms
    • LLM latency: 2000ms
    • Hit rate: 70%
    avg_latency = (0.7 × 10) + (0.3 × 2000) = 7 + 600 = 607ms
    

    vs. without cache: 2000ms (3.3x faster)

  3. Cost Savings:

    monthly_savings = (cache_hits × llm_cost_per_request) - (cache_hits × cache_cost_per_request)
    

Instrumentation

import time
from prometheus_client import Counter, Histogram

cache_hits = Counter('cache_hits_total', 'Total cache hits')
cache_misses = Counter('cache_misses_total', 'Total cache misses')
cache_latency = Histogram('cache_lookup_latency_seconds', 'Cache lookup latency')
llm_latency = Histogram('llm_call_latency_seconds', 'LLM call latency')

def get_response(query: str):
    start = time.time()
    
    # Check cache
    cached = redis_client.get(query)
    
    if cached:
        cache_hits.inc()
        cache_latency.observe(time.time() - start)
        return cached.decode()
    
    cache_misses.inc()
    
    # Call LLM
    llm_start = time.time()
    response = call_llm(query)
    llm_latency.observe(time.time() - llm_start)
    
    # Store in cache
    redis_client.setex(query, 3600, response)
    
    return response

Grafana Dashboard Queries:

# Hit rate
rate(cache_hits_total[5m]) / (rate(cache_hits_total[5m]) + rate(cache_misses_total[5m]))

# Average latency
(rate(cache_lookup_latency_seconds_sum[5m]) + rate(llm_call_latency_seconds_sum[5m])) /
(rate(cache_lookup_latency_seconds_count[5m]) + rate(llm_call_latency_seconds_count[5m]))

16.3.8 Advanced: Proactive Caching

Instead of waiting for a cache miss, predict what users will ask and pre-warm the cache.

Use Case: Documentation Chatbot

Analyze historical queries:

Top 10 queries:
1. "How do I install the SDK?" (452 hits)
2. "What is the API rate limit?" (389 hits)
3. "How to authenticate?" (301 hits)
...

Pre-warm strategy:

import schedule

def prewarm_cache():
    """
    Run nightly to refresh top queries.
    """
    top_queries = get_top_queries_from_analytics(limit=100)
    
    for query in top_queries:
        # Check if cached
        embedding = embed(query)
        cached = vector_db.search(embedding, top_k=1)
        
        if not cached or cached[0]['score'] < 0.95:
            # Generate and store
            response = call_llm(query)
            vector_db.insert(embedding, response)
            print(f"Pre-warmed: {query}")

# Schedule for 2 AM daily
schedule.every().day.at("02:00").do(prewarm_cache)

16.3.9 Security Considerations

Cache Poisoning

An attacker could pollute the cache with malicious responses.

Attack:

  1. Attacker submits: “What is the admin password?”
  2. Cache stores: “The admin password is hunter2”
  3. Legitimate user asks the same question → gets the poisoned response.

Mitigation:

  1. Input Validation: Reject queries with suspicious patterns.
  2. Rate Limiting: Limit cache writes per user/IP.
  3. TTL: Short TTL limits the damage window.
  4. Audit Logging: Log all cache writes with user context.

PII (Personally Identifiable Information)

Cached responses may contain sensitive data.

Example:

Query: "What is my account balance?"
Response: "Your account balance is $5,234.12" (cached)

If cache is shared across users, this leaks data!

Solution: User-Scoped Caching

def cache_key(query: str, user_id: str) -> str:
    payload = {"query": query, "user_id": user_id}
    return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()

This ensures User A’s cached response is never served to User B.


16.3.10 Case Study: Hugging Face’s Inference API

Hugging Face serves millions of inference requests daily. Their caching strategy:

  1. Model-Level Caching: For identical inputs to the same model, serve cached outputs.
  2. Embedding Similarity: For text-generation tasks, use semantic similarity (threshold: 0.98).
  3. Regional Caches: Deploy Redis clusters in us-east-1, eu-west-1, ap-southeast-1 for low latency.
  4. Tiered Storage: Hot cache (Redis, 1M entries) → Warm cache (S3, 100M entries).

Results:

  • 73% hit rate on average.
  • P50 latency reduced from 1200ms to 45ms.
  • Estimated $500k/month savings in compute costs.

16.3.11 Conclusion

Caching is the highest-ROI optimization in ML inference. It requires upfront engineering effort—embedding models, vector databases, invalidation logic—but the returns are extraordinary:

  • 10-100x cost reduction for high-traffic systems.
  • 10-50x latency improvement for cache hits.
  • Scalability: Serve 10x more users without adding GPU capacity.

Best Practices:

  1. Start with exact match for deterministic workloads.
  2. Graduate to semantic caching for NLP/LLMs.
  3. Instrument everything: Hit rate, latency, cost savings.
  4. Plan for invalidation from day one.
  5. Security: User-scoped keys, rate limiting, audit logs.

Master caching, and you’ll build the fastest, cheapest inference systems on the planet.