§ 6.1 Attention is All You Need
Scaled dot-product attention from first principles. Q/K/V projections, the √dk argument, multi-head splitting, causal & padding masks, and the O(T²d) memory cliff that will motivate FlashAttention.
1. Overview
Self-attention is the operation at the heart of every modern transformer. It takes a sequence of T token vectors and produces a new sequence of T vectors where each output position is a content-addressed mixture of all (or all earlier) inputs. The mixture weights are computed on the fly from the inputs themselves — no learned routing table, no recurrence — which is what makes attention parallelisable across the entire sequence in one matrix multiply.
Vaswani et al. (2017, arXiv:1706.03762) showed that this single primitive — stacked into N blocks alongside a feed-forward layer — replaces every prior sequence-model component (recurrence, convolution, alignment networks) and outperforms them at scale. Every subsequent topic in this curriculum (positional encodings, MQA/GQA, KV-caching, FlashAttention, MoE, speculative decoding) is either a refinement of, or a system built around, scaled dot-product attention.
2. Q, K, V — Projections and Tensor Shapes
Attention starts with three learned linear projections of the same input X ∈ ℝB×T×dmodel:
Q = X W_Q # what each position is "looking for" K = X W_K # what each position "advertises" V = X W_V # what each position contributes if attended to
The intuition is a soft, content-addressed memory: a query vector qi is compared against every key vector kj, the comparisons are turned into a probability distribution, and the output is the corresponding mixture of the value vectors. Because Q, K, V are all linear projections of the same X, this is self-attention; in cross-attention (encoder–decoder) Q comes from one sequence and K/V from another.
Shape conventions used on this site
| Symbol | Meaning | Typical value | Where it lives |
|---|---|---|---|
| B | Batch size | 1 – 4096 | Outer-most dim, parallel sequences |
| T | Sequence length (tokens) | 128 – 128K | Time axis |
| d_model | Hidden / residual stream width | 768 (GPT-2 small) … 16384 (PaLM-540B) | Width of every block input/output |
| H | Number of attention heads | 12 (GPT-2) … 128 (GPT-3 175B) | Parallel attention spaces per layer |
| d_k = d_v = d_model / H | Per-head feature width | 64 – 128 typical | Always satisfies H · d_k = d_model |
| W_Q, W_K, W_V, W_O | Learned projection weights | [d_model, d_model] | All four matrices ⇒ 4·d_model² params per layer |
After projecting and reshaping by head, each of Q, K, V lives in [B, H, T, d_k]. The two leading dims (B, H) are pure broadcast dims — every operation that follows treats them as independent — so a single fused tensor op runs B · H independent attention computations in parallel on the GPU.
3. Core Mechanism — Scaled Dot-Product Attention
Background: Given T queries and T keys (one per token, each of dim dk), we want a similarity score between every query and every key — a T × T matrix. The dot product is the cheapest content-aware similarity available: it can be computed for all pairs simultaneously with one BLAS GEMM.
Plan:
- Compute raw scores
S = Q Kᵀ→ shape [B, H, T, T]. - Scale by
1/√dkto keep logit variance ≈ 1 (derived below). - Add a mask: set forbidden positions to
-∞so they get probability 0 after softmax (causal mask, padding mask). - Row-wise softmax → attention probabilities P (each row sums to 1).
- Weighted sum of values:
O = P V, shape [B, H, T, dv].
Closed form
Attention(Q, K, V) = softmax( Q K^T / sqrt(d_k) + M ) V Shapes: Q : [B, H, T, d_k] K : [B, H, T, d_k] -> K^T : [B, H, d_k, T] V : [B, H, T, d_v] M : [T, T] (mask, broadcasts over B and H; -inf where forbidden, 0 elsewhere) Q K^T -> [B, H, T, T] (one logit per (i, j) pair) / sqrt(d_k) -> [B, H, T, T] (same shape, just rescaled) + M -> [B, H, T, T] (broadcast mask add) softmax dim=-1 -> [B, H, T, T] (row-stochastic) @ V -> [B, H, T, d_v] (one output vector per query position)
Why √dk? — variance argument
Treat the components of q and k as iid with mean 0 and variance 1 (a reasonable approximation after standard initialisation and LayerNorm). Then the dot product is a sum of dk iid terms:
q · k = Σ_{i=1..d_k} q_i k_i
E[q_i k_i] = E[q_i] E[k_i] = 0 (independent, zero-mean)
Var(q_i k_i) = E[(q_i k_i)^2] - 0
= E[q_i^2] E[k_i^2] (independence)
= 1 · 1 = 1
Var(q · k) = sum of d_k iid unit-variance terms
= d_k
std(q · k) = sqrt(d_k)For dk = 64 the logits have std ≈ 8, so typical values land in [−24, +24]. Softmax of such large logits is essentially a one-hot: the largest logit gets probability ≈ 1, the rest get ≈ 0, and the gradient of softmax saturates (∂softmax/∂x → 0 off the argmax). Training cannot proceed.
Dividing by √dk rescales the variance back to 1 independent of head dim, keeping the softmax in its soft, gradient-friendly regime. This is exactly why Vaswani et al. called it scaled dot-product attention.
Worked numerical walkthrough
Take T = 3, dk = 4. Pick concrete Q, K, V (rounded to 2 d.p. for readability). Trace every shape transition end-to-end:
Q = [[ 1.0, 0.0, 0.0, 0.0], # query for token 1
[ 0.0, 1.0, 0.0, 0.0], # query for token 2
[ 0.5, 0.5, 0.0, 0.0]] # query for token 3 (a blend of 1 and 2)
K = [[ 1.0, 0.0, 0.0, 0.0], # key for token 1
[ 0.0, 1.0, 0.0, 0.0], # key for token 2
[ 0.0, 0.0, 1.0, 0.0]] # key for token 3 (unrelated direction)
V = [[10, 20], # value for token 1
[30, 40],
[50, 60]]
Step 1 S = Q K^T (shape 3 x 3)
S[i,j] = <q_i, k_j>
S = [[ 1.0, 0.0, 0.0],
[ 0.0, 1.0, 0.0],
[ 0.5, 0.5, 0.0]]
Step 2 S / sqrt(d_k) = S / 2 (since d_k = 4)
S' = [[ 0.50, 0.00, 0.00],
[ 0.00, 0.50, 0.00],
[ 0.25, 0.25, 0.00]]
Step 3 Apply CAUSAL mask: positions j > i become -inf.
S' = [[ 0.50, -inf, -inf],
[ 0.00, 0.50, -inf],
[ 0.25, 0.25, 0.00]]
Step 4 Row-wise softmax (each row sums to 1):
row 1: [1.00, 0, 0 ]
row 2: softmax([0.0, 0.5]) = [0.378, 0.622, 0]
row 3: softmax([0.25, 0.25, 0.0])
= [exp(.25), exp(.25), exp(0)] / norm
= [1.284, 1.284, 1.0] / 3.568
= [0.360, 0.360, 0.280]
P = [[1.000, 0.000, 0.000],
[0.378, 0.622, 0.000],
[0.360, 0.360, 0.280]]
Step 5 O = P V (shape 3 x 2)
O[0] = 1.000*[10,20] + 0*[30,40] + 0*[50,60]
= [10.0, 20.0]
O[1] = 0.378*[10,20] + 0.622*[30,40] + 0
= [22.44, 32.44]
O[2] = 0.360*[10,20] + 0.360*[30,40] + 0.280*[50,60]
= [28.4, 38.4]
O = [[10.00, 20.00], # token 1 sees only itself
[22.44, 32.44], # token 2 mostly attends to itself, some to token 1
[28.40, 38.40]] # token 3 spreads attention across all threeThe output for token 1 is exactly V[1] — the causal mask forces position 0 to attend only to itself. Token 2 splits its attention between itself and token 1. Token 3 mixes all three. This is the essential picture: each output is a data-dependent convex combination of past values.
4. Multi-Head Attention
A single attention head is already useful, but its output for each token is a single weighted mean — it can only model one type of relationship at a time (e.g., "previous noun" or "matching bracket"). Multi-head attention runs H independent attention computations in parallel — each with its own small Q/K/V projections — and concatenates their outputs. Different heads learn different attention patterns: short-range, long-range, syntactic, coreferent, induction-style copy heads, etc.
Closed form
MultiHead(Q, K, V) = Concat(head_1, ..., head_H) W_O with head_h = Attention( Q W_Q^h, K W_K^h, V W_V^h ) Implementation trick: instead of H separate matmuls, allocate one [d_model, d_model] matrix for each of W_Q, W_K, W_V and just RESHAPE the result into H heads of width d_k = d_model / H. So H heads of width d_k cost the same FLOPs as one head of width d_model, but force the model to spend attention capacity across H parallel subspaces.
The output projection WO is essential: without it the different heads cannot mix. WO takes the H concatenated dk-vectors and produces the next residual-stream update, allowing information from head 7 to be used by future computation in any other channel.
Why split into many small heads instead of one big one?
- Specialisation. Each head can lock onto a different pattern; the model gets multiple "channels" of relational reasoning per layer.
- No extra FLOPs. Because dk = dmodel/H, the QKᵀ matmul costs the same as one big head (B·H·T²·dk = B·T²·dmodel).
- Better gradient flow. Smaller per-head QKᵀ outputs are easier to normalise and produce less-saturated softmax distributions.
- Interpretability. Heads can be analysed individually — see Elhage et al. 2021 (Anthropic) on induction heads and circuits.
5. Masks — Causal, Padding, Prefix-LM
A mask is a tensor added to the pre-softmax scores. Anywhere it equals -∞ the corresponding probability after softmax becomes exactly 0. Masks are how transformers express "you may not look here": causal masking makes a decoder, padding masks ignore padding tokens, prefix-LM masks make a bidirectional prefix followed by causal continuation.
Causal mask (decoder / GPT)
For autoregressive next-token prediction, token i must not see any token j > i; otherwise the loss is trivial (the model can just copy the next token). The mask is an upper-triangular matrix of −∞:
Green = a real score is kept; red = −∞ so the softmax drives it to 0. Row i shows what query position i is allowed to attend to.
Padding mask
When sequences in a batch have different lengths they are padded to the max length with a PAD token. A padding mask zeroes out attention to padding positions (column-wise −∞) so that legitimate tokens never spend probability mass on pad. Modern packed-sequence training (FlashAttention varlen, vLLM) avoids padding entirely with a per-example cu_seqlens index.
Encoder, decoder, prefix-LM
| Mask pattern | Model class | Effect |
|---|---|---|
| Full (all zeros) | BERT, T5 encoder, ViT | Bidirectional self-attention — every token sees everything |
| Strict upper-triangular -inf | GPT-2/3/4, LLaMA, Mistral | Causal — token i sees tokens 1..i only |
| Cross: full K from encoder | Original Transformer decoder | Decoder Q attends to all encoder K/V (no mask), plus causal self-attn |
| Prefix bidirectional + causal suffix | Prefix-LM, UL2, GLM | First k tokens see each other freely; tokens after k are causal |
| Block-diagonal (varlen / packed) | FlashAttention varlen, vLLM | Multiple sequences packed into one tensor; mask hides cross-sequence attention |
6. FLOPs & Memory Cost
Knowing the cost of attention by heart is part of the LLM-infra interview bar. Per layer, per sequence (drop B for clarity):
Compute (FLOPs, multiplies only — multiply by 2 if you count mul+add):
QKV projections : 3 * T * d_model * d_model = 3 T d^2
Q K^T : H * T * T * d_k = T^2 d_model = T^2 d
softmax : T^2 (O(T^2) but cheap; dominated above)
P V : H * T * T * d_v = T^2 d_model = T^2 d
Output projection : T * d_model * d_model = T d^2
Total per layer ~ 4 T d^2 + 2 T^2 d
Activation-memory peak (the part everyone trips over):
scores tensor S : B * H * T * T floats <-- O(T^2) PER HEAD
For B=4, H=32, T=8192, fp16:
S = 4 * 32 * 8192 * 8192 * 2 bytes = 17.18 GB (one tensor!)
This single allocation is what blows up at long context and what
FlashAttention removes by never materialising S in HBM.Two regimes matter when reading roofline plots:
- Short sequence (T < dmodel): compute is dominated by the QKV/output projections (∝ T d²). The matmul is compute-bound — tensor cores saturate.
- Long sequence (T ≫ dmodel): the QKᵀ and PV matmuls (∝ T²d) dominate, and the O(T²) score tensor dominates HBM traffic. Attention becomes memory-bandwidth-bound — this is exactly the regime FlashAttention v1–v3 (T44) is designed for.
7. Minimal Demos
Demo A — Scaled dot-product attention end-to-end
Tiny B=1, H=2, T=4, D=8 example. Every shape is printed. The final assertion confirms that with a causal mask, token 0 attends only to itself — so the output at position 0 is exactly V at position 0.
Demo B — Why divide by √dk?
Sweeps dk ∈ {4, 16, 64, 256, 1024} and prints the entropy and max-probability of the resulting attention distribution, with and without scaling. Without scaling, entropy collapses toward zero as dkgrows — the model degenerates into a hard argmax. With scaling, entropy stays near log2(T).
Demo C — Multi-head attention by hand
Builds MHA from nn.Linear only — no nn.MultiheadAttention magic. Watch the shape go [B,T,dmodel] → [B,H,T,dk] → [B,H,T,dv] → [B,T,dmodel]. The final line confirms the canonical 4·dmodel² parameter count.
8. Production / Source Pointers
| Library / File | Symbol | What it does |
|---|---|---|
| torch/nn/functional.py | F.scaled_dot_product_attention | Reference SDPA; dispatches to math / efficient / FlashAttention backends (PyTorch 2.x) |
| torch/nn/modules/activation.py | nn.MultiheadAttention | Stock MHA module; uses fused QKV projection when in_proj_weight is set |
| transformers/models/llama/modeling_llama.py | LlamaAttention.forward | Reference HF LLaMA attention impl; reads cleanly, includes RoPE + GQA |
| Karpathy/nanoGPT model.py | CausalSelfAttention | ~30 lines of educational MHA with causal mask; identical math to GPT-2 |
| flash-attn/flash_attn/flash_attn_interface.py | flash_attn_func | IO-aware fused attention kernel; never materialises the T x T tensor (T44) |
| vllm/model_executor/layers/attention.py | PagedAttention | Inference-time MHA on top of paged KV-cache; production serving (T34) |
| xformers/ops/fmha/ | memory_efficient_attention | Meta's alternative IO-aware kernel, predates FlashAttention v1 |
| cutlass/examples/41_fused_multi_head_attention | — | CUTLASS reference for hand-fused MHA on Ampere/Hopper tensor cores |
9. References
Papers
- Vaswani et al. (2017) — Attention is All You Need. arXiv:1706.03762
- Bahdanau, Cho, Bengio (2014) — Neural Machine Translation by Jointly Learning to Align and Translate (additive attention, the predecessor). arXiv:1409.0473
- Luong, Pham, Manning (2015) — Effective Approaches to Attention-based NMT (dot-product attention). arXiv:1508.04025
- Britz et al. (2017) — Massive Exploration of NMT Architectures (empirical case for scaled dot-product). arXiv:1703.03906
- Elhage et al. / Anthropic (2021) — A Mathematical Framework for Transformer Circuits (residual stream + head decomposition). transformer-circuits.pub
- Phuong & Hutter (2022) — Formal Algorithms for Transformers. arXiv:2207.09238
- Dao et al. (2022) — FlashAttention: Fast and Memory-Efficient Exact Attention. arXiv:2205.14135 (forward reference for §6.7 cost discussion)
Lectures
- Stanford CS224n (Manning) — Lecture 8: Attention; Lecture 9: Self-Attention & Transformers
- Stanford CS25 — Transformers United (lecture 1: introduction; lecture 2: attention deep dive)
- Stanford CS336 (Liang & Yao) — Building LLMs from Scratch (2024), Attention lecture
- CMU 11-785 (Bhiksha Raj) — Deep Learning, Attention & Transformer lectures
- NYU DS-GA 1008 (Yann LeCun & Alfredo Canziani) — Deep Learning, Self-Attention
- Karpathy — Let's build GPT: from scratch, in code, spelled out (YouTube, 2 hrs)
- 3Blue1Brown — But what is a GPT? / Attention in transformers, visually explained
Textbooks
- Jurafsky & Martin — Speech and Language Processing, 3rd ed., Ch. 10–11 (Transformers and Large Language Models)
- Bishop & Bishop (2024) — Deep Learning: Foundations and Concepts, Ch. 12 (Transformers)
- Goodfellow, Bengio, Courville — Deep Learning (foundational, predates transformers but covers attention basics)
Code / Repos
karpathy/nanoGPT— minimal causal self-attention in ~30 lines; the canonical learning referencekarpathy/llm.c— same model in pure C/CUDA; great for kernel intuitionhuggingface/transformers— LlamaAttention and GPT2Attention reference implementationsDao-AILab/flash-attention— production IO-aware attention kernels (FA-1/2/3)harvardnlp/annotated-transformer— line-by-line annotated PyTorch implementation of Vaswani et al.
Blog Posts
- Sasha Rush — The Annotated Transformer (Harvard NLP, 2018 / 2022 update)
- Jay Alammar — The Illustrated Transformer & The Illustrated GPT-2
- Lilian Weng — The Transformer Family & Attention? Attention! (lilianweng.github.io)
- Anthropic — A Mathematical Framework for Transformer Circuits (transformer-circuits.pub)
- Karpathy — The Unreasonable Effectiveness of Recurrent Neural Networks & later GPT notes
10. Interview Prep
1. Derive scaled dot-product attention end-to-end. State every shape.
Project X ∈ [B, T, d_model] to Q, K, V via three linear maps W_Q, W_K, W_V (all d_model × d_model). Reshape into H heads: Q, K, V ∈ [B, H, T, d_k] with d_k = d_model / H. Compute S = QKᵀ ∈ [B, H, T, T], divide by √d_k, add the mask, take row-wise softmax to get P, then O = PV ∈ [B, H, T, d_v]. Concatenate heads back to [B, T, d_model] and apply output projection W_O. Total params per layer: 4 d_model²; compute: ~4 T d_model² + 2 T² d_model; activation memory peaks at the T × T scores tensor.
2. Why divide by √d_k? Show the variance computation.
Assume q, k components are iid with mean 0, variance 1 (a reasonable post-LayerNorm assumption). Then q·k = Σ q_i k_i has Var = Σ Var(q_i k_i) = Σ E[q_i²] E[k_i²] = d_k. So the dot product has std √d_k, and for d_k = 64 typical logits are ±24 — softmax saturates and gradients vanish. Dividing by √d_k restores unit variance, keeping softmax in its soft, gradient-friendly regime independent of head dim.
3. Compute the activation memory of the attention scores tensor for B=4, H=32, T=8192 in bf16.
S has shape [B, H, T, T] = [4, 32, 8192, 8192]. Bytes per element = 2 (bf16). Total = 4 · 32 · 8192 · 8192 · 2 = 17,179,869,184 bytes ≈ 16 GiB ≈ 17.18 GB. This is per layer; for 80 layers naive attention would need ~1.4 TB just for the scores tensors. FlashAttention reduces this to O(T·d_k) by tiling.
4. Why is the output of MultiHead(Q,K,V) wrapped in a final W_O projection?
Each head produces a vector in its own d_v subspace. Concatenation alone makes the outputs disjoint blocks of the residual stream — block 1 is head 1, block 2 is head 2, etc. W_O is a learned linear map that mixes information across heads and projects back into the shared residual stream so future layers can read what any earlier head wrote. Without W_O, heads cannot communicate.
5. What is the causal mask, and how is it implemented numerically?
A causal mask is a [T, T] additive matrix with 0 on and below the diagonal and -inf strictly above. After the QKᵀ/√d_k step, S += M; the softmax of -inf is exactly 0, so token i puts zero probability on token j > i. In practice it is precomputed once and broadcast over (B, H). Modern fused kernels (FA-2/3) generate the mask on the fly rather than allocating a [T, T] buffer.
6. Why are the W_Q, W_K, W_V matrices often fused into a single [d_model, 3·d_model] weight in production?
All three projections read the same input X. Fusing them lets the GPU do one large GEMM instead of three smaller ones — better tensor-core utilisation, fewer kernel launches, and a single HBM read of X instead of three. This is a standard transformer optimisation in nn.MultiheadAttention (in_proj_weight), HF transformers (Wqkv), Megatron, and vLLM.
7. Attention is permutation-equivariant. Why? And what is the consequence?
Applying a permutation π to the input rows of X permutes Q, K, V the same way; the QKᵀ matrix gets row+column permuted; softmax and the matmul with V preserve this. So Attention(πX) = π · Attention(X). The model has no notion of position from the operation itself. Consequence: we must inject position information, either via additive positional embeddings or via rotary/relative schemes that act on Q and K (see §6.3).
8. Where does the O(T²) memory cliff come from, and how does FlashAttention sidestep it?
The scores tensor S = QKᵀ has shape [B, H, T, T]; its size grows quadratically with sequence length. For T = 32k it can exceed HBM capacity per layer. FlashAttention (Dao et al. 2022) tiles Q, K, V into blocks that fit in on-chip SRAM, computes a partial softmax per tile, and combines tiles via the online-softmax recurrence — never materialising the full S in HBM. Memory is O(T·d_k) instead of O(T²·H). Same math, IO-optimal. Full derivation in Part XIV (T44).