§ 4 · Probability & Statistics
Distributions, Bayes rule, MLE / MAP, KL divergence, cross-entropy, mutual information, and sampling — the probabilistic backbone of every loss function, training objective, and decoding algorithm in LLMs.
1. Overview
Every quantity an LLM computes or optimises is probabilistic at its core. The training loss is negative log-likelihood — a measurement of how wrong the model distribution is relative to the data distribution. The KL divergence between the fine-tuned policy and the SFT reference is what PPO minimises in RLHF. Token generation is categorical sampling from a softmax distribution. Scaling-law curves are fit by maximum-likelihood regression on perplexity. Without fluency in the language of probability, none of these can be read from first principles.
2. Key Concepts
Random Variables and Moments
A random variable X maps outcomes to real numbers. Its first two moments — expectation and variance — characterise the centre and spread of its distribution:
E[X] = ∑ x P(X=x) (discrete)
= ∫ x p(x) dx (continuous)
Var(X) = E[(X - E[X])²]
= E[X²] - E[X]² (computational formula)
Cov(X,Y) = E[(X-EX)(Y-EY)] shape intuition: correlated activations
Corr(X,Y) = Cov(X,Y) / (σ_X σ_Y) ∈ [-1, 1]The law of total expectation E[X] = E[E[X|Y]] appears in importance sampling (section below) and in the ELBO of VAEs. The central limit theorem says that the mean of n iid samples approaches a Gaussian as n → ∞ — this justifies Gaussian approximations to parameter posteriors in large-scale training.
Key Distributions
Five distributions dominate ML. Discrete distributions drive token-level decisions; continuous ones govern weight initialization, latent spaces, and Bayesian priors.
| Distribution | PMF / PDF | LLM connection |
|---|---|---|
| Bernoulli(p) | pˣ(1−p)^(1−x), x∈{0,1} | Binary classification head, reward model output |
| Categorical(π) | ∏ πₖ^[y=k] | Next-token prediction; softmax output layer |
| Gaussian N(μ,σ²) | (2πσ²)^−½ exp(−(x−μ)²/2σ²) | Xavier/Kaiming init, VAE latent, additive noise |
| Beta(α,β) | x^(α−1)(1−x)^(β−1)/B(α,β) | Conjugate prior for coin bias p; Bayesian A/B tests |
| Dirichlet(α) | ∏ πₖ^(αₖ−1) / B(α) | Prior over token distributions; LDA topic modelling |
MLE, MAP, and Bayesian Inference
Three nested frameworks for estimating parameters from data — each adds one additional ingredient:
| Framework | Objective | Output | In practice |
|---|---|---|---|
| MLE | argmax_θ P(D|θ) | Single point estimate | Language model pre-training cross-entropy |
| MAP | argmax_θ P(D|θ)P(θ) | Single point estimate | MLE + L2 weight decay (Gaussian prior on θ) |
| Bayesian | full posterior P(θ|D) | Distribution over θ | Laplace approximation, MCMC, variational inference |
Why MLE = cross-entropy minimization: NLL loss = −log P(D|θ) = −∑ log P(yₜ|y{<t},θ). Minimizing NLL is equivalent to minimizing the KL divergence between the empirical data distribution p_data and the model distribution p_θ, since KL(p_data‖p_θ) = H(p_data, p_θ) − H(p_data), and H(p_data) does not depend on θ. Standard pre-training is nothing more than maximum-likelihood estimation over the next-token Categorical distribution.
argmin_θ -∑_t log P_θ(y_t | y_<t) ← language model training loss
= argmin_θ KL(p_data || p_θ) ← same objective, information-theory view
= argmin_θ H(p_data, p_θ) ← cross-entropy (p_data is fixed)KL Divergence and Information Distances
The KL divergence measures "how much information is lost when Q is used to approximate P":
KL(P||Q) = E_P[ log P(x)/Q(x) ]
= ∑_x P(x) log P(x)/Q(x) (discrete)
= H(P, Q) − H(P) (cross-entropy minus entropy)
Properties:
KL(P||Q) ≥ 0 (Gibbs inequality)
KL(P||Q) = 0 iff P = Q a.e.
KL(P||Q) ≠ KL(Q||P) (asymmetric — not a metric!)| Divergence | Formula | Symmetric? | Bounded? | LLM use |
|---|---|---|---|---|
| KL(P‖Q) | ∑ P log P/Q | No | No (can be ∞) | MLE / NLL training |
| KL(Q‖P) | ∑ Q log Q/P | No | No | VAE ELBO, RLHF PPO |
| JS(P,Q) | (KL(P‖M) + KL(Q‖M))/2 | Yes | [0, ln 2] | GAN loss analysis |
| Wasserstein-1 | inf_γ E_γ[‖x−y‖] | Yes | Metric (triangle ineq) | WGAN; OT-based alignment |
Entropy, Cross-Entropy, and Mutual Information
The information-theoretic trio that appears repeatedly in LLM papers:
H(P) = -∑ P(x) log P(x) entropy (bits if log₂, nats if ln)
= average surprise under P
H(P, Q) = -∑ P(x) log Q(x) cross-entropy of Q under P
= H(P) + KL(P||Q) always ≥ H(P)
Perplexity = exp(H(P, Q)) in nats; 2^H(P,Q) in bits
= exp(avg NLL) lower is better; random = vocab size
I(X;Y) = H(X) - H(X|Y) mutual information
= KL( P(X,Y) || P(X)P(Y) ) how much knowing Y reduces uncertainty in X
= 0 iff X, Y are independentSampling Methods
Sampling from a distribution is non-trivial when only the unnormalised density is available (as in parameter posteriors). Four strategies in increasing generality:
| Method | Requires | Scales to | Used for |
|---|---|---|---|
| Inverse CDF | Closed-form CDF⁻¹ | 1D distributions | Top-p / nucleus sampling of token probs |
| Rejection | Proposal q(x) ≥ p(x)/M | Low-dim; acceptance rate = 1/M | Simulator rejection, data augmentation |
| Importance | Tractable proposal q(x) | Any dim; variance can explode | RLHF off-policy correction, PPO clip |
| MCMC (MH / HMC) | Unnormalised p(x) evaluations | High-dim; slow mixing | Bayesian neural networks, Laplace approx |
Importance sampling estimate of E_p[f(x)] using samples from a proposal q:
E_p[f(x)] = E_q[ f(x) · p(x)/q(x) ] # exact; p/q are importance weights In RLHF PPO: old policy π_old is the proposal; new policy π_θ is the target. Ratio r(a|s) = π_θ(a|s) / π_old(a|s) # per-token importance weight PPO clips r to [1-ε, 1+ε] to prevent high-variance updates when π_θ drifts.
Fisher Information
The Fisher information matrix F(θ) measures how much information data carries about parameters — equivalently, the curvature of the log-likelihood:
F(θ) = E[ (∇_θ log p(x|θ))(∇_θ log p(x|θ))ᵀ ]
= −E[ ∇²_θ log p(x|θ) ] (under mild regularity conditions)
Cramér-Rao bound: Var(θ̂) ≥ 1/F(θ) no unbiased estimator can beat this
In ML:
Natural gradient = F⁻¹ · ∇L (preconditions by geometry of p, not Euclidean)
K-FAC / Shampoo ≈ block-diagonal F⁻¹ (efficient second-order optimisers for LLMs)3. Core Mechanism: Bayesian Inference
Bayesian inference is the mathematically consistent framework for updating beliefs in the face of evidence. MLE and MAP are special cases. The ELBO in VAEs, the KL penalty in PPO, and the conjugate prior trick in Beta-Binomial coin inference all reduce to the same Bayes rule applied in different settings.
Background
You want to estimate the bias p of a coin after observing h heads in n flips. Three strategies: (1) MLE — set p̂ = h/n; gives 0 or 1 at extreme counts, no uncertainty. (2) MAP — add a prior on p, maximise the joint; prior acts as regularisation but still gives a point. (3) Bayesian — maintain a full distribution over p, updating it with every new observation. Beta-Binomial is the textbook example because the posterior is analytically tractable.
Plan
- Choose prior: p ~ Beta(α₀, β₀). Mean = α₀/(α₀+β₀); concentration = α₀+β₀ (higher → tighter prior).
- Observe data: h heads and t = n−h tails from Binomial(n, p). Likelihood: P(h,t | p) = C(n,h) · pʰ(1−p)ᵗ.
- Apply Bayes rule. The normalising constant Z = ∫ Beta(α₀,β₀) · Binomial(h,t|p) dp is a Beta function — closed-form. Result: posterior = Beta(α₀+h, β₀+t).
- Read off: posterior mean = (α₀+h)/(α₀+β₀+n); MAP = (α₀+h−1)/(α₀+β₀+n−2). As n → ∞, both converge to the MLE p̂ = h/n.
Walkthrough — 5 coin flips with weak prior
Initial conditions: prior Beta(2, 2) — symmetric, mean = 0.5, weakly believes the coin is fair. Observe flips: H, H, H, T, H (4 heads, 1 tail).
Start: Beta(2, 2) mean=0.500 (prior: coin probably fair) After H→1: Beta(3, 2) α=3, β=2 mean=0.600 After H→2: Beta(4, 2) α=4, β=2 mean=0.667 After H→3: Beta(5, 2) α=5, β=2 mean=0.714 After T→4: Beta(5, 3) α=5, β=3 mean=0.625 ← T pulls mean back After H→5: Beta(6, 3) α=6, β=3 mean=0.667 Final posterior: Beta(6, 3) mean = 6/9 = 0.667 (Bayesian estimate: 4 heads in 5 flips, but prior adds 2+2=4 pseudo-counts) MLE = 4/5 = 0.800 (ignores prior; unreliable at only 5 flips) MAP = (6-1)/(6+3-2) = 5/7 = 0.714 As n → ∞, the prior's 4 pseudo-counts become negligible and Bayes → MLE.
The prior acts as regularisation: stronger priors (larger α₀+β₀) require more data to be overcome. In ML terms: L2 weight decay is equivalent to MAP inference with a Gaussian prior N(0, 1/λ) on parameters — the regularisation coefficient λ encodes prior belief strength.
4. MCMC — Sampling When Z Is Intractable
When the posterior has no conjugate closed-form (e.g. neural network weights with a non-Gaussian likelihood), the normalising constant Z = ∫ p(θ)p(D|θ)dθ is intractable. Markov Chain Monte Carlo bypasses Z: it constructs a Markov chain whose stationary distribution is the target posterior, so samples from a long-enough run are approximately iid draws from P(θ|D).
| Algorithm | Proposal q | Gradient needed? | Mixing speed |
|---|---|---|---|
| Metropolis-Hastings | Symmetric Gaussian | No | Slow (random walk) |
| Gibbs | Full conditional P(xᵢ|x₋ᵢ) | No | Slow in high corr. |
| HMC | Leapfrog integrator | Yes (∇ log p) | Fast; used in PyMC / Stan |
| NUTS (No-U-Turn) | Adaptive HMC path length | Yes | Fastest in practice |
5. Minimal Demos
Demo 1 — Bayesian Beta-Binomial Update
Enter alpha_0 beta_0 heads tails. The demo performs the exact conjugate Bayesian update — no MCMC. Watch how the posterior mean, MAP estimate, and MLE compare at small vs large sample sizes. The prior strength (α₀+β₀) determines how many observations are needed to overcome it.
Demo 2 — KL vs JS Divergence between Gaussians
Enter mu1 sigma1 mu2 sigma2. Computes KL(P‖Q), KL(Q‖P), JS(P,Q) in closed-form for Gaussians. Observe the asymmetry: shifting the means apart makes KL explode while JS stays bounded. This is why PPO clips importance weights — large KL divergence causes high-variance gradient estimates.
6. Production & Source Pointers
| Concept | PyTorch / library | Source file |
|---|---|---|
| Cross-entropy (NLL) loss | F.cross_entropy(logits, targets) | torch/nn/functional.py → nll_loss |
| KL divergence | F.kl_div(log_q, p, reduction='sum') | torch/nn/functional.py → kl_div |
| Distributions (Beta, Dirichlet, …) | torch.distributions.Beta(a, b) | torch/distributions/beta.py |
| Top-p / nucleus sampling | top_p_filtering() in HF Transformers | transformers/generation/logits_process.py |
| RLHF KL penalty (PPO) | kl_penalty() in trl.PPOTrainer | trl/trainer/ppo_trainer.py |
| Importance sampling weights | ratio = log_probs_new - log_probs_old; clip | trl/trainer/ppo_trainer.py → compute_loss() |
| Fisher / natural gradient | torch.optim.LBFGS (approx), Shampoo | distributed_shampoo / google/distributed_shampoo |
F.kl_div(q, p) on raw probabilities — it expects log_q in the first argument (log-space). Mixing raw and log-space probabilities is a silent bug that gives wrong KL values without raising an error. Always use F.kl_div(torch.log(q), p, reduction='batchmean') or pass log_target=True if both args are log-probs.7. References
Papers
- Schulman et al. 2017 — Proximal Policy Optimization Algorithms (arXiv:1707.06347) — KL penalty + importance sampling in RLHF
- Ouyang et al. 2022 — Training language models to follow instructions with human feedback (arXiv:2203.02155) — InstructGPT; KL(π‖π_ref) penalty
- Kingma & Welling 2013 — Auto-Encoding Variational Bayes (arXiv:1312.6114) — ELBO derivation; KL(q‖p) + reconstruction
- Kaplan et al. 2020 — Scaling Laws for Neural Language Models (arXiv:2001.08361) — MLE fitting of perplexity curves
- Hofmann et al. 2001 — Probabilistic Latent Semantic Analysis — Dirichlet-Categorical prior in topic models
- Martens 2014 — New insights and perspectives on the natural gradient method (arXiv:1412.1193) — Fisher information & NGD
Lectures
- Stanford CS229 — Andrew Ng, Lecture Notes on Probability and Statistics (free PDF at cs229.stanford.edu)
- MIT 6.041 / 6.042 — Probabilistic Systems Analysis — full lecture videos on OpenCourseWare
- CMU 10-708 — Eric Xing, Probabilistic Graphical Models — Bayesian networks, MCMC, variational inference
- Berkeley CS294-158 — Pieter Abbeel, Deep Unsupervised Learning — ELBO, VAE, flow models
- Oxford Hilary Term — Machine Learning — Bayesian inference lectures (Yee Whye Teh)
- DeepMind × UCL — David Silver, Reinforcement Learning lecture 7 — importance sampling in policy gradient
Textbooks
- Bishop — Pattern Recognition and Machine Learning (PRML), Chapters 1–3: probability, Bayes, distributions
- Murphy — Probabilistic Machine Learning: Introduction (free draft at probml.ai) — Vol 1 Ch 2–5
- Murphy — Probabilistic Machine Learning: Advanced Topics — Vol 2 Ch 5: information theory
- MacKay — Information Theory, Inference, and Learning Algorithms (free PDF at inference.org.uk)
- Casella & Berger — Statistical Inference (2nd ed.) — rigorous treatment of MLE/MAP/sufficient statistics
- Gelman et al. — Bayesian Data Analysis (BDA3, free PDF) — MCMC in practice
Code / Repos
huggingface/trl— PPO trainer; KL penalty & importance weights inppo_trainer.pypymc-devs/pymc— Bayesian modelling with NUTS/HMC; sampling from arbitrary posteriorspytorch/pytorch—torch/distributions/(Beta, Dirichlet, Categorical);torch/nn/functional.py(kl_div, cross_entropy)
Blog Posts
- Lilian Weng — From Autoencoder to Beta-VAE (lilianweng.github.io) — ELBO & KL decomposition
- Lilian Weng — Policy Gradient Algorithms — importance sampling, PPO, TRPO derivations
- Eric Jang — A Beginner's Guide to Variational Methods — mean-field vs full Bayes
- Sebastian Ruder — An overview of gradient descent optimisation algorithms — connects SGD to natural gradient
- distill.pub — Why Momentum Really Works — variance reduction interpretation
8. Interview Prep
Questions asked in ML / LLM interviews at Anthropic, OpenAI, DeepMind, Meta AI, and NVIDIA — probability and statistics round.
Q1. Why is minimizing cross-entropy loss equivalent to maximum-likelihood estimation?
MLE: maximize E_data[log p_θ(y|x)]. Cross-entropy H(p_data, p_θ) = −E_data[log p_θ(y|x)], so minimizing H is identical to maximizing log-likelihood. Crucially, H(p_data, p_θ) = H(p_data) + KL(p_data‖p_θ). Since H(p_data) is fixed, minimizing H is equivalent to minimizing KL(p_data‖p_θ) — training pushes the model distribution toward the data distribution.
Q2. What is the difference between KL(P‖Q) and KL(Q‖P)? When does each appear in LLM training?
KL(P‖Q) = E_P[log P/Q]. If Q(x) = 0 where P(x) > 0, KL → ∞ — so the optimiser must spread Q across all of P's support (mass-covering / mean-seeking). Used in MLE training. KL(Q‖P) = E_Q[log Q/P]. If P(x) = 0 where Q(x) > 0, KL → ∞ — so Q avoids placing mass where P is zero, concentrating on one mode (mode-seeking). Used in VAE ELBO and RLHF PPO penalty term (keeps policy from deviating too far from the SFT reference).
Q3. Compute perplexity if the model's average NLL on a test set is 3.0 nats. What does it mean?
PPL = e^3.0 ≈ 20.1. Intuition: the model is, on average, as uncertain as if it had to choose uniformly among ≈20 equally likely next tokens. A random baseline for GPT-2's 50,257-token vocabulary gives PPL = 50,257. Human-level perplexity on PTB is ~60–70 for word-level models. Modern LLMs achieve single-digit perplexity on well-studied datasets.
Q4. How does L2 weight decay relate to MAP inference? What prior does it correspond to?
MAP objective: argmax_θ P(D|θ)P(θ). With a Gaussian prior θ ~ N(0, σ²I), log P(θ) = −‖θ‖₂² / (2σ²) + const. So MAP = argmax [log-likelihood − ‖θ‖₂²/(2σ²)] = argmin [NLL + λ‖θ‖₂²] where λ = 1/(2σ²). Standard L2 regularisation is exactly MAP estimation with an isotropic Gaussian prior on all parameters; stronger λ = stronger (tighter) prior belief that weights should stay near zero.
Q5. Explain importance sampling. Why does PPO clip the importance weights?
Importance sampling lets us estimate E_p[f(x)] using samples from a different distribution q: E_p[f(x)] = E_q[f(x) w(x)] where w(x) = p(x)/q(x). In PPO, samples are collected under π_old (the proposal) and we optimise the objective under π_θ (the target). Ratio r = π_θ/π_old is the importance weight. When π_θ drifts far from π_old, r can be very large → high-variance gradient estimates → unstable training. PPO clips r to [1−ε, 1+ε] (typically ε=0.2), trading off bias for much lower variance.
Q6. Describe the Beta-Binomial conjugate update in one sentence, then give the posterior for Beta(1,1) prior and 7 heads / 3 tails.
Observing h heads and t tails updates Beta(α,β) to Beta(α+h, β+t) — prior pseudo-counts simply accumulate with real counts. Beta(1,1) is a Uniform prior (no prior belief). After 7H/3T: posterior = Beta(8, 4), mean = 8/12 = 0.667, MAP = 7/10 = 0.700, MLE = 7/10 = 0.700. With only 10 flips, MLE = MAP here because the uniform prior adds two equal pseudo-counts that cancel in the mode.
Q7. What is mutual information I(X;Y) and why does it appear in information bottleneck / representation learning?
I(X;Y) = H(X) − H(X|Y) = KL(P(X,Y)‖P(X)P(Y)) measures how much knowing Y reduces uncertainty in X. I(X;Y) = 0 iff X ⊥ Y. In the Information Bottleneck (Tishby et al.), a representation Z is learned to minimize I(X;Z) (compression) while maximizing I(Z;Y) (task-relevant information). In contrastive representation learning (SimCLR, CLIP), the InfoNCE loss is a lower bound on I(X;Y) between input views or modalities.
Q8. What is the Fisher information matrix and why does it matter for optimisation?
F(θ) = E[(∇ log p)(∇ log p)ᵀ] = −E[∇² log p] is the expected curvature of the log-likelihood. The natural gradient ∇̃L = F⁻¹∇L preconditions the gradient by the geometry of the parameter distribution rather than Euclidean space — it is invariant to reparameterisation. In practice, F is d×d (intractable for large LLMs), so K-FAC and Shampoo use Kronecker-factored block-diagonal approximations: F ≈ A ⊗ B for each layer, where A ≈ E[xxᵀ] (input covariance) and B ≈ E[δδᵀ] (gradient covariance).