Focal Loss, Dice Loss & Class Weights
Advanced loss functions for severely imbalanced segmentation — the foundation of DAHiTrA-style damage assessment on xBD.
The Problem — Class Imbalance in Disaster Datasets
Why standard accuracy is a misleading metric for xBD.
In xBD, pixel distribution is roughly:
| Class | % of pixels |
|---|---|
| Background | 96.0% |
| No damage | 2.7% |
| Minor damage | 0.1% |
| Major damage | 0.1% |
| Destroyed | 0.1% |
Solution 1 — Class Weights
Make rare mistakes more expensive.
Standard CE : L = − Σ y_i · log(p_i)
Weighted CE : L = − Σ w_i · y_i · log(p_i)| Class | Count (per 100k) | Weight w | Effect |
|---|---|---|---|
| Background | 96,000 | 1 | Baseline — very common |
| No damage | 2,700 | 5 | Moderately important |
| Minor damage | 100 | 20 | High penalty — rare |
| Major damage | 100 | 20 | High penalty — rare |
| Destroyed | 100 | 20 | Highest — critical |
Worked example. Model says p=0.2 for a destroyed pixel: CE = −log(0.2) = 1.61. Weighted: 20 × 1.61 = 32.2. Same model on a background pixel (p=0.95): CE = 0.051 → still 0.051. The destroyed mistake now costs ~600× more.
Solution 2 — Focal Loss (Lin et al., 2017)
Focus on hard examples, ignore easy ones.
FL(p_t) = − α_t · (1 − p_t)^γ · log(p_t)
p_t = model probability of correct class
α_t = class weight (imbalance)
(1−p_t)^γ = modulating factor (suppresses easy examples)
γ = focusing parameter (recommended: 2)| Example | p_t | CE = −log(p_t) | (1−p_t)² | Focal | vs CE |
|---|---|---|---|---|---|
| Easy (bg) | 0.95 | 0.051 | 0.0025 | 0.00013 | ~400× smaller |
| Medium | 0.50 | 0.693 | 0.25 | 0.173 | ~4× smaller |
| Hard (destroyed) | 0.10 | 2.303 | 0.81 | 1.865 | ~1.2× smaller |
| Very hard | 0.05 | 2.996 | 0.9025 | 2.704 | ~1.1× smaller |
The teacher's dilemma. Student A scores 95% — they already know it. Student B scores 20% — they need focused attention. Focal Loss is the teacher who only spends time on Student B.
Solution 3 — Dice Loss
Reward region overlap, not per-pixel accuracy.
Dice = 2 · |A ∩ B| / ( |A| + |B| ) ∈ [0, 1]
Dice Loss = 1 − Dice| Ground truth | Predicted overlap | Dice | Quality |
|---|---|---|---|
| 100 px | 100 (perfect) | 2·100/(100+100) = 1.00 | Perfect |
| 100 px | 80 | 2·80/200 = 0.80 | Good |
| 100 px | 50 | 2·50/200 = 0.50 | Moderate |
| 100 px | 10 | 2·10/200 = 0.10 | Poor |
| 100 px | 0 | 0.00 | None |
Cross-entropy looks at pixels independently and can miss a building's boundary while still scoring well. Dice rewards correctly drawing the whole shape — essential for disaster mapping where boundary quality matters.
Combining Focal + Dice (DAHiTrA)
Each solves a different problem — together they are complete.
Simple : L = L_Focal + L_Dice
Weighted : L = λ₁ · L_Focal + λ₂ · L_DiceMulti-Scale Loss Supervision
Apply loss at multiple decoder stages — gradients reach early layers faster.
| Stage | Resolution | Loss | λ |
|---|---|---|---|
| Deep supervision | 1/8 | Focal | 0.4 |
| Intermediate | 1/4 | Focal + Dice | 0.6 |
| Final output | Full | Focal + Dice | 1.0 |
| Total | — | Σ λᵢ · Lᵢ | — |
Deep supervision injects loss signal at multiple resolutions so that even the deepest encoder layers see strong gradients — directly mitigating vanishing gradient on tall encoder/decoder stacks.
PyTorch — Focal Loss
python1import torch, torch.nn as nn2import torch.nn.functional as F34class FocalLoss(nn.Module):5 def __init__(self, alpha=None, gamma=2):6 super().__init__()7 self.alpha = alpha # class-weight tensor, shape [C]8 self.gamma = gamma # focusing parameter910 def forward(self, logits, targets):11 # 1. Per-pixel CE (no reduction yet)12 ce_loss = F.cross_entropy(13 logits, targets,14 weight=self.alpha, reduction="none"15 )16 # 2. Recover p_t from CE: CE = -log(p_t) -> p_t = exp(-CE)17 pt = torch.exp(-ce_loss)18 # 3. Apply modulating factor (1 - p_t)^gamma19 focal_loss = ((1 - pt) ** self.gamma) * ce_loss20 return focal_loss.mean()
PyTorch — Dice Loss
python1class DiceLoss(nn.Module):2 def __init__(self, smooth=1e-6):3 super().__init__()4 self.smooth = smooth # prevents div by zero56 def forward(self, preds, targets):7 preds = torch.softmax(preds, dim=1) # logits -> probs8 intersection = (preds * targets).sum() # |A ∩ B|9 union = preds.sum() + targets.sum() # |A| + |B|10 dice = (2 * intersection + self.smooth) / (union + self.smooth)11 return 1 - dice
PyTorch — Combined Loss & Siamese Training Loop
python1# Class weights — derived from inverse frequency2class_weights = torch.tensor([1.0, 5.0, 20.0, 20.0, 20.0]).cuda()3focal_fn = FocalLoss(alpha=class_weights, gamma=2)4dice_fn = DiceLoss(smooth=1e-6)56for pre_img, post_img, mask in dataloader:7 pre_img, post_img, mask = pre_img.cuda(), post_img.cuda(), mask.cuda()89 # Siamese forward — both images through shared encoder10 logits = siamese_model(pre_img, post_img) # [B, C, H, W]1112 # One-hot encode mask for Dice13 mask_onehot = F.one_hot(mask.long(), num_classes=5)14 mask_onehot = mask_onehot.permute(0, 3, 1, 2).float()1516 loss_focal = focal_fn(logits, mask.long())17 loss_dice = dice_fn(logits, mask_onehot)18 total_loss = loss_focal + loss_dice # or λ₁·focal + λ₂·dice1920 optimizer.zero_grad()21 total_loss.backward()22 optimizer.step()