Part II — Mathematical Foundations

§ 5 · Optimization for ML

Convex vs non-convex loss landscapes, gradient descent variants (SGD → Momentum → Adam → AdamW → Lion), learning-rate schedules (cosine, WSD), second-order intuition (K-FAC, Shampoo), and implicit regularisation — everything that drives a 7B parameter model from random weights to a working language model.

1. Overview

Optimization is the engine of ML: given a differentiable loss function over billions of parameters, it finds settings that make predictions accurate. For LLMs the loss is negative log-likelihood (NLL) on next-token prediction, accumulated across trillions of tokens. Every design decision — optimizer choice, learning-rate schedule, gradient clipping threshold, weight-decay coefficient — directly determines whether a 70B model converges in 10k GPU-hours or diverges spectacularly at step 50,000. The diagram below shows where optimization sits in the training loop.

Scale note. GPT-4 training ran ~300B tokens × ~1T parameters × ~6 FLOPs/param/token ≈ 10²⁴ floating-point operations. The optimizer executed one update per ~1M-token batch, meaning ~300k optimizer steps — each involving computing, clipping, and applying gradients for every one of the trillion parameters. Optimizer memory overhead (Adam moment buffers) doubles the already massive parameter footprint, which is why ZeRO (see §XI) shards optimizer state across GPUs.

2. Key Concepts

Loss Landscape Topology

The loss surface is a function from the parameter space ℝᵈ (d ≈ 7×10⁹ for LLaMA-7B) to a scalar. Classical optimization theory assumes convexity (one global minimum, gradient always helpful). Deep learning breaks every convexity assumption — yet empirically, SGD-family methods find excellent solutions. The key insight from Goodfellow et al. (2015) and Li et al. (2018): high-dimensional non-convex surfaces have far fewer bad local minima than low-dimensional ones, and near-zero-gradient points are overwhelmingly saddle points, not local minima.

Condition number  κ = λ_max / λ_min         ratio of largest to smallest curvature
                                              κ >> 1 → ill-conditioned → slow convergence

For a quadratic f(x) = x^T A x / 2:
  SGD convergence rate = (1 - 2α/(λ_max + λ_min))^t   → geometric in t
  Optimal step: α* = 2 / (λ_max + λ_min)
  Best convergence rate: (κ-1)/(κ+1)                  → slow when κ >> 1

Transformers: embedding rows have κ ~ 10^4 vs attention weights κ ~ 10^2
→ per-param adaptive LR (Adam) crucial; uniform SGD LR cannot handle both

Optimizer Family

Every modern optimizer descends from plain SGD by adding one or two ideas: momentum (accumulate velocity to accelerate consistent directions), adaptive per-parameter learning rates (slow down for high-variance params, speed up for low-variance ones), or second-order curvature (precondition by the inverse Hessian or Fisher information). The tree below maps the genealogy.

OptimizerHyperparamsOptimizer state / paramUsed by
SGDlr0 scalarsImage training (ResNet, ViT); rarely for LLMs
SGD+Momentumlr, β=0.91 scalar (velocity v)CV training; Muon uses this variant
Adamlr, β₁=0.9, β₂=0.999, ε=1e-82 scalars (m, v) + step counterBERT, T5, GPT-2; baseline for most research
AdamWlr, β₁, β₂, ε, wd=0.12 scalars (same as Adam)LLaMA, Mistral, GPT-3/4, Falcon — industry default
Lionlr, β₁=0.9, β₂=0.99, wd1 scalar (m only — no v!)DeepMind / Google experiments; 1.5× memory cheaper
Shampoolr, β, update_freqL + R matrices per layer (large!)Google (JAX), Meta (Distributed Shampoo)

Momentum and Nesterov Acceleration

Plain SGD in a ravine (loss curves steeply across one dimension, shallowly along the other) oscillates back and forth across the steep direction while crawling along the shallow direction. Momentum damps oscillations by accumulating a velocity vector — consistent gradient signals build up speed while opposing signals cancel. Nesterov adds a look-ahead correction: it peeks at where momentum would take the parameter next, evaluates the gradient there, and uses that as the corrective signal, which provides better convergence on convex functions.

Momentum update (standard form):
  v_t = β v_{t-1} + ∇L(θ_t)        # β=0.9 typical
  θ_{t+1} = θ_t - α v_t

