Part VI — Transformer

§ 6.1 Attention is All You Need

Scaled dot-product attention from first principles. Q/K/V projections, the √dk argument, multi-head splitting, causal & padding masks, and the O(T²d) memory cliff that will motivate FlashAttention.

1. Overview

Self-attention is the operation at the heart of every modern transformer. It takes a sequence of T token vectors and produces a new sequence of T vectors where each output position is a content-addressed mixture of all (or all earlier) inputs. The mixture weights are computed on the fly from the inputs themselves — no learned routing table, no recurrence — which is what makes attention parallelisable across the entire sequence in one matrix multiply.

Vaswani et al. (2017, arXiv:1706.03762) showed that this single primitive — stacked into N blocks alongside a feed-forward layer — replaces every prior sequence-model component (recurrence, convolution, alignment networks) and outperforms them at scale. Every subsequent topic in this curriculum (positional encodings, MQA/GQA, KV-caching, FlashAttention, MoE, speculative decoding) is either a refinement of, or a system built around, scaled dot-product attention.

One-line summary: attention(Q, K, V) = softmax(QKᵀ / √dk) V. Each row of QKᵀ is a similarity vector between one query and all keys; softmax turns it into a distribution; the matrix multiply with V is a weighted average of value vectors.

2. Q, K, V — Projections and Tensor Shapes

Attention starts with three learned linear projections of the same input X ∈ ℝB×T×dmodel:

Q = X W_Q     # what each position is "looking for"
K = X W_K     # what each position "advertises"
V = X W_V     # what each position contributes if attended to

The intuition is a soft, content-addressed memory: a query vector qi is compared against every key vector kj, the comparisons are turned into a probability distribution, and the output is the corresponding mixture of the value vectors. Because Q, K, V are all linear projections of the same X, this is self-attention; in cross-attention (encoder–decoder) Q comes from one sequence and K/V from another.

Shape conventions used on this site

SymbolMeaningTypical valueWhere it lives
BBatch size1 – 4096Outer-most dim, parallel sequences
TSequence length (tokens)128 – 128KTime axis
d_modelHidden / residual stream width768 (GPT-2 small) … 16384 (PaLM-540B)Width of every block input/output
HNumber of attention heads12 (GPT-2) … 128 (GPT-3 175B)Parallel attention spaces per layer
d_k = d_v = d_model / HPer-head feature width64 – 128 typicalAlways satisfies H · d_k = d_model
W_Q, W_K, W_V, W_OLearned projection weights[d_model, d_model]All four matrices ⇒ 4·d_model² params per layer

After projecting and reshaping by head, each of Q, K, V lives in [B, H, T, d_k]. The two leading dims (B, H) are pure broadcast dims — every operation that follows treats them as independent — so a single fused tensor op runs B · H independent attention computations in parallel on the GPU.

3. Core Mechanism — Scaled Dot-Product Attention

Background: Given T queries and T keys (one per token, each of dim dk), we want a similarity score between every query and every key — a T × T matrix. The dot product is the cheapest content-aware similarity available: it can be computed for all pairs simultaneously with one BLAS GEMM.

Plan:

  1. Compute raw scores S = Q Kᵀ → shape [B, H, T, T].
  2. Scale by 1/√dk to keep logit variance ≈ 1 (derived below).
  3. Add a mask: set forbidden positions to -∞ so they get probability 0 after softmax (causal mask, padding mask).
  4. Row-wise softmax → attention probabilities P (each row sums to 1).
  5. Weighted sum of values: O = P V, shape [B, H, T, dv].

Closed form

Attention(Q, K, V) = softmax( Q K^T / sqrt(d_k) + M ) V

Shapes:
  Q : [B, H, T, d_k]
  K : [B, H, T, d_k]    -> K^T : [B, H, d_k, T]
  V : [B, H, T, d_v]
  M : [T, T]   (mask, broadcasts over B and H; -inf where forbidden, 0 elsewhere)

  Q K^T          -> [B, H, T, T]   (one logit per (i, j) pair)
  / sqrt(d_k)    -> [B, H, T, T]   (same shape, just rescaled)
  + M            -> [B, H, T, T]   (broadcast mask add)
  softmax dim=-1 -> [B, H, T, T]   (row-stochastic)
  @ V            -> [B, H, T, d_v] (one output vector per query position)

Why √dk? — variance argument

Treat the components of q and k as iid with mean 0 and variance 1 (a reasonable approximation after standard initialisation and LayerNorm). Then the dot product is a sum of dk iid terms:

