Part II — Mathematical Foundations

§ 4 · Probability & Statistics

Distributions, Bayes rule, MLE / MAP, KL divergence, cross-entropy, mutual information, and sampling — the probabilistic backbone of every loss function, training objective, and decoding algorithm in LLMs.

1. Overview

Every quantity an LLM computes or optimises is probabilistic at its core. The training loss is negative log-likelihood — a measurement of how wrong the model distribution is relative to the data distribution. The KL divergence between the fine-tuned policy and the SFT reference is what PPO minimises in RLHF. Token generation is categorical sampling from a softmax distribution. Scaling-law curves are fit by maximum-likelihood regression on perplexity. Without fluency in the language of probability, none of these can be read from first principles.

2. Key Concepts

Random Variables and Moments

A random variable X maps outcomes to real numbers. Its first two moments — expectation and variance — characterise the centre and spread of its distribution:

E[X]       = ∑ x P(X=x)         (discrete)
           = ∫ x p(x) dx         (continuous)

Var(X)     = E[(X - E[X])²]
           = E[X²] - E[X]²       (computational formula)

Cov(X,Y)   = E[(X-EX)(Y-EY)]     shape intuition: correlated activations
Corr(X,Y)  = Cov(X,Y) / (σ_X σ_Y)  ∈ [-1, 1]

The law of total expectation E[X] = E[E[X|Y]] appears in importance sampling (section below) and in the ELBO of VAEs. The central limit theorem says that the mean of n iid samples approaches a Gaussian as n → ∞ — this justifies Gaussian approximations to parameter posteriors in large-scale training.

Key Distributions

Five distributions dominate ML. Discrete distributions drive token-level decisions; continuous ones govern weight initialization, latent spaces, and Bayesian priors.

DistributionPMF / PDFLLM connection
Bernoulli(p)pˣ(1−p)^(1−x), x∈{0,1}Binary classification head, reward model output
Categorical(π)∏ πₖ^[y=k]Next-token prediction; softmax output layer
Gaussian N(μ,σ²)(2πσ²)^−½ exp(−(x−μ)²/2σ²)Xavier/Kaiming init, VAE latent, additive noise
Beta(α,β)x^(α−1)(1−x)^(β−1)/B(α,β)Conjugate prior for coin bias p; Bayesian A/B tests
Dirichlet(α)∏ πₖ^(αₖ−1) / B(α)Prior over token distributions; LDA topic modelling
Conjugate priors — when the prior and posterior have the same functional form, Bayesian updating is a closed-form parameter update with zero approximation error. Beta is conjugate to Binomial; Dirichlet is conjugate to Categorical. This is why the coin-flip posterior update in Demo 4103 requires no MCMC.

MLE, MAP, and Bayesian Inference

Three nested frameworks for estimating parameters from data — each adds one additional ingredient:

FrameworkObjectiveOutputIn practice
MLEargmax_θ P(D|θ)Single point estimateLanguage model pre-training cross-entropy
MAPargmax_θ P(D|θ)P(θ)Single point estimateMLE + L2 weight decay (Gaussian prior on θ)
Bayesianfull posterior P(θ|D)Distribution over θLaplace approximation, MCMC, variational inference

Why MLE = cross-entropy minimization: NLL loss = −log P(D|θ) = −∑ log P(yₜ|y{<t},θ). Minimizing NLL is equivalent to minimizing the KL divergence between the empirical data distribution p_data and the model distribution p_θ, since KL(p_data‖p_θ) = H(p_data, p_θ) − H(p_data), and H(p_data) does not depend on θ. Standard pre-training is nothing more than maximum-likelihood estimation over the next-token Categorical distribution.

argmin_θ  -∑_t log P_θ(y_t | y_<t)         ← language model training loss
         =  argmin_θ  KL(p_data || p_θ)     ← same objective, information-theory view
         =  argmin_θ  H(p_data, p_θ)        ← cross-entropy (p_data is fixed)

KL Divergence and Information Distances

The KL divergence measures "how much information is lost when Q is used to approximate P":

KL(P||Q)  =  E_P[ log P(x)/Q(x) ]
           =  ∑_x P(x) log P(x)/Q(x)          (discrete)
           =  H(P, Q) − H(P)                   (cross-entropy minus entropy)

Properties:
  KL(P||Q) ≥ 0                               (Gibbs inequality)
  KL(P||Q) = 0  iff  P = Q  a.e.
  KL(P||Q) ≠ KL(Q||P)                        (asymmetric — not a metric!)