Nesterov (equivalent re-parameterisation, used in PyTorch nesterov=True):
  θ_look = θ_t + β v_{t-1}          # look-ahead position
  v_t    = β v_{t-1} + ∇L(θ_look)  # gradient at look-ahead
  θ_{t+1} = θ_t - α v_t

Convergence on convex f with L-smooth and μ-strongly convex:
  Momentum:  O((1 - √(μ/L))^t)     (same as GD, just faster constant)
  Nesterov:  O((1 - √(μ/L))^t)     (optimal first-order rate — Nesterov 1983)

Adam — Adaptive Moment Estimation

Adam (Kingma & Ba, 2014) combines momentum with adaptive per-parameter learning rates. The key idea: maintain a running estimate of each parameter's gradient magnitude (second moment), and divide the gradient by that estimate. Parameters with historically large gradients get a smaller effective step; parameters with small gradients get a larger step. This is critical for transformers where embedding parameters have gradients orders of magnitude larger than, say, layer-norm bias terms.

Why bias correction? Both m and v are initialised to 0. At step 1, m₁ = (1−β₁)·g₁ = 0.1·g₁ — only 10% of the true gradient. Without correction, the first ~10 steps take tiny updates as the moving average warms up. Dividing by (1−β₁ᵗ) = 0.1 at t=1 inflates m̂₁ back to g₁. By t≈100, β₁¹⁰⁰ ≈ 2.7×10⁻⁵ ≈ 0, so the correction vanishes automatically.
Adam algorithm (Kingma & Ba 2014, Algorithm 1):

Hyperparams: α (lr), β₁=0.9, β₂=0.999, ε=1e-8
Init: m₀ = v₀ = 0

For each step t = 1, 2, ...:
  g_t   ← ∇L(θ_t)                          # gradient (same shape as θ)
  m_t   ← β₁ m_{t-1} + (1-β₁) g_t          # biased 1st moment
  v_t   ← β₂ v_{t-1} + (1-β₂) g_t²         # biased 2nd moment (element-wise g²)
  m̂_t  ← m_t / (1 - β₁ᵗ)                   # bias-corrected 1st moment
  v̂_t  ← v_t / (1 - β₂ᵗ)                   # bias-corrected 2nd moment
  θ_{t+1} ← θ_t - α m̂_t / (√v̂_t + ε)     # update (element-wise)

Memory: 2 fp32 tensors of shape θ (m and v) + 1 int (t)
→ At 7B params × 4 bytes × 3 copies (θ, m, v) = 84 GB at minimum

AdamW: Decoupled Weight Decay

Loshchilov & Hutter (2017) showed that adding L2 regularisation to Adam is wrong. Standard Adam with L2 adds λθ to the gradient before the adaptive step, which means high-gradient parameters get less weight decay than low-gradient ones (the adaptive divisor shrinks both the gradient and the L2 term). AdamW applies weight decay directly to the parameters, bypassing the adaptive scaling, giving uniform regularisation:

Adam   + L2:  g_t ← ∇L(θ_t) + λ θ_t        # wrong: WD gets scaled by 1/√v̂
              θ_{t+1} ← θ_t - α m̂_t / (√v̂_t + ε)

AdamW:         g_t ← ∇L(θ_t)                 # gradient without WD
              θ_{t+1} ← (1 - α λ) θ_t - α m̂_t / (√v̂_t + ε)
              #           ↑ decoupled WD      ↑ adaptive gradient step

Typical:  λ (wd) = 0.1   →   shrinks weights by 0.1% per step at lr=1e-3
Not applied to: bias terms, LayerNorm/RMSNorm scale/shift (common practice)

Learning-Rate Schedules

No fixed learning rate works well across all training phases. Too high at the start → divergence before the model learns anything useful. Too high at the end → the model oscillates near the minimum rather than settling. LR schedules divide training into phases: a warmup ramp, a sustained high-LR phase, and a decay tail.

ScheduleFormula (post-warmup)PropertiesUsed by
Cosinelr_min + (lr_max−lr_min)·½·(1+cos(πt/T))Smooth; no kinks; most commonGPT-3, LLaMA-1/2/3, Mistral-7B
WSDFlat at lr_max → linear decay in tailCheckpoint-friendly; easy restartMiniCPM, Falcon-180B, Mistral-v0.2
Linearlr_max·(1 − t/T) + lr_min·t/TSimple; fewer area under curveFine-tuning; short-run experiments
Cosine restartsCosine within each cycle; T_i × mult at restartEscapes local minima; SGDRSome fine-tuning; curriculum learning

