2.2. Multi-Head Attention

4 min read 808 words

🪄 Step 1: Intuition & Motivation

  • Core Idea (in 1 short paragraph): Self-attention lets each token look at others, but a single attention pattern can miss nuances. Multi-Head Attention (MHA) runs several attention “viewpoints” in parallel. Each head looks at the same sentence through a slightly different lens (subspace), so one head might track grammar links while another follows meaning or coreference. When we stitch these views together, the model gets a richer understanding.

  • Simple Analogy (only if needed): Imagine a movie review panel: one critic focuses on acting, another on story, another on music. Their combined opinions feel fuller than any single review.


🌱 Step 2: Core Concept

What’s Happening Under the Hood?

We start with token embeddings of size $d_{\text{model}}$. Multi-head attention creates $h$ parallel heads. For each head, we linearly project the input into Queries ($Q$), Keys ($K$), and Values ($V$) but with smaller dimensions, typically $d_k = d_v = d_{\text{model}}/h$.

Each head computes its own self-attention weights (who to listen to) and produces a head-specific output. Then we concatenate all head outputs (bringing us back to $d_{\text{model}}$) and pass through a final linear layer to mix information across heads.

Why It Works This Way
Different linguistic or structural relations live in different “directions” of the representation space. Splitting into heads lets the model specialize: one head can focus on subject–verb links, another on long-range dependencies, another on named entities, etc. A single wide head tends to blend these patterns together; multiple smaller heads disentangle them.
How It Fits in ML Thinking
This is classic divide-and-conquer in representation learning. By distributing attention across subspaces and then recombining, MHA increases expressivity without increasing sequence length processing time (still parallel). It’s a structured way to learn multiple complementary features at once.

📐 Step 3: Mathematical Foundation

Multi-Head Attention Equations

For head $i \in {1,\dots,h}$:

$$ \text{head}_i = \text{softmax}!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right)V_i $$

where

  • $Q_i = X W_i^Q,; K_i = X W_i^K,; V_i = X W_i^V$
  • $X \in \mathbb{R}^{n \times d_{\text{model}}}$ (sequence length $n$)
  • $W_i^Q, W_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k},; W_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$

Concatenate all heads and project:

$$ \text{MHA}(X) = \big[\text{head}_1 ;|; \cdots ;|; \text{head}*h\big] W^O,\quad W^O \in \mathbb{R}^{(h d_v)\times d*{\text{model}}} $$
Each head is a small specialist that gathers context its own way. The output layer blends the specialists into one coherent summary per token.
Parameter Count — Per Head and Total

Common setting: $d_k = d_v = d_{\text{model}}/h$.

  • Per head parameters (ignoring biases): $W_i^Q: d_{\text{model}}!\times! d_k,; W_i^K: d_{\text{model}}!\times! d_k,; W_i^V: d_{\text{model}}!\times! d_v$ Total per head $= 3, d_{\text{model}}, d_k$.

  • All heads combined: $3, d_{\text{model}}, d_k \times h = 3, d_{\text{model}} (h d_k) = 3, d_{\text{model}}^2$.

  • Output projection: $W^O: (h d_v)!\times! d_{\text{model}} = d_{\text{model}}^2$.

Grand total (no biases): $;; 3, d_{\text{model}}^2 + d_{\text{model}}^2 = \boxed{4, d_{\text{model}}^2}$ (why many libraries implement $Q,K,V$ with a single $d_{\text{model}}!\times!3d_{\text{model}}$ weight).

Memory/Compute Implications
  • Attention maps: Each head forms an $n \times n$ score matrix $\Rightarrow$ memory $\propto h, n^2$ (dominant).
  • Projections ($Q,K,V$): $O(n, d_{\text{model}})$ activations; relatively smaller.
  • Overall: MHA time/memory scale as $O(h, n^2)$, often reported as $O(n^2)$ since $h$ is fixed. Longer sequences increase cost quadratically; more heads increase cost linearly.

🧠 Step 4: Assumptions or Key Ideas (if applicable)

  • Heads see the same tokens but through different learned projections.
  • Typical choice $d_k = d_v = d_{\text{model}}/h$ keeps total width constant.
  • The information bottleneck is the final $W^O$ that recombines head outputs back to $d_{\text{model}}$.

⚖️ Step 5: Strengths, Limitations & Trade-offs

  • Encourages specialization across heads (syntax vs. semantics, local vs. global).
  • Preserves parallelism; multiple views computed simultaneously.
  • Parameterization stays tidy (≈ $4d_{\text{model}}^2$) independent of $h$ (when $d_k=d_{\text{model}}/h$).
  • Adds linear head-wise overhead to attention memory ($h, n^2$).
  • Some heads may become redundant (not all heads stay useful).
  • Interpretation can be tricky; attention ≠ direct explanation.
  • More heads → richer, more disentangled features but more activation memory.
  • Fewer heads → cheaper but risks blending distinct relations.
  • A single very wide head tends to entangle patterns that multiple narrow heads can separate.

🚧 Step 6: Common Misunderstandings (Optional)

🚨 Common Misunderstandings (Click to Expand)
  • “Multiple heads increase parameters massively.” With $d_k=d_{\text{model}}/h$, total $Q,K,V$ parameters remain $3d_{\text{model}}^2$ regardless of $h$; the main added cost is activation memory for $h$ score matrices.
  • “One huge head equals many small heads.” A single head can’t easily specialize; multiple projections encourage diverse subspaces.
  • “More heads are always better.” After a point, extra heads can be redundant and waste memory; performance gains saturate.

🧩 Step 7: Mini Summary

🧠 What You Learned: Multi-Head Attention creates several parallel attention views, each operating in its own subspace.

⚙️ How It Works: Inputs are projected to $(Q,K,V)$ per head, attention is computed per head, outputs are concatenated and mixed via $W^O$.

🎯 Why It Matters: Multiple heads capture heterogeneous relationships (syntactic and semantic) better than one wide head, at the cost of head-wise activation memory.

Any doubt in content? Ask me anything?
Chat
🤖 👋 Hi there! I'm your learning assistant. If you have any questions about this page or need clarification, feel free to ask!