🧠 AI System Design

Day 12: Checkpointing & Fault Tolerance

šŸ“‚ Data & Training šŸ“– 15 min read Needs expansion

Learning Objectives

  • Understand why training is non-deterministic and failure-prone
  • Learn checkpoint strategies (full, periodic, incremental)
  • Build a checkpoint-aware training script

Theory (15 min)

Why Training Fails

Long-running training jobs fail. It's not if, but when:

Failure Frequency Impact
GPU OOM Weekly Lose current step
Node failure (spot instance) Daily Lose hours of work without checkpoint
Loss spikes (NaN) Monthly May corrupt entire training state
Library version mismatch Rare Full restart

Checkpointing Strategies

1. Full checkpointing (save everything)

Save: model weights + optimizer state + dataloader state + epoch/step + seed
Size: 2x model size (weights + optimizer momentum)
Frequency: every N steps or N minutes

2. Periodic + best-only

Save every 1000 steps (for resume)
Save best validation checkpoint (for use)

3. Ring buffer

Keep last K checkpoints, oldest deleted.
Helps if a bad checkpoint starts loss divergence — rollback to earlier one.

Spot Instances

Cloud GPUs are expensive. Spot instances cost 60-90% less, but you can be preempted at any time.

Design for preemption:

Train → checkpoint every 5 min → preempted → restore last checkpoint → continue

This is standard for all major training runs at Anthropic, OpenAI, Meta.

What to Save

checkpoint/epoch-10/
ā”œā”€ā”€ model.pt           # model state_dict
ā”œā”€ā”€ optimizer.pt       # Adam momentum/variances
ā”œā”€ā”€ scheduler.pt       # learning rate schedule position
ā”œā”€ā”€ dataloader.pt      # which batch index (deterministic resume)
ā”œā”€ā”€ config.yaml        # hyperparameters (restore experiment)
ā”œā”€ā”€ train_state.json   # step, epoch, best_metric, rng_state
└── metrics.csv        # loss curve for visualisation

Hands-on (15 min)

Write a Fault-Tolerant Training Script

#!/usr/bin/env python3
"""checkpoint-training.py — save/resume training from checkpoint."""
import json
import time
import os
from pathlib import Path

# Stub — Ayva will expand with:
# - Actual model training (simple neural net or LoRA fine-tuning)
# - Optimizer state persistence
# - RNG state save/restore for deterministic resume
# - Integration with MLflow for experiment tracking
# - Handling of preemption signals (SIGTERM handlers)
# - Resume verification (integrity check on restore)

CHECKPOINT_DIR = Path("./checkpoints")

class FaultTolerantTrainer:
    def __init__(self, max_epochs=10):
        self.max_epochs = max_epochs
        self.current_epoch = 0
        self.current_step = 0
        self.best_loss = float("inf")
        self.loss_history = []

    def save_checkpoint(self, name="latest"):
        CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
        state = {
            "epoch": self.current_epoch,
            "step": self.current_step,
            "best_loss": self.best_loss,
            "loss_history": self.loss_history[-100:],
        }
        path = CHECKPOINT_DIR / f"{name}.json"
        path.write_text(json.dumps(state, indent=2))
        print(f"šŸ’¾ Saved checkpoint to {path}")

    def load_checkpoint(self, name="latest"):
        path = CHECKPOINT_DIR / f"{name}.json"
        if not path.exists():
            print("šŸ“­ No checkpoint found — starting fresh")
            return False
        state = json.loads(path.read_text())
        self.current_epoch = state["epoch"]
        self.current_step = state["step"]
        self.best_loss = state["best_loss"]
        self.loss_history = state.get("loss_history", [])
        print(f"šŸ“‚ Resumed from epoch {self.current_epoch}, step {self.current_step}")
        return True

    def train_epoch(self):
        """Simulate one training epoch."""
        print(f"\nEpoch {self.current_epoch + 1}/{self.max_epochs}")
        for step in range(5):
            self.current_step += 1
            loss = max(0.1, 2.0 / (self.current_step * 0.5 + 1))
            self.loss_history.append(loss)
            if loss < self.best_loss:
                self.best_loss = loss
            print(f"  Step {self.current_step}: loss = {loss:.4f}" +
                  f" {'āœ… best' if loss == self.best_loss else ''}")
            time.sleep(0.2)

    def run(self):
        self.load_checkpoint()

        for epoch in range(self.current_epoch, self.max_epochs):
            self.current_epoch = epoch
            self.train_epoch()
            self.save_checkpoint("latest")

            # Save best separately
            if self.loss_history[-1] <= self.best_loss:
                self.save_checkpoint("best")
                print("šŸ† New best!")

        print(f"\nāœ… Training complete. Best loss: {self.best_loss:.4f}")


trainer = FaultTolerantTrainer(max_epochs=5)
trainer.run()

print("\nSimulating crash and resume...")
# Simulate a crash
trainer2 = FaultTolerantTrainer(max_epochs=10)
trainer2.run()

Questions for Ayva: - What's the optimal checkpoint frequency (steps vs time)? - How to handle NCCL timeout errors in distributed training? - What's the best strategy for spot instance preemption handling?


Key Takeaways

  • Training failures are normal — checkpointing is non-negotiable
  • Save enough state to resume exactly (model + optimiser + dataloader + RNG)
  • Ring buffer of checkpoints protects against bad-checkpoint divergence
  • Spot instance training is standard practice — checkpoint every 5-10 min

References