2.2. Multi-Head Attention
🪄 Step 1: Intuition & Motivation
Core Idea (in 1 short paragraph): Self-attention lets each token look at others, but a single attention pattern can miss nuances. Multi-Head Attention (MHA) runs several attention “viewpoints” in parallel. Each head looks at the same sentence through a slightly different lens (subspace), so one head might track grammar links while another follows meaning or coreference. When we stitch these views together, the model gets a richer understanding.
Simple Analogy (only if needed): Imagine a movie review panel: one critic focuses on acting, another on story, another on music. Their combined opinions feel fuller than any single review.
🌱 Step 2: Core Concept
What’s Happening Under the Hood?
We start with token embeddings of size $d_{\text{model}}$. Multi-head attention creates $h$ parallel heads. For each head, we linearly project the input into Queries ($Q$), Keys ($K$), and Values ($V$) but with smaller dimensions, typically $d_k = d_v = d_{\text{model}}/h$.
Each head computes its own self-attention weights (who to listen to) and produces a head-specific output. Then we concatenate all head outputs (bringing us back to $d_{\text{model}}$) and pass through a final linear layer to mix information across heads.
Why It Works This Way
How It Fits in ML Thinking
📐 Step 3: Mathematical Foundation
Multi-Head Attention Equations
For head $i \in {1,\dots,h}$:
$$ \text{head}_i = \text{softmax}!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right)V_i $$where
- $Q_i = X W_i^Q,; K_i = X W_i^K,; V_i = X W_i^V$
- $X \in \mathbb{R}^{n \times d_{\text{model}}}$ (sequence length $n$)
- $W_i^Q, W_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k},; W_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$
Concatenate all heads and project:
$$ \text{MHA}(X) = \big[\text{head}_1 ;|; \cdots ;|; \text{head}*h\big] W^O,\quad W^O \in \mathbb{R}^{(h d_v)\times d*{\text{model}}} $$Parameter Count — Per Head and Total
Common setting: $d_k = d_v = d_{\text{model}}/h$.
Per head parameters (ignoring biases): $W_i^Q: d_{\text{model}}!\times! d_k,; W_i^K: d_{\text{model}}!\times! d_k,; W_i^V: d_{\text{model}}!\times! d_v$ Total per head $= 3, d_{\text{model}}, d_k$.
All heads combined: $3, d_{\text{model}}, d_k \times h = 3, d_{\text{model}} (h d_k) = 3, d_{\text{model}}^2$.
Output projection: $W^O: (h d_v)!\times! d_{\text{model}} = d_{\text{model}}^2$.
Grand total (no biases): $;; 3, d_{\text{model}}^2 + d_{\text{model}}^2 = \boxed{4, d_{\text{model}}^2}$ (why many libraries implement $Q,K,V$ with a single $d_{\text{model}}!\times!3d_{\text{model}}$ weight).
Memory/Compute Implications
- Attention maps: Each head forms an $n \times n$ score matrix $\Rightarrow$ memory $\propto h, n^2$ (dominant).
- Projections ($Q,K,V$): $O(n, d_{\text{model}})$ activations; relatively smaller.
- Overall: MHA time/memory scale as $O(h, n^2)$, often reported as $O(n^2)$ since $h$ is fixed. Longer sequences increase cost quadratically; more heads increase cost linearly.
🧠 Step 4: Assumptions or Key Ideas (if applicable)
- Heads see the same tokens but through different learned projections.
- Typical choice $d_k = d_v = d_{\text{model}}/h$ keeps total width constant.
- The information bottleneck is the final $W^O$ that recombines head outputs back to $d_{\text{model}}$.
⚖️ Step 5: Strengths, Limitations & Trade-offs
- Encourages specialization across heads (syntax vs. semantics, local vs. global).
- Preserves parallelism; multiple views computed simultaneously.
- Parameterization stays tidy (≈ $4d_{\text{model}}^2$) independent of $h$ (when $d_k=d_{\text{model}}/h$).
- Adds linear head-wise overhead to attention memory ($h, n^2$).
- Some heads may become redundant (not all heads stay useful).
- Interpretation can be tricky; attention ≠ direct explanation.
- More heads → richer, more disentangled features but more activation memory.
- Fewer heads → cheaper but risks blending distinct relations.
- A single very wide head tends to entangle patterns that multiple narrow heads can separate.
🚧 Step 6: Common Misunderstandings (Optional)
🚨 Common Misunderstandings (Click to Expand)
- “Multiple heads increase parameters massively.” With $d_k=d_{\text{model}}/h$, total $Q,K,V$ parameters remain $3d_{\text{model}}^2$ regardless of $h$; the main added cost is activation memory for $h$ score matrices.
- “One huge head equals many small heads.” A single head can’t easily specialize; multiple projections encourage diverse subspaces.
- “More heads are always better.” After a point, extra heads can be redundant and waste memory; performance gains saturate.
🧩 Step 7: Mini Summary
🧠 What You Learned: Multi-Head Attention creates several parallel attention views, each operating in its own subspace.
⚙️ How It Works: Inputs are projected to $(Q,K,V)$ per head, attention is computed per head, outputs are concatenated and mixed via $W^O$.
🎯 Why It Matters: Multiple heads capture heterogeneous relationships (syntactic and semantic) better than one wide head, at the cost of head-wise activation memory.