Scaled Dot-Product Attention
The transformer architecture uses a specific formulation of attention called scaled dot-product attention. This is the foundation of all modern transformer models, from BERT to GPT to vision transformers.
Query, Key, Value Framework
The transformer introduces a powerful abstraction using three distinct representations:
- Query (Q): What am I looking for?
- Key (K): What do I contain?
- Value (V): What information do I actually provide?
Analogy - Library Database:
- Query: Your search terms (“machine learning books”)
- Keys: Book titles and keywords in the database
- Values: The actual book content
- The system compares your query against keys to find relevant books, then returns the values (content)
This separation allows the model to learn when to attend (Q vs K) independently from what information to retrieve (V).
The Attention Formula
This is arguably the single most important equation in modern AI:
Let’s break it down step by step.
Step 1: Compute Attention Scores
- : Query matrix of shape where is sequence length
- : Key matrix of shape where is source sequence length
- : Matrix multiplication gives scores of shape
- Each element is the dot product of query with key
Intuition: How well does each query match each key?
Step 2: Scale the Scores
Why scaling by ?
For large , dot products grow large in magnitude:
- If and have unit variance, has variance
- Large values push softmax into regions with tiny gradients (saturated)
- Scaling by keeps variance at 1
- Ensures stable gradients during training
Example: If , divide by 8. If , divide by ≈22.6.
Step 3: Apply Softmax
Softmax normalizes each row to a probability distribution:
- All weights in a row sum to 1
- Larger scores get higher probabilities
- Shape remains
Step 4: Compute Weighted Sum
- : Value matrix of shape
- Output shape:
- Each output position is a weighted combination of all value vectors
Complete PyTorch Implementation
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Scaled Dot-Product Attention
Args:
Q: Queries, shape (batch, seq_len_q, d_k)
K: Keys, shape (batch, seq_len_k, d_k)
V: Values, shape (batch, seq_len_k, d_v)
mask: Optional mask, shape (batch, seq_len_q, seq_len_k)
Returns:
output: Attention output, shape (batch, seq_len_q, d_v)
attn_weights: Attention weights, shape (batch, seq_len_q, seq_len_k)
"""
# Get dimension for scaling
d_k = Q.size(-1)
# Step 1 & 2: Compute scaled scores
# Q @ K^T gives (batch, seq_len_q, seq_len_k)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
# Optional: Apply mask (for padding or causality)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Step 3: Apply softmax to get attention weights
attn_weights = F.softmax(scores, dim=-1)
# Step 4: Weighted sum of values
output = attn_weights @ V
return output, attn_weightsSelf-Attention
In transformers, sequences attend to themselves - this is called self-attention:
where is the input sequence (after linear projections).
What this means:
- Each position attends to all positions (including itself)
- Query, key, and value all derived from same sequence
- Builds rich contextual representations
Example: In “The cat sat on the mat”:
- “cat” might attend strongly to “The” (determiner) and “sat” (verb)
- “sat” might attend to “cat” (subject) and “mat” (prepositional object)
- Each word builds context from all other words
Matrix Dimensions Walkthrough
Understanding shapes is crucial for implementation:
Setup:
- Sequence length: tokens
- Model dimension:
- Batch size:
Shapes through the computation:
Q: (32, 10, 64) # batch_size, seq_len, d_k
K: (32, 10, 64) # batch_size, seq_len, d_k
V: (32, 10, 64) # batch_size, seq_len, d_v
QK^T: (32, 10, 10) # batch_size, seq_len, seq_len
# For each batch, this is a 10×10 attention score matrix
After softmax: (32, 10, 10)
# Each row sums to 1 (probability distribution)
Output: (32, 10, 64) # batch_size, seq_len, d_v
# Same shape as input VMasking
Masks control which positions can attend to which:
Padding Mask
Prevent attention to padding tokens:
# Sequence: [word1, word2, <pad>, <pad>]
# Don't attend to padding positions
mask = torch.ones(batch, seq_len, seq_len)
mask[:, :, 2:] = 0 # Mask positions 2, 3
# In attention:
scores = scores.masked_fill(mask == 0, -1e9)
# softmax(-1e9) ≈ 0, so padded positions get zero weightCausal Mask (Look-Ahead Mask)
Prevent attending to future positions for autoregressive generation:
# Lower triangular matrix
mask = torch.tril(torch.ones(seq_len, seq_len))
# [[1, 0, 0, 0],
# [1, 1, 0, 0],
# [1, 1, 1, 0],
# [1, 1, 1, 1]]
# Position i can only attend to positions ≤ iThis ensures position can’t “see the future” during training, critical for language modeling.
Why Scaled Dot-Product?
Advantages:
- Efficient: Single matrix multiplication, highly parallelizable on GPUs
- Simple: No learned parameters in attention mechanism itself
- Stable: Scaling prevents gradient saturation
- Flexible: Works for any sequence length
Comparison to alternatives:
| Method | Complexity | Parameters | Stability |
|---|---|---|---|
| Dot Product | O(n²d) | 0 | Poor for large d |
| Scaled Dot Product | O(n²d) | 0 | Good |
| Additive | O(n²d) | 2d² + d | Good |
Scaled dot-product offers the best balance of speed, simplicity, and stability.
Computational Complexity
For sequence length and dimension :
- Compute :
- Softmax:
- Multiply by :
Total:
The term means attention is quadratic in sequence length - this is the main limitation of transformers for very long sequences (1000+ tokens).
Note: This motivated efficient attention variants like:
- Longformer (sparse attention)
- Performer (linear attention approximation)
- Flash Attention (memory-efficient computation)
Key Takeaways
- Scaled dot-product attention uses Query, Key, Value framework
- The formula is fundamental to all transformers
- Scaling by prevents gradient saturation for large dimensions
- Self-attention means (derived from same sequence)
- Masking controls which positions can attend to which (padding, causality)
- Complexity is - quadratic in sequence length
- Attention mechanism itself has no learned parameters (parameters are in Q/K/V projections)
Related Concepts
- Attention Mechanism - General attention formulation
- Multi-Head Attention - Multiple parallel attention operations
- Attention Is All You Need - The transformer paper
References
Key Papers:
- Vaswani et al. (2017): Attention Is All You Need - Original transformer paper
Learning Resources:
- Jay Alammar: The Illustrated Transformer
- Harvard NLP: The Annotated Transformer
- 3Blue1Brown: Visualizing Attention
Implementation:
- PyTorch Documentation: torch.nn.functional.scaled_dot_product_attention