Siamese Neural Networks
Architecture, variants and theory — from Bell Labs signature verification (1993) to modern xBD pre/post damage assessment.
What is a Siamese Network?
Origins, intuition, and the core idea.
A Siamese network is a neural architecture containing two (or more) identical sub-networks that share the same weights. Each branch processes a different input; their outputs are compared to produce a similarity score or difference map.
- Bromley et al. (1993) — Bell Labs, signature verification. The original "Siamese" twin network.
- Koch et al. (2015) — one-shot image classification via contrastive loss.
- Modern uses — face verification (FaceID-style), satellite change detection (xBD), medical image comparison, self-supervised learning (SimSiam, BYOL).
Core Architecture for xBD Damage Assessment
Pre-event (t₁) ──▶ [Shared Encoder θ] ──▶ F_pre ┐
├─▶ Difference Block |F_pre − F_post| ─▶ Decoder ─▶ Damage map (5 classes)
Post-event (t₂) ──▶ [Shared Encoder θ] ──▶ F_post ┘ │
▼
Combined Loss = Focal + Dice
│
◀──── Backprop (updates ONE shared encoder θ) ────- Pre/post inputs — 512×512×3 satellite images of the same geographic location at two times.
- Shared encoder — ResNet-50/101 or ViT, parameters θ₁ = θ₂. Output: feature maps [C × H/8 × W/8].
- Difference block — |Fpre − Fpost| (or learned fusion). Highlights what changed.
- Decoder — upsamples H/8 → H/4 → H/2 → H with skip connections from encoder stages.
- Output — [5 × H × W] tensor; argmax gives per-pixel damage class.
Inside the Shared Encoder
Conv → BN → ReLU → pool, four times.
Input 512×512×3
→ Conv Block 1 → 256×256×64 (stride 2, 64 filters)
→ Conv Block 2 → 128×128×128
→ Conv Block 3 → 64×64×256
→ Conv Block 4 → 32×32×512 ← Feature map for comparisonDecoder — Skip Connections & Attention Gates
U-Net-style skip connections route encoder features directly to the matching decoder stage, restoring spatial detail that pooling discarded.
Decoder stage k:
u_k = Upsample( d_{k-1} ) # coarse decoder features
skip_k = Encoder feature at same resolution
d_k = Conv( Concat[ u_k, AttentionGate(skip_k, u_k) ] )
# Attention gate (Oktay et al., 2018)
α = σ( ψ( ReLU( W_x · x + W_g · g ) ) ) # gate ∈ [0,1] per pixel
gated_skip = α ⊙ x The attention gate lets the decoder ask "which encoder pixels are relevant here?" and suppresses the rest — sharper boundaries, fewer false positives on background.
Loss Functions for Siamese Training
| Loss | Formula | When to use |
|---|---|---|
| Contrastive | L = y·D² + (1−y)·max(0, m−D)² | Pairwise similarity (face verify, signature) |
| Triplet | L = max(0, D(a,p) − D(a,n) + m) | Anchor/positive/negative — face recognition |
| Focal | −α(1−pₜ)^γ log(pₜ) | Imbalanced segmentation (damage) |
| Dice | 1 − 2|A∩B| / (|A|+|B|) | Region overlap — boundary quality |
| Focal + Dice | L_F + L_D | DAHiTrA / xBD damage assessment ✓ |
Siamese Variants
| Variant | Key idea | When to use |
|---|---|---|
| Classic Siamese | Two identical branches, shared θ | Same modality, same domain (xBD) |
| Pseudo-Siamese | Same architecture, separate θ | Slight domain shift (e.g. different sensors) |
| Asymmetric | Different architectures per branch | Cross-modality (RGB vs SAR, optical vs infrared) |
| Triplet | Three branches (anchor/pos/neg), shared θ | Metric learning, face recognition |
| Self-supervised (SimSiam / BYOL / DINO) | Two augmented views of one image | Label-scarce pre-training |
Vanishing Gradient — Cause & Cure
Why deep Siamese stacks need help, and how skip connections + deep supervision fix it.
For a plain deep net:
∂L/∂x_0 = ∂L/∂x_L · ∏_{ℓ} ∂f_ℓ / ∂x_ℓ
If each |∂f_ℓ/∂x_ℓ| < 1, the product shrinks geometrically → gradient → 0.
Residual block: x_{ℓ+1} = f_ℓ(x_ℓ) + x_ℓ
∂x_{ℓ+1}/∂x_ℓ = ∂f_ℓ/∂x_ℓ + I
→ even if ∂f_ℓ/∂x_ℓ ≈ 0, the +I keeps the product ≥ identity → gradient survives.Deep supervision goes further by attaching auxiliary Focal+Dice losses at intermediate decoder stages, so early encoder layers get strong gradient signal directly — not only through a long chain.
PyTorch — Minimal Siamese Skeleton
python1import torch, torch.nn as nn23class SharedEncoder(nn.Module):4 def __init__(self):5 super().__init__()6 # e.g. ResNet-50 backbone, output [B, 512, H/8, W/8]7 self.backbone = build_resnet50_encoder()89 def forward(self, x):10 return self.backbone(x)1112class SiameseDamageNet(nn.Module):13 def __init__(self, num_classes=5):14 super().__init__()15 self.encoder = SharedEncoder() # ONE encoder, used twice16 self.decoder = UNetDecoder(in_channels=512, out_channels=num_classes)1718 def forward(self, pre, post):19 f_pre = self.encoder(pre) # shared weights θ20 f_post = self.encoder(post) # same θ — literally same nn.Module21 diff = torch.abs(f_pre - f_post) # change signal22 return self.decoder(diff) # [B, 5, H, W] damage logits2324# Training (with Focal + Dice from L2)25model = SiameseDamageNet().cuda()26optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)2728for pre, post, mask in loader:29 logits = model(pre.cuda(), post.cuda())30 loss = focal_fn(logits, mask.cuda().long()) + dice_fn(logits, one_hot(mask))31 optim.zero_grad(); loss.backward(); optim.step()
Which variant for which task?
| Task | Recommended variant | Loss |
|---|---|---|
| xBD building damage (pre/post) | Classic Siamese + UNet decoder | Focal + Dice + deep supervision |
| Signature / face verification | Classic Siamese, embedding head | Contrastive or Triplet |
| Optical vs SAR change detection | Asymmetric Siamese | Focal + Dice |
| Self-supervised pre-training | SimSiam / BYOL / DINO | Cosine similarity (no negatives) |
| One-shot classification | Classic Siamese | Contrastive |