Skip to Content
LibraryConceptsMulti-Head Attention

Multi-Head Attention

Instead of performing a single attention function, multi-head attention runs multiple attention operations in parallel. This allows the model to jointly attend to information from different representation subspaces at different positions.

Motivation: Why Multiple Heads?

Different aspects of a sentence require different types of attention patterns:

Example: “The cat sat on the mat”

A single attention head might focus on:

  • Syntactic relationships (subject-verb, verb-object)

But we also want to capture:

  • Semantic relationships (conceptual connections)
  • Positional patterns (nearby words)
  • Coreference resolution (what refers to what)
  • Long-range dependencies

Solution: Use multiple attention heads, each learning to capture different patterns.

Analogy: Like having multiple cameras viewing a scene from different angles - each provides different but complementary information.

Mathematical Formulation

Multi-head attention consists of:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O

where each head is:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)

Components:

  • hh: Number of heads (typically 8, 12, or 16)
  • WiQ,WiK,WiVW^Q_i, W^K_i, W^V_i: Learned projection matrices for head ii
  • WOW^O: Output projection matrix that combines all heads

Step-by-Step Process

Step 1: Project Q, K, V for Each Head

For each head ii, project the inputs:

Qi=QWiQ,Ki=KWiK,Vi=VWiVQ_i = Q W^Q_i, \quad K_i = K W^K_i, \quad V_i = V W^V_i

Dimensions:

  • Input: Q,K,VRn×dmodelQ, K, V \in \mathbb{R}^{n \times d_{\text{model}}}
  • Projections: WiQ,WiK,WiVRdmodel×dkW^Q_i, W^K_i, W^V_i \in \mathbb{R}^{d_{\text{model}} \times d_k}
  • Output: Qi,Ki,ViRn×dkQ_i, K_i, V_i \in \mathbb{R}^{n \times d_k}

where dk=dmodel/hd_k = d_{\text{model}} / h (split dimension across heads)

Step 2: Compute Attention for Each Head

Apply scaled dot-product attention in parallel:

headi=softmax(QiKiTdk)Vi\text{head}_i = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i

Output shape per head: (n,dk)(n, d_k)

Step 3: Concatenate Heads

Concatenate all head outputs:

Concat(head1,,headh)Rn×(hdk)=Rn×dmodel\text{Concat}(\text{head}_1, \ldots, \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_k)} = \mathbb{R}^{n \times d_{\text{model}}}

Step 4: Final Linear Projection

Apply output projection to combine information from all heads:

Output=Concat(head1,,headh)WO\text{Output} = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O

where WORdmodel×dmodelW^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}

Complete PyTorch Implementation

import torch import torch.nn as nn import torch.nn.functional as F import math class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): """ Args: d_model: Model dimension (e.g., 512) num_heads: Number of attention heads (e.g., 8) """ super().__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # Dimension per head # Linear projections for Q, K, V for all heads # We can do all heads at once with a single matrix self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) # Output projection self.W_o = nn.Linear(d_model, d_model) def split_heads(self, x): """ Split the last dimension into (num_heads, d_k) Args: x: (batch_size, seq_len, d_model) Returns: (batch_size, num_heads, seq_len, d_k) """ batch_size, seq_len, _ = x.size() x = x.view(batch_size, seq_len, self.num_heads, self.d_k) return x.transpose(1, 2) # (batch, num_heads, seq_len, d_k) def combine_heads(self, x): """ Combine heads back to original shape Args: x: (batch_size, num_heads, seq_len, d_k) Returns: (batch_size, seq_len, d_model) """ batch_size, _, seq_len, _ = x.size() x = x.transpose(1, 2) # (batch, seq_len, num_heads, d_k) return x.contiguous().view(batch_size, seq_len, self.d_model) def forward(self, Q, K, V, mask=None): """ Args: Q: Queries (batch_size, seq_len_q, d_model) K: Keys (batch_size, seq_len_k, d_model) V: Values (batch_size, seq_len_k, d_model) mask: Optional mask (batch_size, 1, seq_len_q, seq_len_k) Returns: output: (batch_size, seq_len_q, d_model) attn_weights: (batch_size, num_heads, seq_len_q, seq_len_k) """ batch_size = Q.size(0) # 1. Linear projections Q = self.W_q(Q) # (batch, seq_len_q, d_model) K = self.W_k(K) # (batch, seq_len_k, d_model) V = self.W_v(V) # (batch, seq_len_k, d_model) # 2. Split into multiple heads Q = self.split_heads(Q) # (batch, num_heads, seq_len_q, d_k) K = self.split_heads(K) # (batch, num_heads, seq_len_k, d_k) V = self.split_heads(V) # (batch, num_heads, seq_len_k, d_k) # 3. Scaled dot-product attention for all heads in parallel scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k) # scores: (batch, num_heads, seq_len_q, seq_len_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) # attn_weights: (batch, num_heads, seq_len_q, seq_len_k) context = attn_weights @ V # context: (batch, num_heads, seq_len_q, d_k) # 4. Concatenate heads context = self.combine_heads(context) # context: (batch, seq_len_q, d_model) # 5. Final linear projection output = self.W_o(context) # output: (batch, seq_len_q, d_model) return output, attn_weights

