From bcbd2cb5f2cac1c7f3e7dd2a15175ea42630369f Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sat, 14 Mar 2026 10:25:50 +0100 Subject: [PATCH 1/9] feat: add model and loss function --- configs/data/modify/test.yaml | 15 ++ configs/data/modify/train.yaml | 15 ++ configs/default.yaml | 55 +++++++ configs/hydra/default.yaml | 7 + configs/logger/mlflow.yaml | 6 + pyproject.toml | 3 + stain_normalization/__init__.py | 0 stain_normalization/__main__.py | 36 +++++ stain_normalization/modeling/__init__.py | 5 + stain_normalization/modeling/l1ssim_loss.py | 136 ++++++++++++++++++ stain_normalization/modeling/unet.py | 119 +++++++++++++++ .../stain_normalization_model.py | 81 +++++++++++ 12 files changed, 478 insertions(+) create mode 100644 configs/data/modify/test.yaml create mode 100644 configs/data/modify/train.yaml create mode 100644 configs/default.yaml create mode 100644 configs/hydra/default.yaml create mode 100644 configs/logger/mlflow.yaml create mode 100644 stain_normalization/__init__.py create mode 100644 stain_normalization/__main__.py create mode 100644 stain_normalization/modeling/__init__.py create mode 100644 stain_normalization/modeling/l1ssim_loss.py create mode 100644 stain_normalization/modeling/unet.py create mode 100644 stain_normalization/stain_normalization_model.py diff --git a/configs/data/modify/test.yaml b/configs/data/modify/test.yaml new file mode 100644 index 0000000..955973b --- /dev/null +++ b/configs/data/modify/test.yaml @@ -0,0 +1,15 @@ +_target_: albumentations.OneOf +p: 1.0 +transforms: + - _target_: stain_normalization.data.modification.HEDFactor + h_range: [0.8, 1.2] + e_range: [0.8, 1.2] + - _target_: stain_normalization.data.modification.ExposureAdjustment + brightness_range: [0.8, 1.2] + - _target_: stain_normalization.data.modification.HVSModification + hue_shift_range: [-0.4, 0.4] + saturation_range: [0.8, 1.5] + value_range: [0.8, 1.3] + - _target_: stain_normalization.data.modification.CombinedModifications + intensity_range: [0.65, 1.35] + brightness_range: [-0.4, 0.4] diff --git a/configs/data/modify/train.yaml b/configs/data/modify/train.yaml new file mode 100644 index 0000000..955973b --- /dev/null +++ b/configs/data/modify/train.yaml @@ -0,0 +1,15 @@ +_target_: albumentations.OneOf +p: 1.0 +transforms: + - _target_: stain_normalization.data.modification.HEDFactor + h_range: [0.8, 1.2] + e_range: [0.8, 1.2] + - _target_: stain_normalization.data.modification.ExposureAdjustment + brightness_range: [0.8, 1.2] + - _target_: stain_normalization.data.modification.HVSModification + hue_shift_range: [-0.4, 0.4] + saturation_range: [0.8, 1.5] + value_range: [0.8, 1.3] + - _target_: stain_normalization.data.modification.CombinedModifications + intensity_range: [0.65, 1.35] + brightness_range: [-0.4, 0.4] diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 0000000..cc75c35 --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,55 @@ +defaults: + - hydra: default + - logger: mlflow + - /data/datasets@data.train: stain_normalization/train + - /data/datasets@data.val: stain_normalization/val + - /data/datasets@data.test: stain_normalization/test + - /data/datasets@data.predict: stain_normalization/predict + - _self_ + +seed: ${random_seed:} +mode: fit +checkpoint: null + + +callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + save_top_k: 2 + monitor: validation/loss + mode: min + + early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: validation/loss + patience: 5 + mode: min + +model: + lr: 1e-4 + lambda_dssim: 0.6 + lambda_l1: 0.2 + lambda_lum: 0.2 + lambda_gdl: 0.1 + +trainer: + enable_checkpointing: True + max_epochs: 100 + limit_train_batches: 5000 + log_every_n_steps: 5 + + callbacks: + - ${callbacks.model_checkpoint} + - ${callbacks.early_stopping} + +data: + batch_size: 64 + num_workers: 8 + +metadata: + user: ??? + experiment_name: Stain-Normalization + run_name: ??? + description: ??? + hyperparams: ${model} + diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000..275e331 --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,7 @@ +defaults: + - _self_ + - override hydra_logging: disabled + - override job_logging: disabled +output_subdir: null +run: + dir: . diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100644 index 0000000..10355c7 --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,6 @@ +_target_: rationai.mlkit.lightning.loggers.MLFlowLogger +experiment_name: ${metadata.experiment_name} +run_name: ${metadata.run_name} +tags: + mlflow.user: ${metadata.user} + mlflow.note.content: ${metadata.description} diff --git a/pyproject.toml b/pyproject.toml index 6b19117..4fcfaf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,9 @@ dependencies = [ [dependency-groups] dev = ["mypy", "ruff"] +[tool.mypy] +ignore_missing_imports = true + [tool.uv] environments = ["sys_platform == 'linux'"] diff --git a/stain_normalization/__init__.py b/stain_normalization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stain_normalization/__main__.py b/stain_normalization/__main__.py new file mode 100644 index 0000000..3765c5d --- /dev/null +++ b/stain_normalization/__main__.py @@ -0,0 +1,36 @@ +from random import randint + +import hydra +import torch +from lightning import seed_everything +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import Trainer, autolog + +from stain_normalization.data import DataModule +from stain_normalization.stain_normalization_model import StainNormalizationModel + + +OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True +) + + +@hydra.main(config_path="../configs", config_name="default", version_base=None) +@autolog +def main(config: DictConfig, logger: Logger | None) -> None: + torch.set_float32_matmul_precision("high") + seed_everything(config.seed, workers=True) + + data = hydra.utils.instantiate( + config.data, + _recursive_=False, # to avoid instantiating all the datasets + _target_=DataModule, + ) + model = hydra.utils.instantiate(config.model, _target_=StainNormalizationModel) + trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) + getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/stain_normalization/modeling/__init__.py b/stain_normalization/modeling/__init__.py new file mode 100644 index 0000000..7a8746c --- /dev/null +++ b/stain_normalization/modeling/__init__.py @@ -0,0 +1,5 @@ +from stain_normalization.modeling.l1ssim_loss import L1SSIMLoss +from stain_normalization.modeling.unet import UNet + + +__all__ = ["L1SSIMLoss", "UNet"] diff --git a/stain_normalization/modeling/l1ssim_loss.py b/stain_normalization/modeling/l1ssim_loss.py new file mode 100644 index 0000000..3bbe0de --- /dev/null +++ b/stain_normalization/modeling/l1ssim_loss.py @@ -0,0 +1,136 @@ +"""Original SSIM code based on pytorch-ssim by Evan Su (MIT License). + +https://github.com/Po-Hsun-Su/pytorch-ssim . + +The SSIM is based on implementation from gaussian-splatting and slightly simplified +(pre-computed windows and removal of unused arguments). +https://github.com/graphdeco-inria/gaussian-splatting/blob/472689c0dc70417448fb451bf529ae532d32c095/utils/loss_utils.py +""" + +from math import exp + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class L1SSIMLoss(nn.Module): + def __init__( + self, + lambda_dssim: float = 0.6, + lambda_l1: float = 0.2, + lambda_lum: float = 0.2, + lambda_gdl: float = 0.1, + ): + super().__init__() + self.lambda_dssim = lambda_dssim + self.lambda_l1 = lambda_l1 + self.lambda_lum = lambda_lum + self.lambda_gdl = lambda_gdl + + # precompute SSIM windows to avoid repetition + self.window_size = 11 + self.channel = 3 + self._1d_window = gaussian(self.window_size, 1.5).unsqueeze(1) + self._2d_window = ( + self._1d_window.mm(self._1d_window.t()).float().unsqueeze(0).unsqueeze(0) + ) + self.window = self._2d_window.expand( + self.channel, 1, self.window_size, self.window_size + ).contiguous() + + def forward(self, image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor: + if self.window.device != image.device: + self.window = self.window.to(image.device) + # L1 color loss + l1_loss = F.l1_loss(image, target_image, reduction="mean") + + # SSIM structural loss + ssim_loss = 1.0 - self._ssim(image, target_image, self.window) + + # Gradient loss for edges + gdl_loss = gradient_loss(image, target_image) + + # Luminance / brightness loss + brig_loss = brightness_loss(image, target_image) + + # total weighted loss + total_loss = ( + self.lambda_l1 * l1_loss + + self.lambda_dssim * ssim_loss + + self.lambda_gdl * gdl_loss + + self.lambda_lum * brig_loss + ) + + return total_loss + + @torch.compile + def _ssim( + self, img1: torch.Tensor, img2: torch.Tensor, window: torch.Tensor + ) -> torch.Tensor: + # Modified _ssim that uses pre-computed window + mu1 = F.conv2d(img1, window, padding=self.window_size // 2, groups=self.channel) + mu2 = F.conv2d(img2, window, padding=self.window_size // 2, groups=self.channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d( + img1 * img1, window, padding=self.window_size // 2, groups=self.channel + ) + - mu1_sq + ) + sigma2_sq = ( + F.conv2d( + img2 * img2, window, padding=self.window_size // 2, groups=self.channel + ) + - mu2_sq + ) + sigma12 = ( + F.conv2d( + img1 * img2, window, padding=self.window_size // 2, groups=self.channel + ) + - mu1_mu2 + ) + + c1 = 0.01**2 + c2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ( + (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ) + + return ssim_map.mean() + + +def gaussian(window_size: int, sigma: float) -> torch.Tensor: + gauss = torch.tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def brightness_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + pred_mean = pred.mean(dim=[1, 2, 3]) + target_mean = target.mean(dim=[1, 2, 3]) + return F.l1_loss(pred_mean, target_mean) + + +def gradient_loss(image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor: + def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Horizontal gradient + dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Vertical gradient + return dx, dy + + image_dx, image_dy = gradient(image) + target_dx, target_dy = gradient(target_image) + + loss_x = F.l1_loss(image_dx, target_dx, reduction="mean") + loss_y = F.l1_loss(image_dy, target_dy, reduction="mean") + + return loss_x + loss_y diff --git a/stain_normalization/modeling/unet.py b/stain_normalization/modeling/unet.py new file mode 100644 index 0000000..ebcb3b7 --- /dev/null +++ b/stain_normalization/modeling/unet.py @@ -0,0 +1,119 @@ +"""Adapted U-Net implementation based on the GitHub repository. + +https://github.com/milesial/Pytorch-UNet . +Original U-Net architecture proposed in the paper. +Ronneberger, O., Fischer, P., & Brox, T. (2015). +U-Net: Convolutional Networks for Biomedical Image Segmentation. +arXiv:1505.04597 [cs.CV]. +Retrieved from https://arxiv.org/abs/1505.04597 . +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2.""" + + def __init__( + self, in_channels: int, out_channels: int, mid_channels: int | None = None + ) -> None: + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv.""" + + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv.""" + + def __init__( + self, in_channels: int, out_channels: int, bilinear: bool = True + ) -> None: + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up: nn.Upsample | nn.ConvTranspose2d = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=True + ) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + x1 = self.up(x1) + diffy = x2.size()[2] - x1.size()[2] + diffx = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, diffy // 2, diffy - diffy // 2]) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class UNet(nn.Module): + def __init__( + self, in_channels: int = 3, out_channels: int = 3, bilinear: bool = True + ) -> None: + super().__init__() + self.in_conv = DoubleConv(in_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + + factor = 2 if bilinear else 1 + self.down4 = Down(512, 1024 // factor) + self.up1 = Up(1024, 512 // factor, bilinear) + self.up2 = Up(512, 256 // factor, bilinear) + self.up3 = Up(256, 128 // factor, bilinear) + self.up4 = Up(128, 64, bilinear) + self.out_conv = OutConv(64, out_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.in_conv(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + return self.out_conv(x) diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py new file mode 100644 index 0000000..97e8d8f --- /dev/null +++ b/stain_normalization/stain_normalization_model.py @@ -0,0 +1,81 @@ +from lightning import LightningModule +from torch import Tensor, stack +from torch.optim import Adam +from torch.optim.optimizer import Optimizer +from torchmetrics import MetricCollection +from torchmetrics.image import StructuralSimilarityIndexMeasure +from torchmetrics.regression import MeanAbsoluteError + +from stain_normalization.modeling import L1SSIMLoss, UNet +from stain_normalization.type_aliases import Batch, Outputs, PredictBatch + + +class StainNormalizationModel(LightningModule): + def __init__( + self, + lr: float = 1e-4, + lambda_dssim: float = 0.6, + lambda_l1: float = 0.2, + lambda_lum: float = 0.2, + lambda_gdl: float = 0.1, + ) -> None: + super().__init__() + self.save_hyperparameters() + self.lr = lr + self.unet = UNet(in_channels=3, out_channels=3) + self.criterion = L1SSIMLoss( + lambda_dssim=lambda_dssim, + lambda_l1=lambda_l1, + lambda_lum=lambda_lum, + lambda_gdl=lambda_gdl, + ) + + self.val_metrics = MetricCollection( + {"ssim": StructuralSimilarityIndexMeasure(), "l1": MeanAbsoluteError()} + ) + self.test_metrics = self.val_metrics.clone(prefix="test/") + self.val_metrics.prefix = "validation/" + + def forward(self, x: Tensor) -> Outputs: + return self.unet(x) + + def training_step(self, batch: Batch) -> Tensor: + inputs, targets = batch + outputs = self(inputs) + + loss = self.criterion(outputs, targets) + self.log("train/loss", loss, on_step=True, prog_bar=True) + + return loss + + def validation_step(self, batch: Batch) -> None: + inputs, targets = batch + outputs = self(inputs) + + loss = self.criterion(outputs, targets) + self.log("validation/loss", loss, on_step=False, on_epoch=True, logger=True) + self.val_metrics.update(outputs, targets) + self.log_dict( + self.val_metrics, + batch_size=len(inputs), + on_epoch=True, + ) + + def test_step(self, batch: PredictBatch) -> Outputs: + inputs, data = batch + outputs = self(inputs) + targets = stack([item["original_image_tensor"] for item in data]) + self.test_metrics.update(outputs, targets) + self.log_dict( + self.test_metrics, + batch_size=len(inputs), + on_epoch=True, + ) + return outputs + + def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs: + inputs = batch[0] + return self(inputs) + + def configure_optimizers(self) -> Optimizer: + return Adam(self.parameters(), lr=self.lr) From dbc461c918bc3920dab5072a384f8d608081d476 Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sat, 14 Mar 2026 14:22:42 +0100 Subject: [PATCH 2/9] fix: review ai agent feedback --- configs/data/modify/test.yaml | 6 +++--- configs/data/modify/train.yaml | 6 +++--- stain_normalization/modeling/l1ssim_loss.py | 16 +++++++--------- stain_normalization/stain_normalization_model.py | 3 +-- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/configs/data/modify/test.yaml b/configs/data/modify/test.yaml index 955973b..077faf1 100644 --- a/configs/data/modify/test.yaml +++ b/configs/data/modify/test.yaml @@ -1,15 +1,15 @@ _target_: albumentations.OneOf -p: 1.0 +p: 1.0 transforms: - _target_: stain_normalization.data.modification.HEDFactor h_range: [0.8, 1.2] e_range: [0.8, 1.2] - _target_: stain_normalization.data.modification.ExposureAdjustment brightness_range: [0.8, 1.2] - - _target_: stain_normalization.data.modification.HVSModification + - _target_: stain_normalization.data.modification.HSVModification hue_shift_range: [-0.4, 0.4] saturation_range: [0.8, 1.5] value_range: [0.8, 1.3] - _target_: stain_normalization.data.modification.CombinedModifications - intensity_range: [0.65, 1.35] + od_scale_range: [0.65, 1.35] brightness_range: [-0.4, 0.4] diff --git a/configs/data/modify/train.yaml b/configs/data/modify/train.yaml index 955973b..077faf1 100644 --- a/configs/data/modify/train.yaml +++ b/configs/data/modify/train.yaml @@ -1,15 +1,15 @@ _target_: albumentations.OneOf -p: 1.0 +p: 1.0 transforms: - _target_: stain_normalization.data.modification.HEDFactor h_range: [0.8, 1.2] e_range: [0.8, 1.2] - _target_: stain_normalization.data.modification.ExposureAdjustment brightness_range: [0.8, 1.2] - - _target_: stain_normalization.data.modification.HVSModification + - _target_: stain_normalization.data.modification.HSVModification hue_shift_range: [-0.4, 0.4] saturation_range: [0.8, 1.5] value_range: [0.8, 1.3] - _target_: stain_normalization.data.modification.CombinedModifications - intensity_range: [0.65, 1.35] + od_scale_range: [0.65, 1.35] brightness_range: [-0.4, 0.4] diff --git a/stain_normalization/modeling/l1ssim_loss.py b/stain_normalization/modeling/l1ssim_loss.py index 3bbe0de..2285100 100644 --- a/stain_normalization/modeling/l1ssim_loss.py +++ b/stain_normalization/modeling/l1ssim_loss.py @@ -1,7 +1,4 @@ -"""Original SSIM code based on pytorch-ssim by Evan Su (MIT License). - -https://github.com/Po-Hsun-Su/pytorch-ssim . - +""" The SSIM is based on implementation from gaussian-splatting and slightly simplified (pre-computed windows and removal of unused arguments). https://github.com/graphdeco-inria/gaussian-splatting/blob/472689c0dc70417448fb451bf529ae532d32c095/utils/loss_utils.py @@ -35,13 +32,14 @@ def __init__( self._2d_window = ( self._1d_window.mm(self._1d_window.t()).float().unsqueeze(0).unsqueeze(0) ) - self.window = self._2d_window.expand( - self.channel, 1, self.window_size, self.window_size - ).contiguous() + self.register_buffer( + "window", + self._2d_window.expand( + self.channel, 1, self.window_size, self.window_size + ).contiguous(), + ) def forward(self, image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor: - if self.window.device != image.device: - self.window = self.window.to(image.device) # L1 color loss l1_loss = F.l1_loss(image, target_image, reduction="mean") diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py index 97e8d8f..028e329 100644 --- a/stain_normalization/stain_normalization_model.py +++ b/stain_normalization/stain_normalization_model.py @@ -21,7 +21,6 @@ def __init__( ) -> None: super().__init__() self.save_hyperparameters() - self.lr = lr self.unet = UNet(in_channels=3, out_channels=3) self.criterion = L1SSIMLoss( lambda_dssim=lambda_dssim, @@ -78,4 +77,4 @@ def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs: return self(inputs) def configure_optimizers(self) -> Optimizer: - return Adam(self.parameters(), lr=self.lr) + return Adam(self.parameters(), lr=self.hparams.lr) From 769864740473d93cafb9f3a5ffc08cde9b5d12d3 Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sat, 14 Mar 2026 14:23:46 +0100 Subject: [PATCH 3/9] fix: dataloading with new mlkit version --- stain_normalization/data/datasets/predict_dataset.py | 10 +++++----- stain_normalization/data/datasets/test_dataset.py | 10 +++++----- stain_normalization/data/datasets/train_dataset.py | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/stain_normalization/data/datasets/predict_dataset.py b/stain_normalization/data/datasets/predict_dataset.py index aba3372..25ec978 100644 --- a/stain_normalization/data/datasets/predict_dataset.py +++ b/stain_normalization/data/datasets/predict_dataset.py @@ -35,7 +35,7 @@ def generate_datasets(self) -> Iterable[Dataset[PredictSample]]: tiles=self.filter_tiles_by_slide(slide["id"]), normalize=self.normalize, ) - for _, slide in self.slides.iterrows() + for slide in self.slides ) @@ -48,10 +48,10 @@ def __init__( ) -> None: super().__init__() self.slide_tiles = OpenSlideTilesDataset( - slide_path=slide_metadata.path, - level=slide_metadata.level, - tile_extent_x=slide_metadata.tile_extent_x, - tile_extent_y=slide_metadata.tile_extent_y, + slide_path=slide_metadata["path"], + level=slide_metadata["level"], + tile_extent_x=slide_metadata["tile_extent_x"], + tile_extent_y=slide_metadata["tile_extent_y"], tiles=tiles, ) diff --git a/stain_normalization/data/datasets/test_dataset.py b/stain_normalization/data/datasets/test_dataset.py index 5a536b0..941a5bb 100644 --- a/stain_normalization/data/datasets/test_dataset.py +++ b/stain_normalization/data/datasets/test_dataset.py @@ -40,7 +40,7 @@ def generate_datasets(self) -> Iterable[Dataset[PredictSample]]: modify=self.modify, normalize=self.normalize, ) - for _, slide in self.slides.iterrows() + for slide in self.slides ) @@ -54,10 +54,10 @@ def __init__( ) -> None: super().__init__() self.slide_tiles = OpenSlideTilesDataset( - slide_path=slide_metadata.path, - level=slide_metadata.level, - tile_extent_x=slide_metadata.tile_extent_x, - tile_extent_y=slide_metadata.tile_extent_y, + slide_path=slide_metadata["path"], + level=slide_metadata["level"], + tile_extent_x=slide_metadata["tile_extent_x"], + tile_extent_y=slide_metadata["tile_extent_y"], tiles=tiles, ) self.modify = modify diff --git a/stain_normalization/data/datasets/train_dataset.py b/stain_normalization/data/datasets/train_dataset.py index 27e6d43..366c698 100644 --- a/stain_normalization/data/datasets/train_dataset.py +++ b/stain_normalization/data/datasets/train_dataset.py @@ -40,7 +40,7 @@ def generate_datasets(self) -> Iterable[Dataset[Sample]]: modify=self.modify, normalize=self.normalize, ) - for _, slide in self.slides.iterrows() + for slide in self.slides ) @@ -54,10 +54,10 @@ def __init__( ) -> None: super().__init__() self.slide_tiles = OpenSlideTilesDataset( - slide_path=slide_metadata.path, - level=slide_metadata.level, - tile_extent_x=slide_metadata.tile_extent_x, - tile_extent_y=slide_metadata.tile_extent_y, + slide_path=slide_metadata["path"], + level=slide_metadata["level"], + tile_extent_x=slide_metadata["tile_extent_x"], + tile_extent_y=slide_metadata["tile_extent_y"], tiles=tiles, ) self.modify = modify From 1cb501860d9b245114c7003e14349906e4e08ff7 Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sat, 14 Mar 2026 14:32:35 +0100 Subject: [PATCH 4/9] fix: mypy and ruff format fixes --- stain_normalization/data/datasets/test_dataset.py | 2 +- stain_normalization/data/datasets/train_dataset.py | 2 +- stain_normalization/modeling/l1ssim_loss.py | 1 + stain_normalization/stain_normalization_model.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/stain_normalization/data/datasets/test_dataset.py b/stain_normalization/data/datasets/test_dataset.py index 941a5bb..26a7ffb 100644 --- a/stain_normalization/data/datasets/test_dataset.py +++ b/stain_normalization/data/datasets/test_dataset.py @@ -54,7 +54,7 @@ def __init__( ) -> None: super().__init__() self.slide_tiles = OpenSlideTilesDataset( - slide_path=slide_metadata["path"], + slide_path=slide_metadata["path"], level=slide_metadata["level"], tile_extent_x=slide_metadata["tile_extent_x"], tile_extent_y=slide_metadata["tile_extent_y"], diff --git a/stain_normalization/data/datasets/train_dataset.py b/stain_normalization/data/datasets/train_dataset.py index 366c698..a116d9a 100644 --- a/stain_normalization/data/datasets/train_dataset.py +++ b/stain_normalization/data/datasets/train_dataset.py @@ -40,7 +40,7 @@ def generate_datasets(self) -> Iterable[Dataset[Sample]]: modify=self.modify, normalize=self.normalize, ) - for slide in self.slides + for slide in self.slides ) diff --git a/stain_normalization/modeling/l1ssim_loss.py b/stain_normalization/modeling/l1ssim_loss.py index 2285100..2cee12c 100644 --- a/stain_normalization/modeling/l1ssim_loss.py +++ b/stain_normalization/modeling/l1ssim_loss.py @@ -32,6 +32,7 @@ def __init__( self._2d_window = ( self._1d_window.mm(self._1d_window.t()).float().unsqueeze(0).unsqueeze(0) ) + self.window: torch.Tensor self.register_buffer( "window", self._2d_window.expand( diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py index 028e329..48b30c8 100644 --- a/stain_normalization/stain_normalization_model.py +++ b/stain_normalization/stain_normalization_model.py @@ -77,4 +77,4 @@ def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs: return self(inputs) def configure_optimizers(self) -> Optimizer: - return Adam(self.parameters(), lr=self.hparams.lr) + return Adam(self.parameters(), lr=self.hparams.lr) # type: ignore[attr-defined] From 05da4bdfb025608fa0d073c6d32f82f47efbf034 Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sat, 14 Mar 2026 14:39:54 +0100 Subject: [PATCH 5/9] fix: move stacked targets to model device --- stain_normalization/stain_normalization_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py index 48b30c8..de2048e 100644 --- a/stain_normalization/stain_normalization_model.py +++ b/stain_normalization/stain_normalization_model.py @@ -63,7 +63,7 @@ def validation_step(self, batch: Batch) -> None: def test_step(self, batch: PredictBatch) -> Outputs: inputs, data = batch outputs = self(inputs) - targets = stack([item["original_image_tensor"] for item in data]) + targets = stack([item["original_image_tensor"] for item in data]).to(outputs.device) self.test_metrics.update(outputs, targets) self.log_dict( self.test_metrics, From 9df92a823dd9e97459880f3fd184183f781bd643 Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sat, 14 Mar 2026 14:43:30 +0100 Subject: [PATCH 6/9] chore: ruff formating --- stain_normalization/stain_normalization_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py index de2048e..70968bc 100644 --- a/stain_normalization/stain_normalization_model.py +++ b/stain_normalization/stain_normalization_model.py @@ -63,7 +63,9 @@ def validation_step(self, batch: Batch) -> None: def test_step(self, batch: PredictBatch) -> Outputs: inputs, data = batch outputs = self(inputs) - targets = stack([item["original_image_tensor"] for item in data]).to(outputs.device) + targets = stack([item["original_image_tensor"] for item in data]).to( + outputs.device + ) self.test_metrics.update(outputs, targets) self.log_dict( self.test_metrics, From cb9125e7d96d81aebcbec379de7c3a3bdc8dd738 Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sun, 15 Mar 2026 23:40:54 +0100 Subject: [PATCH 7/9] fix: review feedback --- configs/default.yaml | 5 +++-- stain_normalization/modeling/l1ssim_loss.py | 4 ++-- stain_normalization/stain_normalization_model.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/configs/default.yaml b/configs/default.yaml index cc75c35..8b57808 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -15,7 +15,8 @@ checkpoint: null callbacks: model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - save_top_k: 2 + save_top_k: 1 + save_last: true monitor: validation/loss mode: min @@ -36,7 +37,7 @@ trainer: enable_checkpointing: True max_epochs: 100 limit_train_batches: 5000 - log_every_n_steps: 5 + log_every_n_steps: 50 callbacks: - ${callbacks.model_checkpoint} diff --git a/stain_normalization/modeling/l1ssim_loss.py b/stain_normalization/modeling/l1ssim_loss.py index 2cee12c..c557459 100644 --- a/stain_normalization/modeling/l1ssim_loss.py +++ b/stain_normalization/modeling/l1ssim_loss.py @@ -122,8 +122,8 @@ def brightness_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def gradient_loss(image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor: def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Horizontal gradient - dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Vertical gradient + dx = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Horizontal gradient + dy = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Vertical gradient return dx, dy image_dx, image_dy = gradient(image) diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py index 70968bc..7b03e86 100644 --- a/stain_normalization/stain_normalization_model.py +++ b/stain_normalization/stain_normalization_model.py @@ -20,7 +20,7 @@ def __init__( lambda_gdl: float = 0.1, ) -> None: super().__init__() - self.save_hyperparameters() + self.lr = lr self.unet = UNet(in_channels=3, out_channels=3) self.criterion = L1SSIMLoss( lambda_dssim=lambda_dssim, @@ -79,4 +79,4 @@ def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs: return self(inputs) def configure_optimizers(self) -> Optimizer: - return Adam(self.parameters(), lr=self.hparams.lr) # type: ignore[attr-defined] + return Adam(self.parameters(), lr=self.lr) From da187a4fa6ed299d58f84eec0c700ee480932335 Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sun, 15 Mar 2026 23:44:04 +0100 Subject: [PATCH 8/9] Revert "fix: review feedback" This reverts commit cb9125e7d96d81aebcbec379de7c3a3bdc8dd738. --- configs/default.yaml | 5 ++--- stain_normalization/modeling/l1ssim_loss.py | 4 ++-- stain_normalization/stain_normalization_model.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/configs/default.yaml b/configs/default.yaml index 8b57808..cc75c35 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -15,8 +15,7 @@ checkpoint: null callbacks: model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - save_top_k: 1 - save_last: true + save_top_k: 2 monitor: validation/loss mode: min @@ -37,7 +36,7 @@ trainer: enable_checkpointing: True max_epochs: 100 limit_train_batches: 5000 - log_every_n_steps: 50 + log_every_n_steps: 5 callbacks: - ${callbacks.model_checkpoint} diff --git a/stain_normalization/modeling/l1ssim_loss.py b/stain_normalization/modeling/l1ssim_loss.py index c557459..2cee12c 100644 --- a/stain_normalization/modeling/l1ssim_loss.py +++ b/stain_normalization/modeling/l1ssim_loss.py @@ -122,8 +122,8 @@ def brightness_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def gradient_loss(image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor: def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - dx = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Horizontal gradient - dy = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Vertical gradient + dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Horizontal gradient + dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Vertical gradient return dx, dy image_dx, image_dy = gradient(image) diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py index 7b03e86..70968bc 100644 --- a/stain_normalization/stain_normalization_model.py +++ b/stain_normalization/stain_normalization_model.py @@ -20,7 +20,7 @@ def __init__( lambda_gdl: float = 0.1, ) -> None: super().__init__() - self.lr = lr + self.save_hyperparameters() self.unet = UNet(in_channels=3, out_channels=3) self.criterion = L1SSIMLoss( lambda_dssim=lambda_dssim, @@ -79,4 +79,4 @@ def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs: return self(inputs) def configure_optimizers(self) -> Optimizer: - return Adam(self.parameters(), lr=self.lr) + return Adam(self.parameters(), lr=self.hparams.lr) # type: ignore[attr-defined] From e3fc89eb34b368fc5b7c81141d9fb13f238f9ef0 Mon Sep 17 00:00:00 2001 From: Adam Lopatka <524733@mail.muni.cz> Date: Sun, 15 Mar 2026 23:50:43 +0100 Subject: [PATCH 9/9] fix: review feedback --- configs/default.yaml | 5 +++-- stain_normalization/modeling/l1ssim_loss.py | 4 ++-- stain_normalization/stain_normalization_model.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/configs/default.yaml b/configs/default.yaml index cc75c35..8b57808 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -15,7 +15,8 @@ checkpoint: null callbacks: model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - save_top_k: 2 + save_top_k: 1 + save_last: true monitor: validation/loss mode: min @@ -36,7 +37,7 @@ trainer: enable_checkpointing: True max_epochs: 100 limit_train_batches: 5000 - log_every_n_steps: 5 + log_every_n_steps: 50 callbacks: - ${callbacks.model_checkpoint} diff --git a/stain_normalization/modeling/l1ssim_loss.py b/stain_normalization/modeling/l1ssim_loss.py index 2cee12c..c557459 100644 --- a/stain_normalization/modeling/l1ssim_loss.py +++ b/stain_normalization/modeling/l1ssim_loss.py @@ -122,8 +122,8 @@ def brightness_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def gradient_loss(image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor: def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Horizontal gradient - dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Vertical gradient + dx = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Horizontal gradient + dy = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Vertical gradient return dx, dy image_dx, image_dy = gradient(image) diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py index 70968bc..7b03e86 100644 --- a/stain_normalization/stain_normalization_model.py +++ b/stain_normalization/stain_normalization_model.py @@ -20,7 +20,7 @@ def __init__( lambda_gdl: float = 0.1, ) -> None: super().__init__() - self.save_hyperparameters() + self.lr = lr self.unet = UNet(in_channels=3, out_channels=3) self.criterion = L1SSIMLoss( lambda_dssim=lambda_dssim, @@ -79,4 +79,4 @@ def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs: return self(inputs) def configure_optimizers(self) -> Optimizer: - return Adam(self.parameters(), lr=self.hparams.lr) # type: ignore[attr-defined] + return Adam(self.parameters(), lr=self.lr)