Day 12: Checkpointing & Fault Tolerance
Learning Objectives
- Understand why training is non-deterministic and failure-prone
- Learn checkpoint strategies (full, periodic, incremental)
- Build a checkpoint-aware training script
Theory (15 min)
Why Training Fails
Long-running training jobs fail. It's not if, but when:
| Failure | Frequency | Impact |
|---|---|---|
| GPU OOM | Weekly | Lose current step |
| Node failure (spot instance) | Daily | Lose hours of work without checkpoint |
| Loss spikes (NaN) | Monthly | May corrupt entire training state |
| Library version mismatch | Rare | Full restart |
Checkpointing Strategies
1. Full checkpointing (save everything)
Save: model weights + optimizer state + dataloader state + epoch/step + seed
Size: 2x model size (weights + optimizer momentum)
Frequency: every N steps or N minutes
2. Periodic + best-only
Save every 1000 steps (for resume)
Save best validation checkpoint (for use)
3. Ring buffer
Keep last K checkpoints, oldest deleted.
Helps if a bad checkpoint starts loss divergence ā rollback to earlier one.
Spot Instances
Cloud GPUs are expensive. Spot instances cost 60-90% less, but you can be preempted at any time.
Design for preemption:
Train ā checkpoint every 5 min ā preempted ā restore last checkpoint ā continue
This is standard for all major training runs at Anthropic, OpenAI, Meta.
What to Save
checkpoint/epoch-10/
āāā model.pt # model state_dict
āāā optimizer.pt # Adam momentum/variances
āāā scheduler.pt # learning rate schedule position
āāā dataloader.pt # which batch index (deterministic resume)
āāā config.yaml # hyperparameters (restore experiment)
āāā train_state.json # step, epoch, best_metric, rng_state
āāā metrics.csv # loss curve for visualisation
Hands-on (15 min)
Write a Fault-Tolerant Training Script
#!/usr/bin/env python3
"""checkpoint-training.py ā save/resume training from checkpoint."""
import json
import time
import os
from pathlib import Path
# Stub ā Ayva will expand with:
# - Actual model training (simple neural net or LoRA fine-tuning)
# - Optimizer state persistence
# - RNG state save/restore for deterministic resume
# - Integration with MLflow for experiment tracking
# - Handling of preemption signals (SIGTERM handlers)
# - Resume verification (integrity check on restore)
CHECKPOINT_DIR = Path("./checkpoints")
class FaultTolerantTrainer:
def __init__(self, max_epochs=10):
self.max_epochs = max_epochs
self.current_epoch = 0
self.current_step = 0
self.best_loss = float("inf")
self.loss_history = []
def save_checkpoint(self, name="latest"):
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
state = {
"epoch": self.current_epoch,
"step": self.current_step,
"best_loss": self.best_loss,
"loss_history": self.loss_history[-100:],
}
path = CHECKPOINT_DIR / f"{name}.json"
path.write_text(json.dumps(state, indent=2))
print(f"š¾ Saved checkpoint to {path}")
def load_checkpoint(self, name="latest"):
path = CHECKPOINT_DIR / f"{name}.json"
if not path.exists():
print("š No checkpoint found ā starting fresh")
return False
state = json.loads(path.read_text())
self.current_epoch = state["epoch"]
self.current_step = state["step"]
self.best_loss = state["best_loss"]
self.loss_history = state.get("loss_history", [])
print(f"š Resumed from epoch {self.current_epoch}, step {self.current_step}")
return True
def train_epoch(self):
"""Simulate one training epoch."""
print(f"\nEpoch {self.current_epoch + 1}/{self.max_epochs}")
for step in range(5):
self.current_step += 1
loss = max(0.1, 2.0 / (self.current_step * 0.5 + 1))
self.loss_history.append(loss)
if loss < self.best_loss:
self.best_loss = loss
print(f" Step {self.current_step}: loss = {loss:.4f}" +
f" {'ā
best' if loss == self.best_loss else ''}")
time.sleep(0.2)
def run(self):
self.load_checkpoint()
for epoch in range(self.current_epoch, self.max_epochs):
self.current_epoch = epoch
self.train_epoch()
self.save_checkpoint("latest")
# Save best separately
if self.loss_history[-1] <= self.best_loss:
self.save_checkpoint("best")
print("š New best!")
print(f"\nā
Training complete. Best loss: {self.best_loss:.4f}")
trainer = FaultTolerantTrainer(max_epochs=5)
trainer.run()
print("\nSimulating crash and resume...")
# Simulate a crash
trainer2 = FaultTolerantTrainer(max_epochs=10)
trainer2.run()
Questions for Ayva: - What's the optimal checkpoint frequency (steps vs time)? - How to handle NCCL timeout errors in distributed training? - What's the best strategy for spot instance preemption handling?
Key Takeaways
- Training failures are normal ā checkpointing is non-negotiable
- Save enough state to resume exactly (model + optimiser + dataloader + RNG)
- Ring buffer of checkpoints protects against bad-checkpoint divergence
- Spot instance training is standard practice ā checkpoint every 5-10 min