Vision Transformers (ViT)
Architecture, attention mechanism, variants, and the vanishing-gradient story — a complete undergraduate lecture adapted from the CEAMLS slide deck.
Why Transformers for Vision? CNNs vs ViT
Fundamental limitations of convolution and what Transformers offer instead.
A 3×3 kernel sees only 9 pixels at a time — a local receptive field. To let a top-left pixel influence a bottom-right pixel, you must stack many conv layers, and the global receptive field is still approximate and path-dependent.
- Global self-attention — every patch attends to every other patch in a single layer.
- No baked-in inductive bias — the model learns translation invariance (or not) from data.
- Unified architecture with NLP Transformers → multi-modal models, transfer learning.
- Scales with data and parameters; ViT-G (6B params) hits 90%+ on ImageNet.
| Feature | CNN | ViT |
|---|---|---|
| Context range | Local (kernel size) | Global (all patches) |
| Inductive bias | Strong (translation inv.) | Weak — learned from data |
| Data needed | Less (biases help) | More, or pre-training |
| Computation | O(N·k²) | O(N²) — quadratic in tokens |
| Long-range dependencies | Many stacked layers | One single layer ✓ |
Patch Embedding & Tokenisation
How an image becomes a sequence of tokens — the very first step of ViT.
- Original image — 224×224×3 RGB.
- Extract patches — divide into non-overlapping P×P squares (P=16) → 14×14 = 196 patches of 16×16×3 = 768 values each. Each patch is a "word".
- Linear projection — flatten and multiply by learnable matrix E ∈ ℝP²C × D → each patch becomes a D=768 embedding. Mathematically equivalent to a Conv 16×16, stride 16.
- Prepend [CLS] token — learnable vector that aggregates global information; its output is the classification feature (BERT-style).
- Add positional encoding — learnable vector per position. Without it, self-attention is permutation-invariant and cannot tell top-left from bottom-right.
Self-Attention — Query, Key, Value
The mathematical core of the Transformer — how patches 'talk to' each other.
Each token produces three vectors: Q (what I'm looking for), K (what I offer), V (my content).
Attention(Q, K, V) = softmax( Q·Kᵀ / √d_k ) · V
Q = X · W_Q (queries)
K = X · W_K (keys)
V = X · W_V (values)
d_k = dimension of key vectors (scaling)- Q·Kᵀ is an [N×N] matrix; entry (i,j) = how much token i wants to attend to token j.
- Softmax turns each row into probabilities summing to 1.
- ·V takes a weighted sum of values — the output for each token.
Multi-Head Self-Attention (MSA)
Running attention in parallel across multiple representation subspaces.
MSA(X) = Concat(head₁, …, head_h) · W_O
head_i = Attention(X·W_Qi, X·W_Ki, X·W_Vi)Complete ViT Architecture
From image → patches → 12 encoder blocks → CLS → classification head.
Image 224×224×3
↓ Patch Embed (16×16, stride 16)
↓ + [CLS] + Positional Encoding
[197 × 768]
↓ × 12 Encoder Blocks
[197 × 768]
↓ extract CLS token
[1 × 768]
↓ MLP head
[K classes]| Variant | Layers | Hidden D | Heads | Params |
|---|---|---|---|---|
| ViT-Small | 12 | 384 | 6 | 22M |
| ViT-Base | 12 | 768 | 12 | 86M |
| ViT-Large | 24 | 1024 | 16 | 307M |
| ViT-Huge | 32 | 1280 | 16 | 632M |
Transformer Encoder Block — Inside One Block
Two residual sub-blocks: Pre-LN → MSA → ⊕, Pre-LN → FFN → ⊕.
# Pre-LN Transformer block (used in ViT)
x' = MSA( LN(x) ) + x # residual 1
x'' = FFN( LN(x') ) + x' # residual 2
# Layer Norm (per-token, across D features):
LN(x) = γ ⊙ (x − μ) / √(σ² + ε) + β
# Feed-Forward Network (per-token, independent):
FFN(x) = GELU(x·W₁ + b₁) · W₂ + b₂ # 768 → 3072 → 768- Pre-LN (normalise before MSA/FFN) is more stable than Post-LN for deep stacks.
- MSA mixes tokens; FFN does not — FFN acts independently per token, a learned "memory" lookup.
- GELU (Gaussian Error Linear Unit) smoothly gates the input; outperforms ReLU in Transformers.
- Two residuals per block × 12 blocks = 24 gradient highways from output back to input.
Vanishing Gradient — Why Residuals Save ViT
A short proof that the +x in x' = F(x)+x guarantees gradient flow.
x_{ℓ+1} = F_ℓ(x_ℓ) + x_ℓ (residual block)
∂L/∂x_0 = ∂L/∂x_L · ∏_{ℓ=0..L-1} (∂F_ℓ/∂x_ℓ + I)
Even if ∂F_ℓ/∂x_ℓ → 0 (vanishing), the +I term keeps
the product at least the identity → gradient never collapses to 0.Combined with LayerNorm (keeps activations at μ=0, σ=1) and the √d_k scaling inside attention, ViT trains stably to 12, 24, even 32 layers deep — something CNNs only achieved after the ResNet (2015) skip-connection breakthrough.
ViT Variants
Five families you should know — when to reach for each.
| Variant | Year | Key idea | Best for |
|---|---|---|---|
| DeiT | 2021 | Data-efficient ViT; distillation token from a CNN teacher. | Small/medium datasets — no JFT-300M pre-training needed. |
| Swin | 2021 | Shifted-window attention; hierarchical, linear complexity. | Dense prediction (detection, segmentation), high-res imagery. |
| BEiT | 2021 | BERT-style masked image modelling self-supervised pre-training. | Label-scarce domains (medical, satellite). |
| CaiT | 2021 | Class-Attention layers; LayerScale enables very deep ViTs. | Maximum ImageNet accuracy at fixed compute. |
| MaxViT | 2022 | Block-attention + grid-attention; hybrid with conv stem. | Long-range + local features in one network — strong all-rounder. |
ViT vs CNN — When to Use Which
| Situation | Prefer |
|---|---|
| < 50k labelled images, no pre-training | CNN (ResNet/EfficientNet) |
| Large pre-training corpus available (ImageNet-21k, JFT) | ViT |
| High-resolution dense prediction (segmentation) | Swin / MaxViT |
| Multi-modal (image + text) | ViT (shares Transformer with NLP) |
| Edge device, tight FLOPs budget | CNN or MobileViT |
| Satellite / medical, self-supervised pre-training | BEiT / MAE-pretrained ViT |
Equations cheat sheet
text1Patch embedding: z_0 = [x_cls; x_p^1 E; x_p^2 E; … ; x_p^N E] + E_pos2Attention head: A_i = softmax(Q_i K_iᵀ / √d_k) V_i3MSA: MSA(x) = Concat(A_1,…,A_h) W_O4Encoder block: x' = MSA(LN(x)) + x ; z = FFN(LN(x')) + x'5FFN: FFN(x) = GELU(x W_1 + b_1) W_2 + b_26Classification: ŷ = softmax( LN(z_L^{cls}) · W_head )