Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions slideflow/mil/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,53 @@ def build_val_dataloader(
dataloader_kwargs=dataloader_kwargs
)

class HierarchicalLoss(nn.Module):
"""Custom hierarchical loss function for multi-level classification.

Handles 3 levels of classification:
- Level 1: Group classification (As/Bs/TC)
- Level 2a: As subtype classification (A/AB) - only for As samples
- Level 2b: Bs subtype classification (B1/B2/B3) - only for Bs samples using ordinal logic
"""
def __init__(self, weight_ce=None, a_weight=1.0, b_weight=1.0):
super().__init__()
self.ce_hi = nn.CrossEntropyLoss(weight=weight_ce)
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
self.a_weight = a_weight
self.b_weight = b_weight

def forward(self, logits, targets):

# Get targets for each level
level1_target = torch.argmax(targets[:, :3], dim=1) # Get As/Bs/TC from first 3 dims
as_target = torch.argmax(targets[:, 3:5], dim=1) # Get A/AB from next 2 dims
bs_target = targets[:, 5:7] # Get B1/B2/B3 ordinal bits

# Split logits
level1_logits = logits[:, :3] # 3 classes: As, Bs, TC
as_logits = logits[:, 3:5] # 2 classes: A, AB
bs_logits = logits[:, 5:7] # 2 bits for ordinal B1/B2/B3

# Level 1 loss
level1_loss = self.ce_hi(level1_logits, level1_target)

# Level 2 losses - only compute for relevant samples
as_mask = (level1_target == 0) # As samples
bs_mask = (level1_target == 1) # Bs samples

as_loss = torch.tensor(0.0).to(logits.device)
bs_loss = torch.tensor(0.0).to(logits.device)

if as_mask.sum() > 0:
as_loss = self.ce(as_logits[as_mask], as_target[as_mask])

if bs_mask.sum() > 0:
bs_loss = self.bce(bs_logits[bs_mask], bs_target[bs_mask])

total_loss = level1_loss + self.a_weight * as_loss + self.b_weight * bs_loss
return total_loss

def inspect_batch(self, batch) -> Tuple[int, int]:
"""Inspect a batch of data.

Expand Down