q · k = Σ_{i=1..d_k}  q_i k_i

E[q_i k_i]    = E[q_i] E[k_i] = 0         (independent, zero-mean)
Var(q_i k_i)  = E[(q_i k_i)^2] - 0
              = E[q_i^2] E[k_i^2]          (independence)
              = 1 · 1 = 1

Var(q · k)    = sum of d_k iid unit-variance terms
              = d_k

  std(q · k)  = sqrt(d_k)

For dk = 64 the logits have std ≈ 8, so typical values land in [−24, +24]. Softmax of such large logits is essentially a one-hot: the largest logit gets probability ≈ 1, the rest get ≈ 0, and the gradient of softmax saturates (∂softmax/∂x → 0 off the argmax). Training cannot proceed.

Dividing by √dk rescales the variance back to 1 independent of head dim, keeping the softmax in its soft, gradient-friendly regime. This is exactly why Vaswani et al. called it scaled dot-product attention.

Worked numerical walkthrough

Take T = 3, dk = 4. Pick concrete Q, K, V (rounded to 2 d.p. for readability). Trace every shape transition end-to-end:

Q = [[ 1.0,  0.0,  0.0,  0.0],     # query for token 1
     [ 0.0,  1.0,  0.0,  0.0],     # query for token 2
     [ 0.5,  0.5,  0.0,  0.0]]     # query for token 3 (a blend of 1 and 2)

K = [[ 1.0,  0.0,  0.0,  0.0],     # key for token 1
     [ 0.0,  1.0,  0.0,  0.0],     # key for token 2
     [ 0.0,  0.0,  1.0,  0.0]]     # key for token 3 (unrelated direction)

V = [[10, 20],                      # value for token 1
     [30, 40],
     [50, 60]]

Step 1   S = Q K^T          (shape 3 x 3)
         S[i,j] = <q_i, k_j>
         S = [[ 1.0,  0.0,  0.0],
              [ 0.0,  1.0,  0.0],
              [ 0.5,  0.5,  0.0]]

Step 2   S / sqrt(d_k) = S / 2     (since d_k = 4)
         S' = [[ 0.50, 0.00, 0.00],
               [ 0.00, 0.50, 0.00],
               [ 0.25, 0.25, 0.00]]

Step 3   Apply CAUSAL mask: positions j > i become -inf.
         S' = [[ 0.50, -inf, -inf],
               [ 0.00,  0.50, -inf],
               [ 0.25,  0.25,  0.00]]

Step 4   Row-wise softmax (each row sums to 1):
         row 1: [1.00, 0,    0   ]
         row 2: softmax([0.0, 0.5])   = [0.378, 0.622, 0]
         row 3: softmax([0.25, 0.25, 0.0])
                = [exp(.25), exp(.25), exp(0)] / norm
                = [1.284, 1.284, 1.0] / 3.568
                = [0.360, 0.360, 0.280]

         P = [[1.000, 0.000, 0.000],
              [0.378, 0.622, 0.000],
              [0.360, 0.360, 0.280]]

Step 5   O = P V                    (shape 3 x 2)
         O[0] = 1.000*[10,20] + 0*[30,40] + 0*[50,60]
              = [10.0, 20.0]
         O[1] = 0.378*[10,20] + 0.622*[30,40] + 0
              = [22.44, 32.44]
         O[2] = 0.360*[10,20] + 0.360*[30,40] + 0.280*[50,60]
              = [28.4, 38.4]

         O = [[10.00, 20.00],       # token 1 sees only itself
              [22.44, 32.44],       # token 2 mostly attends to itself, some to token 1
              [28.40, 38.40]]       # token 3 spreads attention across all three

The output for token 1 is exactly V[1] — the causal mask forces position 0 to attend only to itself. Token 2 splits its attention between itself and token 1. Token 3 mixes all three. This is the essential picture: each output is a data-dependent convex combination of past values.

4. Multi-Head Attention

A single attention head is already useful, but its output for each token is a single weighted mean — it can only model one type of relationship at a time (e.g., "previous noun" or "matching bracket"). Multi-head attention runs H independent attention computations in parallel — each with its own small Q/K/V projections — and concatenates their outputs. Different heads learn different attention patterns: short-range, long-range, syntactic, coreferent, induction-style copy heads, etc.

Closed form

MultiHead(Q, K, V) = Concat(head_1, ..., head_H) W_O

with   head_h = Attention( Q W_Q^h, K W_K^h, V W_V^h )

