Lecture L2

Focal Loss, Dice Loss & Class Weights

Advanced loss functions for severely imbalanced segmentation — the foundation of DAHiTrA-style damage assessment on xBD.

Class imbalanceWeighted CEFocal (γ=2)Dice overlapFocal + DiceMulti-scale supervisionPyTorch
1

The Problem — Class Imbalance in Disaster Datasets

Why standard accuracy is a misleading metric for xBD.

In xBD, pixel distribution is roughly:

Class% of pixels
Background96.0%
No damage2.7%
Minor damage0.1%
Major damage0.1%
Destroyed0.1%
The accuracy trap
A model that always predicts background scores 96% accuracy — and is 0% useful for damage detection. We need losses that actively penalise missing rare, critical classes.
2

Solution 1 — Class Weights

Make rare mistakes more expensive.

Standard CE :  L = − Σ  y_i · log(p_i)
Weighted CE :  L = − Σ  w_i · y_i · log(p_i)
ClassCount (per 100k)Weight wEffect
Background96,0001Baseline — very common
No damage2,7005Moderately important
Minor damage10020High penalty — rare
Major damage10020High penalty — rare
Destroyed10020Highest — critical

Worked example. Model says p=0.2 for a destroyed pixel: CE = −log(0.2) = 1.61. Weighted: 20 × 1.61 = 32.2. Same model on a background pixel (p=0.95): CE = 0.051 → still 0.051. The destroyed mistake now costs ~600× more.

Why this isn't enough
Class weights still let the model coast on millions of easy background pixels. We need to suppress easy examples too — that is Focal Loss.
3

Solution 2 — Focal Loss (Lin et al., 2017)

Focus on hard examples, ignore easy ones.

FL(p_t) = − α_t · (1 − p_t)^γ · log(p_t)

p_t = model probability of correct class
α_t = class weight (imbalance)
(1−p_t)^γ = modulating factor (suppresses easy examples)
γ = focusing parameter (recommended: 2)
Examplep_tCE = −log(p_t)(1−p_t)²Focalvs CE
Easy (bg)0.950.0510.00250.00013~400× smaller
Medium0.500.6930.250.173~4× smaller
Hard (destroyed)0.102.3030.811.865~1.2× smaller
Very hard0.052.9960.90252.704~1.1× smaller

The teacher's dilemma. Student A scores 95% — they already know it. Student B scores 20% — they need focused attention. Focal Loss is the teacher who only spends time on Student B.

4

Solution 3 — Dice Loss

Reward region overlap, not per-pixel accuracy.

Dice = 2 · |A ∩ B| / ( |A| + |B| )      ∈ [0, 1]
Dice Loss = 1 − Dice
Ground truthPredicted overlapDiceQuality
100 px100 (perfect)2·100/(100+100) = 1.00Perfect
100 px802·80/200 = 0.80Good
100 px502·50/200 = 0.50Moderate
100 px102·10/200 = 0.10Poor
100 px00.00None

Cross-entropy looks at pixels independently and can miss a building's boundary while still scoring well. Dice rewards correctly drawing the whole shape — essential for disaster mapping where boundary quality matters.

5

Combining Focal + Dice (DAHiTrA)

Each solves a different problem — together they are complete.

Simple   :  L = L_Focal + L_Dice
Weighted :  L = λ₁ · L_Focal + λ₂ · L_Dice
CE only
Pixel-wise; can miss building boundaries.
Dice only
Good shapes, still struggles with imbalance.
Focal + Dice ✓
Finds and outlines rare damaged buildings.
6

Multi-Scale Loss Supervision

Apply loss at multiple decoder stages — gradients reach early layers faster.

StageResolutionLossλ
Deep supervision1/8Focal0.4
Intermediate1/4Focal + Dice0.6
Final outputFullFocal + Dice1.0
TotalΣ λᵢ · Lᵢ

Deep supervision injects loss signal at multiple resolutions so that even the deepest encoder layers see strong gradients — directly mitigating vanishing gradient on tall encoder/decoder stacks.

7

PyTorch — Focal Loss

python
1import torch, torch.nn as nn
2import torch.nn.functional as F
3
4class FocalLoss(nn.Module):
5 def __init__(self, alpha=None, gamma=2):
6 super().__init__()
7 self.alpha = alpha # class-weight tensor, shape [C]
8 self.gamma = gamma # focusing parameter
9
10 def forward(self, logits, targets):
11 # 1. Per-pixel CE (no reduction yet)
12 ce_loss = F.cross_entropy(
13 logits, targets,
14 weight=self.alpha, reduction="none"
15 )
16 # 2. Recover p_t from CE: CE = -log(p_t) -> p_t = exp(-CE)
17 pt = torch.exp(-ce_loss)
18 # 3. Apply modulating factor (1 - p_t)^gamma
19 focal_loss = ((1 - pt) ** self.gamma) * ce_loss
20 return focal_loss.mean()
8

PyTorch — Dice Loss

python
1class DiceLoss(nn.Module):
2 def __init__(self, smooth=1e-6):
3 super().__init__()
4 self.smooth = smooth # prevents div by zero
5
6 def forward(self, preds, targets):
7 preds = torch.softmax(preds, dim=1) # logits -> probs
8 intersection = (preds * targets).sum() # |A ∩ B|
9 union = preds.sum() + targets.sum() # |A| + |B|
10 dice = (2 * intersection + self.smooth) / (union + self.smooth)
11 return 1 - dice
9

PyTorch — Combined Loss & Siamese Training Loop

python
1# Class weights — derived from inverse frequency
2class_weights = torch.tensor([1.0, 5.0, 20.0, 20.0, 20.0]).cuda()
3focal_fn = FocalLoss(alpha=class_weights, gamma=2)
4dice_fn = DiceLoss(smooth=1e-6)
5
6for pre_img, post_img, mask in dataloader:
7 pre_img, post_img, mask = pre_img.cuda(), post_img.cuda(), mask.cuda()
8
9 # Siamese forward — both images through shared encoder
10 logits = siamese_model(pre_img, post_img) # [B, C, H, W]
11
12 # One-hot encode mask for Dice
13 mask_onehot = F.one_hot(mask.long(), num_classes=5)
14 mask_onehot = mask_onehot.permute(0, 3, 1, 2).float()
15
16 loss_focal = focal_fn(logits, mask.long())
17 loss_dice = dice_fn(logits, mask_onehot)
18 total_loss = loss_focal + loss_dice # or λ₁·focal + λ₂·dice
19
20 optimizer.zero_grad()
21 total_loss.backward()
22 optimizer.step()
Practical recipe
Start with γ=2, inverse-frequency class weights, and an equal-weighted Focal+Dice. Tune λ₁, λ₂ only after a clean baseline trains. Monitor per-class Dice and recall — not overall accuracy.