DivergenceFormulaSymmetric?Bounded?LLM use
KL(P‖Q)∑ P log P/QNoNo (can be ∞)MLE / NLL training
KL(Q‖P)∑ Q log Q/PNoNoVAE ELBO, RLHF PPO
JS(P,Q)(KL(P‖M) + KL(Q‖M))/2Yes[0, ln 2]GAN loss analysis
Wasserstein-1inf_γ E_γ[‖x−y‖]YesMetric (triangle ineq)WGAN; OT-based alignment

Entropy, Cross-Entropy, and Mutual Information

The information-theoretic trio that appears repeatedly in LLM papers:

H(P)        = -∑ P(x) log P(x)           entropy (bits if log₂, nats if ln)
            = average surprise under P

H(P, Q)     = -∑ P(x) log Q(x)           cross-entropy of Q under P
            = H(P) + KL(P||Q)             always ≥ H(P)

Perplexity  = exp(H(P, Q))               in nats;  2^H(P,Q)  in bits
            = exp(avg NLL)               lower is better; random = vocab size

I(X;Y)      = H(X) - H(X|Y)             mutual information
            = KL( P(X,Y) || P(X)P(Y) )   how much knowing Y reduces uncertainty in X
            = 0  iff X, Y are independent
Perplexity quick-math. GPT-2 (117M) achieves ~30 perplexity on WebText (PTB). A token-level cross-entropy of 3.4 nats → PPL = e^3.4 ≈ 30. Each nat = 1/ln2 ≈ 1.44 bits. A perfect model on uniform vocabulary of size V would have PPL = V. Random baseline: PPL = 50,257 (GPT-2 vocab size).

Sampling Methods

Sampling from a distribution is non-trivial when only the unnormalised density is available (as in parameter posteriors). Four strategies in increasing generality:

MethodRequiresScales toUsed for
Inverse CDFClosed-form CDF⁻¹1D distributionsTop-p / nucleus sampling of token probs
RejectionProposal q(x) ≥ p(x)/MLow-dim; acceptance rate = 1/MSimulator rejection, data augmentation
ImportanceTractable proposal q(x)Any dim; variance can explodeRLHF off-policy correction, PPO clip
MCMC (MH / HMC)Unnormalised p(x) evaluationsHigh-dim; slow mixingBayesian neural networks, Laplace approx

Importance sampling estimate of E_p[f(x)] using samples from a proposal q:

E_p[f(x)] = E_q[ f(x) · p(x)/q(x) ]      # exact; p/q are importance weights

In RLHF PPO: old policy π_old is the proposal; new policy π_θ is the target.
Ratio  r(a|s) = π_θ(a|s) / π_old(a|s)     # per-token importance weight
PPO clips r to [1-ε, 1+ε] to prevent high-variance updates when π_θ drifts.

Fisher Information

The Fisher information matrix F(θ) measures how much information data carries about parameters — equivalently, the curvature of the log-likelihood:

F(θ)  =  E[ (∇_θ log p(x|θ))(∇_θ log p(x|θ))ᵀ ]
       =  −E[ ∇²_θ log p(x|θ) ]             (under mild regularity conditions)

Cramér-Rao bound:  Var(θ̂) ≥ 1/F(θ)         no unbiased estimator can beat this

In ML:
  Natural gradient = F⁻¹ · ∇L             (preconditions by geometry of p, not Euclidean)
  K-FAC / Shampoo  ≈ block-diagonal F⁻¹   (efficient second-order optimisers for LLMs)

3. Core Mechanism: Bayesian Inference

Bayesian inference is the mathematically consistent framework for updating beliefs in the face of evidence. MLE and MAP are special cases. The ELBO in VAEs, the KL penalty in PPO, and the conjugate prior trick in Beta-Binomial coin inference all reduce to the same Bayes rule applied in different settings.

Background

You want to estimate the bias p of a coin after observing h heads in n flips. Three strategies: (1) MLE — set p̂ = h/n; gives 0 or 1 at extreme counts, no uncertainty. (2) MAP — add a prior on p, maximise the joint; prior acts as regularisation but still gives a point. (3) Bayesian — maintain a full distribution over p, updating it with every new observation. Beta-Binomial is the textbook example because the posterior is analytically tractable.

