Skip to Content
LibraryConceptsCausal Attention

Causal (Masked) Self-Attention

Causal attention is a variant of self-attention that prevents each position from attending to future positions in the sequence. This enables autoregressive generation where each token is predicted based only on previous tokens.

Autoregressive Generation

GPT and other decoder-only models generate text one token at a time, where each new token is conditioned on all previous tokens:

P(x1,x2,...,xn)=i=1nP(xix1,...,xi1)P(x_1, x_2, ..., x_n) = \prod_{i=1}^n P(x_i | x_1, ..., x_{i-1})

This factorization creates a causal (left-to-right) dependency structure where:

  • Token 1 depends on nothing (just the start token)
  • Token 2 depends on token 1
  • Token 3 depends on tokens 1 and 2
  • Token nn depends on tokens 1 through n1n-1

Causal Masking

To enforce causal dependencies during training, we prevent each position from attending to future positions using a causal mask (also called a look-ahead mask).

Creating the Causal Mask

# Create lower-triangular mask seq_len = 5 mask = torch.tril(torch.ones(seq_len, seq_len)) print(mask) # tensor([[1., 0., 0., 0., 0.], # [1., 1., 0., 0., 0.], # [1., 1., 1., 0., 0.], # [1., 1., 1., 1., 0.], # [1., 1., 1., 1., 1.]]) # Apply to attention scores (before softmax) scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(scores, dim=-1)

How It Works

The mask enforces the autoregressive property:

  • Position 0: Can only attend to itself (row 0 has one 1)
  • Position 1: Can attend to positions 0 and 1 (row 1 has two 1s)
  • Position 2: Can attend to positions 0, 1, and 2 (row 2 has three 1s)
  • Position ii: Can attend to positions 0 through ii (row ii has i+1i+1 ones)

By setting masked positions to -\infty before softmax, those positions get probability 0:

softmax()=e=0=0\text{softmax}(-\infty) = \frac{e^{-\infty}}{\sum} = \frac{0}{\sum} = 0

This ensures predictions for position ii cannot “cheat” by using information from positions i+1,i+2,...i+1, i+2, ...

Decoder-Only Architecture

GPT uses only the decoder part of the original transformer:

Key differences from encoder-decoder transformers:

  • No encoder-decoder cross-attention: Only self-attention within the sequence
  • Causal self-attention: Each position attends only to previous positions (including itself)
  • Unidirectional: Information flows strictly left-to-right

Comparison with other architectures:

ArchitectureAttention TypeUse Case
Encoder (BERT)Bidirectional self-attentionUnderstanding, classification, embeddings
Decoder (GPT)Causal self-attentionGeneration, completion, autoregressive tasks
Encoder-Decoder (T5)Both typesSequence-to-sequence (translation, summarization)

Complete Implementation

class CausalSelfAttention(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() assert d_model % n_heads == 0 # Key, query, value projections for all heads (combined) self.c_attn = nn.Linear(d_model, 3 * d_model) # Output projection self.c_proj = nn.Linear(d_model, d_model) # Regularization self.attn_dropout = nn.Dropout(dropout) self.resid_dropout = nn.Dropout(dropout) self.n_heads = n_heads self.d_model = d_model def forward(self, x): B, T, C = x.shape # batch, sequence length, embedding dim # Calculate query, key, values for all heads in batch q, k, v = self.c_attn(x).split(self.d_model, dim=2) # Reshape for multi-head attention: (B, T, C) -> (B, nh, T, hs) k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) # Causal self-attention with masking # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # Apply causal mask att = att.masked_fill( self.bias[:, :, :T, :T] == 0, float('-inf') ) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) # Apply attention to values y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) # Reassemble all head outputs side by side y = y.transpose(1, 2).contiguous().view(B, T, C) # Output projection y = self.resid_dropout(self.c_proj(y)) return y

Training vs Inference

During Training

Parallel processing:

  • Process entire sequences at once
  • Mask prevents looking at future tokens
  • All positions computed simultaneously
  • Efficient use of GPU parallelism
# Training: process full sequence in one forward pass logits = model(input_ids) # (B, T, vocab_size) # Causal mask ensures each position only uses past context

During Inference

Sequential generation:

  • Generate one token at a time
  • Only attend to already-generated tokens
  • Can cache key/value pairs for efficiency (KV cache)
  • Inherently sequential process
# Inference: generate token by token for i in range(max_new_tokens): logits = model(generated_tokens) # (1, current_len, vocab_size) next_token = sample(logits[:, -1, :]) # Only use last position generated_tokens = torch.cat([generated_tokens, next_token], dim=1)

Attention Pattern Visualization

Causal attention creates a characteristic triangular pattern:

Position: 0 1 2 3 4 ┌────┬────┬────┬────┬────┐ Token 0 │ ✓ │ │ │ │ │ Can only see self ├────┼────┼────┼────┼────┤ Token 1 │ ✓ │ ✓ │ │ │ │ Sees 0, 1 ├────┼────┼────┼────┼────┤ Token 2 │ ✓ │ ✓ │ ✓ │ │ │ Sees 0, 1, 2 ├────┼────┼────┼────┼────┤ Token 3 │ ✓ │ ✓ │ ✓ │ ✓ │ │ Sees 0, 1, 2, 3 ├────┼────┼────┼────┼────┤ Token 4 │ ✓ │ ✓ │ ✓ │ ✓ │ ✓ │ Sees all previous └────┴────┴────┴────┴────┘

Why Causal Masking is Critical

Without causal masking (using bidirectional attention during training):

  • Model would learn to “cheat” by looking at future tokens
  • Training loss would be artificially low
  • Model would fail at generation (future tokens not available)
  • Train/test mismatch: bidirectional during training, causal during inference

With causal masking:

  • Model learns true conditional distributions P(xix<i)P(x_i | x_{<i})
  • Training matches inference conditions
  • Enables autoregressive generation
  • Generalizes to generating arbitrary-length sequences

Computational Complexity

Same as standard attention:

  • Time complexity: O(n2d)O(n^2 d) where nn is sequence length, dd is model dimension
  • Space complexity: O(n2)O(n^2) for attention matrix
  • The mask doesn’t reduce complexity, just enforces causality

For very long sequences, this quadratic scaling is the bottleneck. Efficient attention variants (like Flash Attention) maintain causality while reducing memory.

Key Insights

  1. Causal masking is essential: Without it, autoregressive models can’t work properly
  2. Parallel training, sequential inference: Training is efficient despite sequential generation
  3. No information leakage: Strict left-to-right information flow during both training and inference
  4. Foundation for LLMs: All GPT-family models use causal attention
  5. Simple but powerful: Single triangular mask enables autoregressive modeling

Learning Resources

Video Explanations

Articles

Papers

  • Attention Is All You Need (Vaswani et al., 2017) - Original transformer with causal masking in decoder
  • Improving Language Understanding by Generative Pre-Training (Radford et al., 2018) - GPT’s decoder-only architecture