Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions configs/data/modify/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_target_: albumentations.OneOf
p: 1.0
transforms:
- _target_: stain_normalization.data.modification.HEDFactor
h_range: [0.8, 1.2]
e_range: [0.8, 1.2]
- _target_: stain_normalization.data.modification.ExposureAdjustment
brightness_range: [0.8, 1.2]
- _target_: stain_normalization.data.modification.HSVModification
hue_shift_range: [-0.4, 0.4]
saturation_range: [0.8, 1.5]
value_range: [0.8, 1.3]
- _target_: stain_normalization.data.modification.CombinedModifications
od_scale_range: [0.65, 1.35]
brightness_range: [-0.4, 0.4]
15 changes: 15 additions & 0 deletions configs/data/modify/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_target_: albumentations.OneOf
p: 1.0
transforms:
- _target_: stain_normalization.data.modification.HEDFactor
h_range: [0.8, 1.2]
e_range: [0.8, 1.2]
- _target_: stain_normalization.data.modification.ExposureAdjustment
brightness_range: [0.8, 1.2]
- _target_: stain_normalization.data.modification.HSVModification
hue_shift_range: [-0.4, 0.4]
saturation_range: [0.8, 1.5]
value_range: [0.8, 1.3]
- _target_: stain_normalization.data.modification.CombinedModifications
od_scale_range: [0.65, 1.35]
brightness_range: [-0.4, 0.4]
56 changes: 56 additions & 0 deletions configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defaults:
- hydra: default
- logger: mlflow
- /data/datasets@data.train: stain_normalization/train
- /data/datasets@data.val: stain_normalization/val
- /data/datasets@data.test: stain_normalization/test
- /data/datasets@data.predict: stain_normalization/predict
- _self_

seed: ${random_seed:}
mode: fit
checkpoint: null


callbacks:
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
save_top_k: 1
save_last: true
monitor: validation/loss
mode: min

early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
monitor: validation/loss
patience: 5
mode: min

model:
lr: 1e-4
lambda_dssim: 0.6
lambda_l1: 0.2
lambda_lum: 0.2
lambda_gdl: 0.1

trainer:
enable_checkpointing: True
max_epochs: 100
limit_train_batches: 5000
log_every_n_steps: 50

callbacks:
- ${callbacks.model_checkpoint}
- ${callbacks.early_stopping}

data:
batch_size: 64
num_workers: 8

metadata:
user: ???
experiment_name: Stain-Normalization
run_name: ???
description: ???
hyperparams: ${model}

7 changes: 7 additions & 0 deletions configs/hydra/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- _self_
- override hydra_logging: disabled
- override job_logging: disabled
output_subdir: null
run:
dir: .
6 changes: 6 additions & 0 deletions configs/logger/mlflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_target_: rationai.mlkit.lightning.loggers.MLFlowLogger
experiment_name: ${metadata.experiment_name}
run_name: ${metadata.run_name}
tags:
mlflow.user: ${metadata.user}
mlflow.note.content: ${metadata.description}
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ dependencies = [
[dependency-groups]
dev = ["mypy", "ruff"]

[tool.mypy]
ignore_missing_imports = true

[tool.uv]
environments = ["sys_platform == 'linux'"]

Expand Down
Empty file added stain_normalization/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions stain_normalization/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from random import randint

import hydra
import torch
from lightning import seed_everything
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig, OmegaConf
from rationai.mlkit import Trainer, autolog

from stain_normalization.data import DataModule
from stain_normalization.stain_normalization_model import StainNormalizationModel


OmegaConf.register_new_resolver(
"random_seed", lambda: randint(0, 2**31), use_cache=True
)


@hydra.main(config_path="../configs", config_name="default", version_base=None)
@autolog
def main(config: DictConfig, logger: Logger | None) -> None:
torch.set_float32_matmul_precision("high")
seed_everything(config.seed, workers=True)

data = hydra.utils.instantiate(
config.data,
_recursive_=False, # to avoid instantiating all the datasets
_target_=DataModule,
)
model = hydra.utils.instantiate(config.model, _target_=StainNormalizationModel)
trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger)
getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint)


if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
10 changes: 5 additions & 5 deletions stain_normalization/data/datasets/predict_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def generate_datasets(self) -> Iterable[Dataset[PredictSample]]:
tiles=self.filter_tiles_by_slide(slide["id"]),
normalize=self.normalize,
)
for _, slide in self.slides.iterrows()
for slide in self.slides
)


Expand All @@ -48,10 +48,10 @@ def __init__(
) -> None:
super().__init__()
self.slide_tiles = OpenSlideTilesDataset(
slide_path=slide_metadata.path,
level=slide_metadata.level,
tile_extent_x=slide_metadata.tile_extent_x,
tile_extent_y=slide_metadata.tile_extent_y,
slide_path=slide_metadata["path"],
level=slide_metadata["level"],
tile_extent_x=slide_metadata["tile_extent_x"],
tile_extent_y=slide_metadata["tile_extent_y"],
tiles=tiles,
)

Expand Down
10 changes: 5 additions & 5 deletions stain_normalization/data/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def generate_datasets(self) -> Iterable[Dataset[PredictSample]]:
modify=self.modify,
normalize=self.normalize,
)
for _, slide in self.slides.iterrows()
for slide in self.slides
)


