🧠 AI System Design

Day 5: Stateless vs Stateful Inference

šŸ“‚ Foundations šŸ“– 15 min read Ready

Learning Objectives

  • Understand why LLM inference is inherently stateful (KV cache)
  • Learn how chat state is managed across requests
  • Build a conversation state manager

Theory (15 min)

The Core Problem

HTTP is stateless. LLMs need state. Every chat interaction requires:

  1. Conversation history — previous turns sent as context
  2. KV cache — precomputed attention keys/values for the prefix
  3. (Optional) User profile — preferences, facts from prior conversations

Types of State

State Size Persistence Updates
Conversation history Variable (tokens) Per-session Append-only
KV cache Large (GB per long context) Ephemeral (memory) Replaced on each turn
User profile Small (KB) Persistent Occasional edits

How Chat State Works Under the Hood

Turn 1: User: "Hi" ──▶ [sys prompt] [user: Hi] ──▶ Compute KV cache ──▶ "Hello!"
Turn 2: User: "What's ML?" ──▶ [sys][u:Hi][a:Hello!][u:What's ML?] ──▶ Recompute KV cache ──▶ ...

Each turn recomputes the KV cache for the entire conversation prefix. The longer the history, the more prefill compute needed.

Managing State at Scale

Approach 1: Stateless server, state in DB

Server A (any instance)
  │
  └── Read history from Redis/postgres
  └── Build prompt with full context
  └── Send to model
  └── Store new turn

Pros: Any server can serve any request. Easy horizontal scaling. Cons: Must send full history every time (wasted tokens).

Approach 2: Stateful server with KV cache reuse

Server A (pinned session)
  │
  └── KV cache preserved in GPU memory
  └── Only compute new tokens
  └── Append to existing KV cache

Pros: Much faster — only compute the new tokens. Cons: Server must pin sessions (sticky routing). Memory grows with session length.

Approach 3: Hybrid (what most production systems use)

Router ─┬─▶ Server A (pinned, warm KV cache) ── Fast path
         │
         └─▶ Fallback: full recompute on any server

Hands-on (15 min)

Build a Conversation State Manager

#!/usr/bin/env python3
"""conversation-state.py — manage chat history with pruning."""
import json
import time
import hashlib
from dataclasses import dataclass, field
from typing import List

@dataclass
class Message:
    role: str      # "user" | "assistant" | "system"
    content: str
    timestamp: float = field(default_factory=time.time)

    def token_estimate(self) -> int:
        return len(self.content.split()) * 1.3  # rough: 1.3 tokens per word

class ConversationSession:
    """Manages a single conversation's state with automatic pruning."""

    def __init__(self, session_id: str, max_tokens: int = 4096, system_prompt: str = ""):
        self.session_id = session_id
        self.max_tokens = max_tokens
        self.messages: List[Message] = []
        if system_prompt:
            self.messages.append(Message("system", system_prompt))

    def add_message(self, role: str, content: str):
        self.messages.append(Message(role, content))
        self._prune_if_needed()

    def _prune_if_needed(self):
        """Remove oldest messages (preserving system prompt) when over limit."""
        while self._total_tokens() > self.max_tokens and len(self.messages) > 2:
            # Remove oldest non-system message
            for i, m in enumerate(self.messages):
                if m.role != "system":
                    self.messages.pop(i)
                    break

    def _total_tokens(self) -> int:
        return sum(m.token_estimate() for m in self.messages)

    def build_prompt(self) -> str:
        """Format as the prompt seen by the model."""
        parts = []
        for m in self.messages:
            if m.role == "system":
                parts.append(f"<system>\n{m.content}\n</system>")
            elif m.role == "user":
                parts.append(f"<user>\n{m.content}\n</user>")
            elif m.role == "assistant":
                parts.append(f"<assistant>\n{m.content}\n</assistant>")
        return "\n".join(parts)

    def summary(self) -> dict:
        return {
            "session_id": self.session_id,
            "messages": len(self.messages),
            "total_tokens": int(self._total_tokens()),
            "max_tokens": self.max_tokens,
        }


class SessionManager:
    """Manages multiple conversations in memory."""

    def __init__(self):
        self.sessions = {}

    def get_or_create(self, session_id: str, system_prompt: str = "") -> ConversationSession:
        if session_id not in self.sessions:
            self.sessions[session_id] = ConversationSession(session_id, system_prompt=system_prompt)
        return self.sessions[session_id]

    def remove(self, session_id: str):
        self.sessions.pop(session_id, None)

    def stats(self) -> dict:
        return {"active_sessions": len(self.sessions)}


# Demo
sm = SessionManager()

# Simulate a conversation
session = sm.get_or_create("user-123", "You are a helpful assistant.")
session.add_message("user", "What is an LLM?")
session.add_message("assistant", "An LLM is a large language model...")
session.add_message("user", "How does attention work?")
session.add_message("assistant", "Attention allows the model to weigh different parts of the input...")

print("Session state:")
for m in session.messages:
    ts = time.strftime("%H:%M:%S", time.localtime(m.timestamp))
    print(f"  [{ts}] {m.role:>9}: {m.content[:50]}...")

print(f"\nšŸ“Š Summary: {session.summary()}")
print(f"\nšŸ“‹ Full prompt:\n{session.build_prompt()[:200]}...")

Run it:

cd /tmp
python3 conversation-state.py

Extension: Add a save()/load() method that serialises to a JSON file or Redis so state survives a server restart.


Key Takeaways

  • LLM inference is inherently stateful (KV cache grows with context)
  • Stateless servers: easy to scale, but recompute KV cache every turn
  • Stateful servers: fast (reuse KV cache), but need sticky sessions + memory management
  • Pruning strategy determines how many conversation turns you retain
  • For production, use Redis for session state and pin sessions with sticky routing

References