Skip to content

feat: add model and loss function#8

Merged
vejtek merged 9 commits intomainfrom
feature/ml-model
Mar 16, 2026
Merged

feat: add model and loss function#8
vejtek merged 9 commits intomainfrom
feature/ml-model

Conversation

@LAdam-ix
Copy link
Collaborator

@LAdam-ix LAdam-ix commented Mar 14, 2026

Using implementation of UNet

Total loss function:

L1 loss (0.2), SSIM loss (0.6), gradient loss (0.1), brightness loss (0.2)

Brightness was added after the initial training because outputs were slightly darker, so I fine-tuned with it instead of retraining from scratch. After the code review I will try training it again with brightness included from the start. It might also be worth trying to drop the gradient loss. Now with brightness and changed weights it probably won't have that much impact on the total loss.

Training setup mostly follows the ML template defaults. As the dataset is quite big (1,5 mil tiles) i adjusted the epoch size to be smaller.

Summary by CodeRabbit

  • New Features

    • End-to-end stain-normalization training workflow: UNet model, composite SSIM/L1 loss, Lightning training module, CLI entry point, and MLflow logging.
    • Configurable augmentation pipelines for train and test that probabilistically apply color/brightness/contrast modifications.
    • Trainer defaults including checkpointing and early stopping.
  • Refactor

    • Simplified dataset iteration and consistent metadata access across train/val/test/predict.
  • Chores

    • Static type (mypy) configuration added.

@LAdam-ix LAdam-ix requested a review from matejpekar March 14, 2026 11:45
@LAdam-ix LAdam-ix self-assigned this Mar 14, 2026
@LAdam-ix LAdam-ix requested review from a team and ejdam87 March 14, 2026 11:45
@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the stain normalization project by integrating a U-Net based model and a sophisticated custom loss function. The new loss function combines L1, SSIM, gradient, and brightness components to improve image quality and address specific issues like output darkness. The changes also include a robust configuration system for training and logging, streamlining the development and experimentation process for machine learning models.

Highlights

  • New Model Architecture: Introduced a U-Net model for stain normalization, adapted from an existing PyTorch implementation.
  • Custom Loss Function: Implemented a combined loss function, L1SSIMLoss, which integrates L1 loss, SSIM loss, gradient loss, and brightness loss with configurable weights (0.2, 0.6, 0.1, 0.2 respectively).
  • Training and Configuration Setup: Established a comprehensive training pipeline using PyTorch Lightning and Hydra, including default configurations for data modification, model parameters, trainer settings, and MLFlow logging.
  • Brightness Loss Integration: Added brightness loss to address issues with darker outputs during initial training, allowing for fine-tuning without retraining from scratch.
Changelog
  • configs/data/modify/test.yaml
    • Added configuration for data modification transforms for testing, including HEDFactor, ExposureAdjustment, HVSModification, and CombinedModifications.
  • configs/data/modify/train.yaml
    • Added configuration for data modification transforms for training, mirroring the test configurations.
  • configs/default.yaml
    • Added a new default configuration file, defining global settings for Hydra, MLFlow logger, dataset paths, random seed, training mode, checkpointing, callbacks (ModelCheckpoint, EarlyStopping), model hyperparameters (learning rate, loss weights), trainer parameters (epochs, batch limits), and metadata.
  • configs/hydra/default.yaml
    • Added a default Hydra configuration, disabling logging and setting the output subdirectory to null.
  • configs/logger/mlflow.yaml
    • Added MLFlow logger configuration, specifying experiment name, run name, user, and description tags.
  • pyproject.toml
    • Updated the mypy configuration to ignore missing imports.
  • stain_normalization/main.py
    • Added the main entry point for the application, setting up Hydra for configuration management, PyTorch for tensor operations, and PyTorch Lightning for training. It includes a random seed resolver and instantiates the DataModule, StainNormalizationModel, and Trainer.
  • stain_normalization/modeling/init.py
    • Added an __init__.py file to expose L1SSIMLoss and UNet from the modeling subpackage.
  • stain_normalization/modeling/l1ssim_loss.py
    • Added the L1SSIMLoss class, which combines L1, SSIM, gradient, and brightness losses with configurable weights. It includes helper functions for Gaussian window generation, brightness loss calculation, and gradient loss calculation.
  • stain_normalization/modeling/unet.py
    • Added the UNet model implementation, including DoubleConv, Down, Up, and OutConv modules, adapted from a known PyTorch U-Net repository.
  • stain_normalization/stain_normalization_model.py
    • Added the StainNormalizationModel class, a PyTorch Lightning module that integrates the UNet model and the L1SSIMLoss criterion. It defines the forward pass, training, validation, test, and predict steps, along with metric tracking (SSIM, L1 MAE) and optimizer configuration.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link

