Causal (Masked) Self-Attention
Causal attention is a variant of self-attention that prevents each position from attending to future positions in the sequence. This enables autoregressive generation where each token is predicted based only on previous tokens.
Autoregressive Generation
GPT and other decoder-only models generate text one token at a time, where each new token is conditioned on all previous tokens:
This factorization creates a causal (left-to-right) dependency structure where:
- Token 1 depends on nothing (just the start token)
- Token 2 depends on token 1
- Token 3 depends on tokens 1 and 2
- Token depends on tokens 1 through
Causal Masking
To enforce causal dependencies during training, we prevent each position from attending to future positions using a causal mask (also called a look-ahead mask).
Creating the Causal Mask
# Create lower-triangular mask
seq_len = 5
mask = torch.tril(torch.ones(seq_len, seq_len))
print(mask)
# tensor([[1., 0., 0., 0., 0.],
# [1., 1., 0., 0., 0.],
# [1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1.]])
# Apply to attention scores (before softmax)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)How It Works
The mask enforces the autoregressive property:
- Position 0: Can only attend to itself (row 0 has one 1)
- Position 1: Can attend to positions 0 and 1 (row 1 has two 1s)
- Position 2: Can attend to positions 0, 1, and 2 (row 2 has three 1s)
- Position : Can attend to positions 0 through (row has ones)
By setting masked positions to before softmax, those positions get probability 0:
This ensures predictions for position cannot “cheat” by using information from positions
Decoder-Only Architecture
GPT uses only the decoder part of the original transformer:
Key differences from encoder-decoder transformers:
- No encoder-decoder cross-attention: Only self-attention within the sequence
- Causal self-attention: Each position attends only to previous positions (including itself)
- Unidirectional: Information flows strictly left-to-right
Comparison with other architectures:
| Architecture | Attention Type | Use Case |
|---|---|---|
| Encoder (BERT) | Bidirectional self-attention | Understanding, classification, embeddings |
| Decoder (GPT) | Causal self-attention | Generation, completion, autoregressive tasks |
| Encoder-Decoder (T5) | Both types | Sequence-to-sequence (translation, summarization) |
Complete Implementation
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
# Key, query, value projections for all heads (combined)
self.c_attn = nn.Linear(d_model, 3 * d_model)
# Output projection
self.c_proj = nn.Linear(d_model, d_model)
# Regularization
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.n_heads = n_heads
self.d_model = d_model
def forward(self, x):
B, T, C = x.shape # batch, sequence length, embedding dim
# Calculate query, key, values for all heads in batch
q, k, v = self.c_attn(x).split(self.d_model, dim=2)
# Reshape for multi-head attention: (B, T, C) -> (B, nh, T, hs)
k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
# Causal self-attention with masking
# (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# Apply causal mask
att = att.masked_fill(
self.bias[:, :, :T, :T] == 0, float('-inf')
)
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
# Apply attention to values
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# Reassemble all head outputs side by side
y = y.transpose(1, 2).contiguous().view(B, T, C)
# Output projection
y = self.resid_dropout(self.c_proj(y))
return yTraining vs Inference
During Training
Parallel processing:
- Process entire sequences at once
- Mask prevents looking at future tokens
- All positions computed simultaneously
- Efficient use of GPU parallelism
# Training: process full sequence in one forward pass
logits = model(input_ids) # (B, T, vocab_size)
# Causal mask ensures each position only uses past contextDuring Inference
Sequential generation:
- Generate one token at a time
- Only attend to already-generated tokens
- Can cache key/value pairs for efficiency (KV cache)
- Inherently sequential process
# Inference: generate token by token
for i in range(max_new_tokens):
logits = model(generated_tokens) # (1, current_len, vocab_size)
next_token = sample(logits[:, -1, :]) # Only use last position
generated_tokens = torch.cat([generated_tokens, next_token], dim=1)Attention Pattern Visualization
Causal attention creates a characteristic triangular pattern:
Position: 0 1 2 3 4
┌────┬────┬────┬────┬────┐
Token 0 │ ✓ │ │ │ │ │ Can only see self
├────┼────┼────┼────┼────┤
Token 1 │ ✓ │ ✓ │ │ │ │ Sees 0, 1
├────┼────┼────┼────┼────┤
Token 2 │ ✓ │ ✓ │ ✓ │ │ │ Sees 0, 1, 2
├────┼────┼────┼────┼────┤
Token 3 │ ✓ │ ✓ │ ✓ │ ✓ │ │ Sees 0, 1, 2, 3
├────┼────┼────┼────┼────┤
Token 4 │ ✓ │ ✓ │ ✓ │ ✓ │ ✓ │ Sees all previous
└────┴────┴────┴────┴────┘Why Causal Masking is Critical
Without causal masking (using bidirectional attention during training):
- Model would learn to “cheat” by looking at future tokens
- Training loss would be artificially low
- Model would fail at generation (future tokens not available)
- Train/test mismatch: bidirectional during training, causal during inference
With causal masking:
- Model learns true conditional distributions
- Training matches inference conditions
- Enables autoregressive generation
- Generalizes to generating arbitrary-length sequences
Computational Complexity
Same as standard attention:
- Time complexity: where is sequence length, is model dimension
- Space complexity: for attention matrix
- The mask doesn’t reduce complexity, just enforces causality
For very long sequences, this quadratic scaling is the bottleneck. Efficient attention variants (like Flash Attention) maintain causality while reducing memory.
Key Insights
- Causal masking is essential: Without it, autoregressive models can’t work properly
- Parallel training, sequential inference: Training is efficient despite sequential generation
- No information leakage: Strict left-to-right information flow during both training and inference
- Foundation for LLMs: All GPT-family models use causal attention
- Simple but powerful: Single triangular mask enables autoregressive modeling
Related Concepts
- Scaled Dot-Product Attention - The base attention mechanism
- Multi-Head Attention - Parallel attention heads
- GPT Architecture - Uses causal attention in every block
- Text Generation - Autoregressive sampling using causal attention
- Attention Mechanism - General attention concepts
Learning Resources
Video Explanations
- Andrej Karpathy - Let’s Build GPT - Implements causal attention from scratch
- 3Blue1Brown - Attention in Transformers - Visual explanation of attention masking
Articles
- The Illustrated GPT-2 - Jay Alammar’s visual guide to GPT architecture
- The Annotated Transformer - Harvard NLP’s line-by-line implementation
Papers
- Attention Is All You Need (Vaswani et al., 2017) - Original transformer with causal masking in decoder
- Improving Language Understanding by Generative Pre-Training (Radford et al., 2018) - GPT’s decoder-only architecture