§ 10 Practical Training — Init / Norm / Mixed-Precision
Weight initialisation, normalisation layers, residual connections, regularisation, mixed-precision training, gradient monitoring, and loss-curve diagnostics — the engineering craft behind making deep networks actually converge.
1. Overview
Building a network architecture is only the first step. Training it reliably on deep stacks — without activations vanishing, gradients exploding, or loss curves diverging — requires engineering care at every layer. This section covers six levers every practitioner must understand:
- Initialisation — seed weights so signal propagates to layer 100 on step 1.
- Normalisation — smooth the loss landscape so large learning rates are safe.
- Residual connections — route a gradient highway through every block.
- Regularisation — prevent memorisation without starving the model of capacity.
- Mixed precision — 2x memory savings with BF16 while keeping numerical stability.
- Monitoring — read a training curve and diagnose a run in under ten seconds.
2. Weight Initialisation
Poor initialisation can kill training before it starts. If weights are too small, activations shrink to zero layer-by-layer (vanishing signal). If too large, they explode to infinity. The goal: choose a weight variance so that signal variance is preserved as it propagates forward through all layers.
Variance Analysis
For a single linear layer y = Wx, where each entry of W is i.i.d. with mean 0 and variance σ², and inputs x_j are independent of the weights:
Var(y_i) = Σ_j Var(W_ij) · Var(x_j) # independence
= fan_in · σ² · Var(x)
To preserve variance across layers: fan_in · σ² = 1
-> σ² = 1/fan_in # LeCun 1998 (linear / tanh)
-> σ² = 2/(fan_in + fan_out) # Xavier 2010 (harmonic mean: fwd + bwd)
-> σ² = 2/fan_in # Kaiming 2015 (ReLU: E[a²] = ½ Var(pre))Core Mechanism — Kaiming Derivation
Background: Xavier assumes linear activations (or symmetric activations like tanh). ReLU zeros out roughly half of its inputs, so the expected square of the post-activation is only half that of a linear layer. Using Xavier with ReLU still leads to signal vanishing in deep networks.
Plan:
- For pre-activation
h = Wx, computeVar(h_i) = fan_in · σ². - Post-ReLU:
a = max(0, h). By symmetry of zero-mean Gaussian:E[a²] = ½ · E[h²] = ½ · Var(h). - Chain through L layers:
Var(a_L) = (½ · fan_in · σ²)^L · Var(a_0). - Set the base equal to 1:
½ · fan_in · σ² = 1impliesσ² = 2/fan_in.
Concrete walkthrough: width=256, depth=8, Kaiming: σ² = 2/256 ≈ 0.0078. Variance factor per layer: ½ · 256 · 0.0078 = 1.0. After 8 layers: 1.0^8 = 1 — signal preserved. Bad-small (σ=0.01, σ²=0.0001): factor ≈ 0.0128 per layer. After 8 layers: 0.0128^8 ≈ 10⁻¹⁶ — signal annihilated.
nn.init.kaiming_uniform_(w, nonlinearity="relu") is the default for nn.Linear. nn.init.xavier_uniform_(w) for linear/tanh layers.3. Normalisation Layers
Even with good initialisation, deep networks accumulate covariate shift — the distribution of layer inputs drifts during training, forcing earlier layers to chase a moving target. Normalisation layers fix this by standardising activations at every layer, dramatically increasing the stable learning rate range.
BatchNorm
Normalise each channel across the batch and spatial dimensions. Learns affine parameters γ and β. Maintains running mean/variance for inference-time use.
mu_c = mean over {n, h, w} # mean across batch + spatial
sigma2 = var over {n, h, w}
x_hat = (x - mu_c) / sqrt(sigma2 + eps)
y_c = gamma_c * x_hat + beta_c # learned per-channel scale and shiftBN breaks at batch-size 1 (zero variance) and during autoregressive generation where batch semantics are ill-defined. This is why transformers and LLMs never use it.
LayerNorm (used in transformers)
Normalise over the feature dimension for each sample independently — no cross-sample statistics, no running averages, works at any batch size including 1.
mu_n = mean over features (C, ...) for sample n sigma2 = var over features x_hat = (x_n - mu_n) / sqrt(sigma2 + eps) y_n = gamma * x_hat + beta # shared gamma, beta across all samples
RMSNorm (LLaMA, PaLM, Gemma)
Drops mean-centring — only divides by the root-mean-square of activations. Empirically equivalent to LayerNorm for LLMs with 15–20% fewer operations and simpler kernel fusion.
RMS(x) = sqrt( mean(x^2) + eps ) y = gamma * (x / RMS(x)) # no beta, no mean subtraction
4. Residual Connections
Before ResNet (He et al., 2015), plain 56-layer networks performed worse than 20-layer networks — not due to overfitting, but optimisation failure. Adding more layers made training harder. Residual connections solve this by giving gradients a direct path from the loss back to every earlier layer.
Gradient Highway
For output H(x) = F(x) + x, the backward-pass gradient decomposes into two terms:
dL/dx = dL/dH · (dF(x)/dx + I)
^^^^^^^^^^^^^^^^ ^
through residual through identity (always = dL/dH, no decay)
After L stacked blocks:
dL/dx_0 = dL/dx_L · prod_{l=1}^{L} (dF_l/dx_l + I)
= dL/dx_L · (1 + correction terms) <- never collapses to zeroThe identity term prevents the product from collapsing even when individual residual Jacobians are small. This is why ResNet-152 trains successfully but a plain 152-layer network cannot.
Pre-norm vs Post-norm
Original ResNet places BatchNorm after the residual add (post-norm). Modern transformers use pre-norm — apply LayerNorm/RMSNorm to the inputbefore each sub-layer, then add the (unnormalised) residual. Pre-norm stabilises training and removes the need for careful LR warmup tuning. It is now universal in LLMs (GPT-2+, LLaMA, Mistral, PaLM, Gemma).
# Post-norm (original Transformer / ResNet) x = x + SubLayer(x) x = LayerNorm(x) # Pre-norm (GPT-2, LLaMA, all modern LLMs) x = x + SubLayer(LayerNorm(x)) # cleaner gradient path
5. Regularisation
Regularisation prevents a model from memorising the training set. Each technique attacks overfitting from a different angle.
| Technique | Mechanism | Best For | Typical Value |
|---|---|---|---|
| L2 weight decay | Adds λ‖θ‖² to loss; penalises large weights | Universal baseline | λ = 0.01–0.1 |
| Dropout | Zero each activation with prob p; scale by 1/(1−p) at train time | MLPs, attention layers | p = 0.1–0.5 |
| Label smoothing | Replace one-hot y with (1−ε)·y + ε/K; prevents overconfidence | Classification, LLM vocab head | ε = 0.1 |
| Stochastic depth | Skip each residual block with survival prob p_l during training | Very deep nets (EfficientNet, ViT) | p_L = 0.1 (linear decay) |
| Mixup / CutMix | Convex combination of two training samples and their labels | Vision; improves calibration | α = 0.2 |
In modern LLMs, dropout is often set to 0 during pre-training — the massive dataset volume serves as the regulariser. Dropout is reintroduced at low rates during fine-tuning. Weight decay remains universal: AdamW applies it directly to weights rather than adding to the gradient (decoupled decay).
6. Mixed-Precision Training
Training in full FP32 wastes memory and under-utilises tensor cores, which peak at BF16/FP16. Mixed precision uses low-precision for compute-intensive ops (matmul, conv) and high-precision for accumulation and weight updates.
FP16 vs BF16
| Format | Sign | Exponent bits | Mantissa bits | Max value | Loss scaling? |
|---|---|---|---|---|---|
| FP32 | 1 | 8 | 23 | 3.4 × 10³⁸ | No |
| BF16 | 1 | 8 | 7 | 3.4 × 10³⁸ | No — same exp range as FP32 |
| FP16 | 1 | 5 | 10 | 65 504 | Yes — gradients underflow |
Loss Scaling (FP16 only)
Small gradients (e.g., 10⁻⁶) underflow to zero in FP16. Loss scaling multiplies the loss by a large scalar S before backward, then divides gradients by S before the optimiser step. This shifts gradient magnitudes into the FP16 representable range.
# Manual FP16 loss scaling
scaled_loss = loss * S # push gradients into FP16 range
scaled_loss.backward()
for p in model.parameters():
p.grad /= S # undo scale before optimizer.step()
# PyTorch GradScaler automates this (dynamic S)
scaler = torch.cuda.amp.GradScaler()
with torch.autocast("cuda", dtype=torch.float16):
loss = model(x)
scaler.scale(loss).backward() # scale before backward
scaler.step(optimizer) # unscale, check for inf/nan, step
scaler.update() # adjust S for next iterationOn A100/H100, prefer BF16 — no loss scaling needed, identical tensor-core throughput to FP16, and same dynamic range as FP32.
7. Gradient Monitoring and Loss Diagnostics
Gradient Clipping
When gradient norms spike — bad batch, LR too high, early training instability — they can push weights to extreme values. Gradient clipping rescales the entire gradient vector to fit within a maximum norm:
grad_norm = sqrt( sum_theta ||grad_theta||^2 )
if grad_norm > clip_value:
for each grad_theta:
grad_theta *= clip_value / grad_norm # scale all grads uniformly
# PyTorch:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)Log the pre-clip grad-norm every step. A healthy run has a stable, slowly decreasing norm. Persistent spikes: reduce LR or increase warmup steps. One-off spike: likely a bad batch; clipping absorbed it.
Loss Curve Taxonomy
| Pattern | Train loss | Val loss | Diagnosis | Fix |
|---|---|---|---|---|
| High bias | High, barely drops | Tracks train | Underfit — model too small | Increase capacity; reduce regularisation |
| High variance | Low, still dropping | Rising gap | Overfit — memorising | More data; dropout; weight decay; early stop |
| Instability | Oscillating / spiking | Noisy | LR too high or bad batch | Reduce LR 10x; add warmup; clip gradients |
| Healthy | Smooth decay | Slightly above train | Converging normally | Monitor for late overfit; adjust LR schedule |
Gradient Accumulation
When a single device cannot hold a large global batch, run N micro-batches and accumulate gradients across them before the optimiser step. Effective batch size = micro_batch × accumulation_steps. Common for LLM pre-training with global batches of 2M+ tokens.
optimizer.zero_grad()
for step, micro_batch in enumerate(micro_batches):
loss = model(micro_batch) / accumulation_steps # scale to mean
loss.backward() # accumulate into .grad
optimizer.step() # one update per N micro-batches8. Minimal Demos
Demo A — Init Explorer
Watch how activation standard deviation evolves layer-by-layer under four initialisation schemes. Kaiming is the only one that stays numerically stable through 8 ReLU layers. Run it to see the signal vanish (bad_small) and explode (bad_large).
Demo B — Training-Loss Diagnostic
Enter a scenario number 1–4. The simulator prints synthetic training and validation loss curves for that failure mode, then gives the fix. Practise reading the pattern before you encounter a real run.
9. Production / Source Pointers
| Concept | PyTorch path | Key symbol |
|---|---|---|
| Weight init | torch/nn/init.py | kaiming_uniform_, xavier_uniform_, trunc_normal_ |
| BatchNorm | torch/nn/modules/batchnorm.py | BatchNorm2d — running_mean / running_var accumulation |
| LayerNorm | torch/nn/modules/normalization.py | LayerNorm calls F.layer_norm (fused CUDA kernel) |
| RMSNorm | transformers/models/llama/modeling_llama.py | LlamaRMSNorm / torch.nn.RMSNorm (PyTorch 2.4+) |
| Mixed precision BF16 | torch/amp/autocast_mode.py | torch.autocast(dtype=torch.bfloat16) |
| Loss scaling FP16 | torch/cuda/amp/grad_scaler.py | GradScaler.scale() / .step() / .update() |
| Gradient clipping | torch/nn/utils/clip_grad.py | clip_grad_norm_(params, max_norm) |
| Stochastic depth | timm/layers/drop.py | DropPath — applies bernoulli mask to residual branch |
10. References
Papers
- He et al. 2015 — Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification (Kaiming init). arXiv:1502.01852
- He et al. 2015 — Deep Residual Learning for Image Recognition (ResNet). arXiv:1512.03385
- Glorot & Bengio 2010 — Understanding the Difficulty of Training Deep Feedforward Neural Networks (Xavier init). AISTATS 2010
- Ioffe & Szegedy 2015 — Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. arXiv:1502.03167
- Ba et al. 2016 — Layer Normalization. arXiv:1607.06450
- Zhang & Sennrich 2019 — Root Mean Square Layer Normalization. arXiv:1910.07467
- Micikevicius et al. 2018 — Mixed Precision Training. arXiv:1710.03740
- Loshchilov & Hutter 2019 — Decoupled Weight Decay Regularization (AdamW). arXiv:1711.05101
- Srivastava et al. 2014 — Dropout: A Simple Way to Prevent Neural Networks from Overfitting. JMLR 15(1)
- Huang et al. 2016 — Deep Networks with Stochastic Depth. arXiv:1603.09382
Lectures (Stanford / MIT / CMU / NYU / fast.ai)
- Stanford CS231n — Lecture 6: Training Neural Networks I (Kaiming, BatchNorm, optimisation)
- Stanford CS230 — Week 2: Regularisation and improving deep networks (Andrew Ng)
- MIT 6.S191 — Lecture 1: Introduction to Deep Learning (Adam Amini)
- CMU 11-785 — Lecture 8 recitation: Initialisation and normalisation
- NYU DS-GA 1008 — Yann LeCun, Lecture 5: Training dynamics and loss landscape
- fast.ai — Practical Deep Learning Part 1, Lesson 5: training improvements
Textbooks
- Goodfellow, Bengio, Courville — Deep Learning, Ch. 7 (Regularisation), Ch. 8 (Optimisation) — deeplearningbook.org
- Zhang et al. — Dive Into Deep Learning, Ch. 4 (Multilayer Perceptrons) — d2l.ai
Code / Repos
karpathy/nanoGPT— uses LayerNorm, AdamW, gradient clipping, BF16; minimal GPT pre-training referencepytorch/pytorch—torch.nn.init,torch.cuda.amp.GradScalerhuggingface/transformers— LlamaRMSNorm, BF16/FP16 training argumentsmicrosoft/DeepSpeed— ZeRO + BF16 mixed-precision pipelinehuggingface/pytorch-image-models(timm) — stochastic depth, EfficientNet training recipe
Blog Posts
- Andrej Karpathy — A Recipe for Training Neural Networks (karpathy.github.io/2019/04/25/recipe/)
- Sebastian Ruder — An Overview of Gradient Descent Optimisation Algorithms
- Lilian Weng — Normalization in Neural Networks (lilianweng.github.io)
- PyTorch Blog — Automatic Mixed Precision Training
11. Interview Prep
- Derive the Kaiming init formula for a ReLU network.
For
h = Wxwith W ~ N(0,σ²): Var(h_i) = fan_in·σ². Post-ReLU: E[a²] = ½·Var(h) by Gaussian symmetry. Chained L layers: (½·fan_in·σ²)^L. Set = 1 → σ² = 2/fan_in. - Compare BatchNorm and LayerNorm. Why do LLMs use LayerNorm or RMSNorm?
BN normalises across the batch — breaks at batch-size 1 and autoregressive decoding. LN normalises per sample over the feature dimension — batch-agnostic, no running stats, works at any sequence length. RMSNorm drops mean-centring: faster and empirically equivalent for LLMs.
- How do residual connections solve optimisation failure in deep networks?
The gradient splits into a residual path (potentially vanishing) and an identity path (always = upstream gradient). The identity guarantees the signal reaches shallow layers in 100+-layer networks. Without it, gradients vanish multiplicatively.
- Why does BF16 not need loss scaling but FP16 does?
BF16 has 8 exponent bits (same as FP32) — max representable value ≈ 3.4×10³⁸. FP16 has only 5 exponent bits — max ≈ 65504. Small gradients (10⁻⁶) underflow to zero in FP16. Loss scaling multiplies loss by a large S to shift gradients into the FP16 range before backward.
- A training loss oscillates wildly step-to-step. Name three causes and fixes.
(1) LR too high → halve it or add linear warmup. (2) Gradient norms exploding → add clip_grad_norm_ with max_norm=1.0. (3) Bad batch (corrupt data, extreme outlier) → check data pipeline, clip individual sample losses.
- Explain label smoothing and when it helps.
Replace one-hot target y with (1−ε)·y + ε/K. This prevents overconfidence on training labels, improves calibration, and often improves generalisation. Useful when labels are noisy (web data) or the vocabulary is large (LLM next-token prediction head). Standard ε = 0.1.
- A run diverges at step 10k. Walk through your debugging checklist.
(1) Check grad-norm at step 10k — is it a spike? (2) Was there a LR schedule event (cosine restart, warmup end)? (3) Inspect the data batch from around that step for corruption or extreme values. (4) Enable
torch.autograd.set_detect_anomaly(True)to find the first NaN. (5) Restore checkpoint at step 9k and resume with a lower LR or by skipping the suspect batch. - What is gradient accumulation and when is it needed?
Run N micro-batches, accumulating gradients across them before one optimiser step. Effective batch size = micro_batch × N. Used when a target global batch size (e.g., 2M tokens for LLM pre-training) exceeds single-device memory capacity. Note: divide each micro-batch loss by N to keep the gradient scale consistent.