§ 5 · Optimization for ML
Convex vs non-convex loss landscapes, gradient descent variants (SGD → Momentum → Adam → AdamW → Lion), learning-rate schedules (cosine, WSD), second-order intuition (K-FAC, Shampoo), and implicit regularisation — everything that drives a 7B parameter model from random weights to a working language model.
1. Overview
Optimization is the engine of ML: given a differentiable loss function over billions of parameters, it finds settings that make predictions accurate. For LLMs the loss is negative log-likelihood (NLL) on next-token prediction, accumulated across trillions of tokens. Every design decision — optimizer choice, learning-rate schedule, gradient clipping threshold, weight-decay coefficient — directly determines whether a 70B model converges in 10k GPU-hours or diverges spectacularly at step 50,000. The diagram below shows where optimization sits in the training loop.
2. Key Concepts
Loss Landscape Topology
The loss surface is a function from the parameter space ℝᵈ (d ≈ 7×10⁹ for LLaMA-7B) to a scalar. Classical optimization theory assumes convexity (one global minimum, gradient always helpful). Deep learning breaks every convexity assumption — yet empirically, SGD-family methods find excellent solutions. The key insight from Goodfellow et al. (2015) and Li et al. (2018): high-dimensional non-convex surfaces have far fewer bad local minima than low-dimensional ones, and near-zero-gradient points are overwhelmingly saddle points, not local minima.
Condition number κ = λ_max / λ_min ratio of largest to smallest curvature
κ >> 1 → ill-conditioned → slow convergence
For a quadratic f(x) = x^T A x / 2:
SGD convergence rate = (1 - 2α/(λ_max + λ_min))^t → geometric in t
Optimal step: α* = 2 / (λ_max + λ_min)
Best convergence rate: (κ-1)/(κ+1) → slow when κ >> 1
Transformers: embedding rows have κ ~ 10^4 vs attention weights κ ~ 10^2
→ per-param adaptive LR (Adam) crucial; uniform SGD LR cannot handle bothOptimizer Family
Every modern optimizer descends from plain SGD by adding one or two ideas: momentum (accumulate velocity to accelerate consistent directions), adaptive per-parameter learning rates (slow down for high-variance params, speed up for low-variance ones), or second-order curvature (precondition by the inverse Hessian or Fisher information). The tree below maps the genealogy.
| Optimizer | Hyperparams | Optimizer state / param | Used by |
|---|---|---|---|
| SGD | lr | 0 scalars | Image training (ResNet, ViT); rarely for LLMs |
| SGD+Momentum | lr, β=0.9 | 1 scalar (velocity v) | CV training; Muon uses this variant |
| Adam | lr, β₁=0.9, β₂=0.999, ε=1e-8 | 2 scalars (m, v) + step counter | BERT, T5, GPT-2; baseline for most research |
| AdamW | lr, β₁, β₂, ε, wd=0.1 | 2 scalars (same as Adam) | LLaMA, Mistral, GPT-3/4, Falcon — industry default |
| Lion | lr, β₁=0.9, β₂=0.99, wd | 1 scalar (m only — no v!) | DeepMind / Google experiments; 1.5× memory cheaper |
| Shampoo | lr, β, update_freq | L + R matrices per layer (large!) | Google (JAX), Meta (Distributed Shampoo) |
Momentum and Nesterov Acceleration
Plain SGD in a ravine (loss curves steeply across one dimension, shallowly along the other) oscillates back and forth across the steep direction while crawling along the shallow direction. Momentum damps oscillations by accumulating a velocity vector — consistent gradient signals build up speed while opposing signals cancel. Nesterov adds a look-ahead correction: it peeks at where momentum would take the parameter next, evaluates the gradient there, and uses that as the corrective signal, which provides better convergence on convex functions.
Momentum update (standard form):
v_t = β v_{t-1} + ∇L(θ_t) # β=0.9 typical
θ_{t+1} = θ_t - α v_t
Nesterov (equivalent re-parameterisation, used in PyTorch nesterov=True):
θ_look = θ_t + β v_{t-1} # look-ahead position
v_t = β v_{t-1} + ∇L(θ_look) # gradient at look-ahead
θ_{t+1} = θ_t - α v_t
Convergence on convex f with L-smooth and μ-strongly convex:
Momentum: O((1 - √(μ/L))^t) (same as GD, just faster constant)
Nesterov: O((1 - √(μ/L))^t) (optimal first-order rate — Nesterov 1983)Adam — Adaptive Moment Estimation
Adam (Kingma & Ba, 2014) combines momentum with adaptive per-parameter learning rates. The key idea: maintain a running estimate of each parameter's gradient magnitude (second moment), and divide the gradient by that estimate. Parameters with historically large gradients get a smaller effective step; parameters with small gradients get a larger step. This is critical for transformers where embedding parameters have gradients orders of magnitude larger than, say, layer-norm bias terms.
Adam algorithm (Kingma & Ba 2014, Algorithm 1):
Hyperparams: α (lr), β₁=0.9, β₂=0.999, ε=1e-8
Init: m₀ = v₀ = 0
For each step t = 1, 2, ...:
g_t ← ∇L(θ_t) # gradient (same shape as θ)
m_t ← β₁ m_{t-1} + (1-β₁) g_t # biased 1st moment
v_t ← β₂ v_{t-1} + (1-β₂) g_t² # biased 2nd moment (element-wise g²)
m̂_t ← m_t / (1 - β₁ᵗ) # bias-corrected 1st moment
v̂_t ← v_t / (1 - β₂ᵗ) # bias-corrected 2nd moment
θ_{t+1} ← θ_t - α m̂_t / (√v̂_t + ε) # update (element-wise)
Memory: 2 fp32 tensors of shape θ (m and v) + 1 int (t)
→ At 7B params × 4 bytes × 3 copies (θ, m, v) = 84 GB at minimumAdamW: Decoupled Weight Decay
Loshchilov & Hutter (2017) showed that adding L2 regularisation to Adam is wrong. Standard Adam with L2 adds λθ to the gradient before the adaptive step, which means high-gradient parameters get less weight decay than low-gradient ones (the adaptive divisor shrinks both the gradient and the L2 term). AdamW applies weight decay directly to the parameters, bypassing the adaptive scaling, giving uniform regularisation:
Adam + L2: g_t ← ∇L(θ_t) + λ θ_t # wrong: WD gets scaled by 1/√v̂
θ_{t+1} ← θ_t - α m̂_t / (√v̂_t + ε)
AdamW: g_t ← ∇L(θ_t) # gradient without WD
θ_{t+1} ← (1 - α λ) θ_t - α m̂_t / (√v̂_t + ε)
# ↑ decoupled WD ↑ adaptive gradient step
Typical: λ (wd) = 0.1 → shrinks weights by 0.1% per step at lr=1e-3
Not applied to: bias terms, LayerNorm/RMSNorm scale/shift (common practice)Learning-Rate Schedules
No fixed learning rate works well across all training phases. Too high at the start → divergence before the model learns anything useful. Too high at the end → the model oscillates near the minimum rather than settling. LR schedules divide training into phases: a warmup ramp, a sustained high-LR phase, and a decay tail.
| Schedule | Formula (post-warmup) | Properties | Used by |
|---|---|---|---|
| Cosine | lr_min + (lr_max−lr_min)·½·(1+cos(πt/T)) | Smooth; no kinks; most common | GPT-3, LLaMA-1/2/3, Mistral-7B |
| WSD | Flat at lr_max → linear decay in tail | Checkpoint-friendly; easy restart | MiniCPM, Falcon-180B, Mistral-v0.2 |
| Linear | lr_max·(1 − t/T) + lr_min·t/T | Simple; fewer area under curve | Fine-tuning; short-run experiments |
| Cosine restarts | Cosine within each cycle; T_i × mult at restart | Escapes local minima; SGDR | Some fine-tuning; curriculum learning |
Second-Order Methods
First-order methods (SGD, Adam) use only the gradient — the first derivative of the loss. Second-order methods additionally use the Hessian H = ∂²L/∂θ² or its approximation. The natural gradient direction F⁻¹g (where F is the Fisher information matrix ≈ H near the minimum) is invariant to reparameterisation and converges in far fewer steps. The practical barrier: for a 7B model, H is a 7B×7B matrix (5×10¹⁹ entries) — completely intractable.
| Method | Approximation | Memory | Benefit |
|---|---|---|---|
| Newton | Exact H⁻¹g | O(d²) | Quadratic convergence; totally impractical for LLMs |
| K-FAC | F ≈ A⊗B per layer (Kronecker) | O(d_in² + d_out²) | 2–5× fewer steps; tractable for mid-size models |
| Shampoo | Full matrix preconditioner (no Fisher) | O(d_in² + d_out²) per layer | Used in production at Google; Distributed Shampoo at Meta |
| Muon | Orthogonalise grad via Newton-Schulz | 1 momentum buffer (like SGD) | Cheap approx of Shampoo; gaining traction in 2024 |
Shampoo per-layer update (simplified):
G_t ← layer gradient [d_out × d_in]
L_t ← L_{t-1} + G_t G_tᵀ [d_out × d_out] left Kronecker factor
R_t ← R_{t-1} + G_tᵀ G_t [d_in × d_in] right Kronecker factor
Preconditioned step ← L_t^{-¼} G_t R_t^{-¼}
Why 4th root? The Kronecker product L⊗R has eigenvalues λ_L·λ_R, so
(L⊗R)^{-½} ≈ L^{-¼}⊗R^{-¼}. Inversion every k=100 steps to amortise cost.3. Core Mechanism: Adam Step-by-Step
Background
Imagine training a 7B-parameter LLM with plain SGD at a fixed learning rate α = 1e-3. The token embedding table receives gradient spikes when rare tokens appear (large, sparse gradients), while the layer-norm scale parameters receive tiny but consistent gradients. A single α that prevents embedding divergence will make layer-norm progress infinitesimally slowly — and vice versa. Adam solves this by giving each parameter its own effective learning rate, automatically discovered from gradient history.
Plan
- Maintain m_t = exponential moving average of past gradients (bias-corrected by dividing by 1−β₁ᵗ). This is the “direction to move” signal — smoother than raw gradient due to averaging.
- Maintain v_t = exponential moving average of past squared gradients (bias-corrected by 1−β₂ᵗ). This estimates the gradient's typical magnitude per parameter.
- Divide m̂_t by √v̂_t to get a normalised update. The effective per-parameter step size is approximately α regardless of gradient scale — that's the adaptive property.
- Add ε = 1e-8 to prevent division by zero for parameters that receive zero gradient (e.g., embeddings for tokens not in the current batch).
Walkthrough — Single Parameter, 5 Steps
Initial conditions: one scalar parameter θ₀ = 2.0, loss = θ²/2 (parabola), gradient = θ (true minimum at θ=0). Adam hyperparams: α = 0.1, β₁ = 0.9, β₂ = 0.999, ε = 1e-8. Track m, v, m̂, v̂, and the resulting update per step.
Init: θ=2.0 m=0 v=0 t=1: g=2.0 m₁ = 0.9·0 + 0.1·2.0 = 0.2 v₁ = 0.999·0 + 0.001·4.0 = 0.004 m̂₁ = 0.2/0.1 = 2.0 v̂₁ = 0.004/0.001 = 4.0 step = 0.1 · 2.0 / (√4.0 + ε) ≈ 0.100 θ₁ = 2.0 - 0.100 = 1.900 t=2: g=1.9 m₂ = 0.9·0.2 + 0.1·1.9 = 0.370 v₂ = 0.999·0.004 + 0.001·3.61 = 0.007597 m̂₂ = 0.370/0.19 = 1.947 v̂₂ = 0.007597/0.001999 = 3.800 step = 0.1 · 1.947 / (√3.800 + ε) ≈ 0.099 θ₂ = 1.900 - 0.099 = 1.801 t=3: g=1.8 m₃ = 0.9·0.370+0.1·1.8 = 0.513 v₃ = 0.999·0.00760+0.001·3.24 = 0.010840 m̂₃ = 0.513/0.271 = 1.892 v̂₃ = 0.010840/0.002997 = 3.617 step ≈ 0.099 θ₃ ≈ 1.702 ...continuing, each step ≈ α=0.1 regardless of the shrinking gradient magnitude. Key observation: as g shrinks (param approaches 0), v̂ also shrinks → m̂/√v̂ stays ~1. Adam takes constant step size α until it is very close to the minimum — much faster than SGD where step size ∝ g → shrinks automatically and stalls.
4. Minimal Demos
Demo 1 — Optimizer Race on Rosenbrock
The Rosenbrock “banana” function f(x,y) = (1−x)² + 100(y−x²)² has a narrow curved valley leading to a minimum at (1,1). It is a classic benchmark because SGD oscillates across the valley while adaptive methods navigate it efficiently. Enter steps lr (e.g., 500 0.001). Try lr=0.001 (Adam wins), then lr=0.01 (SGD diverges, Adam stable).
Demo 2 — LR Schedule Visualizer
Compares cosine decay, WSD (Warmup-Stable-Decay), and linear decay at checkpoints throughout training. Enter total_steps warmup_steps peak_lr (e.g., 1000 100 3e-4). Notice how WSD has more area under its curve (longer time at peak LR) and how the sudden decay tail in WSD can be applied retroactively to any checkpoint — a key advantage for data-efficient continued pre-training.
5. Production & Source Pointers
| Concept | PyTorch / library | Source |
|---|---|---|
| AdamW (fused CUDA) | torch.optim.AdamW(fused=True) | torch/optim/adamw.py → _fused_adamw() |
| Gradient clipping | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | torch/nn/utils/clip_grad.py |
| Cosine LR schedule | torch.optim.lr_scheduler.CosineAnnealingLR | torch/optim/lr_scheduler.py → CosineAnnealingLR |
| Warmup + cosine (HF) | get_cosine_schedule_with_warmup() | transformers/optimization.py |
| Lion optimizer | lion-pytorch (pip) | lucidrains/lion-pytorch/lion_pytorch/lion_pytorch.py |
| Distributed Shampoo | distributed-shampoo (pip) | facebookresearch/optimizers/distributed_shampoo/ |
| Megatron optimizer | DistributedOptimizer in Megatron-LM | megatron/optimizer/distrib_optimizer.py |
configure_optimizers() is the canonical reference implementation.6. References
Papers
- Kingma & Ba 2014 — Adam: A Method for Stochastic Optimization (arXiv:1412.6980) — the Adam paper; bias correction derivation
- Loshchilov & Hutter 2017 — Decoupled Weight Decay Regularization (AdamW, arXiv:1711.05101) — shows why L2 in Adam is wrong
- Chen et al. 2023 — Symbolic Discovery of Optimization Algorithms (Lion, arXiv:2302.06675) — signed update, 1.5× memory savings over Adam
- Jordan et al. 2024 — Muon: MomentUM Orthogonalized by Newton-Schulz (arXiv:2409.20325) — cheap Shampoo approximation
- Li et al. 2018 — Visualizing the Loss Landscape of Neural Nets (arXiv:1712.09913) — filter normalisation, sharp vs flat minima
- Martens & Grosse 2015 — Optimizing Neural Networks with Kronecker-factored Approximate Curvature (K-FAC, arXiv:1503.05671)
- Gupta et al. 2018 — Shampoo: Preconditioned Stochastic Tensor Optimization (arXiv:1802.09568)
- Hu et al. 2024 — MiniCPM: Unveiling the Potential of Small LMs (arXiv:2404.06395) — WSD schedule ablation
- Loshchilov & Hutter 2016 — SGDR: Stochastic Gradient Descent with Warm Restarts (arXiv:1608.03983)
- Nesterov 1983 — A method for solving the convex programming problem with convergence rate O(1/k²)
Lectures
- Stanford CS229 — Andrew Ng, Lecture Notes on Optimization (cs229.stanford.edu/notes/) — SGD, convergence, convexity
- Stanford CS231n — Fei-Fei Li, Optimization and SGD variants (Lecture 3, 2024 version) — momentum, RMSProp, Adam
- MIT 6.S191 — Alexander Amini, Deep Learning Optimization (2024) — LR schedules, adaptive methods
- CMU 10-708 — Eric Xing, Optimization for Deep Learning — second-order methods, K-FAC
- NYU DS-GA 1008 — Yann LeCun, Energy-Based Models and Optimization — loss landscape geometry
- Oxford ML — Phil Blunsom, Gradient-Based Optimisation for Deep NLP Models (Hilary term lectures)
- DeepMind × UCL — Lecture 3: Function Approximation with Neural Networks — gradient descent basics
- fast.ai — Jeremy Howard, Lesson 2: SGD from Scratch — practical intuition for LR and momentum
Textbooks
- Boyd & Vandenberghe — Convex Optimization (free PDF at web.stanford.edu/~boyd/cvxbook/) — Chapters 9–10: gradient and Newton methods
- Goodfellow, Bengio & Courville — Deep Learning (2016) — Chapter 8: Optimization for Training Deep Models
- Murphy — Probabilistic Machine Learning: Advanced Topics (Vol 2, free draft at probml.ai) — Ch 8: optimization
- Nocedal & Wright — Numerical Optimization (2nd ed.) — rigorous convergence theory for first and second-order methods
Code / Repos
pytorch/pytorch—torch/optim/adamw.py(fused CUDA Adam),torch/optim/lr_scheduler.pylucidrains/lion-pytorch— clean Lion implementation (<50 lines)facebookresearch/optimizers— Distributed Shampoo at production scalekarpathy/nanoGPT—train.py:configure_optimizers()— AdamW setup, parameter grouping, LR decayhuggingface/transformers—optimization.py: get_cosine_schedule_with_warmup and WSD
Blog Posts
- Sebastian Ruder — An overview of gradient descent optimisation algorithms (ruder.io) — clearest single-page survey
- distill.pub — Why Momentum Really Works — variance reduction and quadratic convergence analysis
- Lilian Weng — An Overview of Optimization Algorithms (lilianweng.github.io) — momentum, Adam, natural gradient
- Andrej Karpathy — A Recipe for Training Neural Networks — practical Adam tuning advice from a practitioner
- Tim Dettmers — Optimizers and Learning Rates in Large-Scale Deep Learning — memory analysis of optimizer states
7. Interview Prep
Questions from ML / LLM interviews at Anthropic, OpenAI, DeepMind, Meta AI, and NVIDIA — optimization round.
Q1. Why does Adam need bias correction? What goes wrong without it?
Both m and v are initialised to 0. At step 1: m₁ = (1−β₁)g₁ = 0.1g₁ — only 10% of the true gradient. v₁ = 0.001g₁². Without correction, the first update is m₁/√v₁ ≈ 0.1g₁/√(0.001g₁²) = 0.1/0.032 ≈ 3.2 (same as corrected). Wait — that's fine! Actually bias correction matters most when β₁ and β₂ differ greatly: with β₂=0.999, v̂₁ is amplified 1000× while m̂₁ is amplified only 10×, making the effective step much smaller. Correction ensures both estimates start at the right scale simultaneously.
Q2. What is the difference between Adam and AdamW? Why does it matter?
Adam + L2 adds λθ to the gradient before the adaptive step: the weight decay is then divided by √v̂ alongside the gradient — giving less decay to high-variance parameters. AdamW applies weight decay directly to the parameters (θ ← (1−αλ)θ − α·adam_step), bypassing the adaptive scaling. Result: all parameters are regularised equally proportional to their magnitude, not their gradient history. Loshchilov & Hutter showed AdamW outperforms Adam+L2 on every benchmark they tested; it is now the default for LLM training.
Q3. What is the condition number of a matrix and why does it matter for optimizer design?
κ = λ_max/λ_min is the ratio of the largest to smallest eigenvalue of the Hessian. For gradient descent on a quadratic, convergence requires t = O(κ log 1/ε) steps. When κ ≫ 1 (ill-conditioned): the gradient oscillates in the high-curvature direction while making tiny progress in the low-curvature direction. Adam's per-parameter adaptive LR implicitly rescales each direction by 1/√v̂ — this is an approximation to the diagonal of H⁻¹, effectively reducing the condition number from κ to ~1 for diagonal H. Shampoo/K-FAC handle the off-diagonal structure too.
Q4. A training run diverges at step 20,000. Walk through your diagnosis checklist.
(1) Plot grad norm — if it spikes just before divergence, increase gradient clipping (typical: max_norm=1.0). (2) Check the LR schedule — is there a warmup long enough? Too-high LR after warmup is the #1 cause. (3) Check for NaN/Inf in loss at that step — might be numerical instability in log or softmax (use log-sum-exp). (4) Check the data batch: corrupted or anomalous data can cause sudden loss spikes. (5) If using mixed precision (BF16/FP16), check loss scaling — FP16 underflows cause silent zero gradients, then the model coasts, then NaN when it encounters a nonzero grad again.
Q5. Explain momentum intuitively. When does it help and when does it hurt?
Momentum accumulates a velocity vector v — the weighted sum of past gradients. In a ravine (narrow valley), gradients consistently point along the valley floor (constructive interference → v accelerates) but oscillate across the valley (destructive interference → oscillations cancel). This accelerates convergence on ill-conditioned quadratics by a factor of O(√κ). Hurt: if the learning rate is too high, momentum overshoots the minimum and “bounces back” — the velocity carries the parameter past the target. With LR decay, momentum should usually be ramped down (or the effective LR ramped down via the schedule) near convergence.
Q6. Why do large-batch LLM runs use a higher peak learning rate? What is the linear scaling rule and when does it break?
Goyal et al. (2017) linear scaling rule: if you multiply batch size by k, multiply LR by k. Intuition: a batch of size kB gives k times less gradient noise than a batch of size B, so you can take a proportionally larger step. Warmup becomes essential because the rule breaks at initialisation — large LR with random weights causes divergence before the model has meaningful activations. The rule also breaks down at very large batches (>~4k per GPU for LLMs) where the gradient signal saturates and diminishing returns set in. In practice, labs tune LR empirically.
Q7. What does implicit regularisation of SGD mean? Does Adam have it?
SGD implicitly prefers flat minima: because its updates are noisy, parameters near a sharp minimum are perturbed out of it, while flat minima absorb the noise and remain stable. Theoretically (Smith & Le 2018), SGD with LR α and batch size B finds solutions that minimise the sharpness ‖∇²L‖ subject to low loss. Adam does not have the same implicit bias — its adaptive preconditioning changes the noise structure, and empirically Adam tends to find sharper minima than SGD. This is partly why AdamW + cosine schedule (which explicitly decays LR at the end) is better than Adam+constant: the decaying LR gives Adam the final consolidation step that SGD gets for free from noise.
Q8. How does optimizer state interact with ZeRO (Zero Redundancy Optimizer)? What memory does a 7B model need for AdamW?
AdamW for a 7B model in BF16 training: model weights 14 GB (7B × 2 bytes), optimizer states (m, v in FP32) 56 GB (7B × 2 tensors × 4 bytes), gradient buffer 14 GB = ~84 GB total minimum, not counting activations. ZeRO-1 shards optimizer states across GPUs; ZeRO-2 additionally shards gradients; ZeRO-3 shards parameters too. With ZeRO-3 on 8×H100: optimizer state per GPU ≈ 56 GB / 8 = 7 GB. DeepSpeed and FSDP implement ZeRO natively. The key insight: m and v are always stored in FP32 even when training in BF16 — their low precision would cause catastrophic loss of convergence information.