Implementation trick: instead of H separate matmuls, allocate one
[d_model, d_model] matrix for each of W_Q, W_K, W_V and just RESHAPE
the result into H heads of width d_k = d_model / H.

So H heads of width d_k cost the same FLOPs as one head of width d_model,
but force the model to spend attention capacity across H parallel subspaces.

The output projection WO is essential: without it the different heads cannot mix. WO takes the H concatenated dk-vectors and produces the next residual-stream update, allowing information from head 7 to be used by future computation in any other channel.

Why split into many small heads instead of one big one?

  • Specialisation. Each head can lock onto a different pattern; the model gets multiple "channels" of relational reasoning per layer.
  • No extra FLOPs. Because dk = dmodel/H, the QKᵀ matmul costs the same as one big head (B·H·T²·dk = B·T²·dmodel).
  • Better gradient flow. Smaller per-head QKᵀ outputs are easier to normalise and produce less-saturated softmax distributions.
  • Interpretability. Heads can be analysed individually — see Elhage et al. 2021 (Anthropic) on induction heads and circuits.

5. Masks — Causal, Padding, Prefix-LM

A mask is a tensor added to the pre-softmax scores. Anywhere it equals -∞ the corresponding probability after softmax becomes exactly 0. Masks are how transformers express "you may not look here": causal masking makes a decoder, padding masks ignore padding tokens, prefix-LM masks make a bidirectional prefix followed by causal continuation.

Causal mask (decoder / GPT)

For autoregressive next-token prediction, token i must not see any token j > i; otherwise the loss is trivial (the model can just copy the next token). The mask is an upper-triangular matrix of −∞:

Green = a real score is kept; red = −∞ so the softmax drives it to 0. Row i shows what query position i is allowed to attend to.

Padding mask

When sequences in a batch have different lengths they are padded to the max length with a PAD token. A padding mask zeroes out attention to padding positions (column-wise −∞) so that legitimate tokens never spend probability mass on pad. Modern packed-sequence training (FlashAttention varlen, vLLM) avoids padding entirely with a per-example cu_seqlens index.

Encoder, decoder, prefix-LM

Mask patternModel classEffect
Full (all zeros)BERT, T5 encoder, ViTBidirectional self-attention — every token sees everything
Strict upper-triangular -infGPT-2/3/4, LLaMA, MistralCausal — token i sees tokens 1..i only
Cross: full K from encoderOriginal Transformer decoderDecoder Q attends to all encoder K/V (no mask), plus causal self-attn
Prefix bidirectional + causal suffixPrefix-LM, UL2, GLMFirst k tokens see each other freely; tokens after k are causal
Block-diagonal (varlen / packed)FlashAttention varlen, vLLMMultiple sequences packed into one tensor; mask hides cross-sequence attention

6. FLOPs & Memory Cost

Knowing the cost of attention by heart is part of the LLM-infra interview bar. Per layer, per sequence (drop B for clarity):

Compute (FLOPs, multiplies only — multiply by 2 if you count mul+add):

  QKV projections     : 3 * T * d_model * d_model      = 3 T d^2
  Q K^T               : H * T * T * d_k = T^2 d_model  = T^2 d
  softmax             : T^2  (O(T^2) but cheap; dominated above)
  P V                 : H * T * T * d_v = T^2 d_model  = T^2 d
  Output projection   : T * d_model * d_model          = T d^2

  Total per layer ~ 4 T d^2  +  2 T^2 d

  Activation-memory peak (the part everyone trips over):
    scores tensor S   :  B * H * T * T  floats   <-- O(T^2) PER HEAD

  For B=4, H=32, T=8192, fp16:
    S = 4 * 32 * 8192 * 8192 * 2 bytes = 17.18 GB    (one tensor!)

  This single allocation is what blows up at long context and what
  FlashAttention removes by never materialising S in HBM.

Two regimes matter when reading roofline plots:

  • Short sequence (T < dmodel): compute is dominated by the QKV/output projections (∝ T d²). The matmul is compute-bound — tensor cores saturate.
  • Long sequence (T ≫ dmodel): the QKᵀ and PV matmuls (∝ T²d) dominate, and the O(T²) score tensor dominates HBM traffic. Attention becomes memory-bandwidth-bound — this is exactly the regime FlashAttention v1–v3 (T44) is designed for.

7. Minimal Demos

Demo A — Scaled dot-product attention end-to-end

Tiny B=1, H=2, T=4, D=8 example. Every shape is printed. The final assertion confirms that with a causal mask, token 0 attends only to itself — so the output at position 0 is exactly V at position 0.

Scaled Dot-Product Attention from Scratch — C Demo
stdin (optional)