Second-Order Methods

First-order methods (SGD, Adam) use only the gradient — the first derivative of the loss. Second-order methods additionally use the Hessian H = ∂²L/∂θ² or its approximation. The natural gradient direction F⁻¹g (where F is the Fisher information matrix ≈ H near the minimum) is invariant to reparameterisation and converges in far fewer steps. The practical barrier: for a 7B model, H is a 7B×7B matrix (5×10¹⁹ entries) — completely intractable.

MethodApproximationMemoryBenefit
NewtonExact H⁻¹gO(d²)Quadratic convergence; totally impractical for LLMs
K-FACF ≈ A⊗B per layer (Kronecker)O(d_in² + d_out²)2–5× fewer steps; tractable for mid-size models
ShampooFull matrix preconditioner (no Fisher)O(d_in² + d_out²) per layerUsed in production at Google; Distributed Shampoo at Meta
MuonOrthogonalise grad via Newton-Schulz1 momentum buffer (like SGD)Cheap approx of Shampoo; gaining traction in 2024
Shampoo per-layer update (simplified):
  G_t ← layer gradient  [d_out × d_in]
  L_t ← L_{t-1} + G_t G_tᵀ              [d_out × d_out]  left Kronecker factor
  R_t ← R_{t-1} + G_tᵀ G_t              [d_in  × d_in]   right Kronecker factor
  Preconditioned step ← L_t^{-¼} G_t R_t^{-¼}

Why 4th root? The Kronecker product L⊗R has eigenvalues λ_L·λ_R, so
(L⊗R)^{-½} ≈ L^{-¼}⊗R^{-¼}.  Inversion every k=100 steps to amortise cost.

3. Core Mechanism: Adam Step-by-Step

Background

Imagine training a 7B-parameter LLM with plain SGD at a fixed learning rate α = 1e-3. The token embedding table receives gradient spikes when rare tokens appear (large, sparse gradients), while the layer-norm scale parameters receive tiny but consistent gradients. A single α that prevents embedding divergence will make layer-norm progress infinitesimally slowly — and vice versa. Adam solves this by giving each parameter its own effective learning rate, automatically discovered from gradient history.

Plan

  1. Maintain m_t = exponential moving average of past gradients (bias-corrected by dividing by 1−β₁ᵗ). This is the “direction to move” signal — smoother than raw gradient due to averaging.
  2. Maintain v_t = exponential moving average of past squared gradients (bias-corrected by 1−β₂ᵗ). This estimates the gradient's typical magnitude per parameter.
  3. Divide m̂_t by √v̂_t to get a normalised update. The effective per-parameter step size is approximately α regardless of gradient scale — that's the adaptive property.
  4. Add ε = 1e-8 to prevent division by zero for parameters that receive zero gradient (e.g., embeddings for tokens not in the current batch).

Walkthrough — Single Parameter, 5 Steps

Initial conditions: one scalar parameter θ₀ = 2.0, loss = θ²/2 (parabola), gradient = θ (true minimum at θ=0). Adam hyperparams: α = 0.1, β₁ = 0.9, β₂ = 0.999, ε = 1e-8. Track m, v, m̂, v̂, and the resulting update per step.

Init: θ=2.0  m=0  v=0

t=1: g=2.0
  m₁ = 0.9·0 + 0.1·2.0  = 0.2        v₁ = 0.999·0 + 0.001·4.0 = 0.004
  m̂₁ = 0.2/0.1          = 2.0        v̂₁ = 0.004/0.001         = 4.0
  step = 0.1 · 2.0 / (√4.0 + ε)    ≈ 0.100
  θ₁  = 2.0 - 0.100 = 1.900

t=2: g=1.9
  m₂ = 0.9·0.2 + 0.1·1.9 = 0.370    v₂ = 0.999·0.004 + 0.001·3.61 = 0.007597
  m̂₂ = 0.370/0.19        = 1.947    v̂₂ = 0.007597/0.001999       = 3.800
  step = 0.1 · 1.947 / (√3.800 + ε) ≈ 0.099
  θ₂  = 1.900 - 0.099 = 1.801