Plan

  1. Choose prior: p ~ Beta(α₀, β₀). Mean = α₀/(α₀+β₀); concentration = α₀+β₀ (higher → tighter prior).
  2. Observe data: h heads and t = n−h tails from Binomial(n, p). Likelihood: P(h,t | p) = C(n,h) · pʰ(1−p)ᵗ.
  3. Apply Bayes rule. The normalising constant Z = ∫ Beta(α₀,β₀) · Binomial(h,t|p) dp is a Beta function — closed-form. Result: posterior = Beta(α₀+h, β₀+t).
  4. Read off: posterior mean = (α₀+h)/(α₀+β₀+n); MAP = (α₀+h−1)/(α₀+β₀+n−2). As n → ∞, both converge to the MLE p̂ = h/n.

Walkthrough — 5 coin flips with weak prior

Initial conditions: prior Beta(2, 2) — symmetric, mean = 0.5, weakly believes the coin is fair. Observe flips: H, H, H, T, H (4 heads, 1 tail).

Start:       Beta(2, 2)          mean=0.500  (prior: coin probably fair)

After H→1:  Beta(3, 2)           α=3, β=2    mean=0.600
After H→2:  Beta(4, 2)           α=4, β=2    mean=0.667
After H→3:  Beta(5, 2)           α=5, β=2    mean=0.714
After T→4:  Beta(5, 3)           α=5, β=3    mean=0.625   ← T pulls mean back
After H→5:  Beta(6, 3)           α=6, β=3    mean=0.667

Final posterior: Beta(6, 3)
  mean = 6/9 = 0.667   (Bayesian estimate: 4 heads in 5 flips, but prior adds 2+2=4 pseudo-counts)
  MLE  = 4/5 = 0.800   (ignores prior; unreliable at only 5 flips)
  MAP  = (6-1)/(6+3-2) = 5/7 = 0.714

As n → ∞, the prior's 4 pseudo-counts become negligible and Bayes → MLE.

The prior acts as regularisation: stronger priors (larger α₀+β₀) require more data to be overcome. In ML terms: L2 weight decay is equivalent to MAP inference with a Gaussian prior N(0, 1/λ) on parameters — the regularisation coefficient λ encodes prior belief strength.

4. MCMC — Sampling When Z Is Intractable

When the posterior has no conjugate closed-form (e.g. neural network weights with a non-Gaussian likelihood), the normalising constant Z = ∫ p(θ)p(D|θ)dθ is intractable. Markov Chain Monte Carlo bypasses Z: it constructs a Markov chain whose stationary distribution is the target posterior, so samples from a long-enough run are approximately iid draws from P(θ|D).

AlgorithmProposal qGradient needed?Mixing speed
Metropolis-HastingsSymmetric GaussianNoSlow (random walk)
GibbsFull conditional P(xᵢ|x₋ᵢ)NoSlow in high corr.
HMCLeapfrog integratorYes (∇ log p)Fast; used in PyMC / Stan
NUTS (No-U-Turn)Adaptive HMC path lengthYesFastest in practice

5. Minimal Demos

Demo 1 — Bayesian Beta-Binomial Update

Enter alpha_0 beta_0 heads tails. The demo performs the exact conjugate Bayesian update — no MCMC. Watch how the posterior mean, MAP estimate, and MLE compare at small vs large sample sizes. The prior strength (α₀+β₀) determines how many observations are needed to overcome it.

Bayesian Beta-Binomial Update — C Demo
stdin (optional)

Demo 2 — KL vs JS Divergence between Gaussians

Enter mu1 sigma1 mu2 sigma2. Computes KL(P‖Q), KL(Q‖P), JS(P,Q) in closed-form for Gaussians. Observe the asymmetry: shifting the means apart makes KL explode while JS stays bounded. This is why PPO clips importance weights — large KL divergence causes high-variance gradient estimates.

KL vs JS Divergence — C Demo
stdin (optional)

6. Production & Source Pointers