coderabbitai bot commented Mar 14, 2026

📝 Walkthrough

Walkthrough

Adds Hydra configs and MLflow logger, two albumentations augmentation YAMLs, mypy settings, a CLI entrypoint, UNet and composite L1+SSIM loss implementations, a LightningModule for training/validation/test, and small dataset iteration and metadata-access adjustments.

Changes

Cohort / File(s) Summary
Configuration & Augmentation
configs/data/modify/test.yaml, configs/data/modify/train.yaml, configs/default.yaml, configs/hydra/default.yaml, configs/logger/mlflow.yaml
New Hydra defaults and logging config; two albumentations OneOf augmentation configs referencing HEDFactor, ExposureAdjustment, HSVModification, CombinedModifications.
Build Configuration
pyproject.toml
Added [tool.mypy] block with ignore_missing_imports = true.
CLI / Orchestration
stain_normalization/__main__.py
New entrypoint: registers OmegaConf resolver, seeds RNG, instantiates DataModule, StainNormalizationModel, Trainer, and dispatches train/validate/test/predict flows.
Modeling — Architecture
stain_normalization/modeling/unet.py
Adds DoubleConv, Down, Up, OutConv and UNet implementation.
Modeling — Loss
stain_normalization/modeling/l1ssim_loss.py
Adds L1SSIMLoss combining L1, SSIM (precomputed Gaussian window), gradient-difference loss, and luminance loss plus helper functions.
Training Model
stain_normalization/stain_normalization_model.py
Adds StainNormalizationModel LightningModule wiring UNet, L1SSIMLoss, metric collections, train/val/test/predict steps and optimizer config.
Module Exports
stain_normalization/modeling/__init__.py
Exports L1SSIMLoss and UNet via __all__.
Datasets
stain_normalization/data/datasets/...
stain_normalization/data/datasets/train_dataset.py, stain_normalization/data/datasets/test_dataset.py, stain_normalization/data/datasets/predict_dataset.py
Switched DataFrame iteration to direct slide iteration and changed slide metadata access from attribute-style to dict-style keys.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Hydra as "Hydra Config"
    participant DataModule
    participant Trainer as "Lightning Trainer"
    participant Model as "StainNormalizationModel"
    participant Loss as "L1SSIMLoss"

    User->>Hydra: load configs (defaults, data, model, trainer)
    Hydra->>DataModule: instantiate data module
    Hydra->>Model: instantiate StainNormalizationModel
    Hydra->>Trainer: instantiate Trainer (with logger)

    loop per training step
        DataModule->>Trainer: provide batch
        Trainer->>Model: training_step(batch)
        Model->>Loss: compute (L1, SSIM, GDL, luminance)
        Loss->>Model: return scalar loss
        Model->>Trainer: return loss -> optimizer step
    end

    loop validation
        DataModule->>Trainer: val batch
        Trainer->>Model: validation_step(batch) & log metrics
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • feat: add data loading #4: Introduced the albumentations modification classes referenced by the new augmentation configs and dataset usage.

Suggested reviewers

  • vejtek
  • 172454

Poem

