Skip to content

feat: add model and loss function#5

Closed
LAdam-ix wants to merge 1 commit intomainfrom
feature/ml-model
Closed

feat: add model and loss function#5
LAdam-ix wants to merge 1 commit intomainfrom
feature/ml-model

Conversation

@LAdam-ix
Copy link
Collaborator

@LAdam-ix LAdam-ix commented Mar 11, 2026

Using implementation of UNet

Total loss function:

  • L1 loss (0.2), SSIM loss (0.6), gradient loss (0.1), brightness loss (0.2)

Brightness was added after the initial training because outputs were slightly darker, so I fine-tuned with it instead of retraining from scratch. Probably after the code review I will try training it again with brightness included from the start. It might also be worth trying to drop the gradient loss to speed it up a little bit. Now with brightness it probably won't have that much impact on the total loss.

Training setup mostly follows the ML template defaults. I added more frequent validation because of long training runs.

P.S. The mypy errors are mostly from the project being split into multiple PRs.

@LAdam-ix LAdam-ix requested a review from matejpekar March 11, 2026 06:52
@LAdam-ix LAdam-ix self-assigned this Mar 11, 2026
@LAdam-ix LAdam-ix requested review from a team and JakubPekar March 11, 2026 06:52
@coderabbitai
Copy link

coderabbitai bot commented Mar 11, 2026

Warning

Rate limit exceeded

@LAdam-ix has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 18 minutes and 15 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 854a55c7-ff17-4669-a9d0-a8de3178582d

📥 Commits

Reviewing files that changed from the base of the PR and between 0ce77d2 and d694135.

⛔ Files ignored due to path filters (2)
  • pdm.lock is excluded by !**/*.lock
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (19)
  • .gitignore
  • configs/data/datasets/stain_normalization/predict.yaml
  • configs/data/datasets/stain_normalization/test.yaml
  • configs/data/datasets/stain_normalization/train.yaml
  • configs/data/datasets/stain_normalization/val.yaml
  • configs/data/modify/test.yaml
  • configs/data/modify/train.yaml
  • configs/data/normalize/default.yaml
  • configs/default.yaml
  • configs/hydra/default.yaml
  • configs/logger/mlflow.yaml
  • pyproject.toml
  • stain_normalization/__init__.py
  • stain_normalization/__main__.py
  • stain_normalization/modeling/__init__.py
  • stain_normalization/modeling/l1ssim_loss.py
  • stain_normalization/modeling/unet.py
  • stain_normalization/stain_normalization_model.py
  • stain_normalization/type_aliases.py
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feature/ml-model

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes the foundational machine learning components for a stain normalization project. It integrates a U-Net architecture and a sophisticated multi-component loss function to enable effective image processing. The changes also set up a comprehensive configuration system using Hydra, streamlining data handling, model training, and experiment logging, which is crucial for reproducibility and scalability in deep learning workflows.

Highlights

  • Model Architecture: Introduced a U-Net model for stain normalization, adapted from a known PyTorch implementation, designed for image-to-image translation tasks.
  • Comprehensive Loss Function: Implemented a custom L1SSIMLoss combining L1, SSIM, gradient, and brightness losses with configurable weights, allowing for fine-grained control over model optimization.
  • Project Configuration and Structure: Established a robust project structure using Hydra for configuration management, defining datasets, data modifications, normalization parameters, and MLflow logging.
  • Training and Evaluation Setup: Integrated the model and loss function into a PyTorch Lightning module, including training, validation, testing, and prediction steps with relevant metrics (SSIM, L1 MAE).
Changelog
  • .gitignore
    • Added a comprehensive .gitignore file to exclude common Python-related temporary and environment files, as well as project-specific outputs like MLflow runs and PyCharm settings.
  • configs/data/datasets/stain_normalization/predict.yaml
    • Added a new configuration file for the prediction dataset, specifying its target class and MLflow artifact URI.
  • configs/data/datasets/stain_normalization/test.yaml
    • Added a new configuration file for the test dataset, including data modification and normalization defaults, and an MLflow artifact URI.
  • configs/data/datasets/stain_normalization/train.yaml
    • Added a new configuration file for the training dataset, including data modification and normalization defaults, and an MLflow artifact URI.
  • configs/data/datasets/stain_normalization/val.yaml
    • Added a new configuration file for the validation dataset, including data modification and normalization defaults, and an MLflow artifact URI.
  • configs/data/modify/test.yaml
    • Added a new configuration file defining a set of image modification transforms (HEDFactor, ExposureAdjustment, HVSModification, CombinedModifications) for testing purposes.
  • configs/data/modify/train.yaml
    • Added a new configuration file defining a set of image modification transforms (HEDFactor, ExposureAdjustment, HVSModification, CombinedModifications) for training purposes.
  • configs/data/normalize/default.yaml
    • Added a new configuration file specifying default normalization parameters (mean, std, max_pixel_value) for image preprocessing.
  • configs/default.yaml
    • Added the main default configuration file, orchestrating various components like logger, datasets, callbacks, trainer settings, and metadata.
  • configs/hydra/default.yaml
    • Added a new configuration file for Hydra, disabling logging and setting the run directory.
  • configs/logger/mlflow.yaml
    • Added a new configuration file for MLflow logging, defining experiment name, run name, and tags based on metadata.
  • pyproject.toml
    • Added project metadata, specified Python version, listed core and development dependencies, and defined PDM scripts for various tasks like training, linting, and formatting.
  • stain_normalization/main.py
    • Added the main entry point for the application, integrating Hydra for configuration loading, PyTorch Lightning for training, and MLflow for logging.
  • stain_normalization/modeling/init.py
    • Added an init.py file to expose the L1SSIMLoss and UNet classes from the modeling module.
  • stain_normalization/modeling/l1ssim_loss.py
    • Added a custom L1SSIMLoss class that combines L1, SSIM, gradient, and brightness losses, along with helper functions for Gaussian window and brightness/gradient calculations.
  • stain_normalization/modeling/unet.py
    • Added the UNet model architecture, including DoubleConv, Down, Up, and OutConv modules, adapted for image-to-image tasks.
  • stain_normalization/stain_normalization_model.py
    • Added the StainNormalizationModel class, a PyTorch Lightning module that encapsulates the UNet model, L1SSIMLoss criterion, and defines training, validation, test, and predict steps with associated metrics.
  • stain_normalization/type_aliases.py
    • Added a new file defining type aliases for better readability and type hinting consistency across the project, specifically for samples and batches.
Activity
  • The pull request was created by LAdam-ix.
  • The author implemented a new UNet model and a custom loss function.
  • Initial training was conducted, after which brightness loss was incorporated for fine-tuning due to darker outputs.
  • The author plans to retrain the model with brightness loss included from the start and may consider reducing the gradient loss component.
  • Validation frequency was increased in the training setup to better monitor long training runs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a U-Net model and a custom loss function for a stain normalization task, using PyTorch Lightning and Hydra. The overall structure is sound. My feedback focuses on improving configurability, reproducibility, and code clarity. The most significant suggestions involve making the model and its components configurable through Hydra to leverage its full potential, and improving the project setup for better collaboration. I have also included some suggestions for the loss function implementation.

name = "stain-normalization"
version = "0.1.0"
authors = [{name = "Adam Lopatka"}]
requires-python = "==3.12.5"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Pinning the Python version to an exact patch (==3.12.5) is overly restrictive and can cause setup issues for developers with a different patch version. It's better to specify a compatible range, for example requires-python = ">=3.12" or requires-python = "~=3.12".

Suggested change
requires-python = "==3.12.5"
requires-python = "~=3.12"

Comment on lines +13 to +67
class StainNormalizationModel(LightningModule):
def __init__(self) -> None:
super().__init__()
self.unet = UNet(in_channels=3, out_channels=3)
self.criterion = L1SSIMLoss()

self.val_metrics = MetricCollection(
{"ssim": StructuralSimilarityIndexMeasure(), "l1": MeanAbsoluteError()}
)
self.test_metrics = self.val_metrics.clone(prefix="test/")
self.val_metrics.prefix = "validation/"

def forward(self, x: Tensor) -> Outputs:
return self.unet(x)

def training_step(self, batch: Batch) -> Tensor:
inputs, targets = batch
outputs = self(inputs)

loss = self.criterion(outputs, targets)
self.log("train/loss", loss, on_step=True, prog_bar=True)

return loss

def validation_step(self, batch: Batch) -> None:
inputs, targets = batch
outputs = self(inputs)

loss = self.criterion(outputs, targets)
self.log("validation/loss", loss, on_step=False, on_epoch=True, logger=True)
self.val_metrics.update(outputs, targets)
self.log_dict(
self.val_metrics,
batch_size=len(inputs),
on_epoch=True,
)

def test_step(self, batch: PredictBatch) -> Outputs:
inputs, data = batch
outputs = self(inputs)
targets = stack([item["original_image_tensor"] for item in data])
self.test_metrics.update(outputs, targets)
self.log_dict(
self.test_metrics,
batch_size=len(inputs),
on_epoch=True,
)
return outputs

def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs:
inputs = batch[0]
return self(inputs)

def configure_optimizers(self) -> Optimizer:
return Adam(self.parameters(), lr=1e-4)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The model, loss function, and optimizer are hardcoded within the StainNormalizationModel. This reduces flexibility and goes against the principles of a Hydra-based configuration. It would be better to instantiate these components via Hydra configs, allowing you to easily swap them or tune their hyperparameters without changing the code.

For example, you could change the __init__ to accept configurations and instantiate them:

import hydra
from omegaconf import DictConfig

class StainNormalizationModel(LightningModule):
    def __init__(
        self,
        net: torch.nn.Module,
        criterion: torch.nn.Module,
        optimizer_config: DictConfig
    ) -> None:
        super().__init__()
        self.net = net
        self.criterion = criterion
        self.optimizer_config = optimizer_config
        # ...
    
    def configure_optimizers(self) -> Optimizer:
        return hydra.utils.instantiate(self.optimizer_config, params=self.parameters())

And in your main config, you would define net, criterion, and optimizer_config as instantiable objects. This would make your StainNormalizationModel a more generic training harness.


seed: ${random_seed:}
mode: fit
checkpoint: mlflow-artifacts:/79/978f5d5e54844be3b42509cce76793d7/artifacts/checkpoints/epoch=7-step=152375/checkpoint.ckpt

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding a specific checkpoint URI makes this configuration difficult to reuse. For resuming or fine-tuning, it's better to pass this path as a command-line argument, e.g., python train.py checkpoint=path/to/ckpt. For a fresh run, this should probably be null.

num_workers: 10

metadata:
user: xlopatka

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding the username is not ideal for collaboration. It's better to fetch this from an environment variable or let the user configure it. For example, you could use an OmegaConf resolver: user: ${env:USER,default_user}.

Comment on lines +15 to +20
"rationai-mlkit @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/mlkit.git",
"rationai-masks @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/masks.git",
"rationai-tiling @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/tiling.git",
"scikit-image>=0.25.2",
"openslide-bin>=4.0.0.6",
"rationai-staining @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/staining.git",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Dependencies from Git should be pinned to a specific commit hash or tag to ensure reproducible builds. Without pinning, pdm install might pull a newer, potentially breaking version of the library. For example: rationai-mlkit @ git+https://...git@<commit_hash_or_tag>.

Comment on lines +18 to +29
def __init__(
self,
lambda_dssim: float = 0.6,
lambda_l1: float = 0.2,
lambda_lum: float = 0.2,
lambda_gdl: float = 0.1,
):
super().__init__()
self.lambda_dssim = lambda_dssim
self.lambda_l1 = lambda_l1
self.lambda_lum = lambda_lum
self.lambda_gdl = lambda_gdl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using lambda as part of a variable name (e.g., lambda_dssim) is confusing because lambda is a Python keyword. It's better to use a more descriptive name like dssim_weight to improve readability. This change should be propagated throughout the class.

Suggested change
def __init__(
self,
lambda_dssim: float = 0.6,
lambda_l1: float = 0.2,
lambda_lum: float = 0.2,
lambda_gdl: float = 0.1,
):
super().__init__()
self.lambda_dssim = lambda_dssim
self.lambda_l1 = lambda_l1
self.lambda_lum = lambda_lum
self.lambda_gdl = lambda_gdl
def __init__(
self,
dssim_weight: float = 0.6,
l1_weight: float = 0.2,
lum_weight: float = 0.2,
gdl_weight: float = 0.1,
):
super().__init__()
self.dssim_weight = dssim_weight
self.l1_weight = l1_weight
self.lum_weight = lum_weight
self.gdl_weight = gdl_weight



def gaussian(window_size: int, sigma: float) -> torch.Tensor:
gauss = torch.Tensor(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch.Tensor is an alias for torch.FloatTensor and its behavior can be surprising. It's recommended to use torch.tensor which infers the dtype or allows explicit setting, providing more predictable behavior. For floating point tensors, you can use torch.tensor(..., dtype=torch.float32).

Suggested change
gauss = torch.Tensor(
gauss = torch.tensor(

) -> torch.Tensor:
device = pred.device
if he_weights is None:
he_weights = [0.33, 0.33, 0.33]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The weights [0.33, 0.33, 0.33] for calculating brightness are unusual, and their sum is 0.99, not 1.0. If equal weighting is desired, it's better to use [1/3, 1/3, 1/3]. For a more perceptually accurate brightness calculation, consider using standard luminance weights like [0.299, 0.587, 0.114].

Comment on lines +73 to +74
diffy = x2.size()[2] - x1.size()[2]
diffx = x2.size()[3] - x1.size()[3]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While .size() works, the more modern and conventional way to get tensor dimensions in PyTorch is by using the .shape attribute. It's generally preferred for consistency and readability.

Suggested change
diffy = x2.size()[2] - x1.size()[2]
diffx = x2.size()[3] - x1.size()[3]
diffy = x2.shape[2] - x1.shape[2]
diffx = x2.shape[3] - x1.shape[3]

@LAdam-ix LAdam-ix closed this Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant