Skip to Content
LibraryPapersCLIP (2021)

CLIP: Contrastive Language-Image Pre-training

Radford et al. (2021) - “Learning Transferable Visual Models From Natural Language Supervision”

Published: ICML 2021 Citations: 25,000+ (as of 2024) Code: openai/CLIP 

The Revolutionary Insight

While ImageNet changed computer vision by scaling labeled data, CLIP asks: what if we scale supervision from natural language instead? By training on 400 million image-text pairs scraped from the internet, CLIP learns visual concepts from text descriptions—enabling zero-shot transfer to new tasks without task-specific training.

Key breakthrough: Use contrastive learning to align images and text in a shared embedding space, then leverage natural language as a flexible interface for any vision task.

Core Innovation

From Supervised to Natural Language Supervision

Traditional approach (ImageNet):

  1. Collect images
  2. Manually label with predefined classes
  3. Train classifier on those classes
  4. Result: Works only on those specific 1000 classes

CLIP approach:

  1. Collect image-text pairs from the internet (no manual labeling!)
  2. Train to align images with their natural captions
  3. At test time, describe classes with text
  4. Result: Works on any classes described in natural language

This shift from closed-vocabulary to open-vocabulary learning is transformative.

The CLIP Training Objective

Given a batch of NN (image, text) pairs:

LCLIP=12(LIT+LTI)\mathcal{L}_{\text{CLIP}} = \frac{1}{2}(\mathcal{L}_{I \to T} + \mathcal{L}_{T \to I})

where each direction uses InfoNCE contrastive loss:

LIT=1Ni=1Nlogexp(sim(Ii,Ti)/τ)j=1Nexp(sim(Ii,Tj)/τ)\mathcal{L}_{I \to T} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{sim}(I_i, T_i) / \tau)}{\sum_{j=1}^N \exp(\text{sim}(I_i, T_j) / \tau)}

Intuition: For each image, the model must identify its matching text caption among NN options. This is a classification task with NN classes, where the “class” is dynamically defined by the batch.

Architecture

CLIP consists of four main components:

1. Image Encoder

Transforms images into fixed-dimensional embeddings:

Options:

  • ResNet-50/101 (modified with attention pooling instead of average pooling)
  • Vision Transformer (ViT) in multiple sizes:
    • ViT-B/32: Base model with 32×32 patches
    • ViT-B/16: Base model with 16×16 patches
    • ViT-L/14: Large model with 14×14 patches (best performance)

Input: 224×224 RGB image Output: Visual features (e.g., 768-d for ResNet, 512-768-d for ViT)

See Vision Transformer for ViT details.

2. Text Encoder

Transformer encoder (similar to GPT) for text:

Architecture:

  • 12 layers (standard CLIP)
  • 8 attention heads
  • 512-dimensional embeddings
  • Causal attention mask (like GPT)
  • Max sequence length: 77 tokens

Input: Tokenized text with BPE (49,152 vocab size) Output: Text features from [EOS] token (512-d)

Why [EOS] token? Similar to BERT’s [CLS], it aggregates information from the entire sequence.

3. Projection Heads

Linear projections to a shared multimodal embedding space:

# Both modalities project to same dimensionality vision_projection = nn.Linear(vision_dim, embed_dim) # e.g., 768 → 512 text_projection = nn.Linear(text_dim, embed_dim) # e.g., 512 → 512

Shared space dimension: Typically 512 (balances expressiveness and efficiency)

4. Contrastive Learning

Symmetric contrastive loss aligns the embeddings in both directions.

Complete PyTorch Implementation

import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class CLIP(nn.Module): def __init__(self, vision_encoder, text_encoder, vision_dim=768, text_dim=512, embed_dim=512, temperature=0.07): """ CLIP: Contrastive Language-Image Pre-training. Args: vision_encoder: Image encoder (ResNet or ViT) text_encoder: Text encoder (Transformer) vision_dim: Dimension of vision features text_dim: Dimension of text features embed_dim: Dimension of shared embedding space temperature: Initial temperature for contrastive loss """ super().__init__() self.vision_encoder = vision_encoder self.text_encoder = text_encoder # Projection heads to shared embedding space self.vision_projection = nn.Linear(vision_dim, embed_dim, bias=False) self.text_projection = nn.Linear(text_dim, embed_dim, bias=False) # Learnable temperature (log-parameterized for stability) self.log_temperature = nn.Parameter(torch.log(torch.tensor(temperature))) def encode_image(self, images): """ Encode images to normalized embeddings. Args: images: (batch_size, 3, H, W) Returns: Normalized image embeddings (batch_size, embed_dim) """ vision_features = self.vision_encoder(images) image_embeddings = self.vision_projection(vision_features) image_embeddings = F.normalize(image_embeddings, dim=-1) return image_embeddings def encode_text(self, texts): """ Encode text to normalized embeddings. Args: texts: (batch_size, seq_len) - tokenized text Returns: Normalized text embeddings (batch_size, embed_dim) """ text_features = self.text_encoder(texts) text_embeddings = self.text_projection(text_features) text_embeddings = F.normalize(text_embeddings, dim=-1) return text_embeddings def forward(self, images, texts): """ Forward pass computing contrastive loss. Args: images: (batch_size, 3, H, W) texts: (batch_size, seq_len) Returns: loss: Scalar contrastive loss logits_per_image: Similarity matrix (batch_size, batch_size) logits_per_text: Similarity matrix transposed """ # Encode both modalities image_embeddings = self.encode_image(images) text_embeddings = self.encode_text(texts) # Compute similarity matrix temperature = self.log_temperature.exp() logits_per_image = image_embeddings @ text_embeddings.T / temperature logits_per_text = logits_per_image.T # Labels: diagonal elements are positive pairs batch_size = len(images) labels = torch.arange(batch_size, device=images.device) # Symmetric contrastive loss loss_i2t = F.cross_entropy(logits_per_image, labels) loss_t2i = F.cross_entropy(logits_per_text, labels) loss = (loss_i2t + loss_t2i) / 2 return loss, logits_per_image, logits_per_text

Zero-Shot Classification

The killer application of CLIP:

def zero_shot_classify(clip_model, image, class_names, templates=None): """ Classify image using natural language class descriptions. Args: clip_model: Trained CLIP model image: Input image tensor (1, 3, H, W) class_names: List of class names ["cat", "dog", ...] templates: Optional prompt templates Returns: predicted_class: Index of most likely class probabilities: Softmax probabilities for each class """ if templates is None: templates = ["a photo of a {}."] with torch.no_grad(): # Encode image once image_embedding = clip_model.encode_image(image) # (1, embed_dim) # Encode all class descriptions text_embeddings = [] for class_name in class_names: # Use multiple templates and average (ensemble prompts) class_embeddings = [] for template in templates: text = template.format(class_name) text_tokens = tokenize(text) embedding = clip_model.encode_text(text_tokens) class_embeddings.append(embedding) # Average over templates class_embedding = torch.stack(class_embeddings).mean(dim=0) text_embeddings.append(class_embedding) text_embeddings = torch.stack(text_embeddings) # (num_classes, embed_dim) # Compute similarity scores logits = image_embedding @ text_embeddings.T # (1, num_classes) probabilities = F.softmax(logits / 0.01, dim=-1).squeeze(0) predicted_class = logits.argmax(dim=-1).item() return predicted_class, probabilities # Example usage class_names = ["cat", "dog", "bird", "horse"] templates = [ "a photo of a {}.", "a picture of a {}.", "an image showing a {}." ] pred_idx, probs = zero_shot_classify(clip, image, class_names, templates) print(f"Predicted: {class_names[pred_idx]}") for name, prob in zip(class_names, probs): print(f" {name}: {prob:.3f}")

Image-Text Retrieval

def retrieve_images(clip_model, text_query, image_database): """ Retrieve most relevant images for a text query. Args: clip_model: Trained CLIP model text_query: Natural language query string image_database: Tensor of images (N, 3, H, W) Returns: top_k_indices: Indices of most similar images similarities: Similarity scores """ with torch.no_grad(): # Encode text query text_tokens = tokenize(text_query) text_embedding = clip_model.encode_text(text_tokens) # (1, embed_dim) # Encode all images image_embeddings = clip_model.encode_image(image_database) # (N, embed_dim) # Compute similarities similarities = (text_embedding @ image_embeddings.T).squeeze(0) # (N,) # Get top-k most similar top_k_indices = similarities.topk(k=10).indices return top_k_indices, similarities[top_k_indices] # Example results, scores = retrieve_images( clip, "a golden retriever playing in the park", image_database )

Training Details

Dataset

WIT (WebImageText): 400 million (image, text) pairs collected from the internet:

  • Scraped from various websites
  • Includes alt-text, captions, titles
  • No manual cleaning or verification
  • Noisy but diverse

Why web data? Unlike ImageNet’s 1000 manually defined classes, web data provides:

  • Natural language diversity (any concept can appear)
  • Massive scale (400M pairs vs. 1.3M ImageNet images)
  • Free supervision (no annotation cost)
  • Real-world distribution

Training Setup

Batch size: 32,768 image-text pairs

  • Provides 32,768 negatives per positive
  • Requires distributed training across many GPUs
  • Essential for learning discriminative embeddings

Optimization:

  • AdamW optimizer
  • Decoupled weight decay
  • Cosine learning rate schedule with warmup
  • Mixed precision (float16) training

Compute: ~256 V100 GPUs for 12 days (ViT-L/14 model)

Augmentation: Minimal (just random crop and resize)

  • Unlike self-supervised vision (SimCLR), text provides natural augmentation

Model Configurations

ModelVision EncoderParamsTop-1 ImageNet (Zero-Shot)
RN50ResNet-50102M59.6%
RN101ResNet-101119M62.3%
ViT-B/32ViT-Base, 32×32 patches151M63.3%
ViT-B/16ViT-Base, 16×16 patches149M68.3%
ViT-L/14ViT-Large, 14×14 patches428M75.5%

Key Design Decisions

1. Contrastive vs Predictive Objective

Why contrastive? The paper compared:

  • Predictive: Predict exact caption (like image captioning)
  • Contrastive: Match image to caption among alternatives

Result: Contrastive learning scales better and transfers more effectively.

Reason: Contrastive learning captures semantic alignment without requiring exact word prediction.

2. Symmetric Loss

Both directions trained:

  • Image → Text: “Which caption matches this image?”
  • Text → Image: “Which image matches this caption?”

Benefit: Enables bidirectional retrieval and stronger alignment.

3. Batch Size

Massive batches (32,768 pairs) crucial for performance:

  • More negatives = harder contrastive task = better representations
  • Compute cost: O(N2)O(N^2) similarity matrix, but critical for quality

4. Prompt Engineering

Text encoding uses prompt templates:

# Simple prompt "a photo of a {class}" # More descriptive "a photo of a {class}, a type of {category}" # Context-specific "a satellite photo of a {class}" # for satellite imagery "a CT scan showing {class}" # for medical images

Finding: Prompt design significantly affects performance. Use ensemble of prompts for robustness.

5. No Fine-Tuning Required

Unlike traditional transfer learning:

  • No need to train on downstream task data
  • Just encode text descriptions of classes
  • Immediate application to new tasks

Results

Zero-Shot Transfer Performance

ImageNet:

  • CLIP ViT-L/14: 75.5% top-1 accuracy
  • Without any ImageNet training!
  • Matches ResNet-50 trained on ImageNet (76.2%)

27 additional datasets:

  • Tested on diverse domains: OCR, satellite, texture, action recognition, etc.
  • Competitive with task-specific models on many tasks
  • Particularly strong on fine-grained classification

Few-Shot Learning

When fine-tuned on small amounts of labeled data:

  • 16 shots: CLIP often matches models trained on full datasets
  • Linear probe on top of frozen features works well
  • Adaptation is data-efficient

Robustness

Zero-shot CLIP shows superior robustness to distribution shift:

  • ImageNet variants (ImageNet-A, ImageNet-R, ObjectNet): Smaller accuracy drop
  • Why? Natural language supervision provides more robust features than fixed ImageNet classes

Limitations

Despite impressive results, CLIP has notable weaknesses:

  1. Fine-grained classification: Struggles with subtle differences (e.g., car models, flower species)
  2. Counting: Poor at “how many X are in this image?”
  3. Compositional reasoning: Difficulty with spatial relationships (“X to the left of Y”)
  4. Abstract concepts: Better on concrete objects than abstract ideas
  5. Rare concepts: Depends on frequency in training data
  6. Biases: Reflects biases in web data (gender, race, etc.)
  7. OCR: Not trained to read text in images

Impact and Applications

Direct Applications

Zero-shot classification:

  • Classify into any categories without training
  • Useful for long-tail distributions

Image-text retrieval:

  • Search images with natural language
  • Generate captions by retrieving similar images

Prompt-based adaptation:

  • Change behavior by changing prompts
  • No model retraining needed

As Foundation for Other Models

CLIP embeddings used in:

  • DALL-E 2: Text-to-image generation (CLIP guides diffusion)
  • Stable Diffusion: Open-source text-to-image (CLIP text encoder)
  • Flamingo: Vision-language generalist model
  • GPT-4V: Multimodal GPT (likely uses CLIP-like pre-training)

Research Directions

Improvements:

  • OpenCLIP: Open-source reproduction with better data curation
  • ALIGN (Google): 1.8B noisy image-text pairs
  • Florence: Hierarchical vision-language pre-training
  • CoCa: Contrastive + captioning objectives combined

Extensions:

  • Video-CLIP: Extend to video understanding
  • Audio-CLIP: Add audio modality
  • 3D-CLIP: Apply to 3D data

Practical Considerations

Using Pre-trained CLIP

import torch import clip # Load pre-trained model device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-L/14", device=device) # Preprocess image image = preprocess(Image.open("photo.jpg")).unsqueeze(0).to(device) # Tokenize text text = clip.tokenize(["a cat", "a dog"]).to(device) # Compute features with torch.no_grad(): image_features = model.encode_image(image) text_features = model.encode_text(text) # Normalize image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) # Similarity similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) print(similarity) # [[0.95, 0.05]] - 95% cat, 5% dog

Prompt Engineering Tips

Use descriptive prompts:

  • Bad: "cat"
  • Good: "a photo of a cat"
  • Better: "a high quality photo of a cat"

Ensemble multiple prompts:

templates = [ "a photo of a {}", "a picture of a {}", "{} in the scene", "a rendering of a {}", ] # Average embeddings from all templates

Domain-specific prompts:

  • Medical: "a CT scan showing {}"
  • Satellite: "a satellite image of {}"
  • Sketches: "a sketch of {}"

Fine-Tuning Strategies

Linear probe (freeze CLIP, train classifier):

# Freeze CLIP for param in clip_model.parameters(): param.requires_grad = False # Add classifier classifier = nn.Linear(embed_dim, num_classes).to(device) # Train only classifier optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

Full fine-tuning (unfreeze all):

# Unfreeze CLIP for param in clip_model.parameters(): param.requires_grad = True # Lower learning rate for pre-trained weights optimizer = torch.optim.Adam([ {'params': clip_model.parameters(), 'lr': 1e-6}, {'params': classifier.parameters(), 'lr': 1e-3} ])

For Healthcare Applications

CLIP’s approach is directly applicable to medical multimodal learning:

Adapting CLIP to Healthcare

Replace components:

  • Vision encoder: Process medical images (X-rays, CT scans) or symptom sketches
  • Text encoder: Use ClinicalBERT for medical text
  • Training data: (image, report) pairs from medical databases

Benefits:

  • Learn from naturally occurring (image, report) pairs
  • Zero-shot classification of rare conditions
  • Interpretable through text descriptions
  • No need for extensive manual labeling

Example: Symptom sketch + text multimodal model

# Vision encoder: Process 3D body sketches sketch_encoder = ResNet50() # or ViT # Text encoder: ClinicalBERT for symptom descriptions text_encoder = ClinicalBERT() # Train with contrastive loss on (sketch, symptom text) pairs # Zero-shot: "chest pain in left side" → retrieve similar sketches

Key Takeaways

  1. Natural language supervision scales better than fixed-class supervised learning
  2. Contrastive learning aligns vision and language in a shared space
  3. Zero-shot transfer emerges from massive-scale pre-training
  4. Prompt engineering is crucial for downstream performance
  5. Symmetric loss enables bidirectional retrieval
  6. Large batches (32,768!) critical for learning discriminative features
  7. Foundation model for many modern multimodal systems

Learning Resources

Original Paper

Official Resources

Open-Source Implementations

Explanations

Interactive Demos

Follow-Up Work