Skip to Content
LibraryConceptsScaled Dot-Product Attention

Scaled Dot-Product Attention

The transformer architecture uses a specific formulation of attention called scaled dot-product attention. This is the foundation of all modern transformer models, from BERT to GPT to vision transformers.

Query, Key, Value Framework

The transformer introduces a powerful abstraction using three distinct representations:

  • Query (Q): What am I looking for?
  • Key (K): What do I contain?
  • Value (V): What information do I actually provide?

Analogy - Library Database:

  • Query: Your search terms (“machine learning books”)
  • Keys: Book titles and keywords in the database
  • Values: The actual book content
  • The system compares your query against keys to find relevant books, then returns the values (content)

This separation allows the model to learn when to attend (Q vs K) independently from what information to retrieve (V).

The Attention Formula

This is arguably the single most important equation in modern AI:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Let’s break it down step by step.

Step 1: Compute Attention Scores

Scores=QKT\text{Scores} = QK^T
  • QQ: Query matrix of shape (n,dk)(n, d_k) where nn is sequence length
  • KK: Key matrix of shape (m,dk)(m, d_k) where mm is source sequence length
  • QKTQK^T: Matrix multiplication gives scores of shape (n,m)(n, m)
  • Each element (i,j)(i,j) is the dot product of query ii with key jj

Intuition: How well does each query match each key?

Step 2: Scale the Scores

Scaled Scores=QKTdk\text{Scaled Scores} = \frac{QK^T}{\sqrt{d_k}}

Why scaling by dk\sqrt{d_k}?

For large dkd_k, dot products grow large in magnitude:

  • If QQ and KK have unit variance, QKTQK^T has variance dkd_k
  • Large values push softmax into regions with tiny gradients (saturated)
  • Scaling by dk\sqrt{d_k} keeps variance at 1
  • Ensures stable gradients during training

Example: If dk=64d_k = 64, divide by 8. If dk=512d_k = 512, divide by ≈22.6.

Step 3: Apply Softmax

Attention Weights=softmax(QKTdk)\text{Attention Weights} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)

Softmax normalizes each row to a probability distribution:

  • All weights in a row sum to 1
  • Larger scores get higher probabilities
  • Shape remains (n,m)(n, m)

Step 4: Compute Weighted Sum

Output=Attention WeightsV\text{Output} = \text{Attention Weights} \cdot V
  • VV: Value matrix of shape (m,dv)(m, d_v)
  • Output shape: (n,dv)(n, d_v)
  • Each output position is a weighted combination of all value vectors

Complete PyTorch Implementation

import torch import torch.nn.functional as F import math def scaled_dot_product_attention(Q, K, V, mask=None): """ Scaled Dot-Product Attention Args: Q: Queries, shape (batch, seq_len_q, d_k) K: Keys, shape (batch, seq_len_k, d_k) V: Values, shape (batch, seq_len_k, d_v) mask: Optional mask, shape (batch, seq_len_q, seq_len_k) Returns: output: Attention output, shape (batch, seq_len_q, d_v) attn_weights: Attention weights, shape (batch, seq_len_q, seq_len_k) """ # Get dimension for scaling d_k = Q.size(-1) # Step 1 & 2: Compute scaled scores # Q @ K^T gives (batch, seq_len_q, seq_len_k) scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) # Optional: Apply mask (for padding or causality) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # Step 3: Apply softmax to get attention weights attn_weights = F.softmax(scores, dim=-1) # Step 4: Weighted sum of values output = attn_weights @ V return output, attn_weights

Self-Attention

In transformers, sequences attend to themselves - this is called self-attention:

Q=K=V=XQ = K = V = X

where XX is the input sequence (after linear projections).

What this means:

  • Each position attends to all positions (including itself)
  • Query, key, and value all derived from same sequence
  • Builds rich contextual representations

Example: In “The cat sat on the mat”:

  • “cat” might attend strongly to “The” (determiner) and “sat” (verb)
  • “sat” might attend to “cat” (subject) and “mat” (prepositional object)
  • Each word builds context from all other words

Matrix Dimensions Walkthrough

Understanding shapes is crucial for implementation:

Setup:

  • Sequence length: n=10n = 10 tokens
  • Model dimension: dk=64d_k = 64
  • Batch size: b=32b = 32

Shapes through the computation:

Q: (32, 10, 64) # batch_size, seq_len, d_k K: (32, 10, 64) # batch_size, seq_len, d_k V: (32, 10, 64) # batch_size, seq_len, d_v QK^T: (32, 10, 10) # batch_size, seq_len, seq_len # For each batch, this is a 10×10 attention score matrix After softmax: (32, 10, 10) # Each row sums to 1 (probability distribution) Output: (32, 10, 64) # batch_size, seq_len, d_v # Same shape as input V

Masking

Masks control which positions can attend to which:

Padding Mask

Prevent attention to padding tokens:

# Sequence: [word1, word2, <pad>, <pad>] # Don't attend to padding positions mask = torch.ones(batch, seq_len, seq_len) mask[:, :, 2:] = 0 # Mask positions 2, 3 # In attention: scores = scores.masked_fill(mask == 0, -1e9) # softmax(-1e9) ≈ 0, so padded positions get zero weight

Causal Mask (Look-Ahead Mask)

Prevent attending to future positions for autoregressive generation:

# Lower triangular matrix mask = torch.tril(torch.ones(seq_len, seq_len)) # [[1, 0, 0, 0], # [1, 1, 0, 0], # [1, 1, 1, 0], # [1, 1, 1, 1]] # Position i can only attend to positions ≤ i

This ensures position ii can’t “see the future” during training, critical for language modeling.

Why Scaled Dot-Product?

Advantages:

  1. Efficient: Single matrix multiplication, highly parallelizable on GPUs
  2. Simple: No learned parameters in attention mechanism itself
  3. Stable: Scaling prevents gradient saturation
  4. Flexible: Works for any sequence length

Comparison to alternatives:

MethodComplexityParametersStability
Dot ProductO(n²d)0Poor for large d
Scaled Dot ProductO(n²d)0Good
AdditiveO(n²d)2d² + dGood

Scaled dot-product offers the best balance of speed, simplicity, and stability.

Computational Complexity

For sequence length nn and dimension dd:

  • Compute QKTQK^T: O(n2d)O(n^2 \cdot d)
  • Softmax: O(n2)O(n^2)
  • Multiply by VV: O(n2d)O(n^2 \cdot d)

Total: O(n2d)O(n^2 \cdot d)

The n2n^2 term means attention is quadratic in sequence length - this is the main limitation of transformers for very long sequences (1000+ tokens).

Note: This motivated efficient attention variants like:

  • Longformer (sparse attention)
  • Performer (linear attention approximation)
  • Flash Attention (memory-efficient computation)

Key Takeaways

  • Scaled dot-product attention uses Query, Key, Value framework
  • The formula softmax(QKT/dk)V\text{softmax}(QK^T/\sqrt{d_k})V is fundamental to all transformers
  • Scaling by dk\sqrt{d_k} prevents gradient saturation for large dimensions
  • Self-attention means Q=K=VQ = K = V (derived from same sequence)
  • Masking controls which positions can attend to which (padding, causality)
  • Complexity is O(n2d)O(n^2 d) - quadratic in sequence length
  • Attention mechanism itself has no learned parameters (parameters are in Q/K/V projections)

References

Key Papers:

Learning Resources:

Implementation: