diff --git a/configs/data/modify/test.yaml b/configs/data/modify/test.yaml new file mode 100644 index 0000000..077faf1 --- /dev/null +++ b/configs/data/modify/test.yaml @@ -0,0 +1,15 @@ +_target_: albumentations.OneOf +p: 1.0 +transforms: + - _target_: stain_normalization.data.modification.HEDFactor + h_range: [0.8, 1.2] + e_range: [0.8, 1.2] + - _target_: stain_normalization.data.modification.ExposureAdjustment + brightness_range: [0.8, 1.2] + - _target_: stain_normalization.data.modification.HSVModification + hue_shift_range: [-0.4, 0.4] + saturation_range: [0.8, 1.5] + value_range: [0.8, 1.3] + - _target_: stain_normalization.data.modification.CombinedModifications + od_scale_range: [0.65, 1.35] + brightness_range: [-0.4, 0.4] diff --git a/configs/data/modify/train.yaml b/configs/data/modify/train.yaml new file mode 100644 index 0000000..077faf1 --- /dev/null +++ b/configs/data/modify/train.yaml @@ -0,0 +1,15 @@ +_target_: albumentations.OneOf +p: 1.0 +transforms: + - _target_: stain_normalization.data.modification.HEDFactor + h_range: [0.8, 1.2] + e_range: [0.8, 1.2] + - _target_: stain_normalization.data.modification.ExposureAdjustment + brightness_range: [0.8, 1.2] + - _target_: stain_normalization.data.modification.HSVModification + hue_shift_range: [-0.4, 0.4] + saturation_range: [0.8, 1.5] + value_range: [0.8, 1.3] + - _target_: stain_normalization.data.modification.CombinedModifications + od_scale_range: [0.65, 1.35] + brightness_range: [-0.4, 0.4] diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 0000000..8b57808 --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,56 @@ +defaults: + - hydra: default + - logger: mlflow + - /data/datasets@data.train: stain_normalization/train + - /data/datasets@data.val: stain_normalization/val + - /data/datasets@data.test: stain_normalization/test + - /data/datasets@data.predict: stain_normalization/predict + - _self_ + +seed: ${random_seed:} +mode: fit +checkpoint: null + + +callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + save_top_k: 1 + save_last: true + monitor: validation/loss + mode: min + + early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: validation/loss + patience: 5 + mode: min + +model: + lr: 1e-4 + lambda_dssim: 0.6 + lambda_l1: 0.2 + lambda_lum: 0.2 + lambda_gdl: 0.1 + +trainer: + enable_checkpointing: True + max_epochs: 100 + limit_train_batches: 5000 + log_every_n_steps: 50 + + callbacks: + - ${callbacks.model_checkpoint} + - ${callbacks.early_stopping} + +data: + batch_size: 64 + num_workers: 8 + +metadata: + user: ??? + experiment_name: Stain-Normalization + run_name: ??? + description: ??? + hyperparams: ${model} + diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000..275e331 --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,7 @@ +defaults: + - _self_ + - override hydra_logging: disabled + - override job_logging: disabled +output_subdir: null +run: + dir: . diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100644 index 0000000..10355c7 --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,6 @@ +_target_: rationai.mlkit.lightning.loggers.MLFlowLogger +experiment_name: ${metadata.experiment_name} +run_name: ${metadata.run_name} +tags: + mlflow.user: ${metadata.user} + mlflow.note.content: ${metadata.description} diff --git a/pyproject.toml b/pyproject.toml index 6b19117..4fcfaf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,9 @@ dependencies = [ [dependency-groups] dev = ["mypy", "ruff"] +[tool.mypy] +ignore_missing_imports = true + [tool.uv] environments = ["sys_platform == 'linux'"] diff --git a/stain_normalization/__init__.py b/stain_normalization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stain_normalization/__main__.py b/stain_normalization/__main__.py new file mode 100644 index 0000000..3765c5d --- /dev/null +++ b/stain_normalization/__main__.py @@ -0,0 +1,36 @@ +from random import randint + +import hydra +import torch +from lightning import seed_everything +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import Trainer, autolog + +from stain_normalization.data import DataModule +from stain_normalization.stain_normalization_model import StainNormalizationModel + + +OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True +) + + +@hydra.main(config_path="../configs", config_name="default", version_base=None) +@autolog +def main(config: DictConfig, logger: Logger | None) -> None: + torch.set_float32_matmul_precision("high") + seed_everything(config.seed, workers=True) + + data = hydra.utils.instantiate( + config.data, + _recursive_=False, # to avoid instantiating all the datasets + _target_=DataModule, + ) + model = hydra.utils.instantiate(config.model, _target_=StainNormalizationModel) + trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) + getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/stain_normalization/data/datasets/predict_dataset.py b/stain_normalization/data/datasets/predict_dataset.py index aba3372..25ec978 100644 --- a/stain_normalization/data/datasets/predict_dataset.py +++ b/stain_normalization/data/datasets/predict_dataset.py @@ -35,7 +35,7 @@ def generate_datasets(self) -> Iterable[Dataset[PredictSample]]: tiles=self.filter_tiles_by_slide(slide["id"]), normalize=self.normalize, ) - for _, slide in self.slides.iterrows() + for slide in self.slides ) @@ -48,10 +48,10 @@ def __init__( ) -> None: super().__init__() self.slide_tiles = OpenSlideTilesDataset( - slide_path=slide_metadata.path, - level=slide_metadata.level, - tile_extent_x=slide_metadata.tile_extent_x, - tile_extent_y=slide_metadata.tile_extent_y, + slide_path=slide_metadata["path"], + level=slide_metadata["level"], + tile_extent_x=slide_metadata["tile_extent_x"], + tile_extent_y=slide_metadata["tile_extent_y"], tiles=tiles, ) diff --git a/stain_normalization/data/datasets/test_dataset.py b/stain_normalization/data/datasets/test_dataset.py index 5a536b0..26a7ffb 100644 --- a/stain_normalization/data/datasets/test_dataset.py +++ b/stain_normalization/data/datasets/test_dataset.py @@ -40,7 +40,7 @@ def generate_datasets(self) -> Iterable[Dataset[PredictSample]]: modify=self.modify, normalize=self.normalize, ) - for _, slide in self.slides.iterrows() + for slide in self.slides ) @@ -54,10 +54,10 @@ def __init__( ) -> None: super().__init__() self.slide_tiles = OpenSlideTilesDataset( - slide_path=slide_metadata.path, - level=slide_metadata.level, - tile_extent_x=slide_metadata.tile_extent_x, - tile_extent_y=slide_metadata.tile_extent_y, + slide_path=slide_metadata["path"], + level=slide_metadata["level"], + tile_extent_x=slide_metadata["tile_extent_x"], + tile_extent_y=slide_metadata["tile_extent_y"], tiles=tiles, ) self.modify = modify diff --git a/stain_normalization/data/datasets/train_dataset.py b/stain_normalization/data/datasets/train_dataset.py index 27e6d43..a116d9a 100644 --- a/stain_normalization/data/datasets/train_dataset.py +++ b/stain_normalization/data/datasets/train_dataset.py @@ -40,7 +40,7 @@ def generate_datasets(self) -> Iterable[Dataset[Sample]]: modify=self.modify, normalize=self.normalize, ) - for _, slide in self.slides.iterrows() + for slide in self.slides ) @@ -54,10 +54,10 @@ def __init__( ) -> None: super().__init__() self.slide_tiles = OpenSlideTilesDataset( - slide_path=slide_metadata.path, - level=slide_metadata.level, - tile_extent_x=slide_metadata.tile_extent_x, - tile_extent_y=slide_metadata.tile_extent_y, + slide_path=slide_metadata["path"], + level=slide_metadata["level"], + tile_extent_x=slide_metadata["tile_extent_x"], + tile_extent_y=slide_metadata["tile_extent_y"], tiles=tiles, ) self.modify = modify diff --git a/stain_normalization/modeling/__init__.py b/stain_normalization/modeling/__init__.py new file mode 100644 index 0000000..7a8746c --- /dev/null +++ b/stain_normalization/modeling/__init__.py @@ -0,0 +1,5 @@ +from stain_normalization.modeling.l1ssim_loss import L1SSIMLoss +from stain_normalization.modeling.unet import UNet + + +__all__ = ["L1SSIMLoss", "UNet"] diff --git a/stain_normalization/modeling/l1ssim_loss.py b/stain_normalization/modeling/l1ssim_loss.py new file mode 100644 index 0000000..c557459 --- /dev/null +++ b/stain_normalization/modeling/l1ssim_loss.py @@ -0,0 +1,135 @@ +""" +The SSIM is based on implementation from gaussian-splatting and slightly simplified +(pre-computed windows and removal of unused arguments). +https://github.com/graphdeco-inria/gaussian-splatting/blob/472689c0dc70417448fb451bf529ae532d32c095/utils/loss_utils.py +""" + +from math import exp + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class L1SSIMLoss(nn.Module): + def __init__( + self, + lambda_dssim: float = 0.6, + lambda_l1: float = 0.2, + lambda_lum: float = 0.2, + lambda_gdl: float = 0.1, + ): + super().__init__() + self.lambda_dssim = lambda_dssim + self.lambda_l1 = lambda_l1 + self.lambda_lum = lambda_lum + self.lambda_gdl = lambda_gdl + + # precompute SSIM windows to avoid repetition + self.window_size = 11 + self.channel = 3 + self._1d_window = gaussian(self.window_size, 1.5).unsqueeze(1) + self._2d_window = ( + self._1d_window.mm(self._1d_window.t()).float().unsqueeze(0).unsqueeze(0) + ) + self.window: torch.Tensor + self.register_buffer( + "window", + self._2d_window.expand( + self.channel, 1, self.window_size, self.window_size + ).contiguous(), + ) + + def forward(self, image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor: + # L1 color loss + l1_loss = F.l1_loss(image, target_image, reduction="mean") + + # SSIM structural loss + ssim_loss = 1.0 - self._ssim(image, target_image, self.window) + + # Gradient loss for edges + gdl_loss = gradient_loss(image, target_image) + + # Luminance / brightness loss + brig_loss = brightness_loss(image, target_image) + + # total weighted loss + total_loss = ( + self.lambda_l1 * l1_loss + + self.lambda_dssim * ssim_loss + + self.lambda_gdl * gdl_loss + + self.lambda_lum * brig_loss + ) + + return total_loss + + @torch.compile + def _ssim( + self, img1: torch.Tensor, img2: torch.Tensor, window: torch.Tensor + ) -> torch.Tensor: + # Modified _ssim that uses pre-computed window + mu1 = F.conv2d(img1, window, padding=self.window_size // 2, groups=self.channel) + mu2 = F.conv2d(img2, window, padding=self.window_size // 2, groups=self.channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d( + img1 * img1, window, padding=self.window_size // 2, groups=self.channel + ) + - mu1_sq + ) + sigma2_sq = ( + F.conv2d( + img2 * img2, window, padding=self.window_size // 2, groups=self.channel + ) + - mu2_sq + ) + sigma12 = ( + F.conv2d( + img1 * img2, window, padding=self.window_size // 2, groups=self.channel + ) + - mu1_mu2 + ) + + c1 = 0.01**2 + c2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ( + (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ) + + return ssim_map.mean() + + +def gaussian(window_size: int, sigma: float) -> torch.Tensor: + gauss = torch.tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def brightness_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + pred_mean = pred.mean(dim=[1, 2, 3]) + target_mean = target.mean(dim=[1, 2, 3]) + return F.l1_loss(pred_mean, target_mean) + + +def gradient_loss(image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor: + def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + dx = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) # Horizontal gradient + dy = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) # Vertical gradient + return dx, dy + + image_dx, image_dy = gradient(image) + target_dx, target_dy = gradient(target_image) + + loss_x = F.l1_loss(image_dx, target_dx, reduction="mean") + loss_y = F.l1_loss(image_dy, target_dy, reduction="mean") + + return loss_x + loss_y diff --git a/stain_normalization/modeling/unet.py b/stain_normalization/modeling/unet.py new file mode 100644 index 0000000..ebcb3b7 --- /dev/null +++ b/stain_normalization/modeling/unet.py @@ -0,0 +1,119 @@ +"""Adapted U-Net implementation based on the GitHub repository. + +https://github.com/milesial/Pytorch-UNet . +Original U-Net architecture proposed in the paper. +Ronneberger, O., Fischer, P., & Brox, T. (2015). +U-Net: Convolutional Networks for Biomedical Image Segmentation. +arXiv:1505.04597 [cs.CV]. +Retrieved from https://arxiv.org/abs/1505.04597 . +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2.""" + + def __init__( + self, in_channels: int, out_channels: int, mid_channels: int | None = None + ) -> None: + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv.""" + + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv.""" + + def __init__( + self, in_channels: int, out_channels: int, bilinear: bool = True + ) -> None: + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up: nn.Upsample | nn.ConvTranspose2d = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=True + ) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + x1 = self.up(x1) + diffy = x2.size()[2] - x1.size()[2] + diffx = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, diffy // 2, diffy - diffy // 2]) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class UNet(nn.Module): + def __init__( + self, in_channels: int = 3, out_channels: int = 3, bilinear: bool = True + ) -> None: + super().__init__() + self.in_conv = DoubleConv(in_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + + factor = 2 if bilinear else 1 + self.down4 = Down(512, 1024 // factor) + self.up1 = Up(1024, 512 // factor, bilinear) + self.up2 = Up(512, 256 // factor, bilinear) + self.up3 = Up(256, 128 // factor, bilinear) + self.up4 = Up(128, 64, bilinear) + self.out_conv = OutConv(64, out_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.in_conv(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + return self.out_conv(x) diff --git a/stain_normalization/stain_normalization_model.py b/stain_normalization/stain_normalization_model.py new file mode 100644 index 0000000..7b03e86 --- /dev/null +++ b/stain_normalization/stain_normalization_model.py @@ -0,0 +1,82 @@ +from lightning import LightningModule +from torch import Tensor, stack +from torch.optim import Adam +from torch.optim.optimizer import Optimizer +from torchmetrics import MetricCollection +from torchmetrics.image import StructuralSimilarityIndexMeasure +from torchmetrics.regression import MeanAbsoluteError + +from stain_normalization.modeling import L1SSIMLoss, UNet +from stain_normalization.type_aliases import Batch, Outputs, PredictBatch + + +class StainNormalizationModel(LightningModule): + def __init__( + self, + lr: float = 1e-4, + lambda_dssim: float = 0.6, + lambda_l1: float = 0.2, + lambda_lum: float = 0.2, + lambda_gdl: float = 0.1, + ) -> None: + super().__init__() + self.lr = lr + self.unet = UNet(in_channels=3, out_channels=3) + self.criterion = L1SSIMLoss( + lambda_dssim=lambda_dssim, + lambda_l1=lambda_l1, + lambda_lum=lambda_lum, + lambda_gdl=lambda_gdl, + ) + + self.val_metrics = MetricCollection( + {"ssim": StructuralSimilarityIndexMeasure(), "l1": MeanAbsoluteError()} + ) + self.test_metrics = self.val_metrics.clone(prefix="test/") + self.val_metrics.prefix = "validation/" + + def forward(self, x: Tensor) -> Outputs: + return self.unet(x) + + def training_step(self, batch: Batch) -> Tensor: + inputs, targets = batch + outputs = self(inputs) + + loss = self.criterion(outputs, targets) + self.log("train/loss", loss, on_step=True, prog_bar=True) + + return loss + + def validation_step(self, batch: Batch) -> None: + inputs, targets = batch + outputs = self(inputs) + + loss = self.criterion(outputs, targets) + self.log("validation/loss", loss, on_step=False, on_epoch=True, logger=True) + self.val_metrics.update(outputs, targets) + self.log_dict( + self.val_metrics, + batch_size=len(inputs), + on_epoch=True, + ) + + def test_step(self, batch: PredictBatch) -> Outputs: + inputs, data = batch + outputs = self(inputs) + targets = stack([item["original_image_tensor"] for item in data]).to( + outputs.device + ) + self.test_metrics.update(outputs, targets) + self.log_dict( + self.test_metrics, + batch_size=len(inputs), + on_epoch=True, + ) + return outputs + + def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs: + inputs = batch[0] + return self(inputs) + + def configure_optimizers(self) -> Optimizer: + return Adam(self.parameters(), lr=self.lr)