🐰 I hopped through configs, code, and light,
UNet stitched channels through day and night,
Losses counted edges, hue, and gleam,
Trainer danced with metrics and dream,
Now stains are tamed — a rabbit's delight. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: add model and loss function' accurately describes the main changes, which include adding a UNet-based StainNormalizationModel and a composite L1SSIMLoss.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feature/ml-model
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LAdam-ix LAdam-ix removed the request for review from ejdam87 March 14, 2026 11:48
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the core model (U-Net), loss function (a combination of L1, SSIM, gradient, and brightness losses), and the necessary configuration files for a new stain normalization project. The implementation is well-structured, leveraging PyTorch Lightning and Hydra for configuration management. I've identified a critical issue in the data augmentation configuration that would prevent the code from running, along with several suggestions to improve performance, code clarity, and adherence to best practices in PyTorch Lightning. One comment was modified to align with the repository's preference for hardcoded default configuration values over variable placeholders in YAML files. Overall, this is a solid foundation for the project.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (7)
stain_normalization/__main__.py (1)

32-32: Validate config.mode before dynamic dispatch.

A simple whitelist check gives a clearer error than a late AttributeError if mode is misconfigured.

♻️ Suggested refactor
-    getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint)
+    allowed_modes = {"fit", "validate", "test", "predict"}
+    if config.mode not in allowed_modes:
+        raise ValueError(f"Unsupported mode '{config.mode}'. Expected one of {sorted(allowed_modes)}.")
+    getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/__main__.py` at line 32, Validate config.mode against an
explicit whitelist before doing the dynamic dispatch to avoid a late
AttributeError; e.g., define allowed modes (e.g.,
'fit','validate','test','predict' or the actual supported methods on the Trainer
class), check that config.mode is in that list (or use hasattr(trainer,
config.mode) plus whitelist), and raise a clear ValueError if not valid, then
call getattr(trainer, config.mode)(...) only after the check; reference the
dynamic call site using getattr(trainer, config.mode) and the config.mode value
when composing the error message.
stain_normalization/modeling/l1ssim_loss.py (3)

126-127: Gradient direction comments are swapped.

The comments appear to be reversed. For a tensor with shape [B, C, H, W]:

  • x[:, :, :-1, :] - x[:, :, 1:, :] computes differences along dim 2 (height/rows), which is the vertical gradient
  • x[:, :, :, :-1] - x[:, :, :, 1:] computes differences along dim 3 (width/columns), which is the horizontal gradient
📝 Proposed fix for comments
     def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
-        dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])  # Horizontal gradient
-        dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])  # Vertical gradient
+        dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])  # Vertical gradient (along height)
+        dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])  # Horizontal gradient (along width)
         return dx, dy
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/modeling/l1ssim_loss.py` around lines 126 - 127, The
inline comments for dx and dy are reversed: in l1ssim_loss.py the expression dx
= torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) computes differences along dim=2
(height) so update its comment to "Vertical gradient", and the expression dy =
torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) computes differences along dim=3
(width) so update its comment to "Horizontal gradient"; locate and swap/adjust
the two comments next to the dx and dy assignments accordingly.

98-99: SSIM constants assume pixel values in [0, 1] range.

The constants c1 = 0.01**2 and c2 = 0.03**2 are derived from the standard SSIM formula where c1 = (k1*L)**2 with L being the dynamic range. These values assume L=1 (i.e., inputs normalized to [0, 1]). Consider adding a brief docstring or assertion to document this assumption.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/modeling/l1ssim_loss.py` around lines 98 - 99, The SSIM
constants c1 and c2 in l1ssim_loss.py (currently set as c1 = 0.01**2 and c2 =
0.03**2) assume input pixel values are in [0,1]; update the code by either (a)
adding a brief docstring to the SSIM-related function/class (e.g., the
SSIMLoss/compute_ssim routine) documenting that inputs must be normalized to
[0,1], or (b) adding a runtime assertion/check (e.g., assert inputs.min() >= 0
and inputs.max() <= 1) to enforce the assumption, or (c) make c1 and c2 depend
on a dynamic range parameter L (c1=(k1*L)**2, c2=(k2*L)**2) and expose L/k1/k2
as arguments to the SSIM function so the constants are correct for other ranges;
reference c1 and c2 in l1ssim_loss.py when making the change.

31-44: Register the SSIM window as a buffer to avoid device sync issues.

The window tensor is stored as a plain attribute and moved to the input's device on every forward pass. This approach has issues:

  1. The tensor won't be properly moved when calling .to(device) or .cuda() on the module
  2. The in-place reassignment (self.window = self.window.to(...)) mutates module state during forward, which can cause problems with torch.compile and distributed training
♻️ Proposed fix: register as buffer
         self._2d_window = (
             self._1d_window.mm(self._1d_window.t()).float().unsqueeze(0).unsqueeze(0)
         )
-        self.window = self._2d_window.expand(
+        window = self._2d_window.expand(
             self.channel, 1, self.window_size, self.window_size
         ).contiguous()
+        self.register_buffer("window", window)

     def forward(self, image: torch.Tensor, target_image: torch.Tensor) -> torch.Tensor:
-        if self.window.device != image.device:
-            self.window = self.window.to(image.device)
         # L1 color loss
         l1_loss = F.l1_loss(image, target_image, reduction="mean")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/modeling/l1ssim_loss.py` around lines 31 - 44, The
precomputed SSIM window should be registered as a module buffer instead of a
plain attribute and must not be reassigned during forward; after computing
_2d_window and expanding it, call self.register_buffer('window', window_tensor)
(use the same name self.window) in the __init__ where window is created, then
remove the in-place reassignment self.window = self.window.to(image.device) from
forward; if you still need a device-safe tensor in forward, read it into a local
variable (e.g., window = self.window if self.window.device == image.device else
self.window.to(image.device)) rather than mutating self.window.
stain_normalization/stain_normalization_model.py (2)

33-37: Metric prefix assignment order may cause issues.

The prefix is set on val_metrics after cloning to test_metrics. Since clone() copies the current state (including an empty prefix), test_metrics correctly gets "test/". However, this pattern is fragile - if someone reorders these lines, it would break.

Consider setting the prefix before cloning or using the prefix parameter in the constructor:

♻️ Clearer initialization
-        self.val_metrics = MetricCollection(
-            {"ssim": StructuralSimilarityIndexMeasure(), "l1": MeanAbsoluteError()}
+        self.val_metrics = MetricCollection(
+            {"ssim": StructuralSimilarityIndexMeasure(), "l1": MeanAbsoluteError()},
+            prefix="validation/",
         )
-        self.test_metrics = self.val_metrics.clone(prefix="test/")
-        self.val_metrics.prefix = "validation/"
+        self.test_metrics = self.val_metrics.clone(prefix="test/")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/stain_normalization_model.py` around lines 33 - 37, The
metric prefixing is fragile because val_metrics.prefix is set after cloning
test_metrics; update initialization so the prefix is assigned before cloning (or
pass prefix argument to MetricCollection) to ensure test_metrics receives the
intended base state—adjust the MetricCollection creation for val_metrics (the
MetricCollection instance referenced by val_metrics) and then call
val_metrics.clone(...) to create test_metrics (using clone or
clone(prefix="test/")) so the prefixes for val_metrics and test_metrics
(validation/ and test/) are set deterministically.

80-81: Consider adding a learning rate scheduler for training stability.

The optimizer setup is functional, but for a large dataset (~1.5 million tiles as mentioned in PR objectives), a learning rate scheduler could help with convergence. This is optional and can be added later based on training observations.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/stain_normalization_model.py` around lines 80 - 81, The
configure_optimizers method currently returns a raw Adam optimizer; add an LR
scheduler for training stability by updating configure_optimizers (in the
StainNormalizationModel class) to return both optimizer and scheduler (e.g., a
StepLR or ReduceLROnPlateau) in the PyTorch-Lightning-compatible dict format:
create the Adam optimizer with self.parameters() and self.lr, instantiate a
scheduler (torch.optim.lr_scheduler.StepLR or ReduceLROnPlateau) with sensible
defaults (step_size/gamma or patience/factor), and return {"optimizer":
optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor":
"<metric-if-needed>", "interval": "epoch", "frequency": 1}}; also import
torch.optim.lr_scheduler at the top.
stain_normalization/modeling/unet.py (1)

108-119: UNet output is unbounded.

The network output has no activation function, so values can exceed the [0, 1] range. While this is fine for training (L1 loss will penalize deviations), consider whether inference outputs should be clamped for downstream use (e.g., visualization, saving images). This could be handled at the inference layer if needed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/modeling/unet.py` around lines 108 - 119, The UNet
forward output (forward -> out_conv) is unbounded; add an explicit output
activation or clamping for inference: either apply a final sigmoid (or
torch.clamp(..., 0, 1)) to the tensor returned from out_conv inside forward, or
add a separate inference helper (e.g., UNet.infer or UNet.predict) that calls
out_conv then applies torch.sigmoid / torch.clamp(0,1) before returning for
visualization/saving; update tests/usage to call the new infer/predict when
bounded outputs are required.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@configs/data/modify/test.yaml`:
- Around line 13-15: The test config uses CombinedModifications with the wrong
parameter name; replace the intensity_range key with od_scale_range so that
stain_normalization.data.modification.CombinedModifications receives the
expected parameter (keep the same value [0.65, 1.35]), leaving brightness_range
unchanged; update any references to intensity_range in the test YAML so Hydra
can instantiate CombinedModifications correctly.

In `@configs/data/modify/train.yaml`:
- Around line 13-15: The YAML uses the wrong constructor argument for
CombinedModifications: replace the incorrect intensity_range key with
od_scale_range to match the CombinedModifications signature (keep
brightness_range as-is); update any occurrences of intensity_range that
configure the CombinedModifications instantiation so the runtime will
successfully instantiate the CombinedModifications class with od_scale_range.

In `@stain_normalization/__main__.py`:
- Around line 14-16: Wrap the resolver registration for "random_seed" with a
guard that checks OmegaConf.has_resolver("random_seed") before calling
OmegaConf.register_new_resolver; if the resolver does not exist, register it
with the same lambda randint(0, 2**31) and use_cache=True. This avoids
ValueError on repeated imports/reloads while preserving the existing resolver
behavior in __main__.py.

---

Nitpick comments:
In `@stain_normalization/__main__.py`:
- Line 32: Validate config.mode against an explicit whitelist before doing the
dynamic dispatch to avoid a late AttributeError; e.g., define allowed modes
(e.g., 'fit','validate','test','predict' or the actual supported methods on the
Trainer class), check that config.mode is in that list (or use hasattr(trainer,
config.mode) plus whitelist), and raise a clear ValueError if not valid, then
call getattr(trainer, config.mode)(...) only after the check; reference the
dynamic call site using getattr(trainer, config.mode) and the config.mode value
when composing the error message.

In `@stain_normalization/modeling/l1ssim_loss.py`:
- Around line 126-127: The inline comments for dx and dy are reversed: in
l1ssim_loss.py the expression dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])
computes differences along dim=2 (height) so update its comment to "Vertical
gradient", and the expression dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])
computes differences along dim=3 (width) so update its comment to "Horizontal
gradient"; locate and swap/adjust the two comments next to the dx and dy
assignments accordingly.
- Around line 98-99: The SSIM constants c1 and c2 in l1ssim_loss.py (currently
set as c1 = 0.01**2 and c2 = 0.03**2) assume input pixel values are in [0,1];
update the code by either (a) adding a brief docstring to the SSIM-related
function/class (e.g., the SSIMLoss/compute_ssim routine) documenting that inputs
must be normalized to [0,1], or (b) adding a runtime assertion/check (e.g.,
assert inputs.min() >= 0 and inputs.max() <= 1) to enforce the assumption, or
(c) make c1 and c2 depend on a dynamic range parameter L (c1=(k1*L)**2,
c2=(k2*L)**2) and expose L/k1/k2 as arguments to the SSIM function so the
constants are correct for other ranges; reference c1 and c2 in l1ssim_loss.py
when making the change.
- Around line 31-44: The precomputed SSIM window should be registered as a
module buffer instead of a plain attribute and must not be reassigned during
forward; after computing _2d_window and expanding it, call
self.register_buffer('window', window_tensor) (use the same name self.window) in
the __init__ where window is created, then remove the in-place reassignment
self.window = self.window.to(image.device) from forward; if you still need a
device-safe tensor in forward, read it into a local variable (e.g., window =
self.window if self.window.device == image.device else
self.window.to(image.device)) rather than mutating self.window.

In `@stain_normalization/modeling/unet.py`:
- Around line 108-119: The UNet forward output (forward -> out_conv) is
unbounded; add an explicit output activation or clamping for inference: either
apply a final sigmoid (or torch.clamp(..., 0, 1)) to the tensor returned from
out_conv inside forward, or add a separate inference helper (e.g., UNet.infer or
UNet.predict) that calls out_conv then applies torch.sigmoid / torch.clamp(0,1)
before returning for visualization/saving; update tests/usage to call the new
infer/predict when bounded outputs are required.

In `@stain_normalization/stain_normalization_model.py`:
- Around line 33-37: The metric prefixing is fragile because val_metrics.prefix
is set after cloning test_metrics; update initialization so the prefix is
assigned before cloning (or pass prefix argument to MetricCollection) to ensure
test_metrics receives the intended base state—adjust the MetricCollection
creation for val_metrics (the MetricCollection instance referenced by
val_metrics) and then call val_metrics.clone(...) to create test_metrics (using
clone or clone(prefix="test/")) so the prefixes for val_metrics and test_metrics
(validation/ and test/) are set deterministically.
- Around line 80-81: The configure_optimizers method currently returns a raw
Adam optimizer; add an LR scheduler for training stability by updating
configure_optimizers (in the StainNormalizationModel class) to return both
optimizer and scheduler (e.g., a StepLR or ReduceLROnPlateau) in the
PyTorch-Lightning-compatible dict format: create the Adam optimizer with
self.parameters() and self.lr, instantiate a scheduler
(torch.optim.lr_scheduler.StepLR or ReduceLROnPlateau) with sensible defaults
(step_size/gamma or patience/factor), and return {"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "<metric-if-needed>",
"interval": "epoch", "frequency": 1}}; also import torch.optim.lr_scheduler at
the top.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 717354b7-ea6d-49a4-ae6a-2aa066e18a9c

📥 Commits

Reviewing files that changed from the base of the PR and between babaf36 and bcbd2cb.

📒 Files selected for processing (12)
  • configs/data/modify/test.yaml
  • configs/data/modify/train.yaml
  • configs/default.yaml
  • configs/hydra/default.yaml
  • configs/logger/mlflow.yaml
  • pyproject.toml
  • stain_normalization/__init__.py
  • stain_normalization/__main__.py
  • stain_normalization/modeling/__init__.py
  • stain_normalization/modeling/l1ssim_loss.py
  • stain_normalization/modeling/unet.py
  • stain_normalization/stain_normalization_model.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (1)
stain_normalization/modeling/l1ssim_loss.py (1)

122-126: Gradient direction comments are swapped.

In NCHW format, dim=2 is height (rows) and dim=3 is width (columns). The current computation of dx along dim=2 is a vertical gradient (row differences), and dy along dim=3 is a horizontal gradient (column differences). The comments have these labels reversed.

📝 Suggested comment fix
     def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
-        dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])  # Horizontal gradient
-        dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])  # Vertical gradient
+        dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])  # Vertical gradient (along height)
+        dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])  # Horizontal gradient (along width)
         return dx, dy
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/modeling/l1ssim_loss.py` around lines 122 - 126, The
inline comments in gradient_loss's inner function gradient are incorrect: in
NCHW tensors the index dim=2 is height (rows) so the computation dx =
torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) is the vertical gradient (row
differences) and dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) is the
horizontal gradient (column differences); update the comments in the gradient
function (referencing function gradient inside gradient_loss) to label dx as
"Vertical gradient" and dy as "Horizontal gradient" accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@configs/data/modify/test.yaml`:
- Around line 1-2: The test pipeline currently uses a nondeterministic
augmentation (_target_: albumentations.OneOf with p: 1.0) which will modify
every test sample; update the test.yaml referenced by
configs/data/datasets/stain_normalization/test.yaml (the /data/modify@modify:
test inclusion) to remove the albumentations.OneOf entry or set its p to 0.0, or
alternatively move this augmentation into a separate robustness evaluation
config so that the standard test pipeline is deterministic and reproducible.

In `@stain_normalization/data/datasets/test_dataset.py`:
- Around line 57-60: The line setting slide_path in test_dataset.py has one
extra space causing a Ruff indentation error; align its indentation with the
following argument lines by reducing the leading spaces so the slide_path=...
argument uses the same indentation level as level=..., tile_extent_x=..., and
tile_extent_y=... (adjust the line containing slide_path and
slide_metadata["path"] to match the other arguments).

In `@stain_normalization/data/datasets/train_dataset.py`:
- Line 43: Fix the inconsistent indentation on the line containing "for slide in
self.slides": change its indentation from 13 spaces to 12 so it matches the
surrounding block and passes Ruff formatting. Locate the occurrence of "for
slide in self.slides" within the method (e.g., the loop or list comprehension in
the train dataset class) and adjust the leading whitespace to align with the
enclosing block indentation level.

In `@stain_normalization/modeling/l1ssim_loss.py`:
- Around line 35-40: The buffer created via self.register_buffer returns type
Tensor | Module which mypy widens, so add an explicit attribute type to narrow
it to torch.Tensor (e.g., declare self.window: Tensor before or annotate the
attribute when assigning) so that the _ssim function receives a Tensor; do the
same for the other register_buffer call on line referenced (i.e., annotate that
buffer attribute to torch.Tensor as well). Ensure you reference the attribute
names used with register_buffer (self.window and the other buffer) and keep the
register_buffer call intact while adding the explicit type annotation.

In `@stain_normalization/stain_normalization_model.py`:
- Around line 79-80: The type error comes from accessing self.hparams.lr
directly; change configure_optimizers in stain_normalization_model.py to read
the learning rate via dictionary-style access (e.g. self.hparams["lr"]) or add
an explicit typed attribute for hparams after calling save_hyperparameters() so
mypy knows lr exists; update the return to use Adam(self.parameters(),
lr=self.hparams["lr"]) (or annotate hparams with a dataclass/TypedDict and use
that typed attribute) in the configure_optimizers method.
- Around line 63-73: In test_step, the stacked targets are left on CPU causing a
device mismatch with model outputs; after creating targets in test_step
(stacking item["original_image_tensor"] from data), move them to the
model/output device before calling self.test_metrics.update by calling
.to(outputs.device) on the stacked tensor so targets and outputs are on the same
device.

---

Nitpick comments:
In `@stain_normalization/modeling/l1ssim_loss.py`:
- Around line 122-126: The inline comments in gradient_loss's inner function
gradient are incorrect: in NCHW tensors the index dim=2 is height (rows) so the
computation dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) is the vertical
gradient (row differences) and dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])
is the horizontal gradient (column differences); update the comments in the
gradient function (referencing function gradient inside gradient_loss) to label
dx as "Vertical gradient" and dy as "Horizontal gradient" accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 94f474fc-c451-4efa-9ab8-370ef55f8c07

📥 Commits

Reviewing files that changed from the base of the PR and between bcbd2cb and 7698647.

📒 Files selected for processing (7)
  • configs/data/modify/test.yaml
  • configs/data/modify/train.yaml
  • stain_normalization/data/datasets/predict_dataset.py
  • stain_normalization/data/datasets/test_dataset.py
  • stain_normalization/data/datasets/train_dataset.py
  • stain_normalization/modeling/l1ssim_loss.py
  • stain_normalization/stain_normalization_model.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • configs/data/modify/train.yaml

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
stain_normalization/modeling/l1ssim_loss.py (1)

124-127: Swapped gradient direction comments.

The comments are reversed: dx computes differences along dim=2 (height/rows), which is the vertical gradient, while dy computes differences along dim=3 (width/columns), which is the horizontal gradient. This doesn't affect correctness since both are summed, but could confuse future readers.

📝 Proposed fix
     def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
-        dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])  # Horizontal gradient
-        dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])  # Vertical gradient
+        dx = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])  # Vertical gradient (along height)
+        dy = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])  # Horizontal gradient (along width)
         return dx, dy
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/modeling/l1ssim_loss.py` around lines 124 - 127, The
inline comments in the gradient function are reversed: update the comment for dx
in def gradient(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: to
indicate it computes the vertical gradient (differences along dim=2 / height)
and update the comment for dy to indicate it computes the horizontal gradient
(differences along dim=3 / width); keep the computations (dx, dy) unchanged and
only correct the descriptive comments next to the variables dx and dy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@stain_normalization/stain_normalization_model.py`:
- Around line 75-80: The ruff formatter failed due to a long line in the
test_step method; run ruff format (or manually wrap the long expression in
test_step) to comply with line-length rules and fix CI: locate the test_step
method in stain_normalization_model.py (function name test_step) and break the
offending long line into multiple concatenated parts or use implicit line
continuation with parentheses, then re-run ruff format to apply any other style
fixes before committing.

---

Nitpick comments:
In `@stain_normalization/modeling/l1ssim_loss.py`:
- Around line 124-127: The inline comments in the gradient function are
reversed: update the comment for dx in def gradient(x: torch.Tensor) ->
tuple[torch.Tensor, torch.Tensor]: to indicate it computes the vertical gradient
(differences along dim=2 / height) and update the comment for dy to indicate it
computes the horizontal gradient (differences along dim=3 / width); keep the
computations (dx, dy) unchanged and only correct the descriptive comments next
to the variables dx and dy.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6e77d637-a88e-497c-9010-461f2c4a1471

📥 Commits

Reviewing files that changed from the base of the PR and between 7698647 and 05da4bd.

📒 Files selected for processing (4)
  • stain_normalization/data/datasets/test_dataset.py
  • stain_normalization/data/datasets/train_dataset.py
  • stain_normalization/modeling/l1ssim_loss.py
  • stain_normalization/stain_normalization_model.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • stain_normalization/data/datasets/train_dataset.py

@matejpekar matejpekar requested review from 172454 and vejtek March 15, 2026 16:23
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
stain_normalization/stain_normalization_model.py (1)

32-36: Consider clamping outputs for SSIM metric compatibility.

StructuralSimilarityIndexMeasure typically expects inputs normalized to a consistent range (commonly [0, 1]). Per the context, UNet's final layer (OutConv) applies only a 1x1 convolution without clamping or activation, so outputs may exceed this range during early training or with outlier inputs.

If metrics behave unexpectedly, consider clamping outputs before metric updates: outputs.clamp(0, 1).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@stain_normalization/stain_normalization_model.py` around lines 32 - 36, The
SSIM metric (StructuralSimilarityIndexMeasure) expects inputs in a bounded
range, but UNet's final layer (OutConv) produces raw unconstrained outputs;
clamp model outputs to [0,1] before feeding metrics to avoid spurious
values—e.g., in the validation/test step(s) where self.val_metrics or
self.test_metrics are updated, replace uses of the raw network output with a
clamped tensor (outputs.clamp(0, 1)) so that both self.val_metrics (and its
clone test/ prefixed metrics) receive values in the expected range.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@stain_normalization/stain_normalization_model.py`:
- Around line 32-36: The SSIM metric (StructuralSimilarityIndexMeasure) expects
inputs in a bounded range, but UNet's final layer (OutConv) produces raw
unconstrained outputs; clamp model outputs to [0,1] before feeding metrics to
avoid spurious values—e.g., in the validation/test step(s) where
self.val_metrics or self.test_metrics are updated, replace uses of the raw
network output with a clamped tensor (outputs.clamp(0, 1)) so that both
self.val_metrics (and its clone test/ prefixed metrics) receive values in the
expected range.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ce0747a8-6b3f-4ef4-b22a-4c4d21c2ac62

📥 Commits

Reviewing files that changed from the base of the PR and between 05da4bd and e3fc89e.

📒 Files selected for processing (3)
  • configs/default.yaml
  • stain_normalization/modeling/l1ssim_loss.py
  • stain_normalization/stain_normalization_model.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • stain_normalization/modeling/l1ssim_loss.py

@LAdam-ix LAdam-ix requested a review from matejpekar March 15, 2026 22:59
@vejtek vejtek merged commit d410604 into main Mar 16, 2026
3 checks passed
@vejtek vejtek deleted the feature/ml-model branch March 16, 2026 16:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants