Vision Transformer (ViT)
Dosovitskiy et al. (2021) - “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”
Published: ICLR 2021 Citations: 40,000+ (as of 2024) Code: google-research/vision_transformer
Key Insight
While the Transformer architecture revolutionized NLP, it was unclear if it could replace CNNs for vision. ViT’s breakthrough: treat an image as a sequence of patches and apply a standard transformer encoder—no convolutions needed. With sufficient pre-training data, ViT matches or exceeds state-of-the-art CNNs.
Core Innovation
Patch-based image tokenization: Split images into fixed-size patches (e.g., 16×16 pixels), flatten each patch into a vector, and process the sequence with a standard transformer.
This simple approach:
- Removes reliance on convolutional inductive biases
- Enables direct application of transformer architectures to vision
- Scales better than CNNs with larger datasets
- Provides interpretability through attention maps
Architecture
The ViT Process
- Split image into patches (e.g., 16×16 pixels each)
- Flatten each patch to a 1D vector (768 values for 16×16×3)
- Apply linear projection to create patch embeddings
- Add learnable positional embeddings (transformers lack spatial awareness)
- Prepend CLS token (like BERT’s [CLS] for classification)
- Process with standard transformer encoder
- Use CLS token output for classification
Complete PyTorch Implementation
import torch
import torch.nn as nn
from einops import rearrange # pip install einops
class VisionTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
num_classes=1000,
dim=768, # Embedding dimension
depth=12, # Number of transformer layers
heads=12, # Number of attention heads
mlp_dim=3072): # MLP hidden dimension (4 * dim)
"""
Vision Transformer for image classification.
Args:
image_size: Input image size (assumes square images)
patch_size: Size of each patch (16 or 32 typical)
num_classes: Number of output classes
dim: Embedding dimension
depth: Number of transformer encoder layers
heads: Number of attention heads
mlp_dim: Hidden dimension of MLP layers
"""
super().__init__()
# Calculate number of patches
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2 # 3 channels (RGB)
self.patch_size = patch_size
# Patch embedding: linear projection of flattened patches
self.patch_embedding = nn.Linear(patch_dim, dim)
# Positional embedding (learnable)
# +1 for the CLS token
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# CLS token (like BERT's [CLS] token)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=0.1,
activation='gelu',
batch_first=True,
norm_first=True # Pre-norm (more stable)
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# Classification head
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
"""
Forward pass through ViT.
Args:
img: Input images (batch, 3, H, W)
Returns:
Class logits (batch, num_classes)
"""
batch_size = img.shape[0]
# Create patches: (batch, 3, H, W) -> (batch, num_patches, patch_dim)
patches = rearrange(img,
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=self.patch_size,
p2=self.patch_size)
# Patch embedding: (batch, num_patches, patch_dim) -> (batch, num_patches, dim)
x = self.patch_embedding(patches)
# Prepend CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# Add positional embedding
x = x + self.pos_embedding
# Transformer encoder
x = self.transformer(x)
# Use CLS token for classification
cls_output = x[:, 0]
return self.mlp_head(cls_output)Model Configurations
ViT-Base (ViT-B/16):
- Layers: 12
- Hidden dimension: 768
- MLP dimension: 3072
- Attention heads: 12
- Patch size: 16×16
- Parameters: 86M
ViT-Large (ViT-L/16):
- Layers: 24
- Hidden dimension: 1024
- MLP dimension: 4096
- Attention heads: 16
- Parameters: 307M
ViT-Huge (ViT-H/14):
- Layers: 32
- Hidden dimension: 1280
- MLP dimension: 5120
- Attention heads: 16
- Patch size: 14×14
- Parameters: 632M
Key Components
1. Patch Embedding
Instead of convolutions, ViT uses linear projection of flattened patches:
# For 16×16 patches on RGB images:
patch_dim = 16 * 16 * 3 = 768
embedding = nn.Linear(768, 768) # Often same dimensionalityWhy it works: A linear projection on flattened patches is equivalent to a convolution with kernel size = patch size and stride = patch size. But conceptually simpler.
2. Learnable Positional Embeddings
Transformers have no notion of spatial order. ViT adds learnable 1D positional embeddings:
pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
x = x + pos_embeddingThe model learns spatial relationships from data rather than having them hard-coded (unlike CNNs with spatial locality).
Alternatives explored:
- 2D positional embeddings (slightly better for fine-tuning at different resolutions)
- Sinusoidal encodings (works but learnable is standard)
- Relative positional encodings (used in Swin Transformer)
3. CLS Token
A special learnable token prepended to the sequence:
- Aggregates information from all patches through self-attention
- Used for the final classification (take CLS output, ignore patch outputs)
- Same idea as BERT’s [CLS] token
Alternative: Global average pooling over all patch embeddings. CLS token works slightly better in practice.
ViT vs CNNs
| Aspect | CNN (ResNet) | ViT |
|---|---|---|
| Inductive bias | Strong (locality, translation equivariance, hierarchy) | Weak (learns from data) |
| Data requirements | Works with 100K-1M images | Requires 10M+ images or clever pre-training |
| Computational cost | Linear in image size | Quadratic in number of patches |
| Long-range dependencies | Requires deep stacking | Direct connections via attention |
| Interpretability | Feature maps (less clear) | Attention maps (shows what model looks at) |
| Training time | Faster convergence | Slower convergence (more epochs needed) |
| Transfer learning | Strong | Very strong (better than CNNs) |
| Best use case | Small-medium data, efficiency | Large data, maximum performance |
When to Use ViT
Use ViT when:
- Large pre-training datasets available (ImageNet-21K: 14M images)
- Maximum performance critical
- Sufficient compute resources available
- Transfer learning from large models
- Interpretability via attention important
Use CNNs when:
- Limited data (<1M images)
- Computational efficiency critical
- Faster training needed
- Strong spatial locality important
- Edge deployment required
Attention Visualization
One major advantage: ViT attention weights reveal what the model focuses on:
def visualize_attention(vit_model, image):
"""
Visualize which patches the CLS token attends to.
This shows what parts of the image contributed to the prediction.
"""
# Get attention weights from the last layer
with torch.no_grad():
outputs = vit_model(image)
# Access attention weights (requires modifying forward pass to return them)
attention_weights = vit_model.transformer[-1].self_attn.attention_weights
# Get CLS token attention to patches
# Shape: (batch, heads, 1 + num_patches, 1 + num_patches)
cls_attention = attention_weights[:, :, 0, 1:] # CLS to patches, skip CLS-to-CLS
# Average across heads or examine specific heads
cls_attention = cls_attention.mean(dim=1) # (batch, num_patches)
# Reshape to spatial grid for visualization
grid_size = int(cls_attention.shape[-1] ** 0.5)
attention_map = cls_attention.reshape(-1, grid_size, grid_size)
return attention_mapAttention maps often show:
- Early layers: Local patterns (edges, textures) similar to CNN early layers
- Middle layers: Object parts
- Late layers: Semantic object regions
This provides interpretability unavailable in CNNs.
Hybrid Architectures
Combine CNN inductive bias with ViT global reasoning:
class HybridViT(nn.Module):
def __init__(self, cnn_backbone, transformer_encoder, num_classes):
"""
Hybrid ViT: CNN for local features + Transformer for global reasoning.
Args:
cnn_backbone: CNN (e.g., ResNet-50) for feature extraction
transformer_encoder: Standard transformer encoder
num_classes: Output classes
"""
super().__init__()
self.cnn = cnn_backbone
self.transformer = transformer_encoder
self.head = nn.Linear(transformer_encoder.d_model, num_classes)
def forward(self, img):
# CNN extracts local features
features = self.cnn(img) # (batch, C, H, W)
# Treat spatial locations as patches
patches = rearrange(features, 'b c h w -> b (h w) c')
# Add CLS token and positional embeddings
cls_token = self.cls_token.expand(patches.shape[0], -1, -1)
patches = torch.cat([cls_token, patches], dim=1)
patches = patches + self.pos_embedding
# Transformer for global reasoning
x = self.transformer(patches)
return self.head(x[:, 0]) # Use CLS tokenBenefits of hybrid approach:
- Requires less pre-training data (CNN inductive bias helps)
- Combines CNN efficiency with ViT expressiveness
- Good for medical imaging (limited data domains)
Training Strategy
Pre-training on Large Datasets
ViT’s key requirement: large-scale pre-training
The paper shows:
- ImageNet (1.3M): ViT underperforms ResNet
- ImageNet-21K (14M): ViT matches ResNet
- JFT-300M (300M): ViT significantly outperforms ResNet
Why? CNNs have built-in inductive biases (locality, translation equivariance). ViT learns these from data, requiring more examples.
Fine-tuning
After pre-training, fine-tune on downstream tasks:
- Replace classification head
- Optionally adjust positional embeddings for different resolutions
- Fine-tune with smaller learning rate
Transfer learning results: ViT pre-trained on large datasets transfers better than ResNets to various downstream tasks (CIFAR, VTAB benchmark).
Results
ImageNet Classification
Pre-trained on JFT-300M, fine-tuned on ImageNet:
- ViT-H/14: 88.55% top-1 accuracy (state-of-the-art at publication)
- ViT-L/16: 87.76% top-1
- ViT-B/16: 84.15% top-1
Compared to BiT (ResNet-152x4) pre-trained on JFT-300M: 87.54% (ViT-H beats it).
Transfer Learning
On 19 downstream tasks (VTAB benchmark):
- ViT outperforms ResNets on 18/19 tasks after JFT-300M pre-training
- Particularly strong on tasks requiring global reasoning
Computational Efficiency
Pre-training cost:
- ViT-B/16 on TPUv3: 2.5k core-days
- ViT-L/16: 25k core-days
- ViT-H/14: 83k core-days
Cheaper than largest ResNets at comparable performance.
ViT Variants and Follow-ups
DeiT (Data-efficient Image Transformer)
Problem: ViT requires huge datasets. Solution: Knowledge distillation from CNN teacher + stronger augmentation.
- Trains on ImageNet (1.3M images) without JFT-300M
- Achieves 83.1% top-1 (ViT-B/16 equivalent)
- Adds distillation token alongside CLS token
Swin Transformer
Problem: ViT has quadratic cost in number of patches. Solution: Hierarchical architecture with shifted windows.
- Introduces hierarchy (like CNNs: 4 stages with different resolutions)
- Window-based attention (linear complexity)
- Better for dense tasks (object detection, segmentation)
- State-of-the-art on COCO detection
BEiT (BERT pre-training for Images)
Problem: Supervised pre-training requires labels. Solution: Masked image modeling (predict masked patches).
- Self-supervised pre-training like BERT
- Uses visual tokens from discrete VAE
- Strong transfer learning without labels
MAE (Masked Autoencoders)
Problem: Efficient self-supervised pre-training for ViT. Solution: Mask 75% of patches, reconstruct them.
- Extremely simple and effective
- Asymmetric encoder-decoder (lightweight decoder)
- State-of-the-art self-supervised learning
- Can pre-train ViT-Huge in days
Practical Considerations
1. Patch Size Trade-off
Smaller patches (14×14 or 8×8):
- Pros: More fine-grained detail, better for small objects
- Cons: More tokens → quadratic cost increase, slower training
Larger patches (32×32):
- Pros: Fewer tokens → faster, lower memory
- Cons: Less spatial resolution, may miss small details
Standard choice: 16×16 balances resolution and efficiency.
2. Positional Embeddings
1D learnable (standard ViT):
- Simple, works well
- Model learns 2D structure from data
- Can interpolate for different resolutions
2D learnable:
- Explicitly encode row/column position
- Slightly better for resolution changes
- More parameters
Relative positional encodings (Swin):
- Encode relative distances between patches
- Better generalization to unseen sizes
3. Pre-training Strategies
Supervised (original ViT):
- Requires labeled data (ImageNet-21K, JFT-300M)
- Strong performance, but expensive labels
Self-supervised (BEiT, MAE):
- No labels needed
- Can leverage unlimited unlabeled images
- Comparable or better transfer learning
Hybrid (SimMIM, iBOT):
- Combine multiple self-supervised objectives
- State-of-the-art results
Impact
ViT demonstrated that:
- Transformers work for vision: Architecture generality across modalities
- Inductive biases can be learned: Given enough data, models learn what CNNs hard-code
- Attention provides interpretability: Attention maps show model reasoning
- Scaling works: Larger models + more data = better performance (scaling laws)
Follow-up impact:
- CLIP uses ViT for vision-language learning
- Swin Transformer state-of-the-art for detection/segmentation
- BEiT/MAE enable efficient self-supervised learning
- Foundation for many multimodal models
Key Takeaways
- Patch-based tokenization enables transformers for vision
- Large-scale pre-training essential for ViT to outperform CNNs
- Less inductive bias = more data needed but better scaling
- Attention visualization provides interpretability
- Hybrid architectures balance efficiency and performance
- Transfer learning from ViT often better than CNNs
Related Concepts
- Attention Is All You Need - The transformer architecture
- Attention Mechanism - Core attention mechanism
- Multi-Head Attention - Multiple attention heads
- Scaled Dot-Product Attention - The attention formula
- CLIP - Uses ViT for vision-language learning
- Multimodal Learning - ViT enables multimodal models
- Convolution - What ViT replaces