t=3: g=1.8
  m₃ = 0.9·0.370+0.1·1.8 = 0.513    v₃ = 0.999·0.00760+0.001·3.24 = 0.010840
  m̂₃ = 0.513/0.271       = 1.892    v̂₃ = 0.010840/0.002997        = 3.617
  step ≈ 0.099
  θ₃  ≈ 1.702

...continuing, each step ≈ α=0.1 regardless of the shrinking gradient magnitude.

Key observation: as g shrinks (param approaches 0), v̂ also shrinks → m̂/√v̂ stays ~1.
Adam takes constant step size α until it is very close to the minimum — much faster
than SGD where step size ∝ g → shrinks automatically and stalls.
Implicit regularisation. Adam's bias correction gives it an implicit bias toward parameters that received gradients early in training — they accumulate larger v̂ which permanently reduces their effective LR. This is why AdamW's explicit weight decay (which decays all parameters uniformly, bypassing v̂) produces better generalisation: it doesn't leave parameters that happened to get large early gradients under-regularised.

4. Minimal Demos

Demo 1 — Optimizer Race on Rosenbrock

The Rosenbrock “banana” function f(x,y) = (1−x)² + 100(y−x²)² has a narrow curved valley leading to a minimum at (1,1). It is a classic benchmark because SGD oscillates across the valley while adaptive methods navigate it efficiently. Enter steps lr (e.g., 500 0.001). Try lr=0.001 (Adam wins), then lr=0.01 (SGD diverges, Adam stable).

Optimizer Race — Rosenbrock — C Demo
stdin (optional)

Demo 2 — LR Schedule Visualizer

Compares cosine decay, WSD (Warmup-Stable-Decay), and linear decay at checkpoints throughout training. Enter total_steps warmup_steps peak_lr (e.g., 1000 100 3e-4). Notice how WSD has more area under its curve (longer time at peak LR) and how the sudden decay tail in WSD can be applied retroactively to any checkpoint — a key advantage for data-efficient continued pre-training.

LR Schedule Visualizer — C Demo
stdin (optional)

5. Production & Source Pointers

ConceptPyTorch / librarySource
AdamW (fused CUDA)torch.optim.AdamW(fused=True)torch/optim/adamw.py → _fused_adamw()
Gradient clippingtorch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)torch/nn/utils/clip_grad.py
Cosine LR scheduletorch.optim.lr_scheduler.CosineAnnealingLRtorch/optim/lr_scheduler.py → CosineAnnealingLR
Warmup + cosine (HF)get_cosine_schedule_with_warmup()transformers/optimization.py
Lion optimizerlion-pytorch (pip)lucidrains/lion-pytorch/lion_pytorch/lion_pytorch.py
Distributed Shampoodistributed-shampoo (pip)facebookresearch/optimizers/distributed_shampoo/
Megatron optimizerDistributedOptimizer in Megatron-LMmegatron/optimizer/distrib_optimizer.py
Common pitfall: forgetting to exclude certain params from weight decay. LayerNorm and RMSNorm scale/shift parameters, embedding tables, and bias terms should typically have wd=0. Always split parameters into two groups before constructing the optimizer: decayed params (all weight matrices) and non-decayed params (bias, norm scale/shift). Getting this wrong degrades performance measurably — nanoGPT's configure_optimizers() is the canonical reference implementation.

6. References

Papers

  • Kingma & Ba 2014 — Adam: A Method for Stochastic Optimization (arXiv:1412.6980) — the Adam paper; bias correction derivation
  • Loshchilov & Hutter 2017 — Decoupled Weight Decay Regularization (AdamW, arXiv:1711.05101) — shows why L2 in Adam is wrong
  • Chen et al. 2023 — Symbolic Discovery of Optimization Algorithms (Lion, arXiv:2302.06675) — signed update, 1.5× memory savings over Adam
  • Jordan et al. 2024 — Muon: MomentUM Orthogonalized by Newton-Schulz (arXiv:2409.20325) — cheap Shampoo approximation
  • Li et al. 2018 — Visualizing the Loss Landscape of Neural Nets (arXiv:1712.09913) — filter normalisation, sharp vs flat minima
  • Martens & Grosse 2015 — Optimizing Neural Networks with Kronecker-factored Approximate Curvature (K-FAC, arXiv:1503.05671)
  • Gupta et al. 2018 — Shampoo: Preconditioned Stochastic Tensor Optimization (arXiv:1802.09568)
  • Hu et al. 2024 — MiniCPM: Unveiling the Potential of Small LMs (arXiv:2404.06395) — WSD schedule ablation
  • Loshchilov & Hutter 2016 — SGDR: Stochastic Gradient Descent with Warm Restarts (arXiv:1608.03983)
  • Nesterov 1983 — A method for solving the convex programming problem with convergence rate O(1/k²)

