Skip to Content
LibraryPapersVision Transformers (2021)

Vision Transformer (ViT)

Dosovitskiy et al. (2021) - “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”

Published: ICLR 2021 Citations: 40,000+ (as of 2024) Code: google-research/vision_transformer 

Key Insight

While the Transformer architecture revolutionized NLP, it was unclear if it could replace CNNs for vision. ViT’s breakthrough: treat an image as a sequence of patches and apply a standard transformer encoder—no convolutions needed. With sufficient pre-training data, ViT matches or exceeds state-of-the-art CNNs.

Core Innovation

Patch-based image tokenization: Split images into fixed-size patches (e.g., 16×16 pixels), flatten each patch into a vector, and process the sequence with a standard transformer.

This simple approach:

  • Removes reliance on convolutional inductive biases
  • Enables direct application of transformer architectures to vision
  • Scales better than CNNs with larger datasets
  • Provides interpretability through attention maps

Architecture

The ViT Process

  1. Split image into patches (e.g., 16×16 pixels each)
  2. Flatten each patch to a 1D vector (768 values for 16×16×3)
  3. Apply linear projection to create patch embeddings
  4. Add learnable positional embeddings (transformers lack spatial awareness)
  5. Prepend CLS token (like BERT’s [CLS] for classification)
  6. Process with standard transformer encoder
  7. Use CLS token output for classification

Complete PyTorch Implementation

import torch import torch.nn as nn from einops import rearrange # pip install einops class VisionTransformer(nn.Module): def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, # Embedding dimension depth=12, # Number of transformer layers heads=12, # Number of attention heads mlp_dim=3072): # MLP hidden dimension (4 * dim) """ Vision Transformer for image classification. Args: image_size: Input image size (assumes square images) patch_size: Size of each patch (16 or 32 typical) num_classes: Number of output classes dim: Embedding dimension depth: Number of transformer encoder layers heads: Number of attention heads mlp_dim: Hidden dimension of MLP layers """ super().__init__() # Calculate number of patches num_patches = (image_size // patch_size) ** 2 patch_dim = 3 * patch_size ** 2 # 3 channels (RGB) self.patch_size = patch_size # Patch embedding: linear projection of flattened patches self.patch_embedding = nn.Linear(patch_dim, dim) # Positional embedding (learnable) # +1 for the CLS token self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # CLS token (like BERT's [CLS] token) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=dim, nhead=heads, dim_feedforward=mlp_dim, dropout=0.1, activation='gelu', batch_first=True, norm_first=True # Pre-norm (more stable) ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) # Classification head self.mlp_head = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, num_classes) ) def forward(self, img): """ Forward pass through ViT. Args: img: Input images (batch, 3, H, W) Returns: Class logits (batch, num_classes) """ batch_size = img.shape[0] # Create patches: (batch, 3, H, W) -> (batch, num_patches, patch_dim) patches = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) # Patch embedding: (batch, num_patches, patch_dim) -> (batch, num_patches, dim) x = self.patch_embedding(patches) # Prepend CLS token cls_tokens = self.cls_token.expand(batch_size, -1, -1) x = torch.cat([cls_tokens, x], dim=1) # Add positional embedding x = x + self.pos_embedding # Transformer encoder x = self.transformer(x) # Use CLS token for classification cls_output = x[:, 0] return self.mlp_head(cls_output)

Model Configurations

ViT-Base (ViT-B/16):

  • Layers: 12
  • Hidden dimension: 768
  • MLP dimension: 3072
  • Attention heads: 12
  • Patch size: 16×16
  • Parameters: 86M

ViT-Large (ViT-L/16):

  • Layers: 24
  • Hidden dimension: 1024
  • MLP dimension: 4096
  • Attention heads: 16
  • Parameters: 307M

ViT-Huge (ViT-H/14):

  • Layers: 32
  • Hidden dimension: 1280
  • MLP dimension: 5120
  • Attention heads: 16
  • Patch size: 14×14
  • Parameters: 632M

Key Components

1. Patch Embedding

Instead of convolutions, ViT uses linear projection of flattened patches:

# For 16×16 patches on RGB images: patch_dim = 16 * 16 * 3 = 768 embedding = nn.Linear(768, 768) # Often same dimensionality

Why it works: A linear projection on flattened patches is equivalent to a convolution with kernel size = patch size and stride = patch size. But conceptually simpler.

2. Learnable Positional Embeddings

Transformers have no notion of spatial order. ViT adds learnable 1D positional embeddings:

pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) x = x + pos_embedding

The model learns spatial relationships from data rather than having them hard-coded (unlike CNNs with spatial locality).

Alternatives explored:

  • 2D positional embeddings (slightly better for fine-tuning at different resolutions)
  • Sinusoidal encodings (works but learnable is standard)
  • Relative positional encodings (used in Swin Transformer)

3. CLS Token

A special learnable token prepended to the sequence:

  • Aggregates information from all patches through self-attention
  • Used for the final classification (take CLS output, ignore patch outputs)
  • Same idea as BERT’s [CLS] token

Alternative: Global average pooling over all patch embeddings. CLS token works slightly better in practice.

ViT vs CNNs

AspectCNN (ResNet)ViT
Inductive biasStrong (locality, translation equivariance, hierarchy)Weak (learns from data)
Data requirementsWorks with 100K-1M imagesRequires 10M+ images or clever pre-training
Computational costLinear in image sizeQuadratic in number of patches
Long-range dependenciesRequires deep stackingDirect connections via attention
InterpretabilityFeature maps (less clear)Attention maps (shows what model looks at)
Training timeFaster convergenceSlower convergence (more epochs needed)
Transfer learningStrongVery strong (better than CNNs)
Best use caseSmall-medium data, efficiencyLarge data, maximum performance

When to Use ViT

Use ViT when:

  • Large pre-training datasets available (ImageNet-21K: 14M images)
  • Maximum performance critical
  • Sufficient compute resources available
  • Transfer learning from large models
  • Interpretability via attention important

Use CNNs when:

  • Limited data (<1M images)
  • Computational efficiency critical
  • Faster training needed
  • Strong spatial locality important
  • Edge deployment required

Attention Visualization

One major advantage: ViT attention weights reveal what the model focuses on:

def visualize_attention(vit_model, image): """ Visualize which patches the CLS token attends to. This shows what parts of the image contributed to the prediction. """ # Get attention weights from the last layer with torch.no_grad(): outputs = vit_model(image) # Access attention weights (requires modifying forward pass to return them) attention_weights = vit_model.transformer[-1].self_attn.attention_weights # Get CLS token attention to patches # Shape: (batch, heads, 1 + num_patches, 1 + num_patches) cls_attention = attention_weights[:, :, 0, 1:] # CLS to patches, skip CLS-to-CLS # Average across heads or examine specific heads cls_attention = cls_attention.mean(dim=1) # (batch, num_patches) # Reshape to spatial grid for visualization grid_size = int(cls_attention.shape[-1] ** 0.5) attention_map = cls_attention.reshape(-1, grid_size, grid_size) return attention_map

Attention maps often show:

  • Early layers: Local patterns (edges, textures) similar to CNN early layers
  • Middle layers: Object parts
  • Late layers: Semantic object regions

This provides interpretability unavailable in CNNs.

Hybrid Architectures

Combine CNN inductive bias with ViT global reasoning:

class HybridViT(nn.Module): def __init__(self, cnn_backbone, transformer_encoder, num_classes): """ Hybrid ViT: CNN for local features + Transformer for global reasoning. Args: cnn_backbone: CNN (e.g., ResNet-50) for feature extraction transformer_encoder: Standard transformer encoder num_classes: Output classes """ super().__init__() self.cnn = cnn_backbone self.transformer = transformer_encoder self.head = nn.Linear(transformer_encoder.d_model, num_classes) def forward(self, img): # CNN extracts local features features = self.cnn(img) # (batch, C, H, W) # Treat spatial locations as patches patches = rearrange(features, 'b c h w -> b (h w) c') # Add CLS token and positional embeddings cls_token = self.cls_token.expand(patches.shape[0], -1, -1) patches = torch.cat([cls_token, patches], dim=1) patches = patches + self.pos_embedding # Transformer for global reasoning x = self.transformer(patches) return self.head(x[:, 0]) # Use CLS token

Benefits of hybrid approach:

  • Requires less pre-training data (CNN inductive bias helps)
  • Combines CNN efficiency with ViT expressiveness
  • Good for medical imaging (limited data domains)

Training Strategy

Pre-training on Large Datasets

ViT’s key requirement: large-scale pre-training

The paper shows:

  • ImageNet (1.3M): ViT underperforms ResNet
  • ImageNet-21K (14M): ViT matches ResNet
  • JFT-300M (300M): ViT significantly outperforms ResNet

Why? CNNs have built-in inductive biases (locality, translation equivariance). ViT learns these from data, requiring more examples.

Fine-tuning

After pre-training, fine-tune on downstream tasks:

  1. Replace classification head
  2. Optionally adjust positional embeddings for different resolutions
  3. Fine-tune with smaller learning rate

Transfer learning results: ViT pre-trained on large datasets transfers better than ResNets to various downstream tasks (CIFAR, VTAB benchmark).

Results

ImageNet Classification

Pre-trained on JFT-300M, fine-tuned on ImageNet:

  • ViT-H/14: 88.55% top-1 accuracy (state-of-the-art at publication)
  • ViT-L/16: 87.76% top-1
  • ViT-B/16: 84.15% top-1

Compared to BiT (ResNet-152x4) pre-trained on JFT-300M: 87.54% (ViT-H beats it).

Transfer Learning

On 19 downstream tasks (VTAB benchmark):

  • ViT outperforms ResNets on 18/19 tasks after JFT-300M pre-training
  • Particularly strong on tasks requiring global reasoning

Computational Efficiency

Pre-training cost:

  • ViT-B/16 on TPUv3: 2.5k core-days
  • ViT-L/16: 25k core-days
  • ViT-H/14: 83k core-days

Cheaper than largest ResNets at comparable performance.

ViT Variants and Follow-ups

DeiT (Data-efficient Image Transformer)

Problem: ViT requires huge datasets. Solution: Knowledge distillation from CNN teacher + stronger augmentation.

  • Trains on ImageNet (1.3M images) without JFT-300M
  • Achieves 83.1% top-1 (ViT-B/16 equivalent)
  • Adds distillation token alongside CLS token

Swin Transformer

Problem: ViT has quadratic cost in number of patches. Solution: Hierarchical architecture with shifted windows.

  • Introduces hierarchy (like CNNs: 4 stages with different resolutions)
  • Window-based attention (linear complexity)
  • Better for dense tasks (object detection, segmentation)
  • State-of-the-art on COCO detection

BEiT (BERT pre-training for Images)

Problem: Supervised pre-training requires labels. Solution: Masked image modeling (predict masked patches).

  • Self-supervised pre-training like BERT
  • Uses visual tokens from discrete VAE
  • Strong transfer learning without labels

MAE (Masked Autoencoders)

Problem: Efficient self-supervised pre-training for ViT. Solution: Mask 75% of patches, reconstruct them.

  • Extremely simple and effective
  • Asymmetric encoder-decoder (lightweight decoder)
  • State-of-the-art self-supervised learning
  • Can pre-train ViT-Huge in days

Practical Considerations

1. Patch Size Trade-off

Smaller patches (14×14 or 8×8):

  • Pros: More fine-grained detail, better for small objects
  • Cons: More tokens → quadratic cost increase, slower training

Larger patches (32×32):

  • Pros: Fewer tokens → faster, lower memory
  • Cons: Less spatial resolution, may miss small details

Standard choice: 16×16 balances resolution and efficiency.

2. Positional Embeddings

1D learnable (standard ViT):

  • Simple, works well
  • Model learns 2D structure from data
  • Can interpolate for different resolutions

2D learnable:

  • Explicitly encode row/column position
  • Slightly better for resolution changes
  • More parameters

Relative positional encodings (Swin):

  • Encode relative distances between patches
  • Better generalization to unseen sizes

3. Pre-training Strategies

Supervised (original ViT):

  • Requires labeled data (ImageNet-21K, JFT-300M)
  • Strong performance, but expensive labels

Self-supervised (BEiT, MAE):

  • No labels needed
  • Can leverage unlimited unlabeled images
  • Comparable or better transfer learning

Hybrid (SimMIM, iBOT):

  • Combine multiple self-supervised objectives
  • State-of-the-art results

Impact

ViT demonstrated that:

  1. Transformers work for vision: Architecture generality across modalities
  2. Inductive biases can be learned: Given enough data, models learn what CNNs hard-code
  3. Attention provides interpretability: Attention maps show model reasoning
  4. Scaling works: Larger models + more data = better performance (scaling laws)

Follow-up impact:

  • CLIP uses ViT for vision-language learning
  • Swin Transformer state-of-the-art for detection/segmentation
  • BEiT/MAE enable efficient self-supervised learning
  • Foundation for many multimodal models

Key Takeaways

  1. Patch-based tokenization enables transformers for vision
  2. Large-scale pre-training essential for ViT to outperform CNNs
  3. Less inductive bias = more data needed but better scaling
  4. Attention visualization provides interpretability
  5. Hybrid architectures balance efficiency and performance
  6. Transfer learning from ViT often better than CNNs

Learning Resources

Original Paper

Code

Explanations

Interactive

Follow-up Papers