Multi-Head Attention
Instead of performing a single attention function, multi-head attention runs multiple attention operations in parallel. This allows the model to jointly attend to information from different representation subspaces at different positions.
Motivation: Why Multiple Heads?
Different aspects of a sentence require different types of attention patterns:
Example: “The cat sat on the mat”
A single attention head might focus on:
- Syntactic relationships (subject-verb, verb-object)
But we also want to capture:
- Semantic relationships (conceptual connections)
- Positional patterns (nearby words)
- Coreference resolution (what refers to what)
- Long-range dependencies
Solution: Use multiple attention heads, each learning to capture different patterns.
Analogy: Like having multiple cameras viewing a scene from different angles - each provides different but complementary information.
Mathematical Formulation
Multi-head attention consists of:
where each head is:
Components:
- : Number of heads (typically 8, 12, or 16)
- : Learned projection matrices for head
- : Output projection matrix that combines all heads
Step-by-Step Process
Step 1: Project Q, K, V for Each Head
For each head , project the inputs:
Dimensions:
- Input:
- Projections:
- Output:
where (split dimension across heads)
Step 2: Compute Attention for Each Head
Apply scaled dot-product attention in parallel:
Output shape per head:
Step 3: Concatenate Heads
Concatenate all head outputs:
Step 4: Final Linear Projection
Apply output projection to combine information from all heads:
where
Complete PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
"""
Args:
d_model: Model dimension (e.g., 512)
num_heads: Number of attention heads (e.g., 8)
"""
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
# Linear projections for Q, K, V for all heads
# We can do all heads at once with a single matrix
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Output projection
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x):
"""
Split the last dimension into (num_heads, d_k)
Args:
x: (batch_size, seq_len, d_model)
Returns:
(batch_size, num_heads, seq_len, d_k)
"""
batch_size, seq_len, _ = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (batch, num_heads, seq_len, d_k)
def combine_heads(self, x):
"""
Combine heads back to original shape
Args:
x: (batch_size, num_heads, seq_len, d_k)
Returns:
(batch_size, seq_len, d_model)
"""
batch_size, _, seq_len, _ = x.size()
x = x.transpose(1, 2) # (batch, seq_len, num_heads, d_k)
return x.contiguous().view(batch_size, seq_len, self.d_model)
def forward(self, Q, K, V, mask=None):
"""
Args:
Q: Queries (batch_size, seq_len_q, d_model)
K: Keys (batch_size, seq_len_k, d_model)
V: Values (batch_size, seq_len_k, d_model)
mask: Optional mask (batch_size, 1, seq_len_q, seq_len_k)
Returns:
output: (batch_size, seq_len_q, d_model)
attn_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
"""
batch_size = Q.size(0)
# 1. Linear projections
Q = self.W_q(Q) # (batch, seq_len_q, d_model)
K = self.W_k(K) # (batch, seq_len_k, d_model)
V = self.W_v(V) # (batch, seq_len_k, d_model)
# 2. Split into multiple heads
Q = self.split_heads(Q) # (batch, num_heads, seq_len_q, d_k)
K = self.split_heads(K) # (batch, num_heads, seq_len_k, d_k)
V = self.split_heads(V) # (batch, num_heads, seq_len_k, d_k)
# 3. Scaled dot-product attention for all heads in parallel
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
# scores: (batch, num_heads, seq_len_q, seq_len_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
# attn_weights: (batch, num_heads, seq_len_q, seq_len_k)
context = attn_weights @ V
# context: (batch, num_heads, seq_len_q, d_k)
# 4. Concatenate heads
context = self.combine_heads(context)
# context: (batch, seq_len_q, d_model)
# 5. Final linear projection
output = self.W_o(context)
# output: (batch, seq_len_q, d_model)
return output, attn_weightsDimension Example
Setup:
- Model dimension:
- Number of heads:
- Dimension per head:
- Sequence length:
- Batch size:
Shapes throughout:
Input Q, K, V: (32, 10, 512)
After linear proj: (32, 10, 512)
After split_heads: (32, 8, 10, 64)
(batch, heads, seq_len, d_k)
Attention scores: (32, 8, 10, 10)
(batch, heads, seq_len_q, seq_len_k)
After attention: (32, 8, 10, 64)
After combine_heads: (32, 10, 512)
After W_o: (32, 10, 512)What Do Different Heads Learn?
Research shows that different heads specialize in different patterns:
Head 1: Short-range dependencies (adjacent words)
The [cat] sat ...
↑↓Head 2: Long-range dependencies (subject-verb agreement)
The cat ... [was] hungry
↑________________↑Head 3: Syntactic patterns (verb-object relationships)
... sat [on] the [mat]
↑________↑Head 4: Positional patterns (specific relative positions)
Head 5-8: Various semantic and structural patterns
Each head learns to focus on different types of relationships, and their combination provides rich contextual understanding.
Comparison: Single-Head vs Multi-Head
| Aspect | Single Head | Multi-Head (8 heads) |
|---|---|---|
| Dimension per head | 512 | 64 |
| Attention patterns | 1 type | 8 different types |
| Parameters | ||
| Complexity | (same!) | |
| Expressiveness | Limited | High |
Key insight: Multi-head attention has similar computational cost but much higher expressiveness.
Why Not Use Larger Dimensions?
Question: Why split 512 dimensions into 8 heads of 64, instead of using 8 separate 512-dimensional attention mechanisms?
Answer: Computational efficiency!
- 8 heads of 64: Total dimension = 512
- 8 separate full attention: Total dimension = 4096 (8× more expensive)
Multi-head attention gets the benefits of multiple attention patterns without dramatically increasing parameters or computation.
Computational Complexity
For heads, sequence length , and dimension :
Per head:
All heads:
Same as single-head attention! We get multiple attention patterns “for free” (no additional asymptotic cost).
Practical Tips
Choosing Number of Heads
Common choices:
- Small models (d=256-384): 4-8 heads
- Base models (d=512-768): 8-12 heads
- Large models (d=1024+): 16-32 heads
Rule of thumb: Keep
- Too small (): Not enough capacity per head
- Too large (): Heads become too similar, less diversity
Initialization
Initialize projection weights with small values:
nn.init.xavier_uniform_(self.W_q.weight)
nn.init.xavier_uniform_(self.W_k.weight)
nn.init.xavier_uniform_(self.W_v.weight)
nn.init.xavier_uniform_(self.W_o.weight)Helps training stability in early stages.
Key Takeaways
- Multi-head attention runs multiple attention operations in parallel
- Each head learns different attention patterns (syntactic, semantic, positional)
- Split into heads of dimension
- Concatenate head outputs and apply final projection
- Same computational complexity as single-head, but more expressive
- Typical configuration: 8-16 heads with per head
- Always used with residual connections and layer normalization in practice
Related Concepts
- Scaled Dot-Product Attention - The attention mechanism used in each head
- Attention Is All You Need - The transformer paper introducing multi-head attention
- Attention Mechanism - General attention formulation
References
Key Papers:
- Vaswani et al. (2017): Attention Is All You Need - Original transformer paper
- Voita et al. (2019): Analyzing Multi-Head Self-Attention - What do different heads learn?
- Michel et al. (2019): Are Sixteen Heads Really Better than One? - Head pruning analysis
Learning Resources:
- Jay Alammar: The Illustrated Transformer
- Harvard NLP: The Annotated Transformer
- Tensor2Tensor Visualization: Interactive Attention Visualization