Expand All @@ -54,10 +54,10 @@ def __init__(
) -> None:
super().__init__()
self.slide_tiles = OpenSlideTilesDataset(
slide_path=slide_metadata.path,
level=slide_metadata.level,
tile_extent_x=slide_metadata.tile_extent_x,
tile_extent_y=slide_metadata.tile_extent_y,
slide_path=slide_metadata["path"],
level=slide_metadata["level"],
tile_extent_x=slide_metadata["tile_extent_x"],
tile_extent_y=slide_metadata["tile_extent_y"],
tiles=tiles,
)
self.modify = modify
Expand Down
10 changes: 5 additions & 5 deletions stain_normalization/data/datasets/train_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def generate_datasets(self) -> Iterable[Dataset[Sample]]:
modify=self.modify,
normalize=self.normalize,
)
for _, slide in self.slides.iterrows()
for slide in self.slides
)


Expand All @@ -54,10 +54,10 @@ def __init__(
) -> None:
super().__init__()
self.slide_tiles = OpenSlideTilesDataset(
slide_path=slide_metadata.path,
level=slide_metadata.level,
tile_extent_x=slide_metadata.tile_extent_x,
tile_extent_y=slide_metadata.tile_extent_y,
slide_path=slide_metadata["path"],
level=slide_metadata["level"],
tile_extent_x=slide_metadata["tile_extent_x"],
tile_extent_y=slide_metadata["tile_extent_y"],
tiles=tiles,
)
self.modify = modify
Expand Down
5 changes: 5 additions & 0 deletions stain_normalization/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from stain_normalization.modeling.l1ssim_loss import L1SSIMLoss
from stain_normalization.modeling.unet import UNet


__all__ = ["L1SSIMLoss", "UNet"]
135 changes: 135 additions & 0 deletions stain_normalization/modeling/l1ssim_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
The SSIM is based on implementation from gaussian-splatting and slightly simplified
(pre-computed windows and removal of unused arguments).
https://github.com/graphdeco-inria/gaussian-splatting/blob/472689c0dc70417448fb451bf529ae532d32c095/utils/loss_utils.py
"""

from math import exp

import torch
import torch.nn as nn
import torch.nn.functional as F


class L1SSIMLoss(nn.Module):
def __init__(
self,
lambda_dssim: float = 0.6,
lambda_l1: float = 0.2,
lambda_lum: float = 0.2,
lambda_gdl: float = 0.1,
):
super().__init__()
self.lambda_dssim = lambda_dssim
self.lambda_l1 = lambda_l1
self.lambda_lum = lambda_lum
self.lambda_gdl = lambda_gdl

# precompute SSIM windows to avoid repetition
self.window_size = 11
self.channel = 3
self._1d_window = gaussian(self.window_size, 1.5).unsqueeze(1)
self._2d_window = (
self._1d_window.mm(self._1d_window.t()).float().unsqueeze(0).unsqueeze(0)
)
self.window: torch.Tensor
self.register_buffer(
"window",
self._2d_window.expand(
self.channel, 1, self.window_size, self.window_size
).contiguous(),
)

def forward(self, image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor:
# L1 color loss
l1_loss = F.l1_loss(image, target_image, reduction="mean")

# SSIM structural loss
ssim_loss = 1.0 - self._ssim(image, target_image, self.window)

# Gradient loss for edges
gdl_loss = gradient_loss(image, target_image)

# Luminance / brightness loss
brig_loss = brightness_loss(image, target_image)

# total weighted loss
total_loss = (
self.lambda_l1 * l1_loss
+ self.lambda_dssim * ssim_loss
+ self.lambda_gdl * gdl_loss
+ self.lambda_lum * brig_loss
)

return total_loss

@torch.compile
def _ssim(
self, img1: torch.Tensor, img2: torch.Tensor, window: torch.Tensor
) -> torch.Tensor:
# Modified _ssim that uses pre-computed window
mu1 = F.conv2d(img1, window, padding=self.window_size // 2, groups=self.channel)
mu2 = F.conv2d(img2, window, padding=self.window_size // 2, groups=self.channel)

mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2

sigma1_sq = (
F.conv2d(
img1 * img1, window, padding=self.window_size // 2, groups=self.channel
)
- mu1_sq
)
sigma2_sq = (
F.conv2d(
img2 * img2, window, padding=self.window_size // 2, groups=self.channel
)
- mu2_sq
)
sigma12 = (
F.conv2d(
img1 * img2, window, padding=self.window_size // 2, groups=self.channel
)
- mu1_mu2
)

c1 = 0.01**2
c2 = 0.03**2

ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / (
(mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
)

return ssim_map.mean()


def gaussian(window_size: int, sigma: float) -> torch.Tensor:
gauss = torch.tensor(
[
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
for x in range(window_size)
]
)
return gauss / gauss.sum()


def brightness_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
pred_mean = pred.mean(dim=[1, 2, 3])
target_mean = target.mean(dim=[1, 2, 3])
return F.l1_loss(pred_mean, target_mean)


def gradient_loss(image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor:
def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
dx = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Horizontal gradient
dy = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Vertical gradient
return dx, dy

image_dx, image_dy = gradient(image)
target_dx, target_dy = gradient(target_image)

loss_x = F.l1_loss(image_dx, target_dx, reduction="mean")
loss_y = F.l1_loss(image_dy, target_dy, reduction="mean")

return loss_x + loss_y
Loading