Day 5: Stateless vs Stateful Inference
Learning Objectives
- Understand why LLM inference is inherently stateful (KV cache)
- Learn how chat state is managed across requests
- Build a conversation state manager
Theory (15 min)
The Core Problem
HTTP is stateless. LLMs need state. Every chat interaction requires:
- Conversation history ā previous turns sent as context
- KV cache ā precomputed attention keys/values for the prefix
- (Optional) User profile ā preferences, facts from prior conversations
Types of State
| State | Size | Persistence | Updates |
|---|---|---|---|
| Conversation history | Variable (tokens) | Per-session | Append-only |
| KV cache | Large (GB per long context) | Ephemeral (memory) | Replaced on each turn |
| User profile | Small (KB) | Persistent | Occasional edits |
How Chat State Works Under the Hood
Turn 1: User: "Hi" āāā¶ [sys prompt] [user: Hi] āāā¶ Compute KV cache āāā¶ "Hello!"
Turn 2: User: "What's ML?" āāā¶ [sys][u:Hi][a:Hello!][u:What's ML?] āāā¶ Recompute KV cache āāā¶ ...
Each turn recomputes the KV cache for the entire conversation prefix. The longer the history, the more prefill compute needed.
Managing State at Scale
Approach 1: Stateless server, state in DB
Server A (any instance)
ā
āāā Read history from Redis/postgres
āāā Build prompt with full context
āāā Send to model
āāā Store new turn
Pros: Any server can serve any request. Easy horizontal scaling. Cons: Must send full history every time (wasted tokens).
Approach 2: Stateful server with KV cache reuse
Server A (pinned session)
ā
āāā KV cache preserved in GPU memory
āāā Only compute new tokens
āāā Append to existing KV cache
Pros: Much faster ā only compute the new tokens. Cons: Server must pin sessions (sticky routing). Memory grows with session length.
Approach 3: Hybrid (what most production systems use)
Router āā¬āā¶ Server A (pinned, warm KV cache) āā Fast path
ā
āāā¶ Fallback: full recompute on any server
Hands-on (15 min)
Build a Conversation State Manager
#!/usr/bin/env python3
"""conversation-state.py ā manage chat history with pruning."""
import json
import time
import hashlib
from dataclasses import dataclass, field
from typing import List
@dataclass
class Message:
role: str # "user" | "assistant" | "system"
content: str
timestamp: float = field(default_factory=time.time)
def token_estimate(self) -> int:
return len(self.content.split()) * 1.3 # rough: 1.3 tokens per word
class ConversationSession:
"""Manages a single conversation's state with automatic pruning."""
def __init__(self, session_id: str, max_tokens: int = 4096, system_prompt: str = ""):
self.session_id = session_id
self.max_tokens = max_tokens
self.messages: List[Message] = []
if system_prompt:
self.messages.append(Message("system", system_prompt))
def add_message(self, role: str, content: str):
self.messages.append(Message(role, content))
self._prune_if_needed()
def _prune_if_needed(self):
"""Remove oldest messages (preserving system prompt) when over limit."""
while self._total_tokens() > self.max_tokens and len(self.messages) > 2:
# Remove oldest non-system message
for i, m in enumerate(self.messages):
if m.role != "system":
self.messages.pop(i)
break
def _total_tokens(self) -> int:
return sum(m.token_estimate() for m in self.messages)
def build_prompt(self) -> str:
"""Format as the prompt seen by the model."""
parts = []
for m in self.messages:
if m.role == "system":
parts.append(f"<system>\n{m.content}\n</system>")
elif m.role == "user":
parts.append(f"<user>\n{m.content}\n</user>")
elif m.role == "assistant":
parts.append(f"<assistant>\n{m.content}\n</assistant>")
return "\n".join(parts)
def summary(self) -> dict:
return {
"session_id": self.session_id,
"messages": len(self.messages),
"total_tokens": int(self._total_tokens()),
"max_tokens": self.max_tokens,
}
class SessionManager:
"""Manages multiple conversations in memory."""
def __init__(self):
self.sessions = {}
def get_or_create(self, session_id: str, system_prompt: str = "") -> ConversationSession:
if session_id not in self.sessions:
self.sessions[session_id] = ConversationSession(session_id, system_prompt=system_prompt)
return self.sessions[session_id]
def remove(self, session_id: str):
self.sessions.pop(session_id, None)
def stats(self) -> dict:
return {"active_sessions": len(self.sessions)}
# Demo
sm = SessionManager()
# Simulate a conversation
session = sm.get_or_create("user-123", "You are a helpful assistant.")
session.add_message("user", "What is an LLM?")
session.add_message("assistant", "An LLM is a large language model...")
session.add_message("user", "How does attention work?")
session.add_message("assistant", "Attention allows the model to weigh different parts of the input...")
print("Session state:")
for m in session.messages:
ts = time.strftime("%H:%M:%S", time.localtime(m.timestamp))
print(f" [{ts}] {m.role:>9}: {m.content[:50]}...")
print(f"\nš Summary: {session.summary()}")
print(f"\nš Full prompt:\n{session.build_prompt()[:200]}...")
Run it:
cd /tmp
python3 conversation-state.py
Extension: Add a save()/load() method that serialises to a JSON file or Redis so state survives a server restart.
Key Takeaways
- LLM inference is inherently stateful (KV cache grows with context)
- Stateless servers: easy to scale, but recompute KV cache every turn
- Stateful servers: fast (reuse KV cache), but need sticky sessions + memory management
- Pruning strategy determines how many conversation turns you retain
- For production, use Redis for session state and pin sessions with sticky routing
References
- KV Cache Explained
- PagedAttention (vLLM) ā how stateful serving works at scale