🧠 AI System Design

Day 3: Caching Strategies

šŸ“‚ Foundations šŸ“– 15 min read Ready

Learning Objectives

  • Understand why caching is critical in AI systems (high inference cost)
  • Distinguish response caching, semantic caching, and KV-cache reuse
  • Build an LRU response cache for your inference endpoint

Theory (15 min)

Inference is expensive — both in latency and tokens/compute. Caching is the cheapest way to reduce both.

Level 1: Response Cache (Exact Match)

The simplest cache: if someone asks the exact same prompt, return the exact same answer.

Client: "What is attention?"
Cache: HIT → return stored response (0ms inference)
Cache: MISS → forward to model, store response

Best for: Repetitive queries, system prompts, health checks, retries. Hit rate: Low (5-15%) unless queries are very repetitive. Storage: {prompt_hash → (response_text, timestamp, token_count)}

Level 2: Semantic Cache (Similarity Match)

Cache by meaning, not exact text. Two prompts that mean the same thing get the same answer.

Query: "Explain transformer attention"
Query: "How does attention work in transformers?"  → SAME CACHED ANSWER (cosine > 0.92)

How it works: 1. Embed the query (small, fast embedding model) 2. Search vector DB for similar cached queries 3. If cosine similarity > threshold (e.g. 0.92), return cached response 4. Otherwise, generate and cache

Best for: FAQ bots, customer support, internal knowledge bases. Hit rate: 30-60% depending on threshold. Storage: Vector DB + response store.

Level 3: KV-Cache Reuse (System prompt prefix)

When the first N tokens of every request are the same (system prompt, few-shot examples, RAG context), reuse the computed key-value cache for those tokens.

All requests share: "You are a helpful assistant. Answer concisely."
↓
Compute KV cache once for the prefix, append only the new user query
↓
Saves 30-70% on prefill compute

Best for: Applications with long shared prefixes (agents, RAG, system prompts). Storage: GPU memory / host memory (large — 2 bytes Ɨ layers Ɨ heads Ɨ dim Ɨ prefix_len).

Cache Comparison

Type Hit Detection Storage Savings Complexity
Response (exact) Hash lookup RAM / Redis 100% inference Very low
Semantic (similar) Vector search Vector DB 100% inference Medium
KV-cache (prefix) Prefix match GPU memory Prefill compute only High

The Cache Stack in Practice

Production AI systems use all three in layers:

Request → Response Cache (fast) → Semantic Cache (medium) → Model (slow)
                                       ↓ fallback             ↓ update caches

Hands-on (15 min)

Add an LRU Response Cache

#!/usr/bin/env python3
"""lru-inference-cache.py — LRU cache in front of the LLM."""
import hashlib
import time
import json
import httpx
from collections import OrderedDict

LLM_URL = "http://localhost:8080/v1/completions"

class LRUInferenceCache:
    def __init__(self, capacity: int = 128):
        self.cache = OrderedDict()
        self.capacity = capacity
        self.hits = 0
        self.misses = 0

    def _key(self, prompt: str, max_tokens: int, temperature: float) -> str:
        return hashlib.sha256(
            f"{prompt}|{max_tokens}|{temperature}".encode()
        ).hexdigest()

    def get(self, prompt: str, max_tokens: int = 100, temperature: float = 0.0):
        key = self._key(prompt, max_tokens, temperature)

        if key in self.cache:
            self.cache.move_to_end(key)  # LRU refresh
            self.hits += 1
            return self.cache[key]["response"]

        self.misses += 1
        response = self._call_llm(prompt, max_tokens, temperature)

        if len(self.cache) >= self.capacity:
            self.cache.popitem(last=False)  # evict LRU

        self.cache[key] = {"response": response, "prompt": prompt}
        return response

    def _call_llm(self, prompt: str, max_tokens: int, temperature: float) -> str:
        try:
            resp = httpx.post(LLM_URL, json={
                "prompt": prompt,
                "max_tokens": min(max_tokens, 50),
                "temperature": temperature,
            }, timeout=30)
            return resp.json()["choices"][0]["text"]
        except Exception as e:
            return f"[error: {e}]"

    def stats(self):
        total = self.hits + self.misses
        hit_rate = self.hits / total * 100 if total else 0
        return {
            "hits": self.hits,
            "misses": self.misses,
            "hit_rate": f"{hit_rate:.1f}%",
            "cache_size": len(self.cache),
            "capacity": self.capacity,
        }

# Demo
cache = LRUInferenceCache(capacity=16)

queries = [
    "What is a transformer?",
    "Explain batching.",
    "What is a transformer?",   # duplicate
    "Define attention.",
    "Define attention.",        # duplicate
    "What is a transformer?",   # duplicate
    "Explain caching.",
    "Define attention.",        # duplicate
]

print("Testing LRU cache...")
for q in queries:
    start = time.time()
    result = cache.get(q)
    elapsed = time.time() - start
    marker = "🟢 CACHE" if elapsed < 0.1 else "šŸ”“ MODEL"
    print(f"  {marker} [{elapsed*1000:.0f}ms] {q}")

print(f"\nšŸ“Š Cache stats: {cache.stats()}")

Run it:

cd /tmp
python3 lru-inference-cache.py

Observations: - First calls take ~1s+ (model inference) - Duplicate calls take <1ms (cache hit) - Hit rate should be ~50% with the duplicate pattern above

Extension (if time): Add TTL expiry so cached responses expire after N minutes.


Key Takeaways

  • Caching is the highest-ROI optimisation: zero inference cost on hits
  • Three levels of caching serve different purposes: exact, semantic, prefix
  • LRU eviction works well for most workloads
  • Semantic caching (via embeddings) catches paraphrased queries you'd miss with exact match

References