ConceptPyTorch / librarySource file
Cross-entropy (NLL) lossF.cross_entropy(logits, targets)torch/nn/functional.py → nll_loss
KL divergenceF.kl_div(log_q, p, reduction='sum')torch/nn/functional.py → kl_div
Distributions (Beta, Dirichlet, …)torch.distributions.Beta(a, b)torch/distributions/beta.py
Top-p / nucleus samplingtop_p_filtering() in HF Transformerstransformers/generation/logits_process.py
RLHF KL penalty (PPO)kl_penalty() in trl.PPOTrainertrl/trainer/ppo_trainer.py
Importance sampling weightsratio = log_probs_new - log_probs_old; cliptrl/trainer/ppo_trainer.py → compute_loss()
Fisher / natural gradienttorch.optim.LBFGS (approx), Shampoodistributed_shampoo / google/distributed_shampoo
Numerical tip: never call F.kl_div(q, p) on raw probabilities — it expects log_q in the first argument (log-space). Mixing raw and log-space probabilities is a silent bug that gives wrong KL values without raising an error. Always use F.kl_div(torch.log(q), p, reduction='batchmean') or pass log_target=True if both args are log-probs.

7. References

Papers

  • Schulman et al. 2017 — Proximal Policy Optimization Algorithms (arXiv:1707.06347) — KL penalty + importance sampling in RLHF
  • Ouyang et al. 2022 — Training language models to follow instructions with human feedback (arXiv:2203.02155) — InstructGPT; KL(π‖π_ref) penalty
  • Kingma & Welling 2013 — Auto-Encoding Variational Bayes (arXiv:1312.6114) — ELBO derivation; KL(q‖p) + reconstruction
  • Kaplan et al. 2020 — Scaling Laws for Neural Language Models (arXiv:2001.08361) — MLE fitting of perplexity curves
  • Hofmann et al. 2001 — Probabilistic Latent Semantic Analysis — Dirichlet-Categorical prior in topic models
  • Martens 2014 — New insights and perspectives on the natural gradient method (arXiv:1412.1193) — Fisher information & NGD

Lectures

  • Stanford CS229 — Andrew Ng, Lecture Notes on Probability and Statistics (free PDF at cs229.stanford.edu)
  • MIT 6.041 / 6.042 — Probabilistic Systems Analysis — full lecture videos on OpenCourseWare
  • CMU 10-708 — Eric Xing, Probabilistic Graphical Models — Bayesian networks, MCMC, variational inference
  • Berkeley CS294-158 — Pieter Abbeel, Deep Unsupervised Learning — ELBO, VAE, flow models
  • Oxford Hilary Term — Machine Learning — Bayesian inference lectures (Yee Whye Teh)
  • DeepMind × UCL — David Silver, Reinforcement Learning lecture 7 — importance sampling in policy gradient

Textbooks

  • Bishop — Pattern Recognition and Machine Learning (PRML), Chapters 1–3: probability, Bayes, distributions
  • Murphy — Probabilistic Machine Learning: Introduction (free draft at probml.ai) — Vol 1 Ch 2–5
  • Murphy — Probabilistic Machine Learning: Advanced Topics — Vol 2 Ch 5: information theory
  • MacKay — Information Theory, Inference, and Learning Algorithms (free PDF at inference.org.uk)
  • Casella & Berger — Statistical Inference (2nd ed.) — rigorous treatment of MLE/MAP/sufficient statistics
  • Gelman et al. — Bayesian Data Analysis (BDA3, free PDF) — MCMC in practice

Code / Repos

  • huggingface/trl — PPO trainer; KL penalty & importance weights in ppo_trainer.py
  • pymc-devs/pymc — Bayesian modelling with NUTS/HMC; sampling from arbitrary posteriors
  • pytorch/pytorchtorch/distributions/ (Beta, Dirichlet, Categorical); torch/nn/functional.py (kl_div, cross_entropy)

Blog Posts

  • Lilian Weng — From Autoencoder to Beta-VAE (lilianweng.github.io) — ELBO & KL decomposition
  • Lilian Weng — Policy Gradient Algorithms — importance sampling, PPO, TRPO derivations
  • Eric Jang — A Beginner's Guide to Variational Methods — mean-field vs full Bayes
  • Sebastian Ruder — An overview of gradient descent optimisation algorithms — connects SGD to natural gradient
  • distill.pub — Why Momentum Really Works — variance reduction interpretation

8. Interview Prep

Questions asked in ML / LLM interviews at Anthropic, OpenAI, DeepMind, Meta AI, and NVIDIA — probability and statistics round.

Q1. Why is minimizing cross-entropy loss equivalent to maximum-likelihood estimation?

MLE: maximize E_data[log p_θ(y|x)]. Cross-entropy H(p_data, p_θ) = −E_data[log p_θ(y|x)], so minimizing H is identical to maximizing log-likelihood. Crucially, H(p_data, p_θ) = H(p_data) + KL(p_data‖p_θ). Since H(p_data) is fixed, minimizing H is equivalent to minimizing KL(p_data‖p_θ) — training pushes the model distribution toward the data distribution.

Q2. What is the difference between KL(P‖Q) and KL(Q‖P)? When does each appear in LLM training?

KL(P‖Q) = E_P[log P/Q]. If Q(x) = 0 where P(x) > 0, KL → ∞ — so the optimiser must spread Q across all of P's support (mass-covering / mean-seeking). Used in MLE training. KL(Q‖P) = E_Q[log Q/P]. If P(x) = 0 where Q(x) > 0, KL → ∞ — so Q avoids placing mass where P is zero, concentrating on one mode (mode-seeking). Used in VAE ELBO and RLHF PPO penalty term (keeps policy from deviating too far from the SFT reference).

Q3. Compute perplexity if the model's average NLL on a test set is 3.0 nats. What does it mean?

PPL = e^3.0 ≈ 20.1. Intuition: the model is, on average, as uncertain as if it had to choose uniformly among ≈20 equally likely next tokens. A random baseline for GPT-2's 50,257-token vocabulary gives PPL = 50,257. Human-level perplexity on PTB is ~60–70 for word-level models. Modern LLMs achieve single-digit perplexity on well-studied datasets.

Q4. How does L2 weight decay relate to MAP inference? What prior does it correspond to?

MAP objective: argmax_θ P(D|θ)P(θ). With a Gaussian prior θ ~ N(0, σ²I), log P(θ) = −‖θ‖₂² / (2σ²) + const. So MAP = argmax [log-likelihood − ‖θ‖₂²/(2σ²)] = argmin [NLL + λ‖θ‖₂²] where λ = 1/(2σ²). Standard L2 regularisation is exactly MAP estimation with an isotropic Gaussian prior on all parameters; stronger λ = stronger (tighter) prior belief that weights should stay near zero.

Q5. Explain importance sampling. Why does PPO clip the importance weights?

Importance sampling lets us estimate E_p[f(x)] using samples from a different distribution q: E_p[f(x)] = E_q[f(x) w(x)] where w(x) = p(x)/q(x). In PPO, samples are collected under π_old (the proposal) and we optimise the objective under π_θ (the target). Ratio r = π_θ/π_old is the importance weight. When π_θ drifts far from π_old, r can be very large → high-variance gradient estimates → unstable training. PPO clips r to [1−ε, 1+ε] (typically ε=0.2), trading off bias for much lower variance.

Q6. Describe the Beta-Binomial conjugate update in one sentence, then give the posterior for Beta(1,1) prior and 7 heads / 3 tails.

Observing h heads and t tails updates Beta(α,β) to Beta(α+h, β+t) — prior pseudo-counts simply accumulate with real counts. Beta(1,1) is a Uniform prior (no prior belief). After 7H/3T: posterior = Beta(8, 4), mean = 8/12 = 0.667, MAP = 7/10 = 0.700, MLE = 7/10 = 0.700. With only 10 flips, MLE = MAP here because the uniform prior adds two equal pseudo-counts that cancel in the mode.

Q7. What is mutual information I(X;Y) and why does it appear in information bottleneck / representation learning?

I(X;Y) = H(X) − H(X|Y) = KL(P(X,Y)‖P(X)P(Y)) measures how much knowing Y reduces uncertainty in X. I(X;Y) = 0 iff X ⊥ Y. In the Information Bottleneck (Tishby et al.), a representation Z is learned to minimize I(X;Z) (compression) while maximizing I(Z;Y) (task-relevant information). In contrastive representation learning (SimCLR, CLIP), the InfoNCE loss is a lower bound on I(X;Y) between input views or modalities.

Q8. What is the Fisher information matrix and why does it matter for optimisation?

F(θ) = E[(∇ log p)(∇ log p)ᵀ] = −E[∇² log p] is the expected curvature of the log-likelihood. The natural gradient ∇̃L = F⁻¹∇L preconditions the gradient by the geometry of the parameter distribution rather than Euclidean space — it is invariant to reparameterisation. In practice, F is d×d (intractable for large LLMs), so K-FAC and Shampoo use Kronecker-factored block-diagonal approximations: F ≈ A ⊗ B for each layer, where A ≈ E[xxᵀ] (input covariance) and B ≈ E[δδᵀ] (gradient covariance).