Demo B — Why divide by √dk?

Sweeps dk{4, 16, 64, 256, 1024} and prints the entropy and max-probability of the resulting attention distribution, with and without scaling. Without scaling, entropy collapses toward zero as dkgrows — the model degenerates into a hard argmax. With scaling, entropy stays near log2(T).

√d_k Variance Effect — C Demo
stdin (optional)

Demo C — Multi-head attention by hand

Builds MHA from nn.Linear only — no nn.MultiheadAttention magic. Watch the shape go [B,T,dmodel] → [B,H,T,dk] → [B,H,T,dv] → [B,T,dmodel]. The final line confirms the canonical 4·dmodel² parameter count.

Multi-Head Attention from Scratch — C Demo
stdin (optional)

8. Production / Source Pointers

Library / FileSymbolWhat it does
torch/nn/functional.pyF.scaled_dot_product_attentionReference SDPA; dispatches to math / efficient / FlashAttention backends (PyTorch 2.x)
torch/nn/modules/activation.pynn.MultiheadAttentionStock MHA module; uses fused QKV projection when in_proj_weight is set
transformers/models/llama/modeling_llama.pyLlamaAttention.forwardReference HF LLaMA attention impl; reads cleanly, includes RoPE + GQA
Karpathy/nanoGPT model.pyCausalSelfAttention~30 lines of educational MHA with causal mask; identical math to GPT-2
flash-attn/flash_attn/flash_attn_interface.pyflash_attn_funcIO-aware fused attention kernel; never materialises the T x T tensor (T44)
vllm/model_executor/layers/attention.pyPagedAttentionInference-time MHA on top of paged KV-cache; production serving (T34)
xformers/ops/fmha/memory_efficient_attentionMeta's alternative IO-aware kernel, predates FlashAttention v1
cutlass/examples/41_fused_multi_head_attentionCUTLASS reference for hand-fused MHA on Ampere/Hopper tensor cores

9. References

Papers

  • Vaswani et al. (2017) — Attention is All You Need. arXiv:1706.03762
  • Bahdanau, Cho, Bengio (2014) — Neural Machine Translation by Jointly Learning to Align and Translate (additive attention, the predecessor). arXiv:1409.0473
  • Luong, Pham, Manning (2015) — Effective Approaches to Attention-based NMT (dot-product attention). arXiv:1508.04025
  • Britz et al. (2017) — Massive Exploration of NMT Architectures (empirical case for scaled dot-product). arXiv:1703.03906
  • Elhage et al. / Anthropic (2021) — A Mathematical Framework for Transformer Circuits (residual stream + head decomposition). transformer-circuits.pub
  • Phuong & Hutter (2022) — Formal Algorithms for Transformers. arXiv:2207.09238
  • Dao et al. (2022) — FlashAttention: Fast and Memory-Efficient Exact Attention. arXiv:2205.14135 (forward reference for §6.7 cost discussion)

Lectures

  • Stanford CS224n (Manning) — Lecture 8: Attention; Lecture 9: Self-Attention & Transformers
  • Stanford CS25 — Transformers United (lecture 1: introduction; lecture 2: attention deep dive)
  • Stanford CS336 (Liang & Yao) — Building LLMs from Scratch (2024), Attention lecture
  • CMU 11-785 (Bhiksha Raj) — Deep Learning, Attention & Transformer lectures
  • NYU DS-GA 1008 (Yann LeCun & Alfredo Canziani) — Deep Learning, Self-Attention
  • Karpathy — Let's build GPT: from scratch, in code, spelled out (YouTube, 2 hrs)
  • 3Blue1Brown — But what is a GPT? / Attention in transformers, visually explained

Textbooks

  • Jurafsky & Martin — Speech and Language Processing, 3rd ed., Ch. 10–11 (Transformers and Large Language Models)
  • Bishop & Bishop (2024) — Deep Learning: Foundations and Concepts, Ch. 12 (Transformers)
  • Goodfellow, Bengio, Courville — Deep Learning (foundational, predates transformers but covers attention basics)

Code / Repos

  • karpathy/nanoGPT — minimal causal self-attention in ~30 lines; the canonical learning reference
  • karpathy/llm.c — same model in pure C/CUDA; great for kernel intuition
  • huggingface/transformers — LlamaAttention and GPT2Attention reference implementations
  • Dao-AILab/flash-attention — production IO-aware attention kernels (FA-1/2/3)
  • harvardnlp/annotated-transformer — line-by-line annotated PyTorch implementation of Vaswani et al.