Dimension Example

Setup:

  • Model dimension: dmodel=512d_{\text{model}} = 512
  • Number of heads: h=8h = 8
  • Dimension per head: dk=512/8=64d_k = 512 / 8 = 64
  • Sequence length: n=10n = 10
  • Batch size: b=32b = 32

Shapes throughout:

Input Q, K, V: (32, 10, 512) After linear proj: (32, 10, 512) After split_heads: (32, 8, 10, 64) (batch, heads, seq_len, d_k) Attention scores: (32, 8, 10, 10) (batch, heads, seq_len_q, seq_len_k) After attention: (32, 8, 10, 64) After combine_heads: (32, 10, 512) After W_o: (32, 10, 512)

What Do Different Heads Learn?

Research shows that different heads specialize in different patterns:

Head 1: Short-range dependencies (adjacent words)

The [cat] sat ... ↑↓

Head 2: Long-range dependencies (subject-verb agreement)

The cat ... [was] hungry ↑________________↑

Head 3: Syntactic patterns (verb-object relationships)

... sat [on] the [mat] ↑________↑

Head 4: Positional patterns (specific relative positions)

Head 5-8: Various semantic and structural patterns

Each head learns to focus on different types of relationships, and their combination provides rich contextual understanding.

Comparison: Single-Head vs Multi-Head

AspectSingle HeadMulti-Head (8 heads)
Dimension per head51264
Attention patterns1 type8 different types
Parameters3×51223 \times 512^24×51224 \times 512^2
ComplexityO(n2d)O(n^2d)O(n2d)O(n^2d) (same!)
ExpressivenessLimitedHigh

Key insight: Multi-head attention has similar computational cost but much higher expressiveness.

Why Not Use Larger Dimensions?

Question: Why split 512 dimensions into 8 heads of 64, instead of using 8 separate 512-dimensional attention mechanisms?

Answer: Computational efficiency!

  • 8 heads of 64: Total dimension = 512
  • 8 separate full attention: Total dimension = 4096 (8× more expensive)

Multi-head attention gets the benefits of multiple attention patterns without dramatically increasing parameters or computation.

Computational Complexity

For hh heads, sequence length nn, and dimension dd:

Per head: O(n2dk)=O(n2d/h)O(n^2 d_k) = O(n^2 d/h)

All heads: O(hn2d/h)=O(n2d)O(h \cdot n^2 d/h) = O(n^2 d)

Same as single-head attention! We get multiple attention patterns “for free” (no additional asymptotic cost).

Practical Tips

Choosing Number of Heads

Common choices:

  • Small models (d=256-384): 4-8 heads
  • Base models (d=512-768): 8-12 heads
  • Large models (d=1024+): 16-32 heads

Rule of thumb: Keep dk[32,128]d_k \in [32, 128]

  • Too small (dk<32d_k < 32): Not enough capacity per head
  • Too large (dk>128d_k > 128): Heads become too similar, less diversity

Initialization

Initialize projection weights with small values:

nn.init.xavier_uniform_(self.W_q.weight) nn.init.xavier_uniform_(self.W_k.weight) nn.init.xavier_uniform_(self.W_v.weight) nn.init.xavier_uniform_(self.W_o.weight)

Helps training stability in early stages.

Key Takeaways

  • Multi-head attention runs multiple attention operations in parallel
  • Each head learns different attention patterns (syntactic, semantic, positional)
  • Split dmodeld_{\text{model}} into hh heads of dimension dk=dmodel/hd_k = d_{\text{model}} / h
  • Concatenate head outputs and apply final projection
  • Same computational complexity as single-head, but more expressive
  • Typical configuration: 8-16 heads with dk=64d_k = 64 per head
  • Always used with residual connections and layer normalization in practice

References

Key Papers:

Learning Resources: