Conversation
|
Warning Rate limit exceeded
⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: ⛔ Files ignored due to path filters (2)
📒 Files selected for processing (19)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes the foundational machine learning components for a stain normalization project. It integrates a U-Net architecture and a sophisticated multi-component loss function to enable effective image processing. The changes also set up a comprehensive configuration system using Hydra, streamlining data handling, model training, and experiment logging, which is crucial for reproducibility and scalability in deep learning workflows. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a U-Net model and a custom loss function for a stain normalization task, using PyTorch Lightning and Hydra. The overall structure is sound. My feedback focuses on improving configurability, reproducibility, and code clarity. The most significant suggestions involve making the model and its components configurable through Hydra to leverage its full potential, and improving the project setup for better collaboration. I have also included some suggestions for the loss function implementation.
| name = "stain-normalization" | ||
| version = "0.1.0" | ||
| authors = [{name = "Adam Lopatka"}] | ||
| requires-python = "==3.12.5" |
There was a problem hiding this comment.
Pinning the Python version to an exact patch (==3.12.5) is overly restrictive and can cause setup issues for developers with a different patch version. It's better to specify a compatible range, for example requires-python = ">=3.12" or requires-python = "~=3.12".
| requires-python = "==3.12.5" | |
| requires-python = "~=3.12" |
| class StainNormalizationModel(LightningModule): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.unet = UNet(in_channels=3, out_channels=3) | ||
| self.criterion = L1SSIMLoss() | ||
|
|
||
| self.val_metrics = MetricCollection( | ||
| {"ssim": StructuralSimilarityIndexMeasure(), "l1": MeanAbsoluteError()} | ||
| ) | ||
| self.test_metrics = self.val_metrics.clone(prefix="test/") | ||
| self.val_metrics.prefix = "validation/" | ||
|
|
||
| def forward(self, x: Tensor) -> Outputs: | ||
| return self.unet(x) | ||
|
|
||
| def training_step(self, batch: Batch) -> Tensor: | ||
| inputs, targets = batch | ||
| outputs = self(inputs) | ||
|
|
||
| loss = self.criterion(outputs, targets) | ||
| self.log("train/loss", loss, on_step=True, prog_bar=True) | ||
|
|
||
| return loss | ||
|
|
||
| def validation_step(self, batch: Batch) -> None: | ||
| inputs, targets = batch | ||
| outputs = self(inputs) | ||
|
|
||
| loss = self.criterion(outputs, targets) | ||
| self.log("validation/loss", loss, on_step=False, on_epoch=True, logger=True) | ||
| self.val_metrics.update(outputs, targets) | ||
| self.log_dict( | ||
| self.val_metrics, | ||
| batch_size=len(inputs), | ||
| on_epoch=True, | ||
| ) | ||
|
|
||
| def test_step(self, batch: PredictBatch) -> Outputs: | ||
| inputs, data = batch | ||
| outputs = self(inputs) | ||
| targets = stack([item["original_image_tensor"] for item in data]) | ||
| self.test_metrics.update(outputs, targets) | ||
| self.log_dict( | ||
| self.test_metrics, | ||
| batch_size=len(inputs), | ||
| on_epoch=True, | ||
| ) | ||
| return outputs | ||
|
|
||
| def predict_step(self, batch: PredictBatch, batch_idx: int) -> Outputs: | ||
| inputs = batch[0] | ||
| return self(inputs) | ||
|
|
||
| def configure_optimizers(self) -> Optimizer: | ||
| return Adam(self.parameters(), lr=1e-4) |
There was a problem hiding this comment.
The model, loss function, and optimizer are hardcoded within the StainNormalizationModel. This reduces flexibility and goes against the principles of a Hydra-based configuration. It would be better to instantiate these components via Hydra configs, allowing you to easily swap them or tune their hyperparameters without changing the code.
For example, you could change the __init__ to accept configurations and instantiate them:
import hydra
from omegaconf import DictConfig
class StainNormalizationModel(LightningModule):
def __init__(
self,
net: torch.nn.Module,
criterion: torch.nn.Module,
optimizer_config: DictConfig
) -> None:
super().__init__()
self.net = net
self.criterion = criterion
self.optimizer_config = optimizer_config
# ...
def configure_optimizers(self) -> Optimizer:
return hydra.utils.instantiate(self.optimizer_config, params=self.parameters())And in your main config, you would define net, criterion, and optimizer_config as instantiable objects. This would make your StainNormalizationModel a more generic training harness.
|
|
||
| seed: ${random_seed:} | ||
| mode: fit | ||
| checkpoint: mlflow-artifacts:/79/978f5d5e54844be3b42509cce76793d7/artifacts/checkpoints/epoch=7-step=152375/checkpoint.ckpt |
There was a problem hiding this comment.
| num_workers: 10 | ||
|
|
||
| metadata: | ||
| user: xlopatka |
| "rationai-mlkit @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/mlkit.git", | ||
| "rationai-masks @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/masks.git", | ||
| "rationai-tiling @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/tiling.git", | ||
| "scikit-image>=0.25.2", | ||
| "openslide-bin>=4.0.0.6", | ||
| "rationai-staining @ git+https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/staining.git", |
There was a problem hiding this comment.
| def __init__( | ||
| self, | ||
| lambda_dssim: float = 0.6, | ||
| lambda_l1: float = 0.2, | ||
| lambda_lum: float = 0.2, | ||
| lambda_gdl: float = 0.1, | ||
| ): | ||
| super().__init__() | ||
| self.lambda_dssim = lambda_dssim | ||
| self.lambda_l1 = lambda_l1 | ||
| self.lambda_lum = lambda_lum | ||
| self.lambda_gdl = lambda_gdl |
There was a problem hiding this comment.
Using lambda as part of a variable name (e.g., lambda_dssim) is confusing because lambda is a Python keyword. It's better to use a more descriptive name like dssim_weight to improve readability. This change should be propagated throughout the class.
| def __init__( | |
| self, | |
| lambda_dssim: float = 0.6, | |
| lambda_l1: float = 0.2, | |
| lambda_lum: float = 0.2, | |
| lambda_gdl: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.lambda_dssim = lambda_dssim | |
| self.lambda_l1 = lambda_l1 | |
| self.lambda_lum = lambda_lum | |
| self.lambda_gdl = lambda_gdl | |
| def __init__( | |
| self, | |
| dssim_weight: float = 0.6, | |
| l1_weight: float = 0.2, | |
| lum_weight: float = 0.2, | |
| gdl_weight: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.dssim_weight = dssim_weight | |
| self.l1_weight = l1_weight | |
| self.lum_weight = lum_weight | |
| self.gdl_weight = gdl_weight |
|
|
||
|
|
||
| def gaussian(window_size: int, sigma: float) -> torch.Tensor: | ||
| gauss = torch.Tensor( |
There was a problem hiding this comment.
torch.Tensor is an alias for torch.FloatTensor and its behavior can be surprising. It's recommended to use torch.tensor which infers the dtype or allows explicit setting, providing more predictable behavior. For floating point tensors, you can use torch.tensor(..., dtype=torch.float32).
| gauss = torch.Tensor( | |
| gauss = torch.tensor( |
| ) -> torch.Tensor: | ||
| device = pred.device | ||
| if he_weights is None: | ||
| he_weights = [0.33, 0.33, 0.33] |
There was a problem hiding this comment.
| diffy = x2.size()[2] - x1.size()[2] | ||
| diffx = x2.size()[3] - x1.size()[3] |
There was a problem hiding this comment.
While .size() works, the more modern and conventional way to get tensor dimensions in PyTorch is by using the .shape attribute. It's generally preferred for consistency and readability.
| diffy = x2.size()[2] - x1.size()[2] | |
| diffx = x2.size()[3] - x1.size()[3] | |
| diffy = x2.shape[2] - x1.shape[2] | |
| diffx = x2.shape[3] - x1.shape[3] |
Using implementation of UNet
Total loss function:
Brightness was added after the initial training because outputs were slightly darker, so I fine-tuned with it instead of retraining from scratch. Probably after the code review I will try training it again with brightness included from the start. It might also be worth trying to drop the gradient loss to speed it up a little bit. Now with brightness it probably won't have that much impact on the total loss.
Training setup mostly follows the ML template defaults. I added more frequent validation because of long training runs.
P.S. The mypy errors are mostly from the project being split into multiple PRs.