Blog Posts

  • Sasha Rush — The Annotated Transformer (Harvard NLP, 2018 / 2022 update)
  • Jay Alammar — The Illustrated Transformer & The Illustrated GPT-2
  • Lilian Weng — The Transformer Family & Attention? Attention! (lilianweng.github.io)
  • Anthropic — A Mathematical Framework for Transformer Circuits (transformer-circuits.pub)
  • Karpathy — The Unreasonable Effectiveness of Recurrent Neural Networks & later GPT notes

10. Interview Prep

1. Derive scaled dot-product attention end-to-end. State every shape.

Project X ∈ [B, T, d_model] to Q, K, V via three linear maps W_Q, W_K, W_V (all d_model × d_model). Reshape into H heads: Q, K, V ∈ [B, H, T, d_k] with d_k = d_model / H. Compute S = QKᵀ ∈ [B, H, T, T], divide by √d_k, add the mask, take row-wise softmax to get P, then O = PV ∈ [B, H, T, d_v]. Concatenate heads back to [B, T, d_model] and apply output projection W_O. Total params per layer: 4 d_model²; compute: ~4 T d_model² + 2 T² d_model; activation memory peaks at the T × T scores tensor.

2. Why divide by √d_k? Show the variance computation.

Assume q, k components are iid with mean 0, variance 1 (a reasonable post-LayerNorm assumption). Then q·k = Σ q_i k_i has Var = Σ Var(q_i k_i) = Σ E[q_i²] E[k_i²] = d_k. So the dot product has std √d_k, and for d_k = 64 typical logits are ±24 — softmax saturates and gradients vanish. Dividing by √d_k restores unit variance, keeping softmax in its soft, gradient-friendly regime independent of head dim.

3. Compute the activation memory of the attention scores tensor for B=4, H=32, T=8192 in bf16.

S has shape [B, H, T, T] = [4, 32, 8192, 8192]. Bytes per element = 2 (bf16). Total = 4 · 32 · 8192 · 8192 · 2 = 17,179,869,184 bytes ≈ 16 GiB ≈ 17.18 GB. This is per layer; for 80 layers naive attention would need ~1.4 TB just for the scores tensors. FlashAttention reduces this to O(T·d_k) by tiling.

4. Why is the output of MultiHead(Q,K,V) wrapped in a final W_O projection?

Each head produces a vector in its own d_v subspace. Concatenation alone makes the outputs disjoint blocks of the residual stream — block 1 is head 1, block 2 is head 2, etc. W_O is a learned linear map that mixes information across heads and projects back into the shared residual stream so future layers can read what any earlier head wrote. Without W_O, heads cannot communicate.

5. What is the causal mask, and how is it implemented numerically?

A causal mask is a [T, T] additive matrix with 0 on and below the diagonal and -inf strictly above. After the QKᵀ/√d_k step, S += M; the softmax of -inf is exactly 0, so token i puts zero probability on token j > i. In practice it is precomputed once and broadcast over (B, H). Modern fused kernels (FA-2/3) generate the mask on the fly rather than allocating a [T, T] buffer.

6. Why are the W_Q, W_K, W_V matrices often fused into a single [d_model, 3·d_model] weight in production?

All three projections read the same input X. Fusing them lets the GPU do one large GEMM instead of three smaller ones — better tensor-core utilisation, fewer kernel launches, and a single HBM read of X instead of three. This is a standard transformer optimisation in nn.MultiheadAttention (in_proj_weight), HF transformers (Wqkv), Megatron, and vLLM.

7. Attention is permutation-equivariant. Why? And what is the consequence?

Applying a permutation π to the input rows of X permutes Q, K, V the same way; the QKᵀ matrix gets row+column permuted; softmax and the matmul with V preserve this. So Attention(πX) = π · Attention(X). The model has no notion of position from the operation itself. Consequence: we must inject position information, either via additive positional embeddings or via rotary/relative schemes that act on Q and K (see §6.3).

8. Where does the O(T²) memory cliff come from, and how does FlashAttention sidestep it?

The scores tensor S = QKᵀ has shape [B, H, T, T]; its size grows quadratically with sequence length. For T = 32k it can exceed HBM capacity per layer. FlashAttention (Dao et al. 2022) tiles Q, K, V into blocks that fit in on-chip SRAM, computes a partial softmax per tile, and combines tiles via the online-softmax recurrence — never materialising the full S in HBM. Memory is O(T·d_k) instead of O(T²·H). Same math, IO-optimal. Full derivation in Part XIV (T44).