Lectures

  • Stanford CS229 — Andrew Ng, Lecture Notes on Optimization (cs229.stanford.edu/notes/) — SGD, convergence, convexity
  • Stanford CS231n — Fei-Fei Li, Optimization and SGD variants (Lecture 3, 2024 version) — momentum, RMSProp, Adam
  • MIT 6.S191 — Alexander Amini, Deep Learning Optimization (2024) — LR schedules, adaptive methods
  • CMU 10-708 — Eric Xing, Optimization for Deep Learning — second-order methods, K-FAC
  • NYU DS-GA 1008 — Yann LeCun, Energy-Based Models and Optimization — loss landscape geometry
  • Oxford ML — Phil Blunsom, Gradient-Based Optimisation for Deep NLP Models (Hilary term lectures)
  • DeepMind × UCL — Lecture 3: Function Approximation with Neural Networks — gradient descent basics
  • fast.ai — Jeremy Howard, Lesson 2: SGD from Scratch — practical intuition for LR and momentum

Textbooks

  • Boyd & Vandenberghe — Convex Optimization (free PDF at web.stanford.edu/~boyd/cvxbook/) — Chapters 9–10: gradient and Newton methods
  • Goodfellow, Bengio & Courville — Deep Learning (2016) — Chapter 8: Optimization for Training Deep Models
  • Murphy — Probabilistic Machine Learning: Advanced Topics (Vol 2, free draft at probml.ai) — Ch 8: optimization
  • Nocedal & Wright — Numerical Optimization (2nd ed.) — rigorous convergence theory for first and second-order methods

Code / Repos

  • pytorch/pytorchtorch/optim/adamw.py (fused CUDA Adam), torch/optim/lr_scheduler.py
  • lucidrains/lion-pytorch — clean Lion implementation (<50 lines)
  • facebookresearch/optimizers — Distributed Shampoo at production scale
  • karpathy/nanoGPTtrain.py:configure_optimizers() — AdamW setup, parameter grouping, LR decay
  • huggingface/transformersoptimization.py: get_cosine_schedule_with_warmup and WSD

Blog Posts

  • Sebastian Ruder — An overview of gradient descent optimisation algorithms (ruder.io) — clearest single-page survey
  • distill.pub — Why Momentum Really Works — variance reduction and quadratic convergence analysis
  • Lilian Weng — An Overview of Optimization Algorithms (lilianweng.github.io) — momentum, Adam, natural gradient
  • Andrej Karpathy — A Recipe for Training Neural Networks — practical Adam tuning advice from a practitioner
  • Tim Dettmers — Optimizers and Learning Rates in Large-Scale Deep Learning — memory analysis of optimizer states

7. Interview Prep

Questions from ML / LLM interviews at Anthropic, OpenAI, DeepMind, Meta AI, and NVIDIA — optimization round.

Q1. Why does Adam need bias correction? What goes wrong without it?

Both m and v are initialised to 0. At step 1: m₁ = (1−β₁)g₁ = 0.1g₁ — only 10% of the true gradient. v₁ = 0.001g₁². Without correction, the first update is m₁/√v₁ ≈ 0.1g₁/√(0.001g₁²) = 0.1/0.032 ≈ 3.2 (same as corrected). Wait — that's fine! Actually bias correction matters most when β₁ and β₂ differ greatly: with β₂=0.999, v̂₁ is amplified 1000× while m̂₁ is amplified only 10×, making the effective step much smaller. Correction ensures both estimates start at the right scale simultaneously.

Q2. What is the difference between Adam and AdamW? Why does it matter?

Adam + L2 adds λθ to the gradient before the adaptive step: the weight decay is then divided by √v̂ alongside the gradient — giving less decay to high-variance parameters. AdamW applies weight decay directly to the parameters (θ ← (1−αλ)θ − α·adam_step), bypassing the adaptive scaling. Result: all parameters are regularised equally proportional to their magnitude, not their gradient history. Loshchilov & Hutter showed AdamW outperforms Adam+L2 on every benchmark they tested; it is now the default for LLM training.

Q3. What is the condition number of a matrix and why does it matter for optimizer design?

κ = λ_max/λ_min is the ratio of the largest to smallest eigenvalue of the Hessian. For gradient descent on a quadratic, convergence requires t = O(κ log 1/ε) steps. When κ ≫ 1 (ill-conditioned): the gradient oscillates in the high-curvature direction while making tiny progress in the low-curvature direction. Adam's per-parameter adaptive LR implicitly rescales each direction by 1/√v̂ — this is an approximation to the diagonal of H⁻¹, effectively reducing the condition number from κ to ~1 for diagonal H. Shampoo/K-FAC handle the off-diagonal structure too.

Q4. A training run diverges at step 20,000. Walk through your diagnosis checklist.

(1) Plot grad norm — if it spikes just before divergence, increase gradient clipping (typical: max_norm=1.0). (2) Check the LR schedule — is there a warmup long enough? Too-high LR after warmup is the #1 cause. (3) Check for NaN/Inf in loss at that step — might be numerical instability in log or softmax (use log-sum-exp). (4) Check the data batch: corrupted or anomalous data can cause sudden loss spikes. (5) If using mixed precision (BF16/FP16), check loss scaling — FP16 underflows cause silent zero gradients, then the model coasts, then NaN when it encounters a nonzero grad again.

Q5. Explain momentum intuitively. When does it help and when does it hurt?

Momentum accumulates a velocity vector v — the weighted sum of past gradients. In a ravine (narrow valley), gradients consistently point along the valley floor (constructive interference → v accelerates) but oscillate across the valley (destructive interference → oscillations cancel). This accelerates convergence on ill-conditioned quadratics by a factor of O(√κ). Hurt: if the learning rate is too high, momentum overshoots the minimum and “bounces back” — the velocity carries the parameter past the target. With LR decay, momentum should usually be ramped down (or the effective LR ramped down via the schedule) near convergence.

Q6. Why do large-batch LLM runs use a higher peak learning rate? What is the linear scaling rule and when does it break?

Goyal et al. (2017) linear scaling rule: if you multiply batch size by k, multiply LR by k. Intuition: a batch of size kB gives k times less gradient noise than a batch of size B, so you can take a proportionally larger step. Warmup becomes essential because the rule breaks at initialisation — large LR with random weights causes divergence before the model has meaningful activations. The rule also breaks down at very large batches (>~4k per GPU for LLMs) where the gradient signal saturates and diminishing returns set in. In practice, labs tune LR empirically.

Q7. What does implicit regularisation of SGD mean? Does Adam have it?

SGD implicitly prefers flat minima: because its updates are noisy, parameters near a sharp minimum are perturbed out of it, while flat minima absorb the noise and remain stable. Theoretically (Smith & Le 2018), SGD with LR α and batch size B finds solutions that minimise the sharpness ‖∇²L‖ subject to low loss. Adam does not have the same implicit bias — its adaptive preconditioning changes the noise structure, and empirically Adam tends to find sharper minima than SGD. This is partly why AdamW + cosine schedule (which explicitly decays LR at the end) is better than Adam+constant: the decaying LR gives Adam the final consolidation step that SGD gets for free from noise.

Q8. How does optimizer state interact with ZeRO (Zero Redundancy Optimizer)? What memory does a 7B model need for AdamW?

AdamW for a 7B model in BF16 training: model weights 14 GB (7B × 2 bytes), optimizer states (m, v in FP32) 56 GB (7B × 2 tensors × 4 bytes), gradient buffer 14 GB = ~84 GB total minimum, not counting activations. ZeRO-1 shards optimizer states across GPUs; ZeRO-2 additionally shards gradients; ZeRO-3 shards parameters too. With ZeRO-3 on 8×H100: optimizer state per GPU ≈ 56 GB / 8 = 7 GB. DeepSpeed and FSDP implement ZeRO natively. The key insight: m and v are always stored in FP32 even when training in BF16 — their low precision would cause catastrophic loss of convergence information.