From c6692f1419808b26ca205aded069e9f62436e908 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 15:19:18 -0800 Subject: [PATCH 01/75] refactor data loading into its own module --- examples/demo_dlmbl/debug_log_graph.py | 2 +- examples/demo_dlmbl/solution.py | 2 +- tests/light/test_data.py | 2 +- viscy/cli/cli.py | 2 +- viscy/data/__init__.py | 0 viscy/{light/data.py => data/hcs.py} | 0 viscy/light/engine.py | 2 +- viscy/light/predict_writer.py | 2 +- viscy/scripts/profiling.py | 2 +- 9 files changed, 7 insertions(+), 7 deletions(-) create mode 100644 viscy/data/__init__.py rename viscy/{light/data.py => data/hcs.py} (100%) diff --git a/examples/demo_dlmbl/debug_log_graph.py b/examples/demo_dlmbl/debug_log_graph.py index 1819b02fd..ec9871187 100644 --- a/examples/demo_dlmbl/debug_log_graph.py +++ b/examples/demo_dlmbl/debug_log_graph.py @@ -19,7 +19,7 @@ from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule # Trainer class and UNet. from viscy.light.engine import VSUNet diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 933f939df..2c81aa6fe 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -83,7 +83,7 @@ from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule # training augmentations from viscy.transforms import ( diff --git a/tests/light/test_data.py b/tests/light/test_data.py index 263f8f90b..153f175f6 100644 --- a/tests/light/test_data.py +++ b/tests/light/test_data.py @@ -4,7 +4,7 @@ from iohub import open_ome_zarr from pytest import mark -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule from viscy.light.trainer import VSTrainer diff --git a/viscy/cli/cli.py b/viscy/cli/cli.py index 0946bb0f3..f9a55f128 100644 --- a/viscy/cli/cli.py +++ b/viscy/cli/cli.py @@ -9,7 +9,7 @@ from lightning.pytorch.cli import LightningCLI from lightning.pytorch.loggers import TensorBoardLogger -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule from viscy.light.engine import VSUNet from viscy.light.trainer import VSTrainer diff --git a/viscy/data/__init__.py b/viscy/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/viscy/light/data.py b/viscy/data/hcs.py similarity index 100% rename from viscy/light/data.py rename to viscy/data/hcs.py diff --git a/viscy/light/engine.py b/viscy/light/engine.py index f165a056f..74f14aaad 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -25,8 +25,8 @@ structural_similarity_index_measure, ) +from viscy.data.hcs import Sample from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d -from viscy.light.data import Sample from viscy.unet.networks.Unet2D import Unet2d from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index a6ae88cb5..7a58009c7 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -9,7 +9,7 @@ from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import DTypeLike, NDArray -from viscy.light.data import HCSDataModule, Sample +from viscy.data.hcs import HCSDataModule, Sample __all__ = ["HCSPredictionWriter"] _logger = logging.getLogger("lightning.pytorch") diff --git a/viscy/scripts/profiling.py b/viscy/scripts/profiling.py index 0c947f458..a0c3ca6d8 100644 --- a/viscy/scripts/profiling.py +++ b/viscy/scripts/profiling.py @@ -2,7 +2,7 @@ from profilehooks import profile -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule dataset = "/path/to/dataset.zarr" From 3d8e7e2646a10e9483120ad4e12be736342cf621 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 15:26:59 -0800 Subject: [PATCH 02/75] update type annotations --- viscy/unet/networks/Unet21D.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index 7c32e34b8..51ed98395 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -1,11 +1,11 @@ -from typing import Callable, Literal, Optional, Sequence, Union +from typing import Callable, Literal, Sequence import timm import torch from monai.networks.blocks import Convolution, ResidualUnit, UpSample from monai.networks.blocks.dynunet_block import get_conv_layer from monai.networks.utils import normal_init -from torch import nn +from torch import Tensor, nn def icnr_init( @@ -45,7 +45,7 @@ def _get_convnext_stage( in_channels: int, out_channels: int, depth: int, - upsample_factor: Optional[int] = None, + upsample_factor: int | None = None, ) -> nn.Module: stage = timm.models.convnext.ConvNeXtStage( in_chs=in_channels, @@ -83,7 +83,7 @@ def __init__( stride=kernel_size, ) - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.conv(x) b, c, d, h, w = x.shape # project Z/depth into channels @@ -101,7 +101,7 @@ def __init__( mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, norm_name: str, - upsample_pre_conv: Optional[Union[Literal["default"], Callable]], + upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() spatial_dims = 2 @@ -145,11 +145,11 @@ def __init__( upsample_factor=conv_weight_init_factor, ) - def forward(self, inp: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + def forward(self, inp: Tensor, skip: Tensor) -> Tensor: """ - :param torch.Tensor inp: Low resolution features - :param torch.Tensor skip: High resolution skip connection features - :return torch.Tensor: High resolution features + :param Tensor inp: Low resolution features + :param Tensor skip: High resolution skip connection features + :return Tensor: High resolution features """ inp = self.upsample(inp) inp = torch.cat([inp, skip], dim=1) @@ -192,7 +192,7 @@ def __init__( self.out = nn.PixelShuffle(2) self.out_stack_depth = out_stack_depth - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.upsample(x) d = self.out_stack_depth + 2 b, c, h, w = x.shape @@ -209,7 +209,7 @@ class UnsqueezeHead(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = x.unsqueeze(2) return x @@ -222,7 +222,7 @@ def __init__( mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, strides: list[int], - upsample_pre_conv: Optional[Union[Literal["default"], Callable]], + upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() self.decoder_stages = nn.ModuleList([]) @@ -240,7 +240,7 @@ def __init__( ) self.decoder_stages.append(stage) - def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor: + def forward(self, features: Sequence[Tensor]) -> Tensor: feat = features[0] # padding features.append(None) @@ -328,7 +328,7 @@ def num_blocks(self) -> int: """2-times downscaling factor of the smallest feature map""" return 6 - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.stem(x) x: list = self.encoder_stages(x) x.reverse() From fdcbf5536133291cee298c654fac5645ca4acfab Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 16:01:28 -0800 Subject: [PATCH 03/75] move the logging module out --- viscy/unet/{utils => }/logging.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename viscy/unet/{utils => }/logging.py (100%) diff --git a/viscy/unet/utils/logging.py b/viscy/unet/logging.py similarity index 100% rename from viscy/unet/utils/logging.py rename to viscy/unet/logging.py From a2913817e0c432933ba5c81c3662993c579d7e66 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 16:03:10 -0800 Subject: [PATCH 04/75] move old logging into utils --- viscy/{unet => utils}/logging.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename viscy/{unet => utils}/logging.py (100%) diff --git a/viscy/unet/logging.py b/viscy/utils/logging.py similarity index 100% rename from viscy/unet/logging.py rename to viscy/utils/logging.py From 3cf8fa23c73ce754e27498a07879ccb23db7d170 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 09:31:21 -0800 Subject: [PATCH 05/75] rename tests to match module name --- tests/{torch_unet => unet}/networks/Unet25D_tests.py | 0 tests/{torch_unet => unet}/networks/Unet2D_tests.py | 0 tests/{torch_unet => unet}/networks/layers/ConvBlock2D_tests.py | 0 tests/{torch_unet => unet}/networks/layers/ConvBlock3D_tests.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/{torch_unet => unet}/networks/Unet25D_tests.py (100%) rename tests/{torch_unet => unet}/networks/Unet2D_tests.py (100%) rename tests/{torch_unet => unet}/networks/layers/ConvBlock2D_tests.py (100%) rename tests/{torch_unet => unet}/networks/layers/ConvBlock3D_tests.py (100%) diff --git a/tests/torch_unet/networks/Unet25D_tests.py b/tests/unet/networks/Unet25D_tests.py similarity index 100% rename from tests/torch_unet/networks/Unet25D_tests.py rename to tests/unet/networks/Unet25D_tests.py diff --git a/tests/torch_unet/networks/Unet2D_tests.py b/tests/unet/networks/Unet2D_tests.py similarity index 100% rename from tests/torch_unet/networks/Unet2D_tests.py rename to tests/unet/networks/Unet2D_tests.py diff --git a/tests/torch_unet/networks/layers/ConvBlock2D_tests.py b/tests/unet/networks/layers/ConvBlock2D_tests.py similarity index 100% rename from tests/torch_unet/networks/layers/ConvBlock2D_tests.py rename to tests/unet/networks/layers/ConvBlock2D_tests.py diff --git a/tests/torch_unet/networks/layers/ConvBlock3D_tests.py b/tests/unet/networks/layers/ConvBlock3D_tests.py similarity index 100% rename from tests/torch_unet/networks/layers/ConvBlock3D_tests.py rename to tests/unet/networks/layers/ConvBlock3D_tests.py From d4cd41db42ecf62b94ab26e5bbc9a4d7feecfcac Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 09:31:30 -0800 Subject: [PATCH 06/75] bump torch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8d60ee1de..b60cd5346 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ "iohub==0.1.0rc0", - "torch>=2.0.0", + "torch>=2.1.2", "timm>=0.9.5", "tensorboard>=2.13.0", "lightning>=2.0.1", From e87d3969617de3bc7a0b47e136b8e1270dad1ea6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 16:35:30 -0800 Subject: [PATCH 07/75] draft fcmae encoder --- tests/unet/__init__.py | 0 tests/unet/test_fcmae.py | 43 ++++++ viscy/unet/networks/Unet21D.py | 2 +- viscy/unet/networks/fcmae.py | 235 +++++++++++++++++++++++++++++++++ 4 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 tests/unet/__init__.py create mode 100644 tests/unet/test_fcmae.py create mode 100644 viscy/unet/networks/fcmae.py diff --git a/tests/unet/__init__.py b/tests/unet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py new file mode 100644 index 000000000..ae8e0ec6c --- /dev/null +++ b/tests/unet/test_fcmae.py @@ -0,0 +1,43 @@ +import torch + +from viscy.unet.networks.fcmae import ( + MaskedConvNeXtV2Block, + MaskedConvNeXtV2Stage, + MaskedGlobalResponseNorm, +) + + +def test_masked_grn() -> None: + x = torch.rand(2, 3, 4, 5) + grn = MaskedGlobalResponseNorm(3, channels_last=False) + grn.gamma.data = torch.ones_like(grn.gamma.data) + mask = torch.ones((1, 1, 4, 5), dtype=torch.bool) + mask[:, :, 2:, 2:] = False + normalized = grn(x) + assert not torch.allclose(normalized, x) + assert torch.allclose(grn(x, mask)[:, :, 2:, 2:], grn(x[:, :, 2:, 2:])) + grn = MaskedGlobalResponseNorm(5, channels_last=True) + grn.gamma.data = torch.ones_like(grn.gamma.data) + mask = torch.ones((1, 3, 4, 1), dtype=torch.bool) + mask[:, 1:, 2:, :] = False + assert torch.allclose(grn(x, mask)[:, 1:, 2:, :], grn(x[:, 1:, 2:, :])) + + +def test_masked_convnextv2_block() -> None: + x = torch.rand(2, 3, 4, 5) + mask = x[0, 0] > 0.5 + block = MaskedConvNeXtV2Block(3, 3 * 2) + assert len(block(x).unique()) == x.numel() * 2 + block = MaskedConvNeXtV2Block(3, 3) + masked_out = block(x, mask) + assert len(masked_out[:, :, mask].unique()) == x.shape[1] + + +def test_masked_convnextv2_stage() -> None: + x = torch.rand(2, 3, 16, 16) + mask = torch.rand(4, 4) > 0.5 + stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) + out = stage(x) + assert out.shape == (2, 3, 8, 8) + masked_out = stage(x, mask) + assert not torch.allclose(masked_out, out) diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index 51ed98395..c43202403 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -12,7 +12,7 @@ def icnr_init( conv: nn.Module, upsample_factor: int, upsample_dims: int, - init=nn.init.kaiming_normal_, + init: Callable = nn.init.kaiming_normal_, ): """ ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py new file mode 100644 index 000000000..818e8f883 --- /dev/null +++ b/viscy/unet/networks/fcmae.py @@ -0,0 +1,235 @@ +""" +Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 +based on the official JAX example in +https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax +also referring to timm's dense implementation of the encoder in ``timm.models.convnext`` +""" + + +from typing import Callable, Literal, Sequence + +import torch +from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ +from timm.models.convnext import Downsample +from torch import BoolTensor, Tensor, nn + + +def _upsample_mask(mask: BoolTensor, features: Tensor) -> BoolTensor: + mask = mask[..., :, :][None, None] + if features.shape[-2:] != mask.shape[-2:]: + if not all(i % j == 0 for i, j in zip(features.shape[-2:], mask.shape[-2:])): + raise ValueError( + f"feature map shape {features.shape} must be divisible by " + f"mask shape {mask.shape}." + ) + mask = mask.repeat_interleave( + features.shape[-2] // mask.shape[-2], dim=-2 + ).repeat_interleave(features.shape[-1] // mask.shape[-1], dim=-1) + return mask + + +class MaskedGlobalResponseNorm(nn.Module): + """ + Masked Global Response Normalization. + + :param int dim: number of input channels + :param float eps: small value added for numerical stability, + defaults to 1e-6 + :param bool channels_last: BHWC (True) or BCHW (False) dimension ordering, + defaults to False + """ + + def __init__( + self, dim: int, eps: float = 1e-6, channels_last: bool = False + ) -> None: + super().__init__() + if channels_last: + self.spatial_dim = (1, 2) + self.channel_dim = -1 + weights_shape = (1, 1, 1, dim) + else: + self.spatial_dim = (2, 3) + self.channel_dim = 1 + weights_shape = (1, dim, 1, 1) + self.gamma = nn.Parameter(torch.zeros(weights_shape)) + self.beta = nn.Parameter(torch.zeros(weights_shape)) + self.eps = eps + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor, BHWC or BCHW + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: normalized tensor + """ + samples = x if mask is None else x * ~mask + g_x = samples.norm(p=2, dim=self.spatial_dim, keepdim=True) + n_x = g_x / (g_x.mean(dim=self.channel_dim, keepdim=True) + self.eps) + return x + torch.addcmul(self.beta, self.gamma, x * n_x) + + +class MaskedConvNeXtV2Block(nn.Module): + """Masked ConvNeXt V2 Block. + + :param int in_channels: input channels + :param int | None out_channels: output channels, defaults to None + :param int kernel_size: depth-wise convolution kernel size, defaults to 7 + :param int stride: downsample stride, defaults to 1 + :param int mlp_ratio: MLP expansion ratio, defaults to 4 + :param float drop_path: drop path rate, defaults to 0.0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + kernel_size: int = 7, + stride: int = 1, + mlp_ratio: int = 4, + drop_path: float = 0.0, + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + self.dwconv = create_conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + depthwise=True, + ) + self.layernorm = LayerNorm2d(out_channels) + self.pwconv1 = nn.Conv2d(out_channels, mlp_ratio * out_channels, kernel_size=1) + self.act = nn.GELU() + self.grn = MaskedGlobalResponseNorm(mlp_ratio * out_channels) + self.pwconv2 = nn.Conv2d(mlp_ratio * out_channels, out_channels, kernel_size=1) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if in_channels != out_channels or stride > 1: + self.shortcut = Downsample(in_channels, out_channels, stride=stride) + else: + self.shortcut = nn.Identity() + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + shortcut = self.shortcut(x) + if mask is not None: + x *= ~mask + x = self.dwconv(x) + if mask is not None: + x *= ~mask + x = self.layernorm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x, mask) + x = self.pwconv2(x) + x = self.drop_path(x) + shortcut + return x + + +class MaskedConvNeXtV2Stage(nn.Module): + """Masked ConvNeXt V2 Stage. + + :param int in_channels: input channels + :param int out_channels: output channels + :param int kernel_size: depth-wise convolution kernel size, defaults to 7 + :param int stride: downsampling factor of this stage, defaults to 2 + :param int num_blocks: number of residual blocks, defaults to 2 + :param Sequence[float] | None drop_path_rates: drop path rates of each block, + defaults to None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 7, + stride: int = 2, + num_blocks: int = 2, + drop_path_rates: Sequence[float] | None = None, + ) -> None: + super().__init__() + if drop_path_rates is None: + drop_path_rates = [0.0] * num_blocks + elif len(drop_path_rates) != num_blocks: + raise ValueError( + "length of drop_path_rates must be equal to " + f"the number of blocks {num_blocks}, got {len(drop_path_rates)}." + ) + if in_channels != out_channels or stride > 1: + downsample_kernel_size = stride if stride > 1 else 1 + self.downsample = nn.Sequential( + LayerNorm2d(in_channels), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=downsample_kernel_size, + stride=stride, + padding=0, + ), + ) + in_channels = out_channels + else: + self.downsample = nn.Identity() + self.blocks = nn.ModuleList() + for i in range(num_blocks): + self.blocks.append( + MaskedConvNeXtV2Block( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + drop_path=drop_path_rates[i], + ) + ) + in_channels = out_channels + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + x = self.downsample(x) + if mask is not None: + mask = _upsample_mask(mask, x) + for block in self.blocks: + x = block(x, mask) + return x + + +class MaskedMultiscaleEncoder(nn.Module): + def __init__( + self, + in_channels: int, + stage_blocks: Sequence[int] = (3, 3, 9, 3), + dims: Sequence[int] = (96, 192, 384, 768), + drop_path_rate: float = 0.0, + ) -> None: + super().__init__() + self.stages = nn.ModuleList() + chs = [in_channels, *dims] + for i, num_blocks in enumerate(stage_blocks): + self.stages.append( + MaskedConvNeXtV2Stage( + chs[i], + chs[i + 1], + kernel_size=7, + stride=2, + num_blocks=num_blocks, + drop_path_rates=[drop_path_rate] * num_blocks, + ) + ) + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + features = [] + for stage in self.stages: + x = stage(x, mask) + features.append(x) + return features From dccce5f785581300dd4387f2d5f0548be50af5bf Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Jan 2024 15:14:36 -0800 Subject: [PATCH 08/75] add stem to the encoder --- tests/unet/test_fcmae.py | 37 +++++++++++++ viscy/unet/networks/fcmae.py | 101 ++++++++++++++++++++++++++++++----- 2 files changed, 125 insertions(+), 13 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index ae8e0ec6c..73dc5920b 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,9 +1,12 @@ import torch from viscy.unet.networks.fcmae import ( + AdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedGlobalResponseNorm, + MaskedMultiscaleEncoder, + upsample_mask, ) @@ -41,3 +44,37 @@ def test_masked_convnextv2_stage() -> None: assert out.shape == (2, 3, 8, 8) masked_out = stage(x, mask) assert not torch.allclose(masked_out, out) + + +def test_adaptive_projection() -> None: + proj = AdaptiveProjection(3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5) + assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) + assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) + proj = AdaptiveProjection( + 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 + ) + assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) + + +def test_masked_multiscale_encoder() -> None: + xy_size = 64 + dims = [12, 24, 48, 96] + x = torch.rand(2, 3, 5, xy_size, xy_size) + encoder = MaskedMultiscaleEncoder(3, dims=dims) + # auto_masked_features, mask = encoder(x, mask_ratio=0.5) + auto_masked_features = encoder(x) + target_shape = list(x.shape) + target_shape.pop(1) + pre_masked_features = encoder(x) #encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) + assert len(auto_masked_features) == len(pre_masked_features) == 4 + for i, (dim, afeat, pfeat) in enumerate( + zip(dims, auto_masked_features, pre_masked_features) + ): + assert afeat.shape[0] == x.shape[0] + assert afeat.shape[1] == dim + stride = 2 * 2 ** (i + 1) + assert afeat.shape[2] == afeat.shape[3] == xy_size // stride + assert torch.allclose(afeat, pfeat, rtol=1e-1, atol=5e-2), ( + i, + (afeat - pfeat).abs().max(), + ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 818e8f883..71644955d 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -11,20 +11,19 @@ import torch from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ from timm.models.convnext import Downsample -from torch import BoolTensor, Tensor, nn +from torch import BoolTensor, Size, Tensor, nn -def _upsample_mask(mask: BoolTensor, features: Tensor) -> BoolTensor: - mask = mask[..., :, :][None, None] - if features.shape[-2:] != mask.shape[-2:]: - if not all(i % j == 0 for i, j in zip(features.shape[-2:], mask.shape[-2:])): +def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: + if target[-2:] != mask.shape[-2:]: + if not all(i % j == 0 for i, j in zip(target, mask.shape)): raise ValueError( - f"feature map shape {features.shape} must be divisible by " + f"feature map shape {target} must be divisible by " f"mask shape {mask.shape}." ) mask = mask.repeat_interleave( - features.shape[-2] // mask.shape[-2], dim=-2 - ).repeat_interleave(features.shape[-1] // mask.shape[-1], dim=-1) + target[-2] // mask.shape[-2], dim=-2 + ).repeat_interleave(target[-1] // mask.shape[-1], dim=-1) return mask @@ -193,12 +192,64 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: """ x = self.downsample(x) if mask is not None: - mask = _upsample_mask(mask, x) + mask = upsample_mask(mask, x.shape) for block in self.blocks: x = block(x, mask) return x +class AdaptiveProjection(nn.Module): + """ + Patchifying layer for projecting 2D or 3D input into 2D feature maps. + Masking is not needed because the mask will cover entire patches. + + :param int in_channels: input channels + :param int out_channels: output channels + :param Sequence[int, int] | int kernel_size_2d: kernel width and height + :param int kernel_depth: kernel depth for 3D input + :param int in_stack_depth: input stack depth for 3D input + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size_2d: tuple[int, int] | int = 4, + kernel_depth: int = 5, + in_stack_depth: int = 5, + ) -> None: + super().__init__() + ratio = in_stack_depth // kernel_depth + if isinstance(kernel_size_2d, int): + kernel_size_2d = [kernel_size_2d] * 2 + kernel_size_3d = [kernel_depth, *kernel_size_2d] + self.conv3d = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels // ratio, + kernel_size=kernel_size_3d, + stride=kernel_size_3d, + ) + self.conv2d = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size_2d, + stride=kernel_size_2d, + ) + + def forward(self, x: Tensor) -> Tensor: + """ + :param Tensor x: input tensor (BCDHW) + :return Tensor: output tensor (BCHW) + """ + if x.shape[2] > 1: + x = self.conv3d(x) + b, c, d, h, w = x.shape + # project Z/depth into channels + # return a view when possible (contiguous) + return x.reshape(b, c * d, h, w) + return self.conv2d(x.squeeze(2)) + + class MaskedMultiscaleEncoder(nn.Module): def __init__( self, @@ -208,28 +259,52 @@ def __init__( drop_path_rate: float = 0.0, ) -> None: super().__init__() + stem_kernel_size_2d = 4 + self.stem = nn.Sequential( + AdaptiveProjection( + in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 + ), + LayerNorm2d(dims[0]), + ) self.stages = nn.ModuleList() - chs = [in_channels, *dims] + chs = [dims[0], *dims] for i, num_blocks in enumerate(stage_blocks): + stride = 1 if i == 0 else 2 self.stages.append( MaskedConvNeXtV2Stage( chs[i], chs[i + 1], kernel_size=7, - stride=2, + stride=stride, num_blocks=num_blocks, drop_path_rates=[drop_path_rate] * num_blocks, ) ) + self.total_stride = stem_kernel_size_2d * 2 ** (len(self.stages) - 1) - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param float mask_ratio: ratio of the feature maps to mask, + defaults to 0.0 (no masking) :return Tensor: output tensor (BCHW) """ + if mask_ratio > 0.0: + noise = torch.rand( + x.shape[0], + 1, + x.shape[-2] // self.total_stride, + x.shape[-1] // self.total_stride, + device=x.device, + ) + mask = noise > mask_ratio + else: + mask = None + x = self.stem(x) features = [] for stage in self.stages: x = stage(x, mask) features.append(x) + if mask is not None: + return features, mask return features From 55087315f6783417acad17500e0fe3b47899b125 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Jan 2024 15:56:23 -0800 Subject: [PATCH 09/75] wip: masked stem layernorm --- tests/unet/test_fcmae.py | 15 +++++++------- viscy/unet/networks/fcmae.py | 38 ++++++++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 73dc5920b..b9a3d389a 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,7 +1,7 @@ import torch from viscy.unet.networks.fcmae import ( - AdaptiveProjection, + MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedGlobalResponseNorm, @@ -47,10 +47,12 @@ def test_masked_convnextv2_stage() -> None: def test_adaptive_projection() -> None: - proj = AdaptiveProjection(3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5) + proj = MaskedAdaptiveProjection( + 3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5 + ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - proj = AdaptiveProjection( + proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) @@ -61,11 +63,10 @@ def test_masked_multiscale_encoder() -> None: dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) encoder = MaskedMultiscaleEncoder(3, dims=dims) - # auto_masked_features, mask = encoder(x, mask_ratio=0.5) - auto_masked_features = encoder(x) + auto_masked_features, mask = encoder(x, mask_ratio=0.5) target_shape = list(x.shape) target_shape.pop(1) - pre_masked_features = encoder(x) #encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) + pre_masked_features = encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) assert len(auto_masked_features) == len(pre_masked_features) == 4 for i, (dim, afeat, pfeat) in enumerate( zip(dims, auto_masked_features, pre_masked_features) @@ -74,7 +75,7 @@ def test_masked_multiscale_encoder() -> None: assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride - assert torch.allclose(afeat, pfeat, rtol=1e-1, atol=5e-2), ( + assert torch.allclose(afeat, pfeat, rtol=5e-2, atol=5e-2), ( i, (afeat - pfeat).abs().max(), ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 71644955d..416c50ad7 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -9,7 +9,14 @@ from typing import Callable, Literal, Sequence import torch -from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ +from timm.layers import ( + DropPath, + GlobalResponseNormMlp, + LayerNorm2d, + LayerNorm, + create_conv2d, + trunc_normal_, +) from timm.models.convnext import Downsample from torch import BoolTensor, Size, Tensor, nn @@ -198,10 +205,9 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: return x -class AdaptiveProjection(nn.Module): +class MaskedAdaptiveProjection(nn.Module): """ - Patchifying layer for projecting 2D or 3D input into 2D feature maps. - Masking is not needed because the mask will cover entire patches. + Masked patchifying layer for projecting 2D or 3D input into 2D feature maps. :param int in_channels: input channels :param int out_channels: output channels @@ -235,19 +241,35 @@ def __init__( kernel_size=kernel_size_2d, stride=kernel_size_2d, ) + self.norm = nn.LayerNorm(out_channels) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: """ :param Tensor x: input tensor (BCDHW) + :param BoolTensor mask: boolean mask (B1HW), defaults to None :return Tensor: output tensor (BCHW) """ + # no need to mask before convolutions since patches do not spill over if x.shape[2] > 1: x = self.conv3d(x) b, c, d, h, w = x.shape # project Z/depth into channels # return a view when possible (contiguous) - return x.reshape(b, c * d, h, w) - return self.conv2d(x.squeeze(2)) + x = x.reshape(b, c * d, h, w) + else: + x = self.conv2d(x.squeeze(2)) + out_shape = x.shape + if mask is not None: + mask = upsample_mask(mask, x.shape) + x = x[mask] + else: + x = x.flatten(2) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + if mask is not None: + out = torch.zeros(out_shape, device=x.device) + out[mask] = x class MaskedMultiscaleEncoder(nn.Module): @@ -261,7 +283,7 @@ def __init__( super().__init__() stem_kernel_size_2d = 4 self.stem = nn.Sequential( - AdaptiveProjection( + MaskedAdaptiveProjection( in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 ), LayerNorm2d(dims[0]), From 3eec48ed78908eb44edf8cd96991da2b79c8cece Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 16 Jan 2024 20:23:32 -0800 Subject: [PATCH 10/75] wip: patchify masked features for linear --- tests/unet/test_fcmae.py | 51 +++++++++++++-- viscy/unet/networks/fcmae.py | 122 +++++++++++++++++++++-------------- 2 files changed, 119 insertions(+), 54 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index b9a3d389a..fc5349812 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -4,13 +4,52 @@ MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, - MaskedGlobalResponseNorm, + # MaskedGlobalResponseNorm, MaskedMultiscaleEncoder, + generate_mask, + masked_patchify, + masked_unpatchify, upsample_mask, ) -def test_masked_grn() -> None: +def test_generate_mask(): + w = 64 + s = 16 + m = 0.75 + mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m) + assert mask.shape == (2, 1, w // s, w // s) + assert mask.dtype == torch.bool + ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0] + assert torch.allclose(ratio, torch.ones_like(ratio) * m) + + +def test_masked_patchify(): + b, c, h, w = 2, 3, 4, 8 + x = torch.rand(b, c, h, w) + mask_ratio = 0.75 + mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) + mask = upsample_mask(mask, x.shape) + feat = masked_patchify(x, mask) + assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) + + +def test_unmasked_patchify_roundtrip(): + x = torch.rand(2, 3, 4, 8) + y = masked_unpatchify(masked_patchify(x), out_shape=x.shape) + assert torch.allclose(x, y) + + +def test_masked_patchify_roundtrip(): + x = torch.rand(2, 3, 4, 8) + mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) + mask = upsample_mask(mask, x.shape) + y = masked_unpatchify(masked_patchify(x, mask), out_shape=x.shape, mask=mask) + assert torch.all((y == 0) ^ (x == y)) + assert torch.all((y == 0)[:, 0:1] == mask) + + +def test_masked_grn(): x = torch.rand(2, 3, 4, 5) grn = MaskedGlobalResponseNorm(3, channels_last=False) grn.gamma.data = torch.ones_like(grn.gamma.data) @@ -36,7 +75,7 @@ def test_masked_convnextv2_block() -> None: assert len(masked_out[:, :, mask].unique()) == x.shape[1] -def test_masked_convnextv2_stage() -> None: +def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) mask = torch.rand(4, 4) > 0.5 stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) @@ -46,19 +85,21 @@ def test_masked_convnextv2_stage() -> None: assert not torch.allclose(masked_out, out) -def test_adaptive_projection() -> None: +def test_adaptive_projection(): proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5 ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) + mask = torch.rand(2, 1, 2, 2) > 0.5 + masked_out = proj(torch.rand(2, 3, 5, 16, 16), mask) proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) -def test_masked_multiscale_encoder() -> None: +def test_masked_multiscale_encoder(): xy_size = 64 dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 416c50ad7..d852f780b 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -2,18 +2,18 @@ Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 based on the official JAX example in https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax -also referring to timm's dense implementation of the encoder in ``timm.models.convnext`` +and timm's dense implementation of the encoder in ``timm.models.convnext`` """ from typing import Callable, Literal, Sequence import torch +import torch.nn.functional as F from timm.layers import ( DropPath, GlobalResponseNormMlp, LayerNorm2d, - LayerNorm, create_conv2d, trunc_normal_, ) @@ -21,7 +21,27 @@ from torch import BoolTensor, Size, Tensor, nn +def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: + """ + :param Size target: target shape + :param int stride: total stride + :param float mask_ratio: ratio of the pixels to mask + :return BoolTensor: boolean mask (N, H*W) + """ + m_height = target[-2] // stride + m_width = target[-1] // stride + mask_numel = m_height * m_width + masked_elements = int(mask_numel * mask_ratio) + mask = torch.rand(target[0], mask_numel).argsort(1) < masked_elements + return mask.reshape(target[0], 1, m_height, m_width) + + def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: + """ + :param BoolTensor mask: low-resolution boolean mask (B1HW) + :param Size target: target size (BCHW) + :return BoolTensor: upsampled boolean mask (B1HW) + """ if target[-2:] != mask.shape[-2:]: if not all(i % j == 0 for i, j in zip(target, mask.shape)): raise ValueError( @@ -34,43 +54,48 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: return mask -class MaskedGlobalResponseNorm(nn.Module): +def masked_patchify(features: Tensor, mask: BoolTensor | None = None) -> Tensor: """ - Masked Global Response Normalization. - - :param int dim: number of input channels - :param float eps: small value added for numerical stability, - defaults to 1e-6 - :param bool channels_last: BHWC (True) or BCHW (False) dimension ordering, - defaults to False + :param Tensor features: input image features (BCHW) + :param BoolTensor mask: boolean mask (B1HW) + :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) """ + if mask is None: + return features.flatten(2).permute(0, 2, 1) + b, c = features.shape[:2] + # (B, C, H, W) -> (B, H, W, C) + features = features.permute(0, 2, 3, 1) + # (B, H, W, C) -> (B * L, C) -> (B, L, C) + features = features[~mask[:, 0]].reshape(b, -1, c) - def __init__( - self, dim: int, eps: float = 1e-6, channels_last: bool = False - ) -> None: - super().__init__() - if channels_last: - self.spatial_dim = (1, 2) - self.channel_dim = -1 - weights_shape = (1, 1, 1, dim) - else: - self.spatial_dim = (2, 3) - self.channel_dim = 1 - weights_shape = (1, dim, 1, 1) - self.gamma = nn.Parameter(torch.zeros(weights_shape)) - self.beta = nn.Parameter(torch.zeros(weights_shape)) - self.eps = eps + # kernel_size = tuple(features.shape[-i] // mask.shape[-i] for i in (2, 1)) + # # (B, C, H, W) -> (B, C * H_patch * Wp, H_grid * Wg) + # features = F.unfold(features, kernel_size=kernel_size, stride=kernel_size) + # patch_size = kernel_size[0] * kernel_size[1] + # # (B, C * Hp * Wp, Hg * Wg) -> (B, C, Hp * Wp, Hg * Wg) -> (B, Hg * Wg, C, Hp * Wp) + # features = features.view(b, c, patch_size, -1).permute(0, 3, 1, 2) + # # (B, 1, Hg, Wg) -> (B, Hg*Wg) + # idx = ~mask.flatten(1) + # # (B, Hg * Wg, C, Hp * Wp) -> (B * L, C, Hp * Wp) -> (B, L, C, Hp * Wp) + # features = features[idx].view(b, -1, c, patch_size) + # # (B, L, C, Hp * Wp) -> (B, L, Hp * Wp, C) -> (B, L * Hp * Wp, C) + # features = features.permute(0, 1, 3, 2).reshape(b, -1, c) + return features - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: - """ - :param Tensor x: input tensor, BHWC or BCHW - :param BoolTensor | None mask: boolean mask, defaults to None - :return Tensor: normalized tensor - """ - samples = x if mask is None else x * ~mask - g_x = samples.norm(p=2, dim=self.spatial_dim, keepdim=True) - n_x = g_x / (g_x.mean(dim=self.channel_dim, keepdim=True) + self.eps) - return x + torch.addcmul(self.beta, self.gamma, x * n_x) + +def masked_unpatchify( + features: Tensor, out_shape: Size, mask: BoolTensor | None = None +) -> Tensor: + if mask is None: + # (B, L, C) -> (B, C, L) -> (B, C, H, W) + return features.permute(0, 2, 1).reshape(out_shape) + b, c, w, h = out_shape + out = torch.zeros((b, w, h, c), device=features.device, dtype=features.dtype) + # (B, L, C) -> (B * L, C) + features = features.reshape(-1, c) + out[~mask[:, 0]] = features + # (B, H, W, C) -> (B, C, H, W) + return out.permute(0, 3, 1, 2) class MaskedConvNeXtV2Block(nn.Module): @@ -102,11 +127,13 @@ def __init__( stride=stride, depthwise=True, ) - self.layernorm = LayerNorm2d(out_channels) - self.pwconv1 = nn.Conv2d(out_channels, mlp_ratio * out_channels, kernel_size=1) - self.act = nn.GELU() - self.grn = MaskedGlobalResponseNorm(mlp_ratio * out_channels) - self.pwconv2 = nn.Conv2d(mlp_ratio * out_channels, out_channels, kernel_size=1) + self.layernorm = nn.LayerNorm(out_channels) + mid_channels = mlp_ratio * out_channels + self.mlp = GlobalResponseNormMlp( + in_features=out_channels, + hidden_features=mid_channels, + out_features=out_channels, + ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() if in_channels != out_channels or stride > 1: self.shortcut = Downsample(in_channels, out_channels, stride=stride) @@ -125,6 +152,8 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: x = self.dwconv(x) if mask is not None: x *= ~mask + out_shape = x.shape + x = masked_project(x, mask) x = self.layernorm(x) x = self.pwconv1(x) x = self.act(x) @@ -268,8 +297,10 @@ def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: x = self.norm(x) x = x.permute(0, 2, 1) if mask is not None: - out = torch.zeros(out_shape, device=x.device) + out = torch.zeros(out_shape, device=x.device, dtype=x.dtype) out[mask] = x + return out + return x.reshape(out_shape) class MaskedMultiscaleEncoder(nn.Module): @@ -312,14 +343,7 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: :return Tensor: output tensor (BCHW) """ if mask_ratio > 0.0: - noise = torch.rand( - x.shape[0], - 1, - x.shape[-2] // self.total_stride, - x.shape[-1] // self.total_stride, - device=x.device, - ) - mask = noise > mask_ratio + mask = generate_mask(x.shape, self.total_stride, mask_ratio) else: mask = None x = self.stem(x) From 8c54febcf71f8074fe6e7c40198ac1a673cf7678 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 16 Jan 2024 21:37:39 -0800 Subject: [PATCH 11/75] use mlp from timm --- tests/unet/test_fcmae.py | 51 ++++++------------- viscy/unet/networks/fcmae.py | 95 +++++++++++++++--------------------- 2 files changed, 55 insertions(+), 91 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index fc5349812..ba0d7a244 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -4,7 +4,6 @@ MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, - # MaskedGlobalResponseNorm, MaskedMultiscaleEncoder, generate_mask, masked_patchify, @@ -30,7 +29,7 @@ def test_masked_patchify(): mask_ratio = 0.75 mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) mask = upsample_mask(mask, x.shape) - feat = masked_patchify(x, mask) + feat = masked_patchify(x, ~mask) assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) @@ -44,40 +43,28 @@ def test_masked_patchify_roundtrip(): x = torch.rand(2, 3, 4, 8) mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) mask = upsample_mask(mask, x.shape) - y = masked_unpatchify(masked_patchify(x, mask), out_shape=x.shape, mask=mask) + y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask) assert torch.all((y == 0) ^ (x == y)) assert torch.all((y == 0)[:, 0:1] == mask) -def test_masked_grn(): - x = torch.rand(2, 3, 4, 5) - grn = MaskedGlobalResponseNorm(3, channels_last=False) - grn.gamma.data = torch.ones_like(grn.gamma.data) - mask = torch.ones((1, 1, 4, 5), dtype=torch.bool) - mask[:, :, 2:, 2:] = False - normalized = grn(x) - assert not torch.allclose(normalized, x) - assert torch.allclose(grn(x, mask)[:, :, 2:, 2:], grn(x[:, :, 2:, 2:])) - grn = MaskedGlobalResponseNorm(5, channels_last=True) - grn.gamma.data = torch.ones_like(grn.gamma.data) - mask = torch.ones((1, 3, 4, 1), dtype=torch.bool) - mask[:, 1:, 2:, :] = False - assert torch.allclose(grn(x, mask)[:, 1:, 2:, :], grn(x[:, 1:, 2:, :])) - - def test_masked_convnextv2_block() -> None: x = torch.rand(2, 3, 4, 5) - mask = x[0, 0] > 0.5 + mask = generate_mask(x.shape, stride=1, mask_ratio=0.5) block = MaskedConvNeXtV2Block(3, 3 * 2) - assert len(block(x).unique()) == x.numel() * 2 + unmasked_out = block(x) + assert len(unmasked_out.unique()) == x.numel() * 2 + all_unmasked = torch.ones_like(mask) + empty_masked_out = block(x, all_unmasked) + assert torch.allclose(unmasked_out, empty_masked_out) block = MaskedConvNeXtV2Block(3, 3) masked_out = block(x, mask) - assert len(masked_out[:, :, mask].unique()) == x.shape[1] + assert len(masked_out.unique()) == mask.sum() * x.shape[1] + 1 def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) - mask = torch.rand(4, 4) > 0.5 + mask = generate_mask(x.shape, stride=4, mask_ratio=0.5) stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) out = stage(x) assert out.shape == (2, 3, 8, 8) @@ -91,8 +78,9 @@ def test_adaptive_projection(): ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - mask = torch.rand(2, 1, 2, 2) > 0.5 - masked_out = proj(torch.rand(2, 3, 5, 16, 16), mask) + mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6) + masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask) + assert masked_out.shape == (1, 12, 4, 4) proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) @@ -104,19 +92,12 @@ def test_masked_multiscale_encoder(): dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) encoder = MaskedMultiscaleEncoder(3, dims=dims) - auto_masked_features, mask = encoder(x, mask_ratio=0.5) + auto_masked_features, _ = encoder(x, mask_ratio=0.5) target_shape = list(x.shape) target_shape.pop(1) - pre_masked_features = encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) - assert len(auto_masked_features) == len(pre_masked_features) == 4 - for i, (dim, afeat, pfeat) in enumerate( - zip(dims, auto_masked_features, pre_masked_features) - ): + assert len(auto_masked_features) == 4 + for i, (dim, afeat) in enumerate(zip(dims, auto_masked_features)): assert afeat.shape[0] == x.shape[0] assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride - assert torch.allclose(afeat, pfeat, rtol=5e-2, atol=5e-2), ( - i, - (afeat - pfeat).abs().max(), - ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index d852f780b..a2e6849ec 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -54,46 +54,38 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: return mask -def masked_patchify(features: Tensor, mask: BoolTensor | None = None) -> Tensor: +def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor features: input image features (BCHW) - :param BoolTensor mask: boolean mask (B1HW) + :param BoolTensor unmasked: boolean foreground mask (B1HW) :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) """ - if mask is None: + if unmasked is None: return features.flatten(2).permute(0, 2, 1) b, c = features.shape[:2] # (B, C, H, W) -> (B, H, W, C) features = features.permute(0, 2, 3, 1) # (B, H, W, C) -> (B * L, C) -> (B, L, C) - features = features[~mask[:, 0]].reshape(b, -1, c) - - # kernel_size = tuple(features.shape[-i] // mask.shape[-i] for i in (2, 1)) - # # (B, C, H, W) -> (B, C * H_patch * Wp, H_grid * Wg) - # features = F.unfold(features, kernel_size=kernel_size, stride=kernel_size) - # patch_size = kernel_size[0] * kernel_size[1] - # # (B, C * Hp * Wp, Hg * Wg) -> (B, C, Hp * Wp, Hg * Wg) -> (B, Hg * Wg, C, Hp * Wp) - # features = features.view(b, c, patch_size, -1).permute(0, 3, 1, 2) - # # (B, 1, Hg, Wg) -> (B, Hg*Wg) - # idx = ~mask.flatten(1) - # # (B, Hg * Wg, C, Hp * Wp) -> (B * L, C, Hp * Wp) -> (B, L, C, Hp * Wp) - # features = features[idx].view(b, -1, c, patch_size) - # # (B, L, C, Hp * Wp) -> (B, L, Hp * Wp, C) -> (B, L * Hp * Wp, C) - # features = features.permute(0, 1, 3, 2).reshape(b, -1, c) + features = features[unmasked[:, 0]].reshape(b, -1, c) return features def masked_unpatchify( - features: Tensor, out_shape: Size, mask: BoolTensor | None = None + features: Tensor, out_shape: Size, unmasked: BoolTensor | None = None ) -> Tensor: - if mask is None: - # (B, L, C) -> (B, C, L) -> (B, C, H, W) + """ + :param Tensor features: dense channel-last features (BLC) + :param Size out_shape: output shape (BCHW) + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None + :return Tensor: masked features (BCHW) + """ + if unmasked is None: return features.permute(0, 2, 1).reshape(out_shape) b, c, w, h = out_shape out = torch.zeros((b, w, h, c), device=features.device, dtype=features.dtype) # (B, L, C) -> (B * L, C) features = features.reshape(-1, c) - out[~mask[:, 0]] = features + out[unmasked[:, 0]] = features # (B, H, W, C) -> (B, C, H, W) return out.permute(0, 3, 1, 2) @@ -140,25 +132,23 @@ def __init__( else: self.shortcut = nn.Identity() - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None :return Tensor: output tensor (BCHW) """ shortcut = self.shortcut(x) - if mask is not None: - x *= ~mask + if unmasked is not None: + x *= unmasked x = self.dwconv(x) - if mask is not None: - x *= ~mask + if unmasked is not None: + x *= unmasked out_shape = x.shape - x = masked_project(x, mask) + x = masked_patchify(x, unmasked=unmasked) x = self.layernorm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x, mask) - x = self.pwconv2(x) + x = self.mlp(x.unsqueeze(1)).squeeze(1) + x = masked_unpatchify(x, out_shape=out_shape, unmasked=unmasked) x = self.drop_path(x) + shortcut return x @@ -220,17 +210,17 @@ def __init__( ) in_channels = out_channels - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None :return Tensor: output tensor (BCHW) """ x = self.downsample(x) - if mask is not None: - mask = upsample_mask(mask, x.shape) + if unmasked is not None: + unmasked = upsample_mask(unmasked, x.shape) for block in self.blocks: - x = block(x, mask) + x = block(x, unmasked) return x @@ -272,10 +262,10 @@ def __init__( ) self.norm = nn.LayerNorm(out_channels) - def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: """ :param Tensor x: input tensor (BCDHW) - :param BoolTensor mask: boolean mask (B1HW), defaults to None + :param BoolTensor unmasked: boolean foreground mask (B1HW), defaults to None :return Tensor: output tensor (BCHW) """ # no need to mask before convolutions since patches do not spill over @@ -288,19 +278,12 @@ def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: else: x = self.conv2d(x.squeeze(2)) out_shape = x.shape - if mask is not None: - mask = upsample_mask(mask, x.shape) - x = x[mask] - else: - x = x.flatten(2) - x = x.permute(0, 2, 1) + if unmasked is not None: + unmasked = upsample_mask(unmasked, x.shape) + x = masked_patchify(x, unmasked=unmasked) x = self.norm(x) - x = x.permute(0, 2, 1) - if mask is not None: - out = torch.zeros(out_shape, device=x.device, dtype=x.dtype) - out[mask] = x - return out - return x.reshape(out_shape) + x = masked_unpatchify(x, out_shape=out_shape, unmasked=unmasked) + return x class MaskedMultiscaleEncoder(nn.Module): @@ -343,14 +326,14 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: :return Tensor: output tensor (BCHW) """ if mask_ratio > 0.0: - mask = generate_mask(x.shape, self.total_stride, mask_ratio) + unmasked = ~generate_mask(x.shape, self.total_stride, mask_ratio) else: - mask = None + unmasked = None x = self.stem(x) features = [] for stage in self.stages: - x = stage(x, mask) + x = stage(x, unmasked=unmasked) features.append(x) - if mask is not None: - return features, mask + if unmasked is not None: + return features, unmasked return features From 83ecf4a7fcc138fcc9cab7f6b4c1ab6c5ce149a0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 00:14:58 -0800 Subject: [PATCH 12/75] hack: POC training script for FCMAE --- tests/light/test_engine.py | 10 +++ tests/unet/test_fcmae.py | 8 +++ viscy/light/engine.py | 44 +++++++++++++ viscy/scripts/train_fcmae.py | 66 ++++++++++++++++++++ viscy/unet/networks/fcmae.py | 117 +++++++++++++++++++++++++++++------ 5 files changed, 225 insertions(+), 20 deletions(-) create mode 100644 tests/light/test_engine.py create mode 100644 viscy/scripts/train_fcmae.py diff --git a/tests/light/test_engine.py b/tests/light/test_engine.py new file mode 100644 index 000000000..c60133658 --- /dev/null +++ b/tests/light/test_engine.py @@ -0,0 +1,10 @@ +from viscy.light.engine import FcmaeUNet + + +def test_fcmae_vsunet() -> None: + model = FcmaeUNet( + architecture="fcmae", + model_config=dict(in_channels=3), + train_mask_ratio=0.6, + ) + diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index ba0d7a244..870f1138a 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,6 +1,7 @@ import torch from viscy.unet.networks.fcmae import ( + FullyConvolutionalMAE, MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, @@ -101,3 +102,10 @@ def test_masked_multiscale_encoder(): assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride + + +def test_fcmae(): + x = torch.rand(2, 3, 5, 128, 128) + model = FullyConvolutionalMAE(3) + assert model(x).shape == x.shape + assert model(x, mask_ratio=0.6).shape == x.shape diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 74f14aaad..0262cc7df 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -27,6 +27,7 @@ from viscy.data.hcs import Sample from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d +from viscy.unet.networks.fcmae import FullyConvolutionalMAE from viscy.unet.networks.Unet2D import Unet2d from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d @@ -43,6 +44,7 @@ # same class with out_stack_depth > 1 "2.2D": Unet21d, "2.5D": Unet25d, + "fcmae": FullyConvolutionalMAE, } @@ -367,3 +369,45 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) + + +class FcmaeUNet(VSUNet): + def __init__(self, train_mask_ratio: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.train_mask_ratio = train_mask_ratio + + def forward(self, x, mask_ratio: float = 0.0): + return self.model(x, mask_ratio) + + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"] + pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + loss = F.mse_loss(pred, target, reduction="none") + loss = (loss * mask).sum() / mask.sum() + self.log( + "loss/train", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) + return loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"] + pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + loss = F.mse_loss(pred, target, reduction="none") + loss = (loss.mean(2) * mask).sum() / mask.sum() + self.log("loss/validate", loss, sync_dist=True) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py new file mode 100644 index 000000000..692bef6db --- /dev/null +++ b/viscy/scripts/train_fcmae.py @@ -0,0 +1,66 @@ +# %% +from lightning.pytorch.loggers import TensorBoardLogger +from torch import set_float32_matmul_precision + +from viscy.data.hcs import HCSDataModule +from viscy.light.engine import FcmaeUNet +from viscy.light.trainer import VSTrainer +from viscy.transforms import ( + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, +) + +# %% +model = FcmaeUNet( + architecture="fcmae", + model_config=dict(in_channels=1), + train_mask_ratio=0.6, +) + +# %% +ch = "reconstructed-labelfree" + +data = HCSDataModule( + data_path="/hpc/projects/comp.micro/virtual_staining/datasets/training/raw-and-reconstructed.zarr", + source_channel=ch, + target_channel=ch, + z_window_size=5, + batch_size=64, + num_workers=12, + architecture="3D", + augmentations=[ + RandWeightedCropd(ch, ch, spatial_size=[-1, 512, 512], num_samples=2), + RandAffined( + ch, + prob=0.5, + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.0, 0.05, 0.05], + scale_range=[0.2, 0.3, 0.3], + ), + RandAdjustContrastd(ch, prob=0.3, gamma=[0.75, 1.5]), + RandScaleIntensityd(ch, prob=0.3, factors=0.5), + RandGaussianNoised(ch, prob=0.5, mean=0.0, std=5.0), + RandGaussianSmoothd( + ch, prob=0.5, sigma_z=[0.25, 1.5], sigma_y=[0.25, 1.5], sigma_x=[0.25, 1.5] + ), + ], +) + + +# %% +set_float32_matmul_precision("high") + +trainer = VSTrainer( + fast_dev_run=False, + max_epochs=50, + logger=TensorBoardLogger( + save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_0", log_graph=False + ), +) +trainer.fit(model, data) + +# %% diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index a2e6849ec..ad9d95592 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -9,19 +9,36 @@ from typing import Callable, Literal, Sequence import torch -import torch.nn.functional as F -from timm.layers import ( +from timm.models.convnext import ( + Downsample, DropPath, GlobalResponseNormMlp, LayerNorm2d, create_conv2d, trunc_normal_, ) -from timm.models.convnext import Downsample from torch import BoolTensor, Size, Tensor, nn +from viscy.unet.networks.Unet21D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead -def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: + +def _init_weights(module: nn.Module) -> None: + """Initialize weights of the given module.""" + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + +def generate_mask( + target: Size, stride: int, mask_ratio: float, device: str +) -> BoolTensor: """ :param Size target: target shape :param int stride: total stride @@ -32,7 +49,7 @@ def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: m_width = target[-1] // stride mask_numel = m_height * m_width masked_elements = int(mask_numel * mask_ratio) - mask = torch.rand(target[0], mask_numel).argsort(1) < masked_elements + mask = torch.rand(target[0], mask_numel, device=device).argsort(1) < masked_elements return mask.reshape(target[0], 1, m_height, m_width) @@ -293,14 +310,16 @@ def __init__( stage_blocks: Sequence[int] = (3, 3, 9, 3), dims: Sequence[int] = (96, 192, 384, 768), drop_path_rate: float = 0.0, + stem_kernel_size: Sequence[int] = (5, 4, 4), + in_stack_depth: int = 5, ) -> None: super().__init__() - stem_kernel_size_2d = 4 - self.stem = nn.Sequential( - MaskedAdaptiveProjection( - in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 - ), - LayerNorm2d(dims[0]), + self.stem = MaskedAdaptiveProjection( + in_channels, + dims[0], + kernel_size_2d=stem_kernel_size[1:], + kernel_depth=stem_kernel_size[0], + in_stack_depth=in_stack_depth, ) self.stages = nn.ModuleList() chs = [dims[0], *dims] @@ -316,24 +335,82 @@ def __init__( drop_path_rates=[drop_path_rate] * num_blocks, ) ) - self.total_stride = stem_kernel_size_2d * 2 ** (len(self.stages) - 1) + self.total_stride = stem_kernel_size[1] * 2 ** (len(self.stages) - 1) + self.apply(_init_weights) - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: """ - :param Tensor x: input tensor (BCHW) + :param Tensor x: input tensor (BCDHW) :param float mask_ratio: ratio of the feature maps to mask, defaults to 0.0 (no masking) - :return Tensor: output tensor (BCHW) + :return list[Tensor]: output tensors (list of BCHW) + :return BoolTensor | None: boolean foreground mask, None if no masking """ if mask_ratio > 0.0: - unmasked = ~generate_mask(x.shape, self.total_stride, mask_ratio) + mask = generate_mask( + x.shape, self.total_stride, mask_ratio, device=x.device + ) + b, c, d, h, w = x.shape + unmasked = ~mask + mask = upsample_mask(mask, (b, d, h, w)) else: - unmasked = None + mask = unmasked = None x = self.stem(x) features = [] for stage in self.stages: x = stage(x, unmasked=unmasked) features.append(x) - if unmasked is not None: - return features, unmasked - return features + return features, mask + + +class FullyConvolutionalMAE(nn.Module): + def __init__( + self, + in_channels: int, + encoder_blocks: Sequence[int] = [3, 3, 9, 3], + dims: Sequence[int] = [96, 192, 384, 768], + encoder_drop_path_rate: float = 0.0, + head_expansion_ratio: int = 4, + stem_kernel_size: Sequence[int] = (5, 4, 4), + in_stack_depth: int = 5, + ) -> None: + super().__init__() + self.encoder = MaskedMultiscaleEncoder( + in_channels=in_channels, + stage_blocks=encoder_blocks, + dims=dims, + drop_path_rate=encoder_drop_path_rate, + stem_kernel_size=stem_kernel_size, + in_stack_depth=in_stack_depth, + ) + decoder_channels = list(dims) + decoder_channels.reverse() + decoder_channels[-1] = ( + (in_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio + ) + self.decoder = Unet2dDecoder( + decoder_channels, + norm_name="instance", + mode="pixelshuffle", + conv_blocks=1, + strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], + upsample_pre_conv=None, + ) + if in_stack_depth == 1: + self.head = UnsqueezeHead() + else: + self.head = PixelToVoxelHead( + in_channels=decoder_channels[-1], + out_channels=in_channels, + out_stack_depth=in_stack_depth, + expansion_ratio=head_expansion_ratio, + pool=True, + ) + self.out_stack_depth = in_stack_depth + self.num_blocks = 6 + + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + x, mask = self.encoder(x, mask_ratio=mask_ratio) + x.reverse() + x = self.decoder(x) + return self.head(x), mask From 2fffc9928ae6499d6a4850b59f7cfdd1f6994fe5 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 10:25:08 -0800 Subject: [PATCH 13/75] fix mask for fitting --- tests/unet/test_fcmae.py | 8 ++++++-- viscy/light/engine.py | 24 ++++++++++++------------ viscy/scripts/train_fcmae.py | 23 +++++++++++++++++------ viscy/unet/networks/fcmae.py | 6 +++--- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 870f1138a..36fb673ee 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -107,5 +107,9 @@ def test_masked_multiscale_encoder(): def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) model = FullyConvolutionalMAE(3) - assert model(x).shape == x.shape - assert model(x, mask_ratio=0.6).shape == x.shape + y, m = model(x) + assert y.shape == x.shape + assert m is None + y, m = model(x, mask_ratio=0.6) + assert y.shape == x.shape + assert m.shape == (2, 1, 128, 128) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 0262cc7df..852540778 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -372,19 +372,23 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): class FcmaeUNet(VSUNet): - def __init__(self, train_mask_ratio: float = 0.0, **kwargs): + def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(**kwargs) - self.train_mask_ratio = train_mask_ratio + self.fit_mask_ratio = fit_mask_ratio def forward(self, x, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def training_step(self, batch: Sample, batch_idx: int): + def forward_fit(self, batch: Sample): source = batch["source"] target = batch["target"] - pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) loss = F.mse_loss(pred, target, reduction="none") - loss = (loss * mask).sum() / mask.sum() + loss = (loss.mean(2) * mask).sum() / mask.sum() + return source, target, pred, mask, loss + + def training_step(self, batch: Sample, batch_idx: int): + source, target, pred, mask, loss = self.forward_fit(batch) self.log( "loss/train", loss, @@ -396,18 +400,14 @@ def training_step(self, batch: Sample, batch_idx: int): ) if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( - self._detach_sample((source, target, pred)) + self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) return loss def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] - target = batch["target"] - pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) - loss = F.mse_loss(pred, target, reduction="none") - loss = (loss.mean(2) * mask).sum() / mask.sum() + source, target, pred, mask, loss = self.forward_fit(batch) self.log("loss/validate", loss, sync_dist=True) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( - self._detach_sample((source, target, pred)) + self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py index 692bef6db..0c0984548 100644 --- a/viscy/scripts/train_fcmae.py +++ b/viscy/scripts/train_fcmae.py @@ -1,4 +1,5 @@ # %% +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import set_float32_matmul_precision @@ -17,8 +18,11 @@ # %% model = FcmaeUNet( architecture="fcmae", - model_config=dict(in_channels=1), - train_mask_ratio=0.6, + model_config=dict( + in_channels=1, encoder_blocks=[3, 3, 27, 3], dims=[128, 256, 512, 1024] + ), + fit_mask_ratio=0.6, + schedule="WarmupCosine", ) # %% @@ -32,8 +36,10 @@ batch_size=64, num_workers=12, architecture="3D", + yx_patch_size=[384, 384], + normalize_source=True, augmentations=[ - RandWeightedCropd(ch, ch, spatial_size=[-1, 512, 512], num_samples=2), + RandWeightedCropd(ch, ch, spatial_size=[-1, 768, 768], num_samples=2), RandAffined( ch, prob=0.5, @@ -55,11 +61,16 @@ set_float32_matmul_precision("high") trainer = VSTrainer( - fast_dev_run=False, - max_epochs=50, + fast_dev_run=True, + precision="16-mixed", + max_epochs=100, logger=TensorBoardLogger( - save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_0", log_graph=False + save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_1", log_graph=False ), + callbacks=[ + LearningRateMonitor(logging_interval="step"), + ModelCheckpoint(monitor="loss/validate", save_top_k=5, every_n_epochs=1), + ], ) trainer.fit(model, data) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index ad9d95592..7f69cf8fd 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -6,7 +6,7 @@ """ -from typing import Callable, Literal, Sequence +from typing import Sequence import torch from timm.models.convnext import ( @@ -43,7 +43,7 @@ def generate_mask( :param Size target: target shape :param int stride: total stride :param float mask_ratio: ratio of the pixels to mask - :return BoolTensor: boolean mask (N, H*W) + :return BoolTensor: boolean mask (B1HW) """ m_height = target[-2] // stride m_width = target[-1] // stride @@ -352,7 +352,7 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: ) b, c, d, h, w = x.shape unmasked = ~mask - mask = upsample_mask(mask, (b, d, h, w)) + mask = upsample_mask(mask, (b, 1, h, w)) else: mask = unmasked = None x = self.stem(x) From 2a598b28a38acc14fa185026936b757b7695acc9 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 10:29:58 -0800 Subject: [PATCH 14/75] remove training script --- viscy/scripts/train_fcmae.py | 77 ------------------------------------ 1 file changed, 77 deletions(-) delete mode 100644 viscy/scripts/train_fcmae.py diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py deleted file mode 100644 index 0c0984548..000000000 --- a/viscy/scripts/train_fcmae.py +++ /dev/null @@ -1,77 +0,0 @@ -# %% -from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint -from lightning.pytorch.loggers import TensorBoardLogger -from torch import set_float32_matmul_precision - -from viscy.data.hcs import HCSDataModule -from viscy.light.engine import FcmaeUNet -from viscy.light.trainer import VSTrainer -from viscy.transforms import ( - RandAdjustContrastd, - RandAffined, - RandGaussianNoised, - RandGaussianSmoothd, - RandScaleIntensityd, - RandWeightedCropd, -) - -# %% -model = FcmaeUNet( - architecture="fcmae", - model_config=dict( - in_channels=1, encoder_blocks=[3, 3, 27, 3], dims=[128, 256, 512, 1024] - ), - fit_mask_ratio=0.6, - schedule="WarmupCosine", -) - -# %% -ch = "reconstructed-labelfree" - -data = HCSDataModule( - data_path="/hpc/projects/comp.micro/virtual_staining/datasets/training/raw-and-reconstructed.zarr", - source_channel=ch, - target_channel=ch, - z_window_size=5, - batch_size=64, - num_workers=12, - architecture="3D", - yx_patch_size=[384, 384], - normalize_source=True, - augmentations=[ - RandWeightedCropd(ch, ch, spatial_size=[-1, 768, 768], num_samples=2), - RandAffined( - ch, - prob=0.5, - rotate_range=[3.14, 0.0, 0.0], - shear_range=[0.0, 0.05, 0.05], - scale_range=[0.2, 0.3, 0.3], - ), - RandAdjustContrastd(ch, prob=0.3, gamma=[0.75, 1.5]), - RandScaleIntensityd(ch, prob=0.3, factors=0.5), - RandGaussianNoised(ch, prob=0.5, mean=0.0, std=5.0), - RandGaussianSmoothd( - ch, prob=0.5, sigma_z=[0.25, 1.5], sigma_y=[0.25, 1.5], sigma_x=[0.25, 1.5] - ), - ], -) - - -# %% -set_float32_matmul_precision("high") - -trainer = VSTrainer( - fast_dev_run=True, - precision="16-mixed", - max_epochs=100, - logger=TensorBoardLogger( - save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_1", log_graph=False - ), - callbacks=[ - LearningRateMonitor(logging_interval="step"), - ModelCheckpoint(monitor="loss/validate", save_top_k=5, every_n_epochs=1), - ], -) -trainer.fit(model, data) - -# %% From b9b188067221c8b156627cf537c7e2496510ec67 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 14:11:54 -0800 Subject: [PATCH 15/75] default architecture --- viscy/light/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 852540778..e1f699eba 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -373,7 +373,7 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): - super().__init__(**kwargs) + super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio def forward(self, x, mask_ratio: float = 0.0): From fd7700d0ea70339f467c0c431eca4f0c78201f5b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 22 Jan 2024 15:04:03 -0800 Subject: [PATCH 16/75] fine-tuning options --- viscy/unet/networks/fcmae.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 7f69cf8fd..0799d8fb7 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -367,12 +367,15 @@ class FullyConvolutionalMAE(nn.Module): def __init__( self, in_channels: int, + out_channels: int, encoder_blocks: Sequence[int] = [3, 3, 9, 3], dims: Sequence[int] = [96, 192, 384, 768], encoder_drop_path_rate: float = 0.0, head_expansion_ratio: int = 4, stem_kernel_size: Sequence[int] = (5, 4, 4), in_stack_depth: int = 5, + decoder_conv_blocks: int = 1, + pretraining: bool = True, ) -> None: super().__init__() self.encoder = MaskedMultiscaleEncoder( @@ -392,7 +395,7 @@ def __init__( decoder_channels, norm_name="instance", mode="pixelshuffle", - conv_blocks=1, + conv_blocks=decoder_conv_blocks, strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], upsample_pre_conv=None, ) @@ -401,16 +404,20 @@ def __init__( else: self.head = PixelToVoxelHead( in_channels=decoder_channels[-1], - out_channels=in_channels, + out_channels=out_channels, out_stack_depth=in_stack_depth, expansion_ratio=head_expansion_ratio, pool=True, ) self.out_stack_depth = in_stack_depth self.num_blocks = 6 + self.pretraining = pretraining def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: x, mask = self.encoder(x, mask_ratio=mask_ratio) x.reverse() x = self.decoder(x) - return self.head(x), mask + x = self.head(x) + if self.pretraining: + return x, mask + return x From 054249f14e7dac4e4040edf53d55232831ef3fe6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 24 Jan 2024 14:12:19 -0800 Subject: [PATCH 17/75] fix cli for finetuning --- viscy/data/hcs.py | 4 ++-- viscy/light/engine.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 01191db11..f8bb6a22c 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -334,7 +334,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] = "2.5D", + architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), augmentations: Optional[list[MapTransform]] = None, caching: bool = False, @@ -348,7 +348,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["2.2D", "3D"] else True + self.target_2d = False if architecture in ["2.2D", "3D", "fcmae"] else True self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size diff --git a/viscy/light/engine.py b/viscy/light/engine.py index e1f699eba..e6a2dfa4e 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -118,11 +118,12 @@ class VSUNet(LightningModule): def __init__( self, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"], + architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"], model_config: dict = {}, loss_function: Union[nn.Module, MixedLoss] = None, lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", + freeze_encoder: bool = False, log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, example_input_yx_shape: Sequence[int] = (256, 256), @@ -162,6 +163,7 @@ def __init__( self.test_cellpose_model_path = test_cellpose_model_path self.test_cellpose_diameter = test_cellpose_diameter self.test_evaluate_cellpose = test_evaluate_cellpose + self.freeze_encoder = freeze_encoder def forward(self, x) -> torch.Tensor: return self.model(x) @@ -331,6 +333,9 @@ def on_predict_start(self): self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) def configure_optimizers(self): + if self.freeze_encoder: + self.model: FullyConvolutionalMAE + self.model.encoder.requires_grad_(False) optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) if self.schedule == "WarmupCosine": scheduler = WarmupCosineSchedule( From d867e101b3e006ed9dc819722280b2bae8ea5560 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 24 Jan 2024 14:56:10 -0800 Subject: [PATCH 18/75] draft combined data module --- viscy/data/combined.py | 62 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 viscy/data/combined.py diff --git a/viscy/data/combined.py b/viscy/data/combined.py new file mode 100644 index 000000000..6b8dd63ca --- /dev/null +++ b/viscy/data/combined.py @@ -0,0 +1,62 @@ +from typing import Literal, Sequence + +from lightning.pytorch import LightningDataModule +from lightning.pytorch.utilities import combined_loader + +_MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] + + +class CombinedDataModule(LightningDataModule): + """Wrapper for combining multiple data modules. + For supported modes, see ``lightning.pytorch.utilities.combined_loader``. + + :param Sequence[LightningDataModule] data_modules: data modules to combine + :param str train_mode: mode in training stage, defaults to "max_size_cycle" + :param str val_mode: mode in validation stage, defaults to "sequential" + :param str test_mode: mode in testing stage, defaults to "sequential" + :param str predict_mode: mode in prediction stage, defaults to "sequential" + """ + + def __init__( + self, + data_modules: Sequence[LightningDataModule], + train_mode: _MODES = "max_size_cycle", + val_mode: _MODES = "sequential", + test_mode: _MODES = "sequential", + predict_mode: _MODES = "sequential", + ): + super().__init__() + self.data_modules = data_modules + self.train_mode = train_mode + self.val_mode = val_mode + self.test_mode = test_mode + self.predict_mode = predict_mode + + def prepare_data(self): + for dm in self.data_modules: + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + for dm in self.data_modules: + dm.setup(stage) + + def train_dataloader(self): + return combined_loader( + [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode + ) + + def val_dataloader(self): + return combined_loader( + [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode + ) + + def test_dataloader(self): + return combined_loader( + [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode + ) + + def predict_dataloader(self): + return combined_loader( + [dm.predict_dataloader() for dm in self.data_modules], + mode=self.predict_mode, + ) From b06a30077c83402b71390de160f1ce404ca98240 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 25 Jan 2024 15:52:42 -0800 Subject: [PATCH 19/75] fix import --- viscy/data/combined.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 6b8dd63ca..5da700dd0 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,7 +1,7 @@ from typing import Literal, Sequence from lightning.pytorch import LightningDataModule -from lightning.pytorch.utilities import combined_loader +from lightning.pytorch.utilities.combined_loader import CombinedLoader _MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] @@ -41,22 +41,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): dm.setup(stage) def train_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode ) def val_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode ) def test_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode ) def predict_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.predict_dataloader() for dm in self.data_modules], mode=self.predict_mode, ) From 39eafab77f97a046a20c5bc4944bf9f24dc11ca1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 26 Jan 2024 21:35:49 -0800 Subject: [PATCH 20/75] manual validation loss reduction --- viscy/light/engine.py | 48 ++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index e6a2dfa4e..ebd2fd608 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -10,7 +10,7 @@ from monai.optimizers import WarmupCosineSchedule from monai.transforms import DivisiblePad from skimage.exposure import rescale_intensity -from torch import nn +from torch import Tensor, nn from torch.nn import functional as F from torch.optim.lr_scheduler import ConstantLR from torchmetrics.functional import ( @@ -165,7 +165,7 @@ def __init__( self.test_evaluate_cellpose = test_evaluate_cellpose self.freeze_encoder = freeze_encoder - def forward(self, x) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: return self.model(x) def training_step(self, batch: Sample, batch_idx: int): @@ -230,7 +230,7 @@ def test_step(self, batch: Sample, batch_idx: int): else: self._log_segmentation_metrics(None, None) - def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): + def _log_regression_metrics(self, pred: Tensor, target: Tensor): # paired image translation metrics self.log_dict( { @@ -253,7 +253,7 @@ def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): on_epoch=True, ) - def _cellpose_predict(self, pred: torch.Tensor, name: str) -> torch.ShortTensor: + def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: pred_labels_np = self.cellpose_model.eval( pred.cpu().numpy(), channels=[0, 0], diameter=self.test_cellpose_diameter )[0].astype(np.int16) @@ -350,7 +350,7 @@ def configure_optimizers(self): ) return [optimizer], [scheduler] - def _detach_sample(self, imgs: Sequence[torch.Tensor]): + def _detach_sample(self, imgs: Sequence[Tensor]): num_samples = min(imgs[0].shape[0], self.log_samples_per_batch) return [ [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] @@ -380,11 +380,12 @@ class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio + self.validation_losses = [] - def forward(self, x, mask_ratio: float = 0.0): + def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def forward_fit(self, batch: Sample): + def forward_fit(self, batch: Sample) -> tuple[Tensor]: source = batch["source"] target = batch["target"] pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) @@ -392,27 +393,40 @@ def forward_fit(self, batch: Sample): loss = (loss.mean(2) * mask).sum() / mask.sum() return source, target, pred, mask, loss - def training_step(self, batch: Sample, batch_idx: int): - source, target, pred, mask, loss = self.forward_fit(batch) + def training_step(self, batch: Sequence[Sample], batch_idx: int): + losses = [] + batch_size = 0 + for b in batch: + source, target, pred, mask, loss = self.forward_fit(b) + losses.append(loss) + batch_size += source.shape[0] + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target * mask.unsqueeze(2), pred)) + ) + loss_step = torch.stack(losses).mean() self.log( "loss/train", - loss, + loss_step, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, + batch_size=batch_size, ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target * mask.unsqueeze(2), pred)) - ) - return loss + return loss_step - def validation_step(self, batch: Sample, batch_idx: int): + def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source, target, pred, mask, loss = self.forward_fit(batch) - self.log("loss/validate", loss, sync_dist=True) + self.validation_losses.append(loss.detach()) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) + + def on_validation_epoch_end(self): + super().on_validation_epoch_end() + self.log( + "loss/validate", torch.stack(self.validation_losses).mean(), sync_dist=True + ) From 9fbf7a551e0613e0173d7de05ba6f9dfd911d709 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 2 Feb 2024 09:55:29 -0800 Subject: [PATCH 21/75] update linting new black version has different rules --- pyproject.toml | 17 +++++++++++------ viscy/evaluation/evaluation_metrics.py | 1 + viscy/light/engine.py | 26 +++++++++++++------------- viscy/preprocessing/generate_masks.py | 1 + viscy/unet/networks/fcmae.py | 1 - viscy/utils/image_utils.py | 4 +--- viscy/utils/normalize.py | 1 + 7 files changed, 28 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b60cd5346..67142b4ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,15 @@ metrics = [ "ptflops>=0.7", ] visual = ["ipykernel", "graphviz", "torchview"] -dev = ["pytest", "pytest-cov", "hypothesis", "profilehooks", "onnxruntime"] +dev = [ + "pytest", + "pytest-cov", + "hypothesis", + "ruff", + "black", + "profilehooks", + "onnxruntime", +] [project.scripts] viscy = "viscy.cli.cli:main" @@ -39,12 +47,9 @@ viscy = "viscy.cli.cli:main" write_to = "viscy/_version.py" [tool.black] -src = ["viscy"] line-length = 88 [tool.ruff] src = ["viscy", "tests"] -extend-select = ["I001"] - -[tool.ruff.isort] -known-first-party = ["viscy"] +lint.extend-select = ["I001"] +lint.isort.known-first-party = ["viscy"] diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index 589370bd5..fb83c06bd 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -1,4 +1,5 @@ """Metrics for model evaluation""" + from typing import Sequence, Union from warnings import warn diff --git a/viscy/light/engine.py b/viscy/light/engine.py index ebd2fd608..c6197a9a8 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -272,19 +272,19 @@ def _log_segmentation_metrics( self.log_dict( { # semantic segmentation - "test_metrics/accuracy": accuracy( - pred_binary, target_binary, task="binary" - ) - if compute - else -1, - "test_metrics/dice": dice(pred_binary, target_binary) - if compute - else -1, - "test_metrics/jaccard": jaccard_index( - pred_binary, target_binary, task="binary" - ) - if compute - else -1, + "test_metrics/accuracy": ( + accuracy(pred_binary, target_binary, task="binary") + if compute + else -1 + ), + "test_metrics/dice": ( + dice(pred_binary, target_binary) if compute else -1 + ), + "test_metrics/jaccard": ( + jaccard_index(pred_binary, target_binary, task="binary") + if compute + else -1 + ), "test_metrics/mAP": coco_metrics["map"] if compute else -1, "test_metrics/mAP_50": coco_metrics["map_50"] if compute else -1, "test_metrics/mAP_75": coco_metrics["map_75"] if compute else -1, diff --git a/viscy/preprocessing/generate_masks.py b/viscy/preprocessing/generate_masks.py index f88f8fbef..491bc4069 100644 --- a/viscy/preprocessing/generate_masks.py +++ b/viscy/preprocessing/generate_masks.py @@ -1,4 +1,5 @@ """Generate masks from sum of flurophore channels""" + import iohub.ngff as ngff import viscy.utils.aux_utils as aux_utils diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 0799d8fb7..97771365a 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -5,7 +5,6 @@ and timm's dense implementation of the encoder in ``timm.models.convnext`` """ - from typing import Sequence import torch diff --git a/viscy/utils/image_utils.py b/viscy/utils/image_utils.py index f9020dc93..a95691162 100644 --- a/viscy/utils/image_utils.py +++ b/viscy/utils/image_utils.py @@ -21,9 +21,7 @@ def im_bit_convert(im, bit=16, norm=False, limit=[]): / (limit[1] - limit[0] + sys.float_info.epsilon) * (2**bit - 1) ) - im = np.clip( - im, 0, 2**bit - 1 - ) # clip the values to avoid wrap-around by np.astype + im = np.clip(im, 0, 2**bit - 1) # clip the values to avoid wrap-around by np.astype if bit == 8: im = im.astype(np.uint8, copy=False) # convert to 8 bit else: diff --git a/viscy/utils/normalize.py b/viscy/utils/normalize.py index 93c11713c..73753acb7 100644 --- a/viscy/utils/normalize.py +++ b/viscy/utils/normalize.py @@ -1,4 +1,5 @@ """Image normalization related functions""" + import sys import numpy as np From e00f5f3bd0415c1c3b8db924bd600cd2354e4cb7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 2 Feb 2024 10:01:36 -0800 Subject: [PATCH 22/75] update development guide --- CONTRIBUTING.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3b40b075a..44db5bbc7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,7 +10,19 @@ then make an editable installation with all the optional dependencies: pip install -e ".[dev,visual,metrics]" ``` -## Testing +## CI requirements + +Lint with Ruff: + +```sh +ruff check viscy +``` + +Format the code with Black: + +```sh +black viscy +``` Run tests with `pytest`: From 9e345b6c3b59a70a3b7c0bcde8bce184e46c3833 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 13 Feb 2024 15:27:26 -0800 Subject: [PATCH 23/75] update type hints --- viscy/data/hcs.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index f8bb6a22c..218ea414b 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -23,10 +23,11 @@ MultiSampleTrait, RandAffined, ) +from torch import Tensor from torch.utils.data import DataLoader, Dataset -def _ensure_channel_list(str_or_seq: Union[str, Sequence[str]]): +def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ Ensure channel argument is a list of strings. @@ -67,9 +68,9 @@ class Sample(TypedDict, total=False): index: tuple[str, int, int] # optional - source: Union[torch.Tensor, Sequence[torch.Tensor]] - target: Union[torch.Tensor, Sequence[torch.Tensor]] - labels: Union[torch.Tensor, Sequence[torch.Tensor]] + source: Union[Tensor, Sequence[Tensor]] + target: Union[Tensor, Sequence[Tensor]] + labels: Union[Tensor, Sequence[Tensor]] def _collate_samples(batch: Sequence[Sample]) -> Sample: @@ -83,7 +84,7 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: elemment = batch[0] collated = {} for key in elemment.keys(): - data: list[list[torch.Tensor]] = [sample[key] for sample in batch] + data: list[list[Tensor]] = [sample[key] for sample in batch] collated[key] = collate_meta_tensor([im for imgs in data for im in imgs]) return collated @@ -108,13 +109,13 @@ def _stat(self, key: str) -> dict: # FIXME: hard-coded key return self.norm_meta[key]["dataset_statistics"] - def __call__(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]: d = dict(data) for key in self.keys: d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"] return d - def inverse(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def inverse(self, data: dict[str, Tensor]) -> dict[str, Tensor]: d = dict(data) for key in self.keys: d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"] @@ -128,7 +129,7 @@ class SlidingWindowDataset(Dataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform: a callable that transforms data, defaults to None """ @@ -137,7 +138,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None, ) -> None: super().__init__() self.positions = positions @@ -178,14 +179,14 @@ def _find_window(self, index: int) -> tuple[int, int]: def _read_img_window( self, img: ImageArray, ch_idx: list[str], tz: int - ) -> tuple[tuple[torch.Tensor], tuple[str, int, int]]: + ) -> tuple[tuple[Tensor], tuple[str, int, int]]: """Read image window as tensor. :param ImageArray img: NGFF image array :param list[int] channels: list of channel indices to read, output channel ordering will reflect the sequence :param int tz: window index within the FOV, counted Z-first - :return tuple[torch.Tensor], tuple[str, int, int]: + :return tuple[Tensor], tuple[str, int, int]: tuple of (C=1, Z, Y, X) image tensors, tuple of image name, time index, and Z index """ @@ -203,8 +204,8 @@ def __len__(self) -> int: return self._max_window def _stack_channels( - self, sample_images: list[dict[str, torch.Tensor]], key: str - ) -> torch.Tensor: + self, sample_images: list[dict[str, Tensor]], key: str + ) -> Tensor: """Stack single-channel images into a multi-channel tensor.""" if not isinstance(sample_images, list): return torch.stack([sample_images[ch][0] for ch in self.channels[key]]) @@ -258,7 +259,7 @@ class MaskTestDataset(SlidingWindowDataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform: a callable that transforms data, defaults to None """ @@ -267,7 +268,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None, ground_truth_masks: str = None, ) -> None: super().__init__(positions, channels, z_window_size, transform) @@ -527,7 +528,7 @@ def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample if self.trainer: if self.trainer.predicting: predicting = True - if predicting or isinstance(batch, torch.Tensor): + if predicting or isinstance(batch, Tensor): # skipping example input array return batch if self.target_2d: From 96deca5f0020fb9a99388f37d0640e656253145e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 20 Feb 2024 14:44:02 -0800 Subject: [PATCH 24/75] bump iohub --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 67142b4ff..8f6978de7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.10" license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ - "iohub==0.1.0rc0", + "iohub==0.1.0", "torch>=2.1.2", "timm>=0.9.5", "tensorboard>=2.13.0", From e06aa574634dd504755dd21e40346071ea7a6b00 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 21:42:29 -0800 Subject: [PATCH 25/75] draft ctmc v1 dataset --- viscy/data/ctmc_v1.py | 67 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 viscy/data/ctmc_v1.py diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py new file mode 100644 index 000000000..8c42f85d5 --- /dev/null +++ b/viscy/data/ctmc_v1.py @@ -0,0 +1,67 @@ +import logging +from pathlib import Path + +import numpy as np +from iohub.ngff import ImageArray, Plate, Position, TransformationMeta, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, MapTransform +from torch import Tensor +from torch.utils.data import DataLoader + +from viscy.data.hcs import ChannelMap, SlidingWindowDataset + + +class CTMCv1DataModule(LightningDataModule): + """ + Autoregression data module for the CTMCv1 dataset. + Training and validation datasets are stored in separate HCS OME-Zarr stores. + """ + + def __init__( + self, + train_data_path: str | Path, + val_data_path: str | Path, + train_transforms: list[MapTransform], + val_transforms: list[MapTransform], + batch_size: int = 16, + num_workers: int = 8, + channel_name: str = "DIC", + ) -> None: + super().__init__() + self.train_data_path = Path(train_data_path) + self.val_data_path = Path(val_data_path) + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.channel_map = ChannelMap(source=channel_name, target=channel_name) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + train_plate = open_ome_zarr(self.train_data_path, mode="r") + val_plate = open_ome_zarr(self.val_data_path, mode="r") + train_positions = [p for _, p in train_plate.positions()] + val_positions = [p for _, p in val_plate.positions()] + self.train_dataset = SlidingWindowDataset( + train_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.train_transform), + ) + self.val_dataset = SlidingWindowDataset( + val_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.val_transform), + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) From 72de113f8c5a678f4383d374da925517d1cace6b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 22:33:41 -0800 Subject: [PATCH 26/75] update tests --- tests/light/test_engine.py | 5 +---- tests/unet/test_fcmae.py | 14 +++++++------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/light/test_engine.py b/tests/light/test_engine.py index c60133658..9ce182f5f 100644 --- a/tests/light/test_engine.py +++ b/tests/light/test_engine.py @@ -3,8 +3,5 @@ def test_fcmae_vsunet() -> None: model = FcmaeUNet( - architecture="fcmae", - model_config=dict(in_channels=3), - train_mask_ratio=0.6, + model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6 ) - diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 36fb673ee..4ed441b4a 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -17,7 +17,7 @@ def test_generate_mask(): w = 64 s = 16 m = 0.75 - mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m) + mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m, device="cpu") assert mask.shape == (2, 1, w // s, w // s) assert mask.dtype == torch.bool ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0] @@ -28,7 +28,7 @@ def test_masked_patchify(): b, c, h, w = 2, 3, 4, 8 x = torch.rand(b, c, h, w) mask_ratio = 0.75 - mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) + mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio, device=x.device) mask = upsample_mask(mask, x.shape) feat = masked_patchify(x, ~mask) assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) @@ -42,7 +42,7 @@ def test_unmasked_patchify_roundtrip(): def test_masked_patchify_roundtrip(): x = torch.rand(2, 3, 4, 8) - mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=2, mask_ratio=0.5, device=x.device) mask = upsample_mask(mask, x.shape) y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask) assert torch.all((y == 0) ^ (x == y)) @@ -51,7 +51,7 @@ def test_masked_patchify_roundtrip(): def test_masked_convnextv2_block() -> None: x = torch.rand(2, 3, 4, 5) - mask = generate_mask(x.shape, stride=1, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=1, mask_ratio=0.5, device=x.device) block = MaskedConvNeXtV2Block(3, 3 * 2) unmasked_out = block(x) assert len(unmasked_out.unique()) == x.numel() * 2 @@ -65,7 +65,7 @@ def test_masked_convnextv2_block() -> None: def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) - mask = generate_mask(x.shape, stride=4, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=4, mask_ratio=0.5, device=x.device) stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) out = stage(x) assert out.shape == (2, 3, 8, 8) @@ -79,7 +79,7 @@ def test_adaptive_projection(): ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6) + mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6, device="cpu") masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask) assert masked_out.shape == (1, 12, 4, 4) proj = MaskedAdaptiveProjection( @@ -106,7 +106,7 @@ def test_masked_multiscale_encoder(): def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) - model = FullyConvolutionalMAE(3) + model = FullyConvolutionalMAE(3, 3) y, m = model(x) assert y.shape == x.shape assert m is None From 13d0aa0574665d0da4f17407033fc964da00e602 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 23:47:56 -0800 Subject: [PATCH 27/75] move test_data --- tests/data/__init__.py | 0 tests/{light => data}/test_data.py | 0 viscy/data/ctmc_v1.py | 24 ++++++++++++++++-------- 3 files changed, 16 insertions(+), 8 deletions(-) create mode 100644 tests/data/__init__.py rename tests/{light => data}/test_data.py (100%) diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/light/test_data.py b/tests/data/test_data.py similarity index 100% rename from tests/light/test_data.py rename to tests/data/test_data.py diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 8c42f85d5..df1d3223e 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -1,11 +1,8 @@ -import logging from pathlib import Path -import numpy as np -from iohub.ngff import ImageArray, Plate, Position, TransformationMeta, open_ome_zarr +from iohub.ngff import open_ome_zarr from lightning.pytorch import LightningDataModule from monai.transforms import Compose, MapTransform -from torch import Tensor from torch.utils.data import DataLoader from viscy.data.hcs import ChannelMap, SlidingWindowDataset @@ -39,8 +36,11 @@ def __init__( def setup(self, stage: str) -> None: if stage != "fit": raise NotImplementedError("Only fit stage is supported") - train_plate = open_ome_zarr(self.train_data_path, mode="r") - val_plate = open_ome_zarr(self.val_data_path, mode="r") + self._setup_fit() + + def _setup_fit(self) -> None: + train_plate = open_ome_zarr(self.train_data_path) + val_plate = open_ome_zarr(self.val_data_path) train_positions = [p for _, p in train_plate.positions()] val_positions = [p for _, p in val_plate.positions()] self.train_dataset = SlidingWindowDataset( @@ -58,10 +58,18 @@ def setup(self, stage: str) -> None: def train_dataloader(self) -> DataLoader: return DataLoader( - self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, ) def val_dataloader(self) -> DataLoader: return DataLoader( - self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=False, ) From 78aed971aa2bea34e89c0db20bde883ebb98a06e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 23:53:15 -0800 Subject: [PATCH 28/75] remove path conversion --- viscy/data/ctmc_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index df1d3223e..0d65a36ab 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -25,8 +25,8 @@ def __init__( channel_name: str = "DIC", ) -> None: super().__init__() - self.train_data_path = Path(train_data_path) - self.val_data_path = Path(val_data_path) + self.train_data_path = train_data_path + self.val_data_path = val_data_path self.train_transforms = train_transforms self.val_transforms = val_transforms self.channel_map = ChannelMap(source=channel_name, target=channel_name) From 74e7db3633aed6d04c0995aee6f1db70abb51045 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 26 Feb 2024 09:31:49 -0800 Subject: [PATCH 29/75] configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu --- examples/configs/fit_example.yml | 13 +++ tests/conftest.py | 2 + tests/data/test_data.py | 33 ++---- viscy/data/hcs.py | 146 +++++++-------------------- viscy/data/typing.py | 22 ++++ viscy/preprocessing/preprocessing.md | 16 ++- viscy/transforms.py | 39 +++++++ 7 files changed, 139 insertions(+), 132 deletions(-) create mode 100644 viscy/data/typing.py diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index 017c57f03..fd17071e9 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -37,6 +37,19 @@ data: batch_size: 32 num_workers: 16 yx_patch_size: [256, 256] + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [source] + level: 'fov_statistics', + subtrahend: 'mean' + divisor: 'std' + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [target_1] + level: 'fov_statistics', + subtrahend: 'median' + divisor: 'iqr' augmentations: - class_path: viscy.transforms.RandWeightedCropd init_args: diff --git a/tests/conftest.py b/tests/conftest.py index 9ad6630c2..198e51ac5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,8 @@ def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: norm_meta = {channel: {"dataset_statistics": expected} for channel in channel_names} with open_ome_zarr(dataset_path, mode="r+") as dataset: dataset.zattrs["normalization"] = norm_meta + for _, fov in dataset.positions(): + fov.zattrs["normalization"] = norm_meta return dataset_path diff --git a/tests/data/test_data.py b/tests/data/test_data.py index 153f175f6..fb3d8620a 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -18,6 +18,16 @@ def test_preprocess(small_hcs_dataset: Path, default_channels: bool): channel_names = dataset.channel_names trainer = VSTrainer(accelerator="cpu") trainer.preprocess(data_path, channel_names=channel_names, num_workers=2) + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + for channel in channel_names: + assert "dataset_statistics" in dataset.zattrs["normalization"][channel] + for _, fov in dataset.positions(): + norm_metadata = fov.zattrs["normalization"] + for channel in channel_names: + assert channel in norm_metadata + assert "dataset_statistics" in norm_metadata[channel] + assert "fov_statistics" in norm_metadata[channel] def test_datamodule_setup_predict(preprocessed_hcs_dataset): @@ -45,26 +55,3 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset): img.height, img.width, ) - - -def test_datamodule_predict_scales(preprocessed_hcs_dataset): - data_path = preprocessed_hcs_dataset - with open_ome_zarr(data_path) as dataset: - channel_names = dataset.channel_names - - def get_normalized_stack(predict_scale_source): - factor = 1 if predict_scale_source is None else predict_scale_source - dm = HCSDataModule( - data_path=data_path, - source_channel=channel_names[:2], - target_channel=channel_names[2:], - z_window_size=5, - batch_size=2, - num_workers=0, - predict_scale_source=predict_scale_source, - normalize_source=True, - ) - dm.setup(stage="predict") - return dm.predict_dataset[0]["source"] / factor - - assert torch.allclose(get_normalized_stack(None), get_normalized_stack(2)) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 218ea414b..bb0be09cb 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -5,7 +5,7 @@ import tempfile from glob import glob from pathlib import Path -from typing import Callable, Iterable, Literal, Optional, Sequence, TypedDict, Union +from typing import Callable, Literal, Optional, Sequence, Union import numpy as np import torch @@ -18,7 +18,6 @@ from monai.transforms import ( CenterSpatialCropd, Compose, - InvertibleTransform, MapTransform, MultiSampleTrait, RandAffined, @@ -26,6 +25,8 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from viscy.data.typing import ChannelMap, Sample + def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ @@ -55,24 +56,6 @@ def _search_int_in_str(pattern: str, file_name: str) -> str: raise ValueError(f"Cannot find pattern {pattern} in {file_name}.") -class ChannelMap(TypedDict, total=False): - """Source and target channel names.""" - - source: Union[str, Sequence[str]] - # optional - target: Union[str, Sequence[str]] - - -class Sample(TypedDict, total=False): - """Image sample type for mini-batches.""" - - index: tuple[str, int, int] - # optional - source: Union[Tensor, Sequence[Tensor]] - target: Union[Tensor, Sequence[Tensor]] - labels: Union[Tensor, Sequence[Tensor]] - - def _collate_samples(batch: Sequence[Sample]) -> Sample: """Collate samples into a batch sample. @@ -89,38 +72,6 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: return collated -class NormalizeSampled(MapTransform, InvertibleTransform): - """Dictionary transform to only normalize target (fluorescence) channel. - - :param Union[str, Iterable[str]] keys: keys to normalize - :param dict[str, dict] norm_meta: Plate normalization metadata - written in preprocessing - """ - - def __init__( - self, keys: Union[str, Iterable[str]], norm_meta: dict[str, dict] - ) -> None: - if set(keys) > set(norm_meta.keys()): - raise KeyError(f"{keys} is not a subset of {norm_meta.keys()}") - super().__init__(keys, allow_missing_keys=False) - self.norm_meta = norm_meta - - def _stat(self, key: str) -> dict: - # FIXME: hard-coded key - return self.norm_meta[key]["dataset_statistics"] - - def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]: - d = dict(data) - for key in self.keys: - d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"] - return d - - def inverse(self, data: dict[str, Tensor]) -> dict[str, Tensor]: - d = dict(data) - for key in self.keys: - d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"] - - class SlidingWindowDataset(Dataset): """Torch dataset where each element is a window of (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. @@ -161,6 +112,7 @@ def _get_windows(self) -> None: w = 0 self.window_keys = [] self.window_arrays = [] + self.window_norm_meta = [] for fov in self.positions: img_arr = fov["0"] ts = img_arr.frames @@ -168,6 +120,7 @@ def _get_windows(self) -> None: w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) + self.window_norm_meta.append(fov.zattrs["normalization"]) self._max_window = w def _find_window(self, index: int) -> tuple[int, int]: @@ -175,7 +128,8 @@ def _find_window(self, index: int) -> tuple[int, int]: window_idx = sorted(self.window_keys + [index + 1]).index(index + 1) w = self.window_keys[window_idx] tz = index - self.window_keys[window_idx - 1] if window_idx > 0 else index - return self.window_arrays[self.window_keys.index(w)], tz + norm_meta = self.window_norm_meta[self.window_keys.index(w)] + return (self.window_arrays[self.window_keys.index(w)], tz, norm_meta) def _read_img_window( self, img: ImageArray, ch_idx: list[str], tz: int @@ -216,7 +170,7 @@ def _stack_channels( ] def __getitem__(self, index: int) -> Sample: - img, tz = self._find_window(index) + img, tz, norm_meta = self._find_window(index) ch_names = self.channels["source"].copy() ch_idx = self.source_ch_idx.copy() if self.target_ch_idx is not None: @@ -229,6 +183,7 @@ def __getitem__(self, index: int) -> Sample: # since adding a reference to a tensor does not copy # maybe write a weight map in preprocessing to use more information? sample_images["weight"] = sample_images[self.channels["target"][0]] + sample_images["norm_meta"] = norm_meta if self.transform: sample_images = self.transform(sample_images) # if isinstance(sample_images, list): @@ -238,6 +193,7 @@ def __getitem__(self, index: int) -> Sample: sample = { "index": sample_index, "source": self._stack_channels(sample_images, "source"), + "norm_meta": norm_meta, } if self.target_ch_idx is not None: sample["target"] = self._stack_channels(sample_images, "target") @@ -312,18 +268,16 @@ class HCSDataModule(LightningDataModule): defaults to "2.5D" :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) + :param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms + applied to selected channels, defaults to None (no normalization) :param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms applied to the training set, defaults to None (no augmentation) :param bool caching: whether to decompress all the images and cache the result, will store in ``/tmp/$SLURM_JOB_ID/`` if available, defaults to False - :param bool normalize_source: whether to normalize the source channel, - defaults to False :param Optional[Path] ground_truth_masks: path to the ground truth masks, used in the test stage to compute segmentation metrics, defaults to None - :param Optional[float] predict_scale_source: scale the source channel intensity, - defaults to None (no scaling) """ def __init__( @@ -337,11 +291,10 @@ def __init__( num_workers: int = 8, architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), + normalizations: Optional[list[MapTransform]] = None, augmentations: Optional[list[MapTransform]] = None, caching: bool = False, - normalize_source: bool = False, ground_truth_masks: Optional[Path] = None, - predict_scale_source: Optional[float] = None, ): super().__init__() self.data_path = Path(data_path) @@ -353,21 +306,11 @@ def __init__( self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size + self.normalizations = normalizations self.augmentations = augmentations self.caching = caching - self.normalize_source = normalize_source self.ground_truth_masks = ground_truth_masks self.tmp_zarr = None - if predict_scale_source is not None: - if not normalize_source: - raise ValueError( - "Intensity scaling must be applied to normalized source channels." - ) - if predict_scale_source <= 0: - raise ValueError( - f"Intensity scaling {predict_scale_source} should be positive." - ) - self.predict_scale_source = predict_scale_source def prepare_data(self): if not self.caching: @@ -419,31 +362,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): else: raise NotImplementedError(f"{stage} stage") - def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]: - """Setup stages where the target is available (evaluating performance).""" - dataset_settings["channels"]["target"] = self.target_channel - data_path = self.tmp_zarr if self.tmp_zarr else self.data_path - plate = open_ome_zarr(data_path, mode="r") - # disable metadata tracking in MONAI for performance - set_track_meta(False) - # define training stage transforms - norm_keys = self.target_channel.copy() - if self.normalize_source: - norm_keys += self.source_channel - normalize_transform = NormalizeSampled( - norm_keys, - plate.zattrs["normalization"], - ) - return plate, normalize_transform - def _setup_fit(self, dataset_settings: dict): """Set up the training and validation datasets.""" - plate, normalize_transform = self._setup_eval(dataset_settings) + # Setup the transformations + # TODO: These have a fixed order for now... (normalization->augmentation->fit_transform) fit_transform = self._fit_transform() train_transform = Compose( - [normalize_transform] + self._train_transform() + fit_transform + self.normalizations + self._train_transform() + fit_transform ) - val_transform = Compose([normalize_transform] + fit_transform) + val_transform = Compose(self.normalizations + fit_transform) + + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + plate = open_ome_zarr(data_path, mode="r") + + # disable metadata tracking in MONAI for performance + set_track_meta(False) # shuffle positions, randomness is handled globally positions = [pos for _, pos in plate.positions()] shuffled_indices = torch.randperm(len(positions)) @@ -465,26 +399,31 @@ def _setup_fit(self, dataset_settings: dict): **train_dataset_settings, ) self.val_dataset = SlidingWindowDataset( - positions[num_train_fovs:], transform=val_transform, **dataset_settings + positions[num_train_fovs:], + transform=val_transform, + **dataset_settings, ) def _setup_test(self, dataset_settings: dict): """Set up the test stage.""" if self.batch_size != 1: logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") - plate, normalize_transform = self._setup_eval(dataset_settings) + + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + plate = open_ome_zarr(data_path, mode="r") if self.ground_truth_masks: self.test_dataset = MaskTestDataset( [p for _, p in plate.positions()], - transform=normalize_transform, + transform=self.normalizations, ground_truth_masks=self.ground_truth_masks, - **dataset_settings, + norm_meta=plate.zattrs["normalization"] ** dataset_settings, ) else: self.test_dataset = SlidingWindowDataset( [p for _, p in plate.positions()], - transform=normalize_transform, - **dataset_settings, + transform=self.normalizations, + norm_meta=plate.zattrs["normalization"] ** dataset_settings, ) def _setup_predict(self, dataset_settings: dict): @@ -506,16 +445,9 @@ def _setup_predict(self, dataset_settings: dict): positions = [plate[fov_name]] elif isinstance(dataset, Plate): positions = [p for _, p in dataset.positions()] - norm_meta = dataset.zattrs["normalization"].copy() - if self.predict_scale_source is not None: - for ch in self.source_channel: - # FIXME: hard-coded key - norm_meta[ch]["dataset_statistics"]["iqr"] /= self.predict_scale_source - predict_transform = ( - NormalizeSampled(self.source_channel, norm_meta) - if self.normalize_source - else None - ) + + predict_transform = self.normalizations + self.predict_dataset = SlidingWindowDataset( positions=positions, transform=predict_transform, diff --git a/viscy/data/typing.py b/viscy/data/typing.py new file mode 100644 index 000000000..c6b7c32f9 --- /dev/null +++ b/viscy/data/typing.py @@ -0,0 +1,22 @@ +from typing import Sequence, TypedDict, Union + +from torch import Tensor + + +class Sample(TypedDict, total=False): + """Image sample type for mini-batches.""" + + index: tuple[str, int, int] + # optional + source: Union[Tensor, Sequence[Tensor]] + target: Union[Tensor, Sequence[Tensor]] + labels: Union[Tensor, Sequence[Tensor]] + norm_meta: dict[str, dict] + + +class ChannelMap(TypedDict, total=False): + """Source and target channel names.""" + + source: Union[str, Sequence[str]] + # optional + target: Union[str, Sequence[str]] diff --git a/viscy/preprocessing/preprocessing.md b/viscy/preprocessing/preprocessing.md index 76d508c5b..809b456f7 100644 --- a/viscy/preprocessing/preprocessing.md +++ b/viscy/preprocessing/preprocessing.md @@ -87,11 +87,17 @@ The statistics are added as dictionaries into the .zattrs file. An example of pl } ``` -FOV level statistics added to every position: +FOV level statistics added to every position as well as the dataset_statistics to read dataset statistics: ```json "normalization": { "Deconvolved-Nuc": { + "dataset_statistics": { + "iqr": 149.7620086669922, + "mean": 262.2070617675781, + "median": 65.5246353149414, + "std": 890.0471801757812 + }, "fov_statistics": { "iqr": 450.4745788574219, "mean": 486.3854064941406, @@ -99,7 +105,13 @@ FOV level statistics added to every position: "std": 976.02392578125 } }, - "Phase3D": { + "Phase3D": { + "dataset_statistics": { + "iqr": 0.0011349652777425945, + "mean": -1.9603044165705796e-06, + "median": 3.388232289580628e-05, + "std": 0.005480962339788675 + }, "fov_statistics": { "iqr": 0.006403466919437051, "mean": 0.0010083537781611085, diff --git a/viscy/transforms.py b/viscy/transforms.py index cb3d2622e..7ce192af0 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -3,6 +3,7 @@ from typing import Sequence, Union from monai.transforms import ( + MapTransform, RandAdjustContrastd, RandAffined, RandGaussianNoised, @@ -10,6 +11,9 @@ RandScaleIntensityd, RandWeightedCropd, ) +from typing_extensions import Iterable, Literal + +from viscy.data.typing import Sample class RandWeightedCropd(RandWeightedCropd): @@ -118,3 +122,38 @@ def __init__( sigma_z=sigma_z, **kwargs, ) + + +class NormalizeSampled(MapTransform): + """ + Normalize the sample + :param Union[str, Iterable[str]] keys: keys to normalize + :param str fov: fov path with respect to Plate + :param str subtrahend: subtrahend for normalization, defaults to "mean" + :param str divisor: divisor for normalization, defaults to "std" + """ + + def __init__( + self, + keys: Union[str, Iterable[str]], + level: Literal["fov_statistics", "dataset_statistics"], + subtrahend="mean", + divisor="std", + ) -> None: + super().__init__(keys, allow_missing_keys=False) + self.subtrahend = subtrahend + self.divisor = divisor + self.level = level + + # TODO: need to implement the case where the preprocessing already exists + def __call__(self, sample: Sample) -> Sample: + for key in self.keys: + if key in self.keys: + level_meta = sample["norm_meta"][key][self.level] + subtrahend_val = level_meta[self.subtrahend] + divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero + sample[key] = (sample[key] - subtrahend_val) / divisor_val + return sample + + def _normalize(): + NotImplementedError("_normalization() not implemented") From 9b3b032100b480f9340b8aa8b124e8116f232820 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 27 Feb 2024 17:33:53 -0800 Subject: [PATCH 30/75] fix ctmc dataloading --- viscy/data/ctmc_v1.py | 6 +++--- viscy/data/hcs.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 0d65a36ab..47844d68a 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -29,7 +29,7 @@ def __init__( self.val_data_path = val_data_path self.train_transforms = train_transforms self.val_transforms = val_transforms - self.channel_map = ChannelMap(source=channel_name, target=channel_name) + self.channel_map = ChannelMap(source=[channel_name], target=[channel_name]) self.batch_size = batch_size self.num_workers = num_workers @@ -47,13 +47,13 @@ def _setup_fit(self) -> None: train_positions, channels=self.channel_map, z_window_size=1, - transform=Compose(self.train_transform), + transform=Compose(self.train_transforms), ) self.val_dataset = SlidingWindowDataset( val_positions, channels=self.channel_map, z_window_size=1, - transform=Compose(self.val_transform), + transform=Compose(self.val_transforms), ) def train_dataloader(self) -> DataLoader: diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index bb0be09cb..2c7397c07 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -120,7 +120,7 @@ def _get_windows(self) -> None: w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) - self.window_norm_meta.append(fov.zattrs["normalization"]) + self.window_norm_meta.append(fov.zattrs.get("normalization", 0)) self._max_window = w def _find_window(self, index: int) -> tuple[int, int]: From a3569364ac18897858471c78d4a4c6f3381c6d1c Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 27 Feb 2024 17:34:30 -0800 Subject: [PATCH 31/75] add example ctmc v1 loading script --- viscy/scripts/load_ctmc_v1.py | 68 +++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 viscy/scripts/load_ctmc_v1.py diff --git a/viscy/scripts/load_ctmc_v1.py b/viscy/scripts/load_ctmc_v1.py new file mode 100644 index 000000000..e5c190948 --- /dev/null +++ b/viscy/scripts/load_ctmc_v1.py @@ -0,0 +1,68 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from monai.transforms import ( + CenterSpatialCropd, + NormalizeIntensityd, + RandAffined, + RandScaleIntensityd, +) +from tqdm import tqdm + +from viscy.data.ctmc_v1 import CTMCv1DataModule + +# %% +data_path = Path("") + +normalize_transform = NormalizeIntensityd(keys=["DIC"], channel_wise=True) +crop_transform = CenterSpatialCropd(keys=["DIC"], roi_size=[1, 256, 256]) + +data = CTMCv1DataModule( + train_data_path=data_path / "CTMCV1_test.zarr", + val_data_path=data_path / "CTMCV1_train.zarr", + train_transforms=[ + normalize_transform, + RandAffined( + keys=["DIC"], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.0, 0.3, 0.3], + scale_range=[0.0, 0.3, 0.3], + prob=0.8, + ), + RandScaleIntensityd(keys=["DIC"], factors=0.3, prob=0.5), + crop_transform, + ], + val_transforms=[normalize_transform, crop_transform], + batch_size=4, + num_workers=0, + channel_name="DIC", +) + +# %% +data.setup("fit") +dmt = data.train_dataloader() +dmv = data.val_dataloader() + +# %% +for batch in tqdm(dmt): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + +# %% +for batch in tqdm(dmv): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + + +# %% From bac26bedeb1037bf0eec44fb7f1b65fd3da7b653 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 28 Feb 2024 15:52:41 -0800 Subject: [PATCH 32/75] changing the normalization and augmentations default from None to empty list. --- viscy/data/hcs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 2c7397c07..af9a03a8a 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -269,9 +269,9 @@ class HCSDataModule(LightningDataModule): :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) :param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms - applied to selected channels, defaults to None (no normalization) + applied to selected channels, defaults to [] (no normalization) :param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms - applied to the training set, defaults to None (no augmentation) + applied to the training set, defaults to [] (no augmentation) :param bool caching: whether to decompress all the images and cache the result, will store in ``/tmp/$SLURM_JOB_ID/`` if available, defaults to False @@ -291,8 +291,8 @@ def __init__( num_workers: int = 8, architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), - normalizations: Optional[list[MapTransform]] = None, - augmentations: Optional[list[MapTransform]] = None, + normalizations: Optional[list[MapTransform]] = [], + augmentations: Optional[list[MapTransform]] = [], caching: bool = False, ground_truth_masks: Optional[Path] = None, ): From 0b598c7e1b9fc0bbbc2aa08307964300f66b7f8a Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:03:58 -0800 Subject: [PATCH 33/75] invert intensity transform --- viscy/transforms.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/viscy/transforms.py b/viscy/transforms.py index 7ce192af0..88e7f738f 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -8,9 +8,12 @@ RandAffined, RandGaussianNoised, RandGaussianSmoothd, + RandomizableTransform, RandScaleIntensityd, RandWeightedCropd, ) +from monai.transforms.transform import Randomizable +from numpy.random.mtrand import RandomState as RandomState from typing_extensions import Iterable, Literal from viscy.data.typing import Sample @@ -148,12 +151,34 @@ def __init__( # TODO: need to implement the case where the preprocessing already exists def __call__(self, sample: Sample) -> Sample: for key in self.keys: - if key in self.keys: - level_meta = sample["norm_meta"][key][self.level] - subtrahend_val = level_meta[self.subtrahend] - divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero - sample[key] = (sample[key] - subtrahend_val) / divisor_val + level_meta = sample["norm_meta"][key][self.level] + subtrahend_val = level_meta[self.subtrahend] + divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero + sample[key] = (sample[key] - subtrahend_val) / divisor_val return sample def _normalize(): NotImplementedError("_normalization() not implemented") + + +class RandInvertIntensityd(MapTransform, RandomizableTransform): + """ + Randomly invert the intensity of the image. + """ + + def __init__(self, keys: Union[str, Iterable[str]], prob: float = 0.1) -> None: + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) + + def __call__(self, sample: Sample) -> Sample: + self.randomize(None) + for key in self.keys: + if key in sample: + sample[key] = -sample[key] + return sample + + def set_random_state( + self, seed: int | None = None, state: RandomState | None = None + ) -> Randomizable: + super().set_random_state(seed, state) + return self From ddb30e9d05ebb7378a0ec1ee29acab6a33b32a14 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:17 -0800 Subject: [PATCH 34/75] concatenated data module --- viscy/data/combined.py | 72 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 5 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 5da700dd0..45072909f 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,9 +1,19 @@ +from enum import Enum from typing import Literal, Sequence from lightning.pytorch import LightningDataModule from lightning.pytorch.utilities.combined_loader import CombinedLoader +from torch import Tensor +from torch.utils.data import ConcatDataset, DataLoader -_MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] +from viscy.data.hcs import _collate_samples + + +class CombineMode(Enum): + MIN_SIZE = "min_size" + MAX_SIZE_CYCLE = "max_size_cycle" + MAX_SIZE = "max_size" + SEQUENTIAL = "sequential" class CombinedDataModule(LightningDataModule): @@ -20,10 +30,10 @@ class CombinedDataModule(LightningDataModule): def __init__( self, data_modules: Sequence[LightningDataModule], - train_mode: _MODES = "max_size_cycle", - val_mode: _MODES = "sequential", - test_mode: _MODES = "sequential", - predict_mode: _MODES = "sequential", + train_mode: CombineMode = CombineMode.MAX_SIZE_CYCLE, + val_mode: CombineMode = CombineMode.SEQUENTIAL, + test_mode: CombineMode = CombineMode.SEQUENTIAL, + predict_mode: CombineMode = CombineMode.SEQUENTIAL, ): super().__init__() self.data_modules = data_modules @@ -60,3 +70,55 @@ def predict_dataloader(self): [dm.predict_dataloader() for dm in self.data_modules], mode=self.predict_mode, ) + + +class ConcatDataModule(LightningDataModule): + def __init__(self, data_modules: Sequence[LightningDataModule]): + super().__init__() + self.data_modules = data_modules + self.num_workers = data_modules[0].num_workers + self.batch_size = data_modules[0].batch_size + for dm in data_modules: + if dm.num_workers != self.num_workers: + raise ValueError("Inconsistent number of workers") + if dm.batch_size != self.batch_size: + raise ValueError("Inconsistent batch size") + + def prepare_data(self): + for dm in self.data_modules: + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + self.train_patches_per_stack = 0 + for dm in self.data_modules: + dm.setup(stage) + if patches := getattr(dm, "train_patches_per_stack", 0): + if self.train_patches_per_stack == 0: + self.train_patches_per_stack = patches + elif self.train_patches_per_stack != patches: + raise ValueError("Inconsistent patches per stack") + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self.train_dataset = ConcatDataset( + [dm.train_dataset for dm in self.data_modules] + ) + self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.data_modules]) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size // self.train_patches_per_stack, + num_workers=self.num_workers, + shuffle=True, + persistent_workers=bool(self.num_workers), + collate_fn=_collate_samples, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=bool(self.num_workers), + ) From 950475584f15534638c4c83d6e3fcf21314bb1e7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:37 -0800 Subject: [PATCH 35/75] subsample videos --- viscy/data/ctmc_v1.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 47844d68a..d666fdcb5 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -6,12 +6,33 @@ from torch.utils.data import DataLoader from viscy.data.hcs import ChannelMap, SlidingWindowDataset +from viscy.data.typing import Sample + + +class CTMCv1ValidationDataset(SlidingWindowDataset): + subsample_rate: int = 30 + + def __len__(self) -> int: + # sample every 30th frame in the videos + return super().__len__() // self.subsample_rate + + def __getitem__(self, index: int) -> Sample: + index = index * self.subsample_rate + return super().__getitem__(index) class CTMCv1DataModule(LightningDataModule): """ Autoregression data module for the CTMCv1 dataset. Training and validation datasets are stored in separate HCS OME-Zarr stores. + + :param str | Path train_data_path: Path to the training dataset + :param str | Path val_data_path: Path to the validation dataset + :param list[MapTransform] train_transforms: List of transforms for training + :param list[MapTransform] val_transforms: List of transforms for validation + :param int batch_size: Batch size, defaults to 16 + :param int num_workers: Number of workers, defaults to 8 + :param str channel_name: Name of the DIC channel, defaults to "DIC" """ def __init__( @@ -49,7 +70,7 @@ def _setup_fit(self) -> None: z_window_size=1, transform=Compose(self.train_transforms), ) - self.val_dataset = SlidingWindowDataset( + self.val_dataset = CTMCv1ValidationDataset( val_positions, channels=self.channel_map, z_window_size=1, From 808e39c02763f21c6cb07b79d2c60c2501220021 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:48 -0800 Subject: [PATCH 36/75] livecell dataset --- viscy/data/livecell.py | 98 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 viscy/data/livecell.py diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py new file mode 100644 index 000000000..5d83f099d --- /dev/null +++ b/viscy/data/livecell.py @@ -0,0 +1,98 @@ +import json +from pathlib import Path + +import torch +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, Transform +from tifffile import imread +from torch.utils.data import DataLoader, Dataset + +from viscy.data.typing import Sample + + +class LiveCellDataset(Dataset): + """ + LiveCell dataset. + + :param list[Path] images: List of paths to single-page, single-channel TIFF files. + :param Transform | Compose transform: Transform to apply to the dataset + """ + + def __init__(self, images: list[Path], transform: Transform | Compose) -> None: + self.images = images + self.transform = transform + + def __len__(self) -> int: + return len(self.images) + + def __getitem__(self, idx: int) -> Sample: + image = imread(self.images[idx])[None, None] + image = torch.from_numpy(image).to(torch.float32) + image = self.transform(image) + return {"source": image, "target": image} + + +class LiveCellDataModule(LightningDataModule): + def __init__( + self, + train_val_images: Path, + train_annotations: Path, + val_annotations: Path, + train_transforms: list[Transform], + val_transforms: list[Transform], + batch_size: int = 16, + num_workers: int = 8, + ) -> None: + super().__init__() + self.train_val_images = Path(train_val_images) + if not self.train_val_images.is_dir(): + raise NotADirectoryError(str(train_val_images)) + self.train_annotations = Path(train_annotations) + if not self.train_annotations.is_file(): + raise FileNotFoundError(str(train_annotations)) + self.val_annotations = Path(val_annotations) + if not self.val_annotations.is_file(): + raise FileNotFoundError(str(val_annotations)) + self.train_transforms = Compose(train_transforms) + self.val_transforms = Compose(val_transforms) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self._setup_fit() + + def _parse_image_names(self, annotations: Path) -> list[Path]: + with open(annotations) as f: + images = [f["file_name"] for f in json.load(f)["images"]] + return sorted(images) + + def _setup_fit(self) -> None: + train_images = self._parse_image_names(self.train_annotations) + val_images = self._parse_image_names(self.val_annotations) + self.train_dataset = LiveCellDataset( + [self.train_val_images / f for f in train_images], + transform=self.train_transforms, + ) + self.val_dataset = LiveCellDataset( + [self.train_val_images / f for f in val_images], + transform=self.val_transforms, + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + ) From 43d641db2e448336be64a6ccd17ecd4a8c218b95 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:05:04 -0800 Subject: [PATCH 37/75] all sample fields are optional --- viscy/data/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/typing.py b/viscy/data/typing.py index c6b7c32f9..aef7dea73 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -6,8 +6,8 @@ class Sample(TypedDict, total=False): """Image sample type for mini-batches.""" + # all optional index: tuple[str, int, int] - # optional source: Union[Tensor, Sequence[Tensor]] target: Union[Tensor, Sequence[Tensor]] labels: Union[Tensor, Sequence[Tensor]] From 42f81cfd2093e020e97db1238cc897e623fedcb1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:05:19 -0800 Subject: [PATCH 38/75] fix multi-dataloader validation --- viscy/light/engine.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 4d18e9c4d..6c2849549 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -194,12 +194,12 @@ def training_step(self, batch: Sample, batch_idx: int): ) return loss - def validation_step(self, batch: Sample, batch_idx: int): + def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = batch["source"] target = batch["target"] pred = self.forward(source) loss = self.loss_function(pred, target) - self.log("loss/validate", loss, sync_dist=True) + self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target, pred)) @@ -425,7 +425,15 @@ def training_step(self, batch: Sequence[Sample], batch_idx: int): def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source, target, pred, mask, loss = self.forward_fit(batch) - self.validation_losses.append(loss.detach()) + if dataloader_idx + 1 > len(self.validation_losses): + self.validation_losses.append([]) + self.validation_losses[dataloader_idx].append(loss.detach()) + self.log( + f"loss/val/{dataloader_idx}", + loss, + sync_dist=True, + batch_size=source.shape[0], + ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) @@ -433,6 +441,6 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 def on_validation_epoch_end(self): super().on_validation_epoch_end() - self.log( - "loss/validate", torch.stack(self.validation_losses).mean(), sync_dist=True - ) + # average within each dataloader + loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] + self.log("loss/validate", torch.tensor(loss_means).mean(), sync_dist=True) From 4546fc77b8ee469b1a93f9689404b3fc47cc622d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:08:26 -0800 Subject: [PATCH 39/75] lint --- viscy/data/combined.py | 1 - 1 file changed, 1 deletion(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 45072909f..d70b93332 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -3,7 +3,6 @@ from lightning.pytorch import LightningDataModule from lightning.pytorch.utilities.combined_loader import CombinedLoader -from torch import Tensor from torch.utils.data import ConcatDataset, DataLoader from viscy.data.hcs import _collate_samples From 306f3efadce651298647a1a8e60dcdd95eccb6d0 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 29 Feb 2024 13:13:25 -0800 Subject: [PATCH 40/75] fixing preprocessing for varying array shapes (i.e aics dataset) --- viscy/utils/meta_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index d644dadfc..961b66967 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -104,8 +104,9 @@ def generate_normalization_metadata( positions, fov_sample_values = mp_utils.mp_sample_im_pixels( this_channels_args, num_workers ) - dataset_sample_values = np.stack(fov_sample_values, 0) - + dataset_sample_values = np.concatenate( + [arr.flatten() for arr in fov_sample_values] + ) fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers) dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values) From 1a0e3ced8711bcdae7c6698c899aff45f7bdc777 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 1 Mar 2024 20:51:50 -0800 Subject: [PATCH 41/75] update loading scripts --- viscy/scripts/load_ctmc_v1.py | 38 ++++++++++----- viscy/scripts/load_livecell.py | 85 ++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 11 deletions(-) create mode 100644 viscy/scripts/load_livecell.py diff --git a/viscy/scripts/load_ctmc_v1.py b/viscy/scripts/load_ctmc_v1.py index e5c190948..41cef698b 100644 --- a/viscy/scripts/load_ctmc_v1.py +++ b/viscy/scripts/load_ctmc_v1.py @@ -5,7 +5,11 @@ from monai.transforms import ( CenterSpatialCropd, NormalizeIntensityd, + RandAdjustContrastd, RandAffined, + RandFlipd, + RandGaussianNoised, + RandGaussianSmoothd, RandScaleIntensityd, ) from tqdm import tqdm @@ -13,10 +17,11 @@ from viscy.data.ctmc_v1 import CTMCv1DataModule # %% -data_path = Path("") +channel = "DIC" +data_path = Path("/hpc/reference/imaging/ctmc") -normalize_transform = NormalizeIntensityd(keys=["DIC"], channel_wise=True) -crop_transform = CenterSpatialCropd(keys=["DIC"], roi_size=[1, 256, 256]) +normalize_transform = NormalizeIntensityd(keys=[channel], channel_wise=True) +crop_transform = CenterSpatialCropd(keys=[channel], roi_size=[1, 224, 224]) data = CTMCv1DataModule( train_data_path=data_path / "CTMCV1_test.zarr", @@ -24,19 +29,29 @@ train_transforms=[ normalize_transform, RandAffined( - keys=["DIC"], + keys=[channel], rotate_range=[3.14, 0.0, 0.0], - shear_range=[0.0, 0.3, 0.3], - scale_range=[0.0, 0.3, 0.3], + scale_range=[0.0, [-0.6, 0.1], [-0.6, 0.1]], prob=0.8, + padding_mode="zeros", + ), + RandFlipd(keys=[channel], prob=0.5, spatial_axis=(1,2)), + RandAdjustContrastd(keys=[channel], prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=[channel], factors=0.3, prob=0.5), + RandGaussianNoised(keys=[channel], prob=0.5, mean=0.0, std=0.2), + RandGaussianSmoothd( + keys=[channel], + sigma_x=(0.05, 0.3), + sigma_y=(0.05, 0.3), + sigma_z=(0.05, 0.0), + prob=0.5, ), - RandScaleIntensityd(keys=["DIC"], factors=0.3, prob=0.5), crop_transform, ], val_transforms=[normalize_transform, crop_transform], - batch_size=4, + batch_size=32, num_workers=0, - channel_name="DIC", + channel_name=channel, ) # %% @@ -47,7 +62,8 @@ # %% for batch in tqdm(dmt): img = batch["source"] - f, ax = plt.subplots(4, 4, figsize=(12, 12)) + img[:, :, :, 32:64, 32:64] = 0 + f, ax = plt.subplots(5, 5, figsize=(15, 15)) for sample, a in zip(img, ax.flatten()): a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) a.axis("off") @@ -57,7 +73,7 @@ # %% for batch in tqdm(dmv): img = batch["source"] - f, ax = plt.subplots(4, 4, figsize=(12, 12)) + f, ax = plt.subplots(5, 5, figsize=(15, 15)) for sample, a in zip(img, ax.flatten()): a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) a.axis("off") diff --git a/viscy/scripts/load_livecell.py b/viscy/scripts/load_livecell.py new file mode 100644 index 000000000..cfaf2dfed --- /dev/null +++ b/viscy/scripts/load_livecell.py @@ -0,0 +1,85 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from monai.transforms import ( + CenterSpatialCrop, + NormalizeIntensity, + RandAdjustContrast, + RandAffine, + RandFlip, + RandGaussianNoise, + RandGaussianSmooth, + RandScaleIntensity, + RandSpatialCrop, +) +from tqdm import tqdm + +from viscy.data.livecell import LiveCellDataModule + +# %% +data_path = Path("/hpc/reference/imaging/livecell") + +normalize_transform = NormalizeIntensity(channel_wise=True) +crop_transform = CenterSpatialCrop(roi_size=[1, 224, 224]) + +data = LiveCellDataModule( + train_val_images=data_path / "images" / "livecell_train_val_images", + train_annotations=data_path + / "annotations" + / "livecell_coco_train_images_only.json", + val_annotations=data_path / "annotations" / "livecell_coco_val_images_only.json", + train_transforms=[ + normalize_transform, + RandSpatialCrop(roi_size=[1, 384, 384]), + RandAffine( + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, [-0.2, 0.8], [-0.2, 0.8]], + prob=0.8, + padding_mode="zeros", + ), + RandFlip(prob=0.5, spatial_axis=(1, 2)), + RandAdjustContrast(prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensity(factors=0.3, prob=0.5), + RandGaussianNoise(prob=0.5, mean=0.0, std=0.3), + RandGaussianSmooth( + sigma_x=(0.05, 0.3), + sigma_y=(0.05, 0.3), + sigma_z=(0.05, 0.0), + prob=0.5, + ), + crop_transform, + ], + val_transforms=[normalize_transform, crop_transform], + batch_size=16, + num_workers=0, +) + +# %% +data.setup("fit") +dmt = data.train_dataloader() +dmv = data.val_dataloader() + +# %% +for batch in tqdm(dmt): + img = batch["target"] + img[:, :, :, 32:64, 32:64] = 0 + f, ax = plt.subplots(4, 4, figsize=(15, 15)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + +# %% +for batch in tqdm(dmv): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + + +# %% From d3ec94d2c0142bf073b2019ab9ac6eba4312eddd Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 1 Mar 2024 21:26:12 -0800 Subject: [PATCH 42/75] fix CombineMode --- viscy/data/combined.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index d70b93332..13e64e214 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -36,10 +36,10 @@ def __init__( ): super().__init__() self.data_modules = data_modules - self.train_mode = train_mode - self.val_mode = val_mode - self.test_mode = test_mode - self.predict_mode = predict_mode + self.train_mode = CombineMode(train_mode).value + self.val_mode = CombineMode(val_mode).value + self.test_mode = CombineMode(test_mode).value + self.predict_mode = CombineMode(predict_mode).value def prepare_data(self): for dm in self.data_modules: From dd3471229f54a94b0162d93da667580217b773b5 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 4 Mar 2024 11:13:36 -0800 Subject: [PATCH 43/75] added model and annotation code draft --- .../Infection_annotator.py | 55 +++++++ .../Infection_classification_model.py | 142 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 examples/infection_phenotyping/Infection_annotator.py create mode 100644 examples/infection_phenotyping/Infection_classification_model.py diff --git a/examples/infection_phenotyping/Infection_annotator.py b/examples/infection_phenotyping/Infection_annotator.py new file mode 100644 index 000000000..e933a773f --- /dev/null +++ b/examples/infection_phenotyping/Infection_annotator.py @@ -0,0 +1,55 @@ + + +#%% use napari to annotate infected cells in segmented data + +import napari +from iohub.ngff import open_ome_zarr +import numpy as np + +file_in_path = '/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/Exp_2023_09_28_DENV_A2.zarr' +zarr_input = open_ome_zarr( + file_in_path, + layout="hcs", + mode="r+", +) +chan_names = zarr_input.channel_names +# zarr_input.append_channel('Inf_mask',resize_arrays=True) + +file_out_path = '/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/Exp_2023_09_28_DENV_A2_infMarked_rev2.zarr' +zarr_output = open_ome_zarr( + file_out_path, + layout="hcs", + mode="w-", + channel_names=['Sensor','Nucl_mask','Inf_mask'], +) + +v = napari.Viewer() + + +#%% Load label image to napari +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + if well_name == 'A' and well_no == '2': + + for pos_name, pos_data in well_data.positions(): + # if int(pos_name) > 1: + v.layers.clear() + data = pos_data.data + + FITC = data[0,0,...] + v.add_image(FITC, name='FITC', colormap='green', blending='additive') + Inf_mask = data[0,1,...].astype(int) + v.add_labels(Inf_mask) + input("Press Enter") + + label_layer = v.layers['Inf_mask'] + label_array = label_layer.data + label_array = np.expand_dims(label_array, axis=(0, 1)) + # zarr_input.create_image('Inf_mask',label_array) + out_data = np.concatenate((data, label_array), axis=1) + position = zarr_output.create_position(well_name, well_no, pos_name) + position["0"] = out_data + + +# %% diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py new file mode 100644 index 000000000..7f8667fa5 --- /dev/null +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -0,0 +1,142 @@ + +# %% +import torch +import sys +from viscy.data.hcs import HCSDataModule +# from lightning.pytorch import CustomDataset +from lightning.pytorch.callbacks import Callback +from lightning.pytorch import LightningDataModule +# import cv2 +import numpy as np +import torch.nn as nn +import torchvision.models as models +import lightning.pytorch as pl +import torch.nn.functional as F +from viscy.light.engine import VSUNet + +# %% Create a dataloader and visualize the batches. +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr" + +# Create an instance of HCSDataModule +data_module = HCSDataModule(dataset_path, source_channel=['Sensor','Nucl_mask'], target_channel=['Inf_mask'],yx_patch_size=[256, 256], split_ratio=0.8, z_window_size=1, architecture = '2D') + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage = "fit") + +# Create a dataloader +dataloader = data_module.train_dataloader() + +# Visualize the dataset and the batch using napari +import napari +from pytorch_lightning.loggers import TensorBoardLogger +# import os + +# Set the display +# os.environ['DISPLAY'] = ':1' + +# Create a napari viewer +viewer = napari.Viewer() + +# Add the dataset to the viewer +for batch in dataloader: + if isinstance(batch, dict): + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# Start the napari event loop +napari.run() + +# %% use 2D Unet from viscy with a softmax layer at end for 4 label classification +# use for image translation from instance segmentation to annotated image + +# use diceloss function from here: https://gist.github.com/weiliu620/52d140b22685cf9552da4899e2160183 +def dice_loss(pred, target): + """This definition generalize to real valued pred and target vector. This should be differentiable. + pred: tensor with first dimension as batch + target: tensor with first dimension as batch + """ + + smooth = 1. + + # have to use contiguous since they may from a torch.view op + iflat = pred.contiguous().view(-1) + tflat = target.contiguous().view(-1) + intersection = (iflat * tflat).sum() + + A_sum = torch.sum(tflat * iflat) + B_sum = torch.sum(tflat * tflat) + + return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) ) + +unet_model = VSUNet( + architecture='2D', + loss_function=dice_loss, + lr=1e-3, + example_input_xy_shape=(64,64) + ) + +# Define the optimizer +optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) + +# Iterate over the batches +for batch in dataloader: + # Extract the input and target from the batch + input_data, target = batch['source'], batch['target'] + + # Forward pass through the model + output = unet_model(input_data) + + # Apply softmax activation to the output + output = F.softmax(output, dim=1) + + # Calculate the loss + loss = dice_loss(output, target) + + # Perform backpropagation and update the model's parameters + loss.backward() + optimizer.step() + optimizer.zero_grad() + +#%% use the batch for training the unet model using the lightning module + +# Train the model +# Create a TensorBoard logger +logger = TensorBoardLogger("logs", name="infection_classification_model") + +# Pass the logger to the Trainer +trainer = pl.Trainer(gpus=1, logger=logger) + +# Fit the model +trainer.fit(unet_model, data_module) + +# %% test the model on the test set +# Load the test dataset +test_dataloader = data_module.test_dataloader() + +# Set the model to evaluation mode +unet_model.eval() + +# Create a list to store the predictions +predictions = [] + +# Iterate over the test batches +for batch in test_dataloader: + # Extract the input from the batch + input_data = batch['source'] + + # Forward pass through the model + output = unet_model(input_data) + + # Append the predictions to the list + predictions.append(output.detach().cpu().numpy()) + +# Convert the predictions to a numpy array +predictions = np.stack(predictions) + +# Save the predictions as added channel in zarr format +zarr.save('predictions.zarr', predictions) From 5fc9da2756285dd7bcc35f576910e873e1e9bdbf Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 4 Mar 2024 19:35:06 -0800 Subject: [PATCH 44/75] chnaged to simple unet model --- .../Infection_classification_model.py | 104 ++++++++++++++---- 1 file changed, 84 insertions(+), 20 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index 7f8667fa5..f038cc50e 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -1,25 +1,24 @@ # %% import torch -import sys from viscy.data.hcs import HCSDataModule -# from lightning.pytorch import CustomDataset -from lightning.pytorch.callbacks import Callback -from lightning.pytorch import LightningDataModule -# import cv2 + import numpy as np import torch.nn as nn -import torchvision.models as models import lightning.pytorch as pl import torch.nn.functional as F -from viscy.light.engine import VSUNet + +import napari +from pytorch_lightning.loggers import TensorBoardLogger +from monai.transforms import Zoom +from monai.transforms import Compose, RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised # %% Create a dataloader and visualize the batches. # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr" # Create an instance of HCSDataModule -data_module = HCSDataModule(dataset_path, source_channel=['Sensor','Nucl_mask'], target_channel=['Inf_mask'],yx_patch_size=[256, 256], split_ratio=0.8, z_window_size=1, architecture = '2D') +data_module = HCSDataModule(dataset_path, source_channel=['Sensor','Nucl_mask'], target_channel=['Inf_mask'],yx_patch_size=[128,128], split_ratio=0.8, z_window_size=1, architecture = '2D') # Prepare the data data_module.prepare_data() @@ -31,10 +30,6 @@ dataloader = data_module.train_dataloader() # Visualize the dataset and the batch using napari -import napari -from pytorch_lightning.loggers import TensorBoardLogger -# import os - # Set the display # os.environ['DISPLAY'] = ':1' @@ -73,23 +68,86 @@ def dice_loss(pred, target): return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) ) -unet_model = VSUNet( - architecture='2D', - loss_function=dice_loss, - lr=1e-3, - example_input_xy_shape=(64,64) - ) +# load 2D UNet from viscy +# unet_model = VSUNet( +# architecture='2D', +# model_config={"in_channels": 2, "out_channels": 1}, +# loss_function=dice_loss, +# lr=1e-3, +# example_input_xy_shape=(128,128), +# ) + +# Define the data augmentations +# Define the augmentations +# transforms = Compose([ +# RandRotate(range_x=15, prob=0.5), +# Resize(spatial_size=[64, 64],mode='linear'), +# Zoom([0.5,2], mode='bilinear'), +# Flip(spatial_axis=[0,1]), +# RandFlip(spatial_axis=[0,1], prob=0.5), +# RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), +# RandRotate90(spatial_axes=(0,1), prob=0.2, max_k=3), +# RandGaussianNoise(prob=0.5), +# ]) + +transforms = Compose([ + Flip(spatial_axis=[0,1]), +]) + +# create a small unet for image translation which accepts two input images (a label image and a microscopy image) and outputs one label image +class UNet(nn.Module): + def __init__(self, in_channels, out_channels): + super(UNet, self).__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(128, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(256, 512, kernel_size=3, padding=1), + nn.ReLU(inplace=True) + ) + + # Define the decoder part of the U-Net architecture + self.decoder = nn.Sequential( + nn.Conv2d(512, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, out_channels, kernel_size=1) + ) + + def forward(self, x): + # Apply the encoder to the input + x = self.encoder(x) + + # Apply the decoder to the output of the encoder + x = self.decoder(x) + + return x + +unet_model = UNet(in_channels=2, out_channels=1) # Define the optimizer optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) -# Iterate over the batches +#%% Iterate over the batches for batch in dataloader: # Extract the input and target from the batch input_data, target = batch['source'], batch['target'] + # Apply the augmentations to your data + augmented_input = transforms(input_data) + # Forward pass through the model - output = unet_model(input_data) + output = unet_model(augmented_input,target) # Apply softmax activation to the output output = F.softmax(output, dim=1) @@ -102,6 +160,11 @@ def dice_loss(pred, target): optimizer.step() optimizer.zero_grad() +# Visualize sample of the augmented data using napari +for i in range(augmented_data.shape[0]): + viewer.add_image(augmented_data[i].cpu().numpy().astype(np.float32)) + + #%% use the batch for training the unet model using the lightning module # Train the model @@ -139,4 +202,5 @@ def dice_loss(pred, target): predictions = np.stack(predictions) # Save the predictions as added channel in zarr format +# use iohub or viscy to save the predictions!!! zarr.save('predictions.zarr', predictions) From e6274887bdda8f4a6fd718f510fd783d05a0a1f3 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 6 Mar 2024 10:21:25 -0800 Subject: [PATCH 45/75] start with lesser augmentations --- .../Infection_classification_model.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index f038cc50e..91a7078e1 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -10,7 +10,6 @@ import napari from pytorch_lightning.loggers import TensorBoardLogger -from monai.transforms import Zoom from monai.transforms import Compose, RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised # %% Create a dataloader and visualize the batches. @@ -91,7 +90,11 @@ def dice_loss(pred, target): # ]) transforms = Compose([ + RandRotate(range_x=15, prob=0.5), Flip(spatial_axis=[0,1]), + RandFlip(spatial_axis=[0,1], prob=0.5), + RandRotate90(spatial_axes=(0,1), prob=0.2, max_k=3), + RandGaussianNoise(prob=0.5), ]) # create a small unet for image translation which accepts two input images (a label image and a microscopy image) and outputs one label image @@ -121,7 +124,12 @@ def __init__(self, in_channels, out_channels): nn.ReLU(inplace=True), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(64, out_channels, kernel_size=1) + nn.Conv2d(128, out_channels, kernel_size=1), + nn.Softmax(dim=1), + nn.Conv2d(out_channels, out_channels, kernel_size=1), + nn.Softmax(dim=1), + nn.Conv2d(out_channels, out_channels, kernel_size=1), + nn.Softmax(dim=1) ) def forward(self, x): @@ -145,9 +153,10 @@ def forward(self, x): # Apply the augmentations to your data augmented_input = transforms(input_data) + viewer.add_image(augmented_input.cpu().numpy().astype(np.float32)) # Forward pass through the model - output = unet_model(augmented_input,target) + output = unet_model(augmented_input) # Apply softmax activation to the output output = F.softmax(output, dim=1) @@ -161,8 +170,8 @@ def forward(self, x): optimizer.zero_grad() # Visualize sample of the augmented data using napari -for i in range(augmented_data.shape[0]): - viewer.add_image(augmented_data[i].cpu().numpy().astype(np.float32)) +# for i in range(augmented_input.shape[0]): +# viewer.add_image(augmented_input[i].cpu().numpy().astype(np.float32)) #%% use the batch for training the unet model using the lightning module From 310ba7091444683a4c1b3e52995746f62568848c Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 6 Mar 2024 10:22:59 -0800 Subject: [PATCH 46/75] added readme file --- examples/infection_phenotyping/readme.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 examples/infection_phenotyping/readme.md diff --git a/examples/infection_phenotyping/readme.md b/examples/infection_phenotyping/readme.md new file mode 100644 index 000000000..74dbc5000 --- /dev/null +++ b/examples/infection_phenotyping/readme.md @@ -0,0 +1,7 @@ +# Infection Classification Model + +This repository contains the code for the infection classification model (`infection_classification_model.py`) used in the infection phenotyping project. + +## Overview + +The `infection_classification_model.py` file implements a machine learning model for classifying infections based on various features. The model is trained on a labeled dataset, either fluorescence or label-free images, and can be used to predict the infection type for new samples. \ No newline at end of file From 34b81b952f34c8a00e190e354aa544b51fee5f82 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 6 Mar 2024 15:56:07 -0800 Subject: [PATCH 47/75] added tensorboard logging --- .../Infection_classification_model.py | 174 ++++++++++-------- 1 file changed, 96 insertions(+), 78 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index 91a7078e1..aab76139a 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -10,14 +10,26 @@ import napari from pytorch_lightning.loggers import TensorBoardLogger -from monai.transforms import Compose, RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised +from monai.transforms import RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised +from pytorch_lightning.callbacks import ModelCheckpoint # %% Create a dataloader and visualize the batches. # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr" # Create an instance of HCSDataModule -data_module = HCSDataModule(dataset_path, source_channel=['Sensor','Nucl_mask'], target_channel=['Inf_mask'],yx_patch_size=[128,128], split_ratio=0.8, z_window_size=1, architecture = '2D') +data_module = HCSDataModule( + dataset_path, + source_channel=['Sensor','Nucl_mask'], + target_channel=['Inf_mask'], + yx_patch_size=[128,128], + split_ratio=0.8, + z_window_size=1, + architecture = '2D', + num_workers=1, + batch_size=12, + augmentations=[], +) # Prepare the data data_module.prepare_data() @@ -32,41 +44,41 @@ # Set the display # os.environ['DISPLAY'] = ':1' -# Create a napari viewer -viewer = napari.Viewer() +# # Create a napari viewer +# viewer = napari.Viewer() -# Add the dataset to the viewer -for batch in dataloader: - if isinstance(batch, dict): - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - viewer.add_image(v.cpu().numpy().astype(np.float32)) +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) -# Start the napari event loop -napari.run() +# # Start the napari event loop +# napari.run() # %% use 2D Unet from viscy with a softmax layer at end for 4 label classification # use for image translation from instance segmentation to annotated image -# use diceloss function from here: https://gist.github.com/weiliu620/52d140b22685cf9552da4899e2160183 -def dice_loss(pred, target): - """This definition generalize to real valued pred and target vector. This should be differentiable. - pred: tensor with first dimension as batch - target: tensor with first dimension as batch - """ - - smooth = 1. - - # have to use contiguous since they may from a torch.view op - iflat = pred.contiguous().view(-1) - tflat = target.contiguous().view(-1) - intersection = (iflat * tflat).sum() - - A_sum = torch.sum(tflat * iflat) - B_sum = torch.sum(tflat * tflat) +# use diceloss function from here: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch +class DiceLoss(nn.Module): + def __init__(self, weight=None, size_average=True): + super(DiceLoss, self).__init__() + + def forward(self, inputs, targets, smooth=1): + + #comment out if your model contains a sigmoid or equivalent activation layer + inputs = F.sigmoid(inputs) + + #flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + intersection = (inputs * targets).sum() + dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) + + return 1 - dice - return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) ) - # load 2D UNet from viscy # unet_model = VSUNet( # architecture='2D', @@ -76,67 +88,43 @@ def dice_loss(pred, target): # example_input_xy_shape=(128,128), # ) -# Define the data augmentations -# Define the augmentations -# transforms = Compose([ -# RandRotate(range_x=15, prob=0.5), -# Resize(spatial_size=[64, 64],mode='linear'), -# Zoom([0.5,2], mode='bilinear'), -# Flip(spatial_axis=[0,1]), -# RandFlip(spatial_axis=[0,1], prob=0.5), -# RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), -# RandRotate90(spatial_axes=(0,1), prob=0.2, max_k=3), -# RandGaussianNoise(prob=0.5), -# ]) - -transforms = Compose([ - RandRotate(range_x=15, prob=0.5), - Flip(spatial_axis=[0,1]), - RandFlip(spatial_axis=[0,1], prob=0.5), - RandRotate90(spatial_axes=(0,1), prob=0.2, max_k=3), - RandGaussianNoise(prob=0.5), -]) - # create a small unet for image translation which accepts two input images (a label image and a microscopy image) and outputs one label image class UNet(nn.Module): def __init__(self, in_channels, out_channels): super(UNet, self).__init__() self.encoder = nn.Sequential( - nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), + nn.Conv3d(in_channels, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.MaxPool3d(kernel_size=1, stride=1), + nn.Conv3d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(128, 256, kernel_size=3, padding=1), + nn.MaxPool3d(kernel_size=1, stride=1), + nn.Conv3d(128, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(256, 512, kernel_size=3, padding=1), + nn.MaxPool3d(kernel_size=1, stride=1), + nn.Conv3d(256, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) - # Define the decoder part of the U-Net architecture self.decoder = nn.Sequential( - nn.Conv2d(512, 256, kernel_size=3, padding=1), + nn.Conv3d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.Conv3d(256, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.Conv3d(128, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(128, out_channels, kernel_size=1), + nn.Conv3d(64, out_channels, kernel_size=1), nn.Softmax(dim=1), - nn.Conv2d(out_channels, out_channels, kernel_size=1), + nn.Conv3d(out_channels, out_channels, kernel_size=1), nn.Softmax(dim=1), - nn.Conv2d(out_channels, out_channels, kernel_size=1), + nn.Conv3d(out_channels, out_channels, kernel_size=1), nn.Softmax(dim=1) ) def forward(self, x): - # Apply the encoder to the input x = self.encoder(x) - # Apply the decoder to the output of the encoder x = self.decoder(x) return x @@ -150,19 +138,13 @@ def forward(self, x): for batch in dataloader: # Extract the input and target from the batch input_data, target = batch['source'], batch['target'] - - # Apply the augmentations to your data - augmented_input = transforms(input_data) - viewer.add_image(augmented_input.cpu().numpy().astype(np.float32)) + # viewer.add_image(input_data.cpu().numpy().astype(np.float32)) # Forward pass through the model - output = unet_model(augmented_input) - - # Apply softmax activation to the output - output = F.softmax(output, dim=1) + output = unet_model(input_data) # Calculate the loss - loss = dice_loss(output, target) + loss = DiceLoss()(output, target) # Perform backpropagation and update the model's parameters loss.backward() @@ -178,10 +160,46 @@ def forward(self, x): # Train the model # Create a TensorBoard logger -logger = TensorBoardLogger("logs", name="infection_classification_model") +class LightningUNet(pl.LightningModule): + def __init__(self, in_channels, out_channels): + super(LightningUNet, self).__init__() + self.unet_model = UNet(in_channels, out_channels) + + def forward(self, x): + return self.unet_model(x) + + def training_step(self, batch, batch_idx): + input_data, target = batch['source'], batch['target'] + output = self(input_data) + loss = DiceLoss()(output, target) + self.log('train_loss', loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + +# Create an instance of the LightningUNet class +unet_model = LightningUNet(in_channels=2, out_channels=1) + +# Define the logger +logger = TensorBoardLogger("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", name="infection_classification_model") # Pass the logger to the Trainer -trainer = pl.Trainer(gpus=1, logger=logger) +trainer = pl.Trainer(logger=logger, max_epochs=10, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath='/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints', + filename='checkpoint_{epoch:02d}', + save_top_k=-1, + verbose=True, + monitor='val_loss', + mode='min' +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) # Fit the model trainer.fit(unet_model, data_module) From a4e2f0d683553c7971d536c11ef1ee5486c6ad71 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 7 Mar 2024 14:05:57 -0800 Subject: [PATCH 48/75] added validation step --- .../Infection_classification_model.py | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index aab76139a..cab56b5ce 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -38,7 +38,9 @@ data_module.setup(stage = "fit") # Create a dataloader -dataloader = data_module.train_dataloader() +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() # Visualize the dataset and the batch using napari # Set the display @@ -135,7 +137,7 @@ def forward(self, x): optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) #%% Iterate over the batches -for batch in dataloader: +for batch in train_dm: # Extract the input and target from the batch input_data, target = batch['source'], batch['target'] # viewer.add_image(input_data.cpu().numpy().astype(np.float32)) @@ -151,6 +153,17 @@ def forward(self, x): optimizer.step() optimizer.zero_grad() +for batch in val_dm: + # Extract the input and target from the batch + input_data, target = batch['source'], batch['target'] + + # Forward pass through the model + output = unet_model(input_data) + + # Calculate the loss + loss = DiceLoss()(output, target) + + # Visualize sample of the augmented data using napari # for i in range(augmented_input.shape[0]): # viewer.add_image(augmented_input[i].cpu().numpy().astype(np.float32)) @@ -175,6 +188,12 @@ def training_step(self, batch, batch_idx): self.log('train_loss', loss) return loss + def validation_step(self, batch, batch_idx): + input_data, target = batch['source'], batch['target'] + output = self(input_data) + loss = DiceLoss()(output, target) + self.log('val_loss', loss) + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer @@ -186,7 +205,7 @@ def configure_optimizers(self): logger = TensorBoardLogger("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", name="infection_classification_model") # Pass the logger to the Trainer -trainer = pl.Trainer(logger=logger, max_epochs=10, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1) +trainer = pl.Trainer(logger=logger, max_epochs=30, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1) # Define the checkpoint callback checkpoint_callback = ModelCheckpoint( @@ -205,8 +224,14 @@ def configure_optimizers(self): trainer.fit(unet_model, data_module) # %% test the model on the test set -# Load the test dataset -test_dataloader = data_module.test_dataloader() +test_datapath = '/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr' + +test_dm = HCSDataModule( + test_datapath, + source_channel=['Sensor','Nuclei_mask'], +) +# Load the predict dataset +test_dataloader = test_dm.test_dataloader() # Set the model to evaluation mode unet_model.eval() From 0ebb5df40e6597d9af2b797e6ba89cb25bb8491a Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 11 Mar 2024 08:14:18 -0700 Subject: [PATCH 49/75] chnaged to viscy 2d unet --- .../Infection_classification_model.py | 76 ++----------------- 1 file changed, 7 insertions(+), 69 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index cab56b5ce..f97e654ca 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -12,6 +12,8 @@ from pytorch_lightning.loggers import TensorBoardLogger from monai.transforms import RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised from pytorch_lightning.callbacks import ModelCheckpoint +from monai.losses import DiceLoss +from viscy.light.engine import VSUNet # %% Create a dataloader and visualize the batches. # Set the path to the dataset @@ -61,77 +63,13 @@ # %% use 2D Unet from viscy with a softmax layer at end for 4 label classification # use for image translation from instance segmentation to annotated image - -# use diceloss function from here: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch -class DiceLoss(nn.Module): - def __init__(self, weight=None, size_average=True): - super(DiceLoss, self).__init__() - - def forward(self, inputs, targets, smooth=1): - - #comment out if your model contains a sigmoid or equivalent activation layer - inputs = F.sigmoid(inputs) - - #flatten label and prediction tensors - inputs = inputs.view(-1) - targets = targets.view(-1) - - intersection = (inputs * targets).sum() - dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) - - return 1 - dice # load 2D UNet from viscy -# unet_model = VSUNet( -# architecture='2D', -# model_config={"in_channels": 2, "out_channels": 1}, -# loss_function=dice_loss, -# lr=1e-3, -# example_input_xy_shape=(128,128), -# ) - -# create a small unet for image translation which accepts two input images (a label image and a microscopy image) and outputs one label image -class UNet(nn.Module): - def __init__(self, in_channels, out_channels): - super(UNet, self).__init__() - - self.encoder = nn.Sequential( - nn.Conv3d(in_channels, 64, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool3d(kernel_size=1, stride=1), - nn.Conv3d(64, 128, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool3d(kernel_size=1, stride=1), - nn.Conv3d(128, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool3d(kernel_size=1, stride=1), - nn.Conv3d(256, 512, kernel_size=3, padding=1), - nn.ReLU(inplace=True) - ) - - self.decoder = nn.Sequential( - nn.Conv3d(512, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv3d(256, 128, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv3d(128, 64, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv3d(64, out_channels, kernel_size=1), - nn.Softmax(dim=1), - nn.Conv3d(out_channels, out_channels, kernel_size=1), - nn.Softmax(dim=1), - nn.Conv3d(out_channels, out_channels, kernel_size=1), - nn.Softmax(dim=1) - ) - - def forward(self, x): - x = self.encoder(x) - - x = self.decoder(x) - - return x - -unet_model = UNet(in_channels=2, out_channels=1) +unet_model = VSUNet( + architecture='2D', + model_config={"in_channels": 2, "out_channels": 4, "task": "reg"}, + lr=1e-3, +) # Define the optimizer optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) From a0e426a8edb9b09f3dc839dd029c40849f1c462f Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 12 Mar 2024 12:22:36 -0700 Subject: [PATCH 50/75] used crossentropyloss with one-hot encoding --- .../Infection_classification_model.py | 211 ++++++++++-------- 1 file changed, 124 insertions(+), 87 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index f97e654ca..b2230d9ba 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -1,4 +1,3 @@ - # %% import torch from viscy.data.hcs import HCSDataModule @@ -7,13 +6,31 @@ import torch.nn as nn import lightning.pytorch as pl import torch.nn.functional as F +import torchview +from typing import Literal, Union -import napari +# import napari from pytorch_lightning.loggers import TensorBoardLogger -from monai.transforms import RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised +from monai.transforms import ( + RandRotate, + Resize, + Zoom, + Flip, + RandFlip, + RandZoom, + RandRotate90, + RandRotate, + RandAffine, + Rand2DElastic, + Rand3DElastic, + RandGaussianNoise, + RandGaussianNoised, +) from pytorch_lightning.callbacks import ModelCheckpoint from monai.losses import DiceLoss from viscy.light.engine import VSUNet +from viscy.unet.networks.Unet2D import Unet2d +from viscy.data.hcs import Sample # %% Create a dataloader and visualize the batches. # Set the path to the dataset @@ -21,13 +38,13 @@ # Create an instance of HCSDataModule data_module = HCSDataModule( - dataset_path, - source_channel=['Sensor','Nucl_mask'], - target_channel=['Inf_mask'], - yx_patch_size=[128,128], - split_ratio=0.8, - z_window_size=1, - architecture = '2D', + dataset_path, + source_channel=["Sensor"], + target_channel=["Inf_mask"], + yx_patch_size=[128, 128], + split_ratio=0.8, + z_window_size=1, + architecture="2D", num_workers=1, batch_size=12, augmentations=[], @@ -37,7 +54,7 @@ data_module.prepare_data() # Setup the data -data_module.setup(stage = "fit") +data_module.setup(stage="fit") # Create a dataloader train_dm = data_module.train_dataloader() @@ -61,112 +78,132 @@ # # Start the napari event loop # napari.run() -# %% use 2D Unet from viscy with a softmax layer at end for 4 label classification -# use for image translation from instance segmentation to annotated image - -# load 2D UNet from viscy -unet_model = VSUNet( - architecture='2D', - model_config={"in_channels": 2, "out_channels": 4, "task": "reg"}, - lr=1e-3, -) - -# Define the optimizer -optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) - -#%% Iterate over the batches -for batch in train_dm: - # Extract the input and target from the batch - input_data, target = batch['source'], batch['target'] - # viewer.add_image(input_data.cpu().numpy().astype(np.float32)) - - # Forward pass through the model - output = unet_model(input_data) - - # Calculate the loss - loss = DiceLoss()(output, target) - - # Perform backpropagation and update the model's parameters - loss.backward() - optimizer.step() - optimizer.zero_grad() - -for batch in val_dm: - # Extract the input and target from the batch - input_data, target = batch['source'], batch['target'] - - # Forward pass through the model - output = unet_model(input_data) - - # Calculate the loss - loss = DiceLoss()(output, target) - - -# Visualize sample of the augmented data using napari -# for i in range(augmented_input.shape[0]): -# viewer.add_image(augmented_input[i].cpu().numpy().astype(np.float32)) +# %% use 2D Unet and Lightning module -#%% use the batch for training the unet model using the lightning module - # Train the model # Create a TensorBoard logger class LightningUNet(pl.LightningModule): - def __init__(self, in_channels, out_channels): + def __init__( + self, + in_channels, + out_channels, + lr: float = 1e-3, + loss_function: nn.CrossEntropyLoss = None, + schedule: Literal["WarmupCosine", "Constant"] = "Constant", + log_batches_per_epoch: int = 2, + log_samples_per_batch: int = 1, + ): super(LightningUNet, self).__init__() - self.unet_model = UNet(in_channels, out_channels) + self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + self.lr = lr + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule + self.log_batches_per_epoch = log_batches_per_epoch + self.log_samples_per_batch = log_samples_per_batch + self.training_step_outputs = [] + self.validation_step_outputs = [] def forward(self, x): return self.unet_model(x) - def training_step(self, batch, batch_idx): - input_data, target = batch['source'], batch['target'] - output = self(input_data) - loss = DiceLoss()(output, target) - self.log('train_loss', loss) - return loss - - def validation_step(self, batch, batch_idx): - input_data, target = batch['source'], batch['target'] - output = self(input_data) - loss = DiceLoss()(output, target) - self.log('val_loss', loss) - def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer -# Create an instance of the LightningUNet class -unet_model = LightningUNet(in_channels=2, out_channels=1) + def training_step(self, batch: Sample, batch_idx: int): + + # Extract the input and target from the batch + source = batch["source"] + target = batch["target"] + pred = self.forward(source) + + # Convert the target image to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=4).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert target to float type + # Calculate the loss + train_loss = self.loss_function(pred, target_one_hot) + # if batch_idx < self.log_batches_per_epoch: + # self.training_step_outputs.extend( + # self._detach_sample((source, target_one_hot, pred)) + # ) + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss + + def validation_step(self, batch: Sample, batch_idx: int): + + # Extract the input and target from the batch + source = batch["source"] + target = batch["target"] + pred = self.forward(source) + + # Convert the target image to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=4).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert target to float type + # Calculate the loss + loss = self.loss_function(pred, target_one_hot) + self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) + # if batch_idx < self.log_batches_per_epoch: + # self.validation_step_outputs.extend( + # self._detach_sample((source, target, pred)) + # ) + return loss + -# Define the logger -logger = TensorBoardLogger("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", name="infection_classification_model") +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", + name="infection_classification_model", +) # Pass the logger to the Trainer -trainer = pl.Trainer(logger=logger, max_epochs=30, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1) +trainer = pl.Trainer( + logger=logger, + max_epochs=30, + default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", + log_every_n_steps=1, +) # Define the checkpoint callback checkpoint_callback = ModelCheckpoint( - dirpath='/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints', - filename='checkpoint_{epoch:02d}', + dirpath="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints", + filename="checkpoint_{epoch:02d}", save_top_k=-1, verbose=True, - monitor='val_loss', - mode='min' + monitor="loss/validate", + mode="min", ) # Add the checkpoint callback to the trainer trainer.callbacks.append(checkpoint_callback) # Fit the model -trainer.fit(unet_model, data_module) +model = LightningUNet( + in_channels=1, + out_channels=4, + loss_function=nn.CrossEntropyLoss(), +) +trainer.fit(model, data_module) + # %% test the model on the test set -test_datapath = '/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr' +test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr" test_dm = HCSDataModule( - test_datapath, - source_channel=['Sensor','Nuclei_mask'], + test_datapath, + source_channel=["Sensor", "Nuclei_mask"], ) # Load the predict dataset test_dataloader = test_dm.test_dataloader() @@ -180,7 +217,7 @@ def configure_optimizers(self): # Iterate over the test batches for batch in test_dataloader: # Extract the input from the batch - input_data = batch['source'] + input_data = batch["source"] # Forward pass through the model output = unet_model(input_data) @@ -193,4 +230,4 @@ def configure_optimizers(self): # Save the predictions as added channel in zarr format # use iohub or viscy to save the predictions!!! -zarr.save('predictions.zarr', predictions) +zarr.save("predictions.zarr", predictions) From 5ecbde022611e06314190d26fb3bfe2fef2e6a37 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 12 Mar 2024 16:17:53 -0700 Subject: [PATCH 51/75] added sample image logging --- .../Infection_classification_model.py | 69 +++++++++++++++---- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index b2230d9ba..30af733f6 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -7,10 +7,13 @@ import lightning.pytorch as pl import torch.nn.functional as F import torchview -from typing import Literal, Union +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap # import napari from pytorch_lightning.loggers import TensorBoardLogger +from torch import Tensor from monai.transforms import ( RandRotate, Resize, @@ -125,10 +128,10 @@ def training_step(self, batch: Sample, batch_idx: int): target_one_hot = target_one_hot.float() # Convert target to float type # Calculate the loss train_loss = self.loss_function(pred, target_one_hot) - # if batch_idx < self.log_batches_per_epoch: - # self.training_step_outputs.extend( - # self._detach_sample((source, target_one_hot, pred)) - # ) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) self.log( "loss/train", train_loss, @@ -154,13 +157,55 @@ def validation_step(self, batch: Sample, batch_idx: int): target_one_hot = target_one_hot.float() # Convert target to float type # Calculate the loss loss = self.loss_function(pred, target_one_hot) - self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) - # if batch_idx < self.log_batches_per_epoch: - # self.validation_step_outputs.extend( - # self._detach_sample((source, target, pred)) - # ) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) + self.log( + "loss/validate", + loss, + sync_dist=True, + add_dataloader_idx=False, + ) return loss + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) + return self._predict_pad.inverse(self.forward(source)) + + def on_train_epoch_end(self): + self._log_samples("train_samples", self.training_step_outputs) + self.training_step_outputs = [] + + def on_validation_epoch_end(self): + self._log_samples("val_samples", self.validation_step_outputs) + self.validation_step_outputs = [] + + def _detach_sample(self, imgs: Sequence[Tensor]): + num_samples = 2 # min(imgs[0].shape[0], self.log_samples_per_batch) + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] + for sample_images in imgs: + images_row = [] + for i, image in enumerate(sample_images): + cm_name = "gray" if i == 0 else "inferno" + if image.ndim == 2: + image = image[np.newaxis] + for channel in image: + channel = rescale_intensity(channel, out_range=(0, 1)) + render = get_cmap(cm_name)(channel, bytes=True)[..., :3] + images_row.append(render) + images_grid.append(np.concatenate(images_row, axis=1)) + grid = np.concatenate(images_grid, axis=0) + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + # %% Define the logger logger = TensorBoardLogger( @@ -171,7 +216,7 @@ def validation_step(self, batch: Sample, batch_idx: int): # Pass the logger to the Trainer trainer = pl.Trainer( logger=logger, - max_epochs=30, + max_epochs=50, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1, ) @@ -193,7 +238,7 @@ def validation_step(self, batch: Sample, batch_idx: int): model = LightningUNet( in_channels=1, out_channels=4, - loss_function=nn.CrossEntropyLoss(), + loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.4, 0.4, 0.1])), ) trainer.fit(model, data_module) From 58b7fa56fae4efec6975fea03a4e19d710cfbe7c Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 12 Mar 2024 19:38:01 -0700 Subject: [PATCH 52/75] attempt to build magicgui annotation --- .../Infection_annotator.py | 86 +++++++++++++++---- 1 file changed, 71 insertions(+), 15 deletions(-) diff --git a/examples/infection_phenotyping/Infection_annotator.py b/examples/infection_phenotyping/Infection_annotator.py index e933a773f..5117ce918 100644 --- a/examples/infection_phenotyping/Infection_annotator.py +++ b/examples/infection_phenotyping/Infection_annotator.py @@ -1,55 +1,111 @@ +# %% Run this to display napari on the remote server while running the script in local IDE +import os - -#%% use napari to annotate infected cells in segmented data +os.environ["DISPLAY"] = ":1" +# %% use napari to annotate infected cells in segmented data import napari from iohub.ngff import open_ome_zarr import numpy as np +from pathlib import Path + +dataset_folder = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets" +) + +input_file = dataset_folder / "Exp_2023_09_28_DENV_A2.zarr" +output_file = ( + dataset_folder / "Exp_2023_09_28_DENV_A2_infMarked_test_annotation_pipeline.zarr" +) -file_in_path = '/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/Exp_2023_09_28_DENV_A2.zarr' zarr_input = open_ome_zarr( - file_in_path, + input_file, layout="hcs", mode="r+", ) chan_names = zarr_input.channel_names # zarr_input.append_channel('Inf_mask',resize_arrays=True) -file_out_path = '/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/Exp_2023_09_28_DENV_A2_infMarked_rev2.zarr' zarr_output = open_ome_zarr( - file_out_path, + output_file, layout="hcs", - mode="w-", - channel_names=['Sensor','Nucl_mask','Inf_mask'], + mode="w", + channel_names=["Sensor", "Nucl_mask", "Inf_mask"], ) v = napari.Viewer() -#%% Load label image to napari +# %% Load label image to napari for well_id, well_data in zarr_input.wells(): well_name, well_no = well_id.split("/") - if well_name == 'A' and well_no == '2': + if well_name == "A" and well_no == "2": for pos_name, pos_data in well_data.positions(): # if int(pos_name) > 1: v.layers.clear() data = pos_data.data - FITC = data[0,0,...] - v.add_image(FITC, name='FITC', colormap='green', blending='additive') - Inf_mask = data[0,1,...].astype(int) + FITC = data[0, 0, ...] + v.add_image(FITC, name="FITC", colormap="green", blending="additive") + Inf_mask = data[0, 1, ...].astype(int) v.add_labels(Inf_mask) input("Press Enter") - label_layer = v.layers['Inf_mask'] + label_layer = v.layers["Inf_mask"] label_array = label_layer.data label_array = np.expand_dims(label_array, axis=(0, 1)) # zarr_input.create_image('Inf_mask',label_array) out_data = np.concatenate((data, label_array), axis=1) position = zarr_output.create_position(well_name, well_no, pos_name) position["0"] = out_data - + +# %% Template for magicgui based annotation workflow. +from magicgui import magicgui +from napari.types import ImageData + + +# Create an enumeration of all wells +wells = list(w[0] for w in zarr_input.wells()) +well_id, well_data = next(zarr_input.wells()) +positions = list(p[0] for p in well_data.positions()) +channel_names = zarr_input.channel_names + + +@magicgui( + call_button="load data", + wells={"choices", ["A/1", "A/2", "A/3", "A/4", "A/5"]}, + positions={"choices", ["0", "1", "2", "3", "4"]}, +) # defines the widget. +def load_well(well: str, position: str): # defines the callback. + # Load all data from specified well and position + for well_id, well_data in zarr_input.wells(): + if well_id == well: + for pos_name, pos_data in well_data.positions(): + if pos_name == position: + for i, ch in enumerate(channel_names): + data = pos_data.data + v.add_image( + data[0, i, ...], + name=ch, + colormap="gray", + blending="additive", + ) + break + break + + +@magicgui(call_button="save annotations") # defines the widget. +def save_annotations( + annotation_layer: ImageData, output_path: Path +): # defines the callback. + # Save the output to the specified path + print("save") + + +# Add both widgets to napari +v.window.add_dock_widget(load_well(wells, "0")) +v.window.add_dock_widget(save_annotations) # %% From 35ead0cc7aef5d001ad64f3ab06a004adee2df00 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 13 Mar 2024 16:03:51 -0700 Subject: [PATCH 53/75] renamed infection annotation tool --- .../{Infection_annotator.py => Infection_annotation_refiner.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/infection_phenotyping/{Infection_annotator.py => Infection_annotation_refiner.py} (100%) diff --git a/examples/infection_phenotyping/Infection_annotator.py b/examples/infection_phenotyping/Infection_annotation_refiner.py similarity index 100% rename from examples/infection_phenotyping/Infection_annotator.py rename to examples/infection_phenotyping/Infection_annotation_refiner.py From 802ebc33ed705c3942ec41475beeaa9e52e0eaa9 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Sat, 23 Mar 2024 10:06:22 -0700 Subject: [PATCH 54/75] added normalization and augmentations --- .../Infection_classification_model.py | 101 +++++++----------- 1 file changed, 37 insertions(+), 64 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index 30af733f6..d9a045cc2 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -6,7 +6,8 @@ import torch.nn as nn import lightning.pytorch as pl import torch.nn.functional as F -import torchview + +# import torchview from typing import Literal, Sequence from skimage.exposure import rescale_intensity from matplotlib.cm import get_cmap @@ -14,43 +15,47 @@ # import napari from pytorch_lightning.loggers import TensorBoardLogger from torch import Tensor -from monai.transforms import ( - RandRotate, - Resize, - Zoom, - Flip, - RandFlip, - RandZoom, - RandRotate90, - RandRotate, - RandAffine, - Rand2DElastic, - Rand3DElastic, - RandGaussianNoise, - RandGaussianNoised, -) from pytorch_lightning.callbacks import ModelCheckpoint -from monai.losses import DiceLoss -from viscy.light.engine import VSUNet + +# from monai.losses import DiceLoss +# from viscy.light.engine import VSUNet from viscy.unet.networks.Unet2D import Unet2d from viscy.data.hcs import Sample +from viscy.transforms import RandWeightedCropd, RandGaussianNoised +from viscy.transforms import NormalizeSampled # %% Create a dataloader and visualize the batches. # Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr" +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_27_DENV_A2_infMarked_refined.zarr" # Create an instance of HCSDataModule data_module = HCSDataModule( dataset_path, - source_channel=["Sensor"], + source_channel=["Sensor", "Phase"], target_channel=["Inf_mask"], yx_patch_size=[128, 128], split_ratio=0.8, z_window_size=1, architecture="2D", num_workers=1, - batch_size=12, - augmentations=[], + batch_size=64, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=8, + spatial_size=[-1, 128, 128], + keys=["Sensor", "Phase", "Inf_mask"], + w_key="Inf_mask", + ), + RandGaussianNoised(keys=["Sensor", "Phase"], mean=0.0, std=1.0, prob=0.5), + ], ) # Prepare the data @@ -159,13 +164,14 @@ def validation_step(self, batch: Sample, batch_idx: int): loss = self.loss_function(pred, target_one_hot) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( - self._detach_sample((source, target, pred)) + self._detach_sample((source, target_one_hot, pred)) ) self.log( "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, + logger=True, ) return loss @@ -209,21 +215,21 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # %% Define the logger logger = TensorBoardLogger( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", - name="infection_classification_model", + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/", + name="logs_wPhase", ) # Pass the logger to the Trainer trainer = pl.Trainer( logger=logger, - max_epochs=50, - default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", + max_epochs=100, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", log_every_n_steps=1, ) # Define the checkpoint callback checkpoint_callback = ModelCheckpoint( - dirpath="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints", + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/", filename="checkpoint_{epoch:02d}", save_top_k=-1, verbose=True, @@ -236,43 +242,10 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Fit the model model = LightningUNet( - in_channels=1, + in_channels=2, out_channels=4, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.4, 0.4, 0.1])), + loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.3, 0.3, 0.3])), ) trainer.fit(model, data_module) - -# %% test the model on the test set -test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr" - -test_dm = HCSDataModule( - test_datapath, - source_channel=["Sensor", "Nuclei_mask"], -) -# Load the predict dataset -test_dataloader = test_dm.test_dataloader() - -# Set the model to evaluation mode -unet_model.eval() - -# Create a list to store the predictions -predictions = [] - -# Iterate over the test batches -for batch in test_dataloader: - # Extract the input from the batch - input_data = batch["source"] - - # Forward pass through the model - output = unet_model(input_data) - - # Append the predictions to the list - predictions.append(output.detach().cpu().numpy()) - -# Convert the predictions to a numpy array -predictions = np.stack(predictions) - -# Save the predictions as added channel in zarr format -# use iohub or viscy to save the predictions!!! -zarr.save("predictions.zarr", predictions) +# %% From 908039a82a53a7c5e4100d309a5c30f58ccbda7f Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 25 Mar 2024 16:09:05 -0700 Subject: [PATCH 55/75] added model testing code --- .../test_infection_classifier.py | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 examples/infection_phenotyping/test_infection_classifier.py diff --git a/examples/infection_phenotyping/test_infection_classifier.py b/examples/infection_phenotyping/test_infection_classifier.py new file mode 100644 index 000000000..3fed63194 --- /dev/null +++ b/examples/infection_phenotyping/test_infection_classifier.py @@ -0,0 +1,96 @@ +# %% +import numpy as np +from viscy.data.hcs import HCSDataModule +from viscy.transforms import NormalizeSampled +from viscy.unet.networks.Unet2D import Unet2d +from viscy.data.hcs import Sample +import lightning.pytorch as pl +import torch + +from viscy.light.predict_writer import HCSPredictionWriter +from monai.transforms import DivisiblePad + +# %% test the model on the test set +test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +data_module = HCSDataModule( + test_datapath, + source_channel=["Sensor", "Phase"], + target_channel=[], + split_ratio=0.8, + z_window_size=1, + architecture="2D", + num_workers=1, + batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +data_module.setup(stage="predict") +test_dm = data_module.test_dataloader() +sample = next(iter(test_dm)) + +# %% +class LightningUNet(pl.LightningModule): + def __init__( + self, + in_channels, + out_channels, + ckpt_path, + ): + super(LightningUNet, self).__init__() + self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + + def forward(self, x): + return self.unet_model(x) + + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) + pred_class = self.forward(source) + pred_int = torch.argmax(pred_class, dim=4, keepdim=True) + return self._predict_pad.inverse(pred_int) + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + +# %% create trainer and input + +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred.zarr" + +trainer = pl.Trainer( + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + callbacks=[HCSPredictionWriter(output_path, write_input=True)], +) +model = LightningUNet( + in_channels=2, + out_channels=3, + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", +) + +trainer.predict( + model=model, + datamodule=data_module, + return_predictions=True, +) + +# %% test the model on the test set and write to zarr store \ No newline at end of file From 88615d5f5f3006e9956e86cdb452600990c8f151 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 25 Mar 2024 16:28:07 -0700 Subject: [PATCH 56/75] removed annotation refiner --- .../Infection_annotation_refiner.py | 111 ------------------ 1 file changed, 111 deletions(-) delete mode 100644 examples/infection_phenotyping/Infection_annotation_refiner.py diff --git a/examples/infection_phenotyping/Infection_annotation_refiner.py b/examples/infection_phenotyping/Infection_annotation_refiner.py deleted file mode 100644 index 5117ce918..000000000 --- a/examples/infection_phenotyping/Infection_annotation_refiner.py +++ /dev/null @@ -1,111 +0,0 @@ -# %% Run this to display napari on the remote server while running the script in local IDE -import os - -os.environ["DISPLAY"] = ":1" -# %% use napari to annotate infected cells in segmented data - -import napari -from iohub.ngff import open_ome_zarr -import numpy as np -from pathlib import Path - -dataset_folder = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets" -) - -input_file = dataset_folder / "Exp_2023_09_28_DENV_A2.zarr" -output_file = ( - dataset_folder / "Exp_2023_09_28_DENV_A2_infMarked_test_annotation_pipeline.zarr" -) - -zarr_input = open_ome_zarr( - input_file, - layout="hcs", - mode="r+", -) -chan_names = zarr_input.channel_names -# zarr_input.append_channel('Inf_mask',resize_arrays=True) - -zarr_output = open_ome_zarr( - output_file, - layout="hcs", - mode="w", - channel_names=["Sensor", "Nucl_mask", "Inf_mask"], -) - -v = napari.Viewer() - - -# %% Load label image to napari -for well_id, well_data in zarr_input.wells(): - well_name, well_no = well_id.split("/") - - if well_name == "A" and well_no == "2": - - for pos_name, pos_data in well_data.positions(): - # if int(pos_name) > 1: - v.layers.clear() - data = pos_data.data - - FITC = data[0, 0, ...] - v.add_image(FITC, name="FITC", colormap="green", blending="additive") - Inf_mask = data[0, 1, ...].astype(int) - v.add_labels(Inf_mask) - input("Press Enter") - - label_layer = v.layers["Inf_mask"] - label_array = label_layer.data - label_array = np.expand_dims(label_array, axis=(0, 1)) - # zarr_input.create_image('Inf_mask',label_array) - out_data = np.concatenate((data, label_array), axis=1) - position = zarr_output.create_position(well_name, well_no, pos_name) - position["0"] = out_data - - -# %% Template for magicgui based annotation workflow. -from magicgui import magicgui -from napari.types import ImageData - - -# Create an enumeration of all wells -wells = list(w[0] for w in zarr_input.wells()) -well_id, well_data = next(zarr_input.wells()) -positions = list(p[0] for p in well_data.positions()) -channel_names = zarr_input.channel_names - - -@magicgui( - call_button="load data", - wells={"choices", ["A/1", "A/2", "A/3", "A/4", "A/5"]}, - positions={"choices", ["0", "1", "2", "3", "4"]}, -) # defines the widget. -def load_well(well: str, position: str): # defines the callback. - # Load all data from specified well and position - for well_id, well_data in zarr_input.wells(): - if well_id == well: - for pos_name, pos_data in well_data.positions(): - if pos_name == position: - for i, ch in enumerate(channel_names): - data = pos_data.data - v.add_image( - data[0, i, ...], - name=ch, - colormap="gray", - blending="additive", - ) - break - break - - -@magicgui(call_button="save annotations") # defines the widget. -def save_annotations( - annotation_layer: ImageData, output_path: Path -): # defines the callback. - # Save the output to the specified path - print("save") - - -# Add both widgets to napari -v.window.add_dock_widget(load_well(wells, "0")) -v.window.add_dock_widget(save_annotations) -# %% From 82428ed4b988087cc386d25673af44c56534179f Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 25 Mar 2024 17:47:17 -0700 Subject: [PATCH 57/75] corrected conversion of class to int --- .../Infection_classification_model.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index d9a045cc2..b6ef2b1f3 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -6,10 +6,12 @@ import torch.nn as nn import lightning.pytorch as pl import torch.nn.functional as F +import torchmetrics # import torchview from typing import Literal, Sequence from skimage.exposure import rescale_intensity +from sklearn.metrics import ConfusionMatrixDisplay from matplotlib.cm import get_cmap # import napari @@ -26,7 +28,7 @@ # %% Create a dataloader and visualize the batches. # Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_27_DENV_A2_infMarked_refined.zarr" +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr" # Create an instance of HCSDataModule data_module = HCSDataModule( @@ -38,7 +40,7 @@ z_window_size=1, architecture="2D", num_workers=1, - batch_size=64, + batch_size=128, normalizations=[ NormalizeSampled( keys=["Sensor", "Phase"], @@ -53,8 +55,7 @@ spatial_size=[-1, 128, 128], keys=["Sensor", "Phase", "Inf_mask"], w_key="Inf_mask", - ), - RandGaussianNoised(keys=["Sensor", "Phase"], mean=0.0, std=1.0, prob=0.5), + ) ], ) @@ -111,6 +112,9 @@ def __init__( self.log_samples_per_batch = log_samples_per_batch self.training_step_outputs = [] self.validation_step_outputs = [] + self.val_cm = torchmetrics.classification.ConfusionMatrix( + task="multiclass", num_classes=self.n_classes + ) def forward(self, x): return self.unet_model(x) @@ -127,7 +131,7 @@ def training_step(self, batch: Sample, batch_idx: int): pred = self.forward(source) # Convert the target image to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=4).permute( + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( 0, 4, 1, 2, 3 ) target_one_hot = target_one_hot.float() # Convert target to float type @@ -156,12 +160,13 @@ def validation_step(self, batch: Sample, batch_idx: int): pred = self.forward(source) # Convert the target image to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=4).permute( + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( 0, 4, 1, 2, 3 ) target_one_hot = target_one_hot.float() # Convert target to float type # Calculate the loss loss = self.loss_function(pred, target_one_hot) + self.val_cm(target_one_hot, pred) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target_one_hot, pred)) @@ -187,6 +192,16 @@ def on_validation_epoch_end(self): self._log_samples("val_samples", self.validation_step_outputs) self.validation_step_outputs = [] + # Log the confusion matrix at the end of the epoch + confusion_matrix = self.val_cm.compute().cpu().numpy() + + self.logger.experiment.add_figure( + "Validation Confusion Matrix", + ConfusionMatrixDisplay(confusion_matrix, self.index_to_label), + self.current_epoch, + ) + self.val_cm.reset() + def _detach_sample(self, imgs: Sequence[Tensor]): num_samples = 2 # min(imgs[0].shape[0], self.log_samples_per_batch) return [ @@ -243,8 +258,8 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Fit the model model = LightningUNet( in_channels=2, - out_channels=4, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.3, 0.3, 0.3])), + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), ) trainer.fit(model, data_module) From b470ed1fd87850145e66d8b9832b492a561c8b14 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 25 Mar 2024 17:55:36 -0700 Subject: [PATCH 58/75] corrected prediction module --- .../test_infection_classifier.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/infection_phenotyping/test_infection_classifier.py b/examples/infection_phenotyping/test_infection_classifier.py index 3fed63194..d8918f829 100644 --- a/examples/infection_phenotyping/test_infection_classifier.py +++ b/examples/infection_phenotyping/test_infection_classifier.py @@ -6,7 +6,7 @@ from viscy.data.hcs import Sample import lightning.pytorch as pl import torch - +import torchmetrics from viscy.light.predict_writer import HCSPredictionWriter from monai.transforms import DivisiblePad @@ -16,7 +16,7 @@ data_module = HCSDataModule( test_datapath, source_channel=["Sensor", "Phase"], - target_channel=[], + target_channel=["inf_mask"], split_ratio=0.8, z_window_size=1, architecture="2D", @@ -36,8 +36,6 @@ data_module.prepare_data() data_module.setup(stage="predict") -test_dm = data_module.test_dataloader() -sample = next(iter(test_dm)) # %% class LightningUNet(pl.LightningModule): @@ -49,6 +47,9 @@ def __init__( ): super(LightningUNet, self).__init__() self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + # self.pred_cm = torchmetrics.classification.ConfusionMatrix( + # task="multiclass", num_classes=self.n_classes + # ) if ckpt_path is not None: state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ "state_dict" @@ -62,8 +63,8 @@ def forward(self, x): def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) pred_class = self.forward(source) - pred_int = torch.argmax(pred_class, dim=4, keepdim=True) - return self._predict_pad.inverse(pred_int) + pred_int = torch.argmax(pred_class, dim=1, keepdim=True) + return pred_int def on_predict_start(self): """Pad the input shape to be divisible by the downsampling factor. @@ -79,7 +80,7 @@ def on_predict_start(self): trainer = pl.Trainer( default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", - callbacks=[HCSPredictionWriter(output_path, write_input=True)], + callbacks=[HCSPredictionWriter(output_path, write_input=False)], ) model = LightningUNet( in_channels=2, From f3746f89a271e6c907292bf7fdbdd47d2d85ea6a Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 06:35:49 -0700 Subject: [PATCH 59/75] cleaned up the code and comments for the LightningUNet --- .../Infection_classification_model.py | 196 ++++++++---------- 1 file changed, 85 insertions(+), 111 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index b6ef2b1f3..f03595d7b 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -25,7 +25,7 @@ from viscy.data.hcs import Sample from viscy.transforms import RandWeightedCropd, RandGaussianNoised from viscy.transforms import NormalizeSampled - +] # %% Create a dataloader and visualize the batches. # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr" @@ -93,140 +93,114 @@ # Train the model # Create a TensorBoard logger class LightningUNet(pl.LightningModule): + # Initialize the class def __init__( self, - in_channels, - out_channels, - lr: float = 1e-3, - loss_function: nn.CrossEntropyLoss = None, - schedule: Literal["WarmupCosine", "Constant"] = "Constant", - log_batches_per_epoch: int = 2, - log_samples_per_batch: int = 1, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.CrossEntropyLoss = None, # Loss function + schedule: Literal["WarmupCosine", "Constant"] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch ): - super(LightningUNet, self).__init__() + super(LightningUNet, self).__init__() # Call the superclass initializer + # Initialize the UNet model self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) - self.lr = lr + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() - self.schedule = schedule - self.log_batches_per_epoch = log_batches_per_epoch - self.log_samples_per_batch = log_samples_per_batch - self.training_step_outputs = [] - self.validation_step_outputs = [] - self.val_cm = torchmetrics.classification.ConfusionMatrix( - task="multiclass", num_classes=self.n_classes - ) - + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = log_batches_per_epoch # Set the number of batches to log per epoch + self.log_samples_per_batch = log_samples_per_batch # Set the number of samples to log per batch + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = [] # Initialize the list of validation step outputs + # Initialize the confusion matrix for validation + self.val_cm = torchmetrics.classification.ConfusionMatrix(task="multiclass", num_classes=out_channels) + + # Define the forward pass def forward(self, x): - return self.unet_model(x) + return self.unet_model(x) # Pass the input through the UNet model + # Define the optimizer def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) # Use the Adam optimizer return optimizer + # Define the training step def training_step(self, batch: Sample, batch_idx: int): - - # Extract the input and target from the batch - source = batch["source"] - target = batch["target"] - pred = self.forward(source) - - # Convert the target image to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert target to float type - # Calculate the loss - train_loss = self.loss_function(pred, target_one_hot) + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - self.log( - "loss/train", - train_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return train_loss + self.training_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) + # Log the training loss + self.log("loss/train", train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + return train_loss # Return the training loss def validation_step(self, batch: Sample, batch_idx: int): - - # Extract the input and target from the batch - source = batch["source"] - target = batch["target"] - pred = self.forward(source) - - # Convert the target image to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert target to float type - # Calculate the loss - loss = self.loss_function(pred, target_one_hot) - self.val_cm(target_one_hot, pred) + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + self.val_cm(target_one_hot, pred) # Update the confusion matrix + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - self.log( - "loss/validate", - loss, - sync_dist=True, - add_dataloader_idx=False, - logger=True, - ) - return loss + self.validation_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) + # Log the validation loss + self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True) + return loss # Return the validation loss + # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) - return self._predict_pad.inverse(self.forward(source)) + source = self._predict_pad(batch["source"]) # Pad the source + return self._predict_pad.inverse(self.forward(source)) # Make a prediction and remove the padding + # Define what happens at the end of a training epoch def on_train_epoch_end(self): - self._log_samples("train_samples", self.training_step_outputs) - self.training_step_outputs = [] + self._log_samples("train_samples", self.training_step_outputs) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + # Define what happens at the end of a validation epoch def on_validation_epoch_end(self): - self._log_samples("val_samples", self.validation_step_outputs) - self.validation_step_outputs = [] - - # Log the confusion matrix at the end of the epoch - confusion_matrix = self.val_cm.compute().cpu().numpy() - - self.logger.experiment.add_figure( - "Validation Confusion Matrix", - ConfusionMatrixDisplay(confusion_matrix, self.index_to_label), - self.current_epoch, - ) - self.val_cm.reset() - + self._log_samples("val_samples", self.validation_step_outputs) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + # Compute the confusion matrix + confusion_matrix = self.val_cm.compute().cpu().numpy() + # Log the confusion matrix + self.logger.experiment.add_figure("Validation Confusion Matrix", ConfusionMatrixDisplay(confusion_matrix, self.index_to_label), self.current_epoch) + self.val_cm.reset() # Reset the confusion matrix + + # Define a method to detach a sample def _detach_sample(self, imgs: Sequence[Tensor]): - num_samples = 2 # min(imgs[0].shape[0], self.log_samples_per_batch) - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] + # Detach the images and convert them to numpy arrays + return [[np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] for i in range(self.log_samples_per_batch)] + # Define a method to log samples def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] - for sample_images in imgs: - images_row = [] - for i, image in enumerate(sample_images): - cm_name = "gray" if i == 0 else "inferno" - if image.ndim == 2: - image = image[np.newaxis] - for channel in image: - channel = rescale_intensity(channel, out_range=(0, 1)) - render = get_cmap(cm_name)(channel, bytes=True)[..., :3] - images_row.append(render) - images_grid.append(np.concatenate(images_row, axis=1)) - grid = np.concatenate(images_grid, axis=0) - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) - + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate(sample_images): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity(channel, out_range=(0, 1)) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[..., :3] # Render the channel + images_row.append(render) # Append the render to the list of image rows + images_grid.append(np.concatenate(images_row, axis=1)) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image(key, grid, self.current_epoch, dataformats="HWC") # %% Define the logger logger = TensorBoardLogger( From 20655d6411a569be1f2f5e4bab3063d1895c5af9 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 07:20:22 -0700 Subject: [PATCH 60/75] removed confusion matrix code, finding runtime error with model --- .../Infection_classification_model.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index f03595d7b..5b5412566 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -25,8 +25,9 @@ from viscy.data.hcs import Sample from viscy.transforms import RandWeightedCropd, RandGaussianNoised from viscy.transforms import NormalizeSampled -] + # %% Create a dataloader and visualize the batches. + # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr" @@ -87,24 +88,24 @@ # # Start the napari event loop # napari.run() -# %% use 2D Unet and Lightning module +# %% +# Define a 2D UNet model for semantic segmentation as a lightning module. -# Train the model -# Create a TensorBoard logger -class LightningUNet(pl.LightningModule): - # Initialize the class + +class SemanticSegUNet2D(pl.LightningModule): + # Model for semantic segmentation. def __init__( self, in_channels: int, # Number of input channels out_channels: int, # Number of output channels lr: float = 1e-3, # Learning rate - loss_function: nn.CrossEntropyLoss = None, # Loss function + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function schedule: Literal["WarmupCosine", "Constant"] = "Constant", # Learning rate schedule log_batches_per_epoch: int = 2, # Number of batches to log per epoch log_samples_per_batch: int = 2, # Number of samples to log per batch ): - super(LightningUNet, self).__init__() # Call the superclass initializer + super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer # Initialize the UNet model self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) self.lr = lr # Set the learning rate @@ -115,8 +116,7 @@ def __init__( self.log_samples_per_batch = log_samples_per_batch # Set the number of samples to log per batch self.training_step_outputs = [] # Initialize the list of training step outputs self.validation_step_outputs = [] # Initialize the list of validation step outputs - # Initialize the confusion matrix for validation - self.val_cm = torchmetrics.classification.ConfusionMatrix(task="multiclass", num_classes=out_channels) + # Define the forward pass def forward(self, x): @@ -151,7 +151,6 @@ def validation_step(self, batch: Sample, batch_idx: int): target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) target_one_hot = target_one_hot.float() # Convert the target to float type loss = self.loss_function(pred, target_one_hot) # Calculate the loss - self.val_cm(target_one_hot, pred) # Update the confusion matrix # Log the validation step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) @@ -173,11 +172,7 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self): self._log_samples("val_samples", self.validation_step_outputs) # Log the validation samples self.validation_step_outputs = [] # Reset the list of validation step outputs - # Compute the confusion matrix - confusion_matrix = self.val_cm.compute().cpu().numpy() - # Log the confusion matrix - self.logger.experiment.add_figure("Validation Confusion Matrix", ConfusionMatrixDisplay(confusion_matrix, self.index_to_label), self.current_epoch) - self.val_cm.reset() # Reset the confusion matrix + # TODO: Log the confusion matrix # Define a method to detach a sample def _detach_sample(self, imgs: Sequence[Tensor]): @@ -230,11 +225,15 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): trainer.callbacks.append(checkpoint_callback) # Fit the model -model = LightningUNet( +model = SemanticSegUNet2D( in_channels=2, out_channels=3, loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), ) -trainer.fit(model, data_module) + +print(model) +# %% +# Run training. +# trainer.fit(model, data_module) # %% From d022dae9b61f9683964960fb2e72f4f5dff7f707 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 08:05:39 -0700 Subject: [PATCH 61/75] moved scripts to viscy.scripts.infection_phenotyping module to enable imports across scripts --- .../infection_phenotyping/Infection_classification_model.py | 0 {examples => viscy/scripts}/infection_phenotyping/readme.md | 0 .../scripts}/infection_phenotyping/test_infection_classifier.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename {examples => viscy/scripts}/infection_phenotyping/Infection_classification_model.py (100%) rename {examples => viscy/scripts}/infection_phenotyping/readme.md (100%) rename {examples => viscy/scripts}/infection_phenotyping/test_infection_classifier.py (100%) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py similarity index 100% rename from examples/infection_phenotyping/Infection_classification_model.py rename to viscy/scripts/infection_phenotyping/Infection_classification_model.py diff --git a/examples/infection_phenotyping/readme.md b/viscy/scripts/infection_phenotyping/readme.md similarity index 100% rename from examples/infection_phenotyping/readme.md rename to viscy/scripts/infection_phenotyping/readme.md diff --git a/examples/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py similarity index 100% rename from examples/infection_phenotyping/test_infection_classifier.py rename to viscy/scripts/infection_phenotyping/test_infection_classifier.py From 901fd70c846542e155a33eaa24102287954d43be Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 10:03:26 -0700 Subject: [PATCH 62/75] combine the lightning modules for training and prediction, fix the DDP exception --- .../Infection_classification_model.py | 121 ++++++++++++++---- .../test_infection_classifier.py | 49 ++----- 2 files changed, 104 insertions(+), 66 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py index 5b5412566..4e442e0ff 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_model.py @@ -20,6 +20,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint # from monai.losses import DiceLoss +from monai.transforms import DivisiblePad + # from viscy.light.engine import VSUNet from viscy.unet.networks.Unet2D import Unet2d from viscy.data.hcs import Sample @@ -88,7 +90,7 @@ # # Start the napari event loop # napari.run() -# %% +# %% # Define a 2D UNet model for semantic segmentation as a lightning module. @@ -101,9 +103,12 @@ def __init__( out_channels: int, # Number of output channels lr: float = 1e-3, # Learning rate loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function - schedule: Literal["WarmupCosine", "Constant"] = "Constant", # Learning rate schedule + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule log_batches_per_epoch: int = 2, # Number of batches to log per epoch log_samples_per_batch: int = 2, # Number of samples to log per batch + checkpoint_path: str = None, # Path to the checkpoint ): super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer # Initialize the UNet model @@ -112,11 +117,23 @@ def __init__( # Set the loss function to CrossEntropyLoss if none is provided self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() self.schedule = schedule # Set the learning rate schedule - self.log_batches_per_epoch = log_batches_per_epoch # Set the number of batches to log per epoch - self.log_samples_per_batch = log_samples_per_batch # Set the number of samples to log per batch + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) self.training_step_outputs = [] # Initialize the list of training step outputs - self.validation_step_outputs = [] # Initialize the list of validation step outputs - + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + if checkpoint_path is not None: + state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights # Define the forward pass def forward(self, x): @@ -124,7 +141,9 @@ def forward(self, x): # Define the optimizer def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) # Use the Adam optimizer + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer return optimizer # Define the training step @@ -133,14 +152,26 @@ def training_step(self, batch: Sample, batch_idx: int): target = batch["target"] # Extract the target from the batch pred = self.forward(source) # Make a prediction using the source # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) target_one_hot = target_one_hot.float() # Convert the target to float type train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss # Log the training step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) # Log the training loss - self.log("loss/train", train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) return train_loss # Return the training loss def validation_step(self, batch: Sample, batch_idx: int): @@ -148,54 +179,92 @@ def validation_step(self, batch: Sample, batch_idx: int): target = batch["target"] # Extract the target from the batch pred = self.forward(source) # Make a prediction using the source # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) target_one_hot = target_one_hot.float() # Convert the target to float type loss = self.loss_function(pred, target_one_hot) # Calculate the loss # Log the validation step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) # Log the validation loss - self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True) + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) return loss # Return the validation loss + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) # Pad the source - return self._predict_pad.inverse(self.forward(source)) # Make a prediction and remove the padding + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_map = F.softmax(logits, dim=1) # Calculate the probabilities + return prob_map # return the probabilities for computing metrics. # Define what happens at the end of a training epoch def on_train_epoch_end(self): - self._log_samples("train_samples", self.training_step_outputs) # Log the training samples + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples self.training_step_outputs = [] # Reset the list of training step outputs # Define what happens at the end of a validation epoch def on_validation_epoch_end(self): - self._log_samples("val_samples", self.validation_step_outputs) # Log the validation samples + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples self.validation_step_outputs = [] # Reset the list of validation step outputs # TODO: Log the confusion matrix # Define a method to detach a sample def _detach_sample(self, imgs: Sequence[Tensor]): # Detach the images and convert them to numpy arrays - return [[np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] for i in range(self.log_samples_per_batch)] + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] # Define a method to log samples def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): images_grid = [] # Initialize the list of image grids for sample_images in imgs: # For each sample image images_row = [] # Initialize the list of image rows - for i, image in enumerate(sample_images): # For each image in the sample images + for i, image in enumerate( + sample_images + ): # For each image in the sample images cm_name = "gray" if i == 0 else "inferno" # Set the colormap name if image.ndim == 2: # If the image is 2D image = image[np.newaxis] # Add a new axis for channel in image: # For each channel in the image - channel = rescale_intensity(channel, out_range=(0, 1)) # Rescale the intensity of the channel - render = get_cmap(cm_name)(channel, bytes=True)[..., :3] # Render the channel - images_row.append(render) # Append the render to the list of image rows - images_grid.append(np.concatenate(images_row, axis=1)) # Append the concatenated image rows to the list of image grids + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids # Log the image grid - self.logger.experiment.add_image(key, grid, self.current_epoch, dataformats="HWC") + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + # %% Define the logger logger = TensorBoardLogger( @@ -209,6 +278,7 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): max_epochs=100, default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) # Define the checkpoint callback @@ -232,8 +302,9 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): ) print(model) -# %% +# %% # Run training. -# trainer.fit(model, data_module) + +trainer.fit(model, data_module) # %% diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index d8918f829..0780b33e5 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -8,7 +8,9 @@ import torch import torchmetrics from viscy.light.predict_writer import HCSPredictionWriter -from monai.transforms import DivisiblePad +from viscy.scripts.infection_phenotyping.Infection_classification_model import ( + SemanticSegUNet2D, +) # %% test the model on the test set test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" @@ -37,52 +39,17 @@ data_module.setup(stage="predict") -# %% -class LightningUNet(pl.LightningModule): - def __init__( - self, - in_channels, - out_channels, - ckpt_path, - ): - super(LightningUNet, self).__init__() - self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) - # self.pred_cm = torchmetrics.classification.ConfusionMatrix( - # task="multiclass", num_classes=self.n_classes - # ) - if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights - - def forward(self, x): - return self.unet_model(x) - - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) - pred_class = self.forward(source) - pred_int = torch.argmax(pred_class, dim=1, keepdim=True) - return pred_int - - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. - The inverse of this transform crops the prediction to original shape. - """ - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - # %% create trainer and input -output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred.zarr" +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SM.zarr" trainer = pl.Trainer( default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", callbacks=[HCSPredictionWriter(output_path, write_input=False)], + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) -model = LightningUNet( + +model = SemanticSegUNet2D( in_channels=2, out_channels=3, ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", @@ -94,4 +61,4 @@ def on_predict_start(self): return_predictions=True, ) -# %% test the model on the test set and write to zarr store \ No newline at end of file +# %% test the model on the test set and write to zarr store From 708a67ab3990dc31607b412cded0598dc47a4683 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 12:34:43 -0700 Subject: [PATCH 63/75] all the stubs for computing and logging confusion matrix per cell --- .../Infection_classification_model.py | 128 +++++++++++++++++- .../test_infection_classifier.py | 2 - 2 files changed, 125 insertions(+), 5 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py index 4e442e0ff..ee60fd057 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_model.py @@ -21,6 +21,7 @@ # from monai.losses import DiceLoss from monai.transforms import DivisiblePad +from skimage.measure import regionprops # from viscy.light.engine import VSUNet from viscy.unet.networks.Unet2D import Unet2d @@ -128,6 +129,9 @@ def __init__( [] ) # Initialize the list of validation step outputs + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Background", "Infected", "Uninfected"] + if checkpoint_path is not None: state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ "state_dict" @@ -205,11 +209,29 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) # Pad the source + target = batch["target"] # Extract the target from the batch logits = self._predict_pad.inverse( self.forward(source) ) # Predict and remove padding. - prob_map = F.softmax(logits, dim=1) # Calculate the probabilities - return prob_map # return the probabilities for computing metrics. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels + labels_target = torch.argmax(target, dim=1) # Calculate the target labels + # FIXME: Check if compliant with lightning API + self.pred_cm = confusion_matrix_per_cell( + labels_target, labels_pred, num_classes=3 + ) + + return prob_pred # log the probabilities instead of logits. + + # Accumulate the confusion matrix at the end of prediction epoch and log. + def on_predict_epoch_end(self): + confusion_matrix = self.pred_cm.compute().cpu().numpy() + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), + self.current_epoch, + ) # Define what happens at the end of a training epoch def on_train_epoch_end(self): @@ -307,4 +329,104 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): trainer.fit(model, data_module) -# %% +# %% Methods to compute confusion matrix per cell using torchmetrics + + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) + # Compute the confusion matrix per cell + confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( + nuclei_true(nuclei_true > 0), # indexing just non-background pixels. + nuclei_pred(nuclei_true > 0), + num_classes=num_classes, + task="multi_class", + ) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def image_class_to_nuclei_class( + y_true: torch.Tonser, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + nuclei_true = torch.zeros_like(y_true) + nuclie_pred = torch.zeros_like(y_pred) + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + for i in range(batch_size): + regions = regionprops(y_true[i].cpu().numpy()) + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + centroid = region.centroid + pixel_ids = region.coords + # Find the class of the nuclei in the ground truth and prediction. + pix_labels_true = y_true[i, pixel_ids[:, 0], pixel_ids[:, 1]] + consensus_class_true = np.mode(pix_labels_true[:]) + + pix_labels_pred = y_pred[i, pixel_ids[:, 0], pixel_ids[:, 1]] + consensus_class_pred = np.mode(pix_labels_pred[:]) + nuclei_true[i, centroid[0], centroid[1]] = consensus_class_true + nuclei_pred[i, centroid[0], centroid[1]] = consensus_class_pred + + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + + return nuclei_true, nuclei_pred + + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + return fig diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index 0780b33e5..14bff35c7 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -60,5 +60,3 @@ datamodule=data_module, return_predictions=True, ) - -# %% test the model on the test set and write to zarr store From 6bb9ca38fe981ba4fe4036ec4fdac7bf80ebf282 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 1 Apr 2024 15:02:05 -0700 Subject: [PATCH 64/75] separated training and test scripts --- .../Infection_classification_model.py | 324 +----------------- .../test_infection_classifier.py | 52 ++- 2 files changed, 29 insertions(+), 347 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py index ee60fd057..d8056a044 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_model.py @@ -1,33 +1,15 @@ # %% import torch -from viscy.data.hcs import HCSDataModule - -import numpy as np -import torch.nn as nn import lightning.pytorch as pl -import torch.nn.functional as F -import torchmetrics - -# import torchview -from typing import Literal, Sequence -from skimage.exposure import rescale_intensity -from sklearn.metrics import ConfusionMatrixDisplay -from matplotlib.cm import get_cmap +import torch.nn as nn -# import napari from pytorch_lightning.loggers import TensorBoardLogger -from torch import Tensor from pytorch_lightning.callbacks import ModelCheckpoint -# from monai.losses import DiceLoss -from monai.transforms import DivisiblePad -from skimage.measure import regionprops - -# from viscy.light.engine import VSUNet -from viscy.unet.networks.Unet2D import Unet2d -from viscy.data.hcs import Sample -from viscy.transforms import RandWeightedCropd, RandGaussianNoised +from viscy.transforms import RandWeightedCropd from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D # %% Create a dataloader and visualize the batches. @@ -91,202 +73,6 @@ # # Start the napari event loop # napari.run() -# %% - -# Define a 2D UNet model for semantic segmentation as a lightning module. - - -class SemanticSegUNet2D(pl.LightningModule): - # Model for semantic segmentation. - def __init__( - self, - in_channels: int, # Number of input channels - out_channels: int, # Number of output channels - lr: float = 1e-3, # Learning rate - loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function - schedule: Literal[ - "WarmupCosine", "Constant" - ] = "Constant", # Learning rate schedule - log_batches_per_epoch: int = 2, # Number of batches to log per epoch - log_samples_per_batch: int = 2, # Number of samples to log per batch - checkpoint_path: str = None, # Path to the checkpoint - ): - super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer - # Initialize the UNet model - self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) - self.lr = lr # Set the learning rate - # Set the loss function to CrossEntropyLoss if none is provided - self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() - self.schedule = schedule # Set the learning rate schedule - self.log_batches_per_epoch = ( - log_batches_per_epoch # Set the number of batches to log per epoch - ) - self.log_samples_per_batch = ( - log_samples_per_batch # Set the number of samples to log per batch - ) - self.training_step_outputs = [] # Initialize the list of training step outputs - self.validation_step_outputs = ( - [] - ) # Initialize the list of validation step outputs - - self.pred_cm = None # Initialize the confusion matrix - self.index_to_label_dict = ["Background", "Infected", "Uninfected"] - - if checkpoint_path is not None: - state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights - - # Define the forward pass - def forward(self, x): - return self.unet_model(x) # Pass the input through the UNet model - - # Define the optimizer - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), lr=self.lr - ) # Use the Adam optimizer - return optimizer - - # Define the training step - def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the training step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the training loss - self.log( - "loss/train", - train_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return train_loss # Return the training loss - - def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the validation step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the validation loss - self.log( - "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True - ) - return loss # Return the validation loss - - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. - The inverse of this transform crops the prediction to original shape. - """ - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - # Define the prediction step - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) # Pad the source - target = batch["target"] # Extract the target from the batch - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels - labels_target = torch.argmax(target, dim=1) # Calculate the target labels - # FIXME: Check if compliant with lightning API - self.pred_cm = confusion_matrix_per_cell( - labels_target, labels_pred, num_classes=3 - ) - - return prob_pred # log the probabilities instead of logits. - - # Accumulate the confusion matrix at the end of prediction epoch and log. - def on_predict_epoch_end(self): - confusion_matrix = self.pred_cm.compute().cpu().numpy() - self.logger.experiment.add_figure( - "Confusion Matrix per Cell", - plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), - self.current_epoch, - ) - - # Define what happens at the end of a training epoch - def on_train_epoch_end(self): - self._log_samples( - "train_samples", self.training_step_outputs - ) # Log the training samples - self.training_step_outputs = [] # Reset the list of training step outputs - - # Define what happens at the end of a validation epoch - def on_validation_epoch_end(self): - self._log_samples( - "val_samples", self.validation_step_outputs - ) # Log the validation samples - self.validation_step_outputs = [] # Reset the list of validation step outputs - # TODO: Log the confusion matrix - - # Define a method to detach a sample - def _detach_sample(self, imgs: Sequence[Tensor]): - # Detach the images and convert them to numpy arrays - num_samples = 3 - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] - - # Define a method to log samples - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] # Initialize the list of image grids - for sample_images in imgs: # For each sample image - images_row = [] # Initialize the list of image rows - for i, image in enumerate( - sample_images - ): # For each image in the sample images - cm_name = "gray" if i == 0 else "inferno" # Set the colormap name - if image.ndim == 2: # If the image is 2D - image = image[np.newaxis] # Add a new axis - for channel in image: # For each channel in the image - channel = rescale_intensity( - channel, out_range=(0, 1) - ) # Rescale the intensity of the channel - render = get_cmap(cm_name)(channel, bytes=True)[ - ..., :3 - ] # Render the channel - images_row.append( - render - ) # Append the render to the list of image rows - images_grid.append( - np.concatenate(images_row, axis=1) - ) # Append the concatenated image rows to the list of image grids - grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids - # Log the image grid - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) - # %% Define the logger logger = TensorBoardLogger( @@ -328,105 +114,3 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Run training. trainer.fit(model, data_module) - -# %% Methods to compute confusion matrix per cell using torchmetrics - - -# The confusion matrix at the single-cell resolution. -def confusion_matrix_per_cell( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Compute confusion matrix per cell. - - Args: - y_true (torch.Tensor): Ground truth label image (BXHXW). - y_pred (torch.Tensor): Predicted label image (BXHXW). - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Confusion matrix per cell (BXCXC). - """ - # Convert the image class to the nuclei class - nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) - # Compute the confusion matrix per cell - confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( - nuclei_true(nuclei_true > 0), # indexing just non-background pixels. - nuclei_pred(nuclei_true > 0), - num_classes=num_classes, - task="multi_class", - ) - return confusion_matrix_per_cell - - -# These images can be logged with prediction. -def image_class_to_nuclei_class( - y_true: torch.Tonser, y_pred: torch.Tensor, num_classes: int -): - """Convert the class of the image to the class of the nuclei. - - Args: - label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Label images with a consensus class at the centroid of nuclei. - """ - nuclei_true = torch.zeros_like(y_true) - nuclie_pred = torch.zeros_like(y_pred) - batch_size = y_true.size(0) - # find centroids of nuclei from y_true - for i in range(batch_size): - regions = regionprops(y_true[i].cpu().numpy()) - # Find centroids, pixel coordinates from the ground truth. - for region in regions: - centroid = region.centroid - pixel_ids = region.coords - # Find the class of the nuclei in the ground truth and prediction. - pix_labels_true = y_true[i, pixel_ids[:, 0], pixel_ids[:, 1]] - consensus_class_true = np.mode(pix_labels_true[:]) - - pix_labels_pred = y_pred[i, pixel_ids[:, 0], pixel_ids[:, 1]] - consensus_class_pred = np.mode(pix_labels_pred[:]) - nuclei_true[i, centroid[0], centroid[1]] = consensus_class_true - nuclei_pred[i, centroid[0], centroid[1]] = consensus_class_pred - - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - - return nuclei_true, nuclei_pred - - -def plot_confusion_matrix(confusion_matrix, index_to_label_dict): - # Create a figure and axis to plot the confusion matrix - fig, ax = plt.subplots() - - # Create a color heatmap for the confusion matrix - cax = ax.matshow(confusion_matrix, cmap="viridis") - - # Create a colorbar and set the label - fig.colorbar(cax, label="Frequency") - - # Set labels for the classes - - ax.set_xticks(np.arange(len(index_to_label_dict))) - ax.set_yticks(np.arange(len(index_to_label_dict))) - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) - ax.set_yticklabels(index_to_label_dict.values()) - - # Set labels for the axes - ax.set_xlabel("Predicted") - ax.set_ylabel("True") - - # Add text annotations to the confusion matrix - for i in range(len(index_to_label_dict)): - for j in range(len(index_to_label_dict)): - ax.text( - j, - i, - str(int(confusion_matrix[i, j])), - ha="center", - va="center", - color="white", - ) - - return fig diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index 14bff35c7..c12552ca5 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -1,51 +1,37 @@ # %% -import numpy as np from viscy.data.hcs import HCSDataModule -from viscy.transforms import NormalizeSampled -from viscy.unet.networks.Unet2D import Unet2d -from viscy.data.hcs import Sample import lightning.pytorch as pl -import torch -import torchmetrics from viscy.light.predict_writer import HCSPredictionWriter -from viscy.scripts.infection_phenotyping.Infection_classification_model import ( - SemanticSegUNet2D, -) +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from pytorch_lightning.loggers import TensorBoardLogger # %% test the model on the test set test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" data_module = HCSDataModule( - test_datapath, - source_channel=["Sensor", "Phase"], - target_channel=["inf_mask"], + data_path=test_datapath, + source_channel=['Sensor','Phase'], + target_channel=['Inf_mask'], split_ratio=0.8, z_window_size=1, architecture="2D", - num_workers=1, + num_workers=0, batch_size=1, - normalizations=[ - NormalizeSampled( - keys=["Sensor", "Phase"], - level="fov_statistics", - subtrahend="median", - divisor="iqr", - ) - ], ) -# Prepare the data -data_module.prepare_data() - -data_module.setup(stage="predict") +data_module.setup(stage="test") # %% create trainer and input -output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SM.zarr" +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/", + name="logs_wPhase", +) trainer = pl.Trainer( + logger=logger, default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", - callbacks=[HCSPredictionWriter(output_path, write_input=False)], + log_every_n_steps=1, devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) @@ -55,6 +41,18 @@ ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", ) +trainer.test(model=model, datamodule=data_module) + +# %% predict the test set + +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SP.zarr" + +trainer = pl.Trainer( + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + callbacks=[HCSPredictionWriter(output_path, write_input=False)], + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + trainer.predict( model=model, datamodule=data_module, From 99a387645882e7b66d439d649a35316736856c6c Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 1 Apr 2024 15:14:16 -0700 Subject: [PATCH 65/75] lightning module --- .../classify_infection.py | 332 ++++++++++++++++++ 1 file changed, 332 insertions(+) create mode 100644 viscy/scripts/infection_phenotyping/classify_infection.py diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection.py new file mode 100644 index 000000000..10f263404 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection.py @@ -0,0 +1,332 @@ + +import torch +import torch.nn as nn +import lightning.pytorch as pl +import torch.nn.functional as F +from torch import Tensor +import torchmetrics +from statistics import mode + +# import torchview +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap +from skimage.measure import regionprops +import numpy as np +import matplotlib.pyplot as plt + +from monai.transforms import DivisiblePad +from viscy.unet.networks.Unet2D import Unet2d +from viscy.data.hcs import Sample + + +# %% Methods to compute confusion matrix per cell using torchmetrics + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) + # Compute the confusion matrix per cell + confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( + nuclei_true[nuclei_true > 0], # indexing just non-background pixels. + nuclei_pred[nuclei_true > 0], + num_classes=num_classes, + task="multiclass", # Fix: Change "multi_class" to "multiclass" + ) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def image_class_to_nuclei_class( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + nuclei_true = torch.zeros_like(y_true[:, 0, 0, :, :]) + nuclei_pred = torch.zeros_like(y_pred[:, 0, : , :]) + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + print(y_true_reshaped.shape) + regions = regionprops(y_true_reshaped.astype(int)) + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + centroid = region.centroid + pixel_ids = region.coords + # Find the class of the nuclei in the ground truth and prediction. + pix_labels_true = y_true[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] + consensus_class_true = mode(pix_labels_true[:]) + + pix_labels_pred = y_pred[i, 0, pixel_ids[:, 0], pixel_ids[:, 1]] + consensus_class_pred = mode(pix_labels_pred[:]) + nuclei_true[i, pixel_ids[0], pixel_ids[1]] = torch.FloatTensor([consensus_class_true]).to(y_true.dtype) + nuclei_pred[i, pixel_ids[0], pixel_ids[1]] = torch.FloatTensor([consensus_class_pred]).to(y_pred.dtype) + + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + + return nuclei_true, nuclei_pred + +# Define a 2d unet model for infection classification as a lightning module. + +class SemanticSegUNet2D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Background", "Infected", "Uninfected"] + + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def test_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels + self.pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=3 + ) + + return self.pred_cm + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels + + return labels_pred # log the class predicted image + + # Accumulate the confusion matrix at the end of prediction epoch and log. + def on_test_epoch_end(self): + confusion_matrix = self.pred_cm.compute().cpu().numpy() + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + # TODO: Log the confusion matrix + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + plt.show(fig) # Show the figure + return fig \ No newline at end of file From 000a966bec2df456f2eeb42bc87ec017ec39669d Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 2 Apr 2024 16:38:55 -0700 Subject: [PATCH 66/75] corrected test cm compute --- .../Infection_classification_model.py | 2 + .../classify_infection.py | 163 ++++++++++-------- .../test_infection_classifier.py | 11 +- 3 files changed, 105 insertions(+), 71 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py index d8056a044..37f606cb2 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_model.py @@ -114,3 +114,5 @@ # Run training. trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection.py index 10f263404..51cac9851 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection.py +++ b/viscy/scripts/infection_phenotyping/classify_infection.py @@ -6,6 +6,7 @@ from torch import Tensor import torchmetrics from statistics import mode +# import napari # import torchview from typing import Literal, Sequence @@ -43,7 +44,7 @@ def confusion_matrix_per_cell( nuclei_true[nuclei_true > 0], # indexing just non-background pixels. nuclei_pred[nuclei_true > 0], num_classes=num_classes, - task="multiclass", # Fix: Change "multi_class" to "multiclass" + task="multiclass", ) return confusion_matrix_per_cell @@ -62,7 +63,8 @@ def image_class_to_nuclei_class( torch.Tensor: Label images with a consensus class at the centroid of nuclei. """ nuclei_true = torch.zeros_like(y_true[:, 0, 0, :, :]) - nuclei_pred = torch.zeros_like(y_pred[:, 0, : , :]) + nuclei_pred = torch.zeros_like(y_pred[:, 0, 0, :, :]) + batch_size = y_true.size(0) # find centroids of nuclei from y_true for i in range(batch_size): @@ -78,16 +80,51 @@ def image_class_to_nuclei_class( pix_labels_true = y_true[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] consensus_class_true = mode(pix_labels_true[:]) - pix_labels_pred = y_pred[i, 0, pixel_ids[:, 0], pixel_ids[:, 1]] + pix_labels_pred = y_pred[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] consensus_class_pred = mode(pix_labels_pred[:]) - nuclei_true[i, pixel_ids[0], pixel_ids[1]] = torch.FloatTensor([consensus_class_true]).to(y_true.dtype) - nuclei_pred[i, pixel_ids[0], pixel_ids[1]] = torch.FloatTensor([consensus_class_pred]).to(y_pred.dtype) + nuclei_true[i, int(centroid[0]), int(centroid[1])] = torch.FloatTensor([consensus_class_true]).to(y_true.dtype) + nuclei_pred[i, int(centroid[0]), int(centroid[1])] = torch.FloatTensor([consensus_class_pred]).to(y_pred.dtype) # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. return nuclei_true, nuclei_pred +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig # Define a 2d unet model for infection classification as a lightning module. class SemanticSegUNet2D(pl.LightningModule): @@ -108,6 +145,12 @@ def __init__( super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer # Initialize the UNet model self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights self.lr = lr # Set the learning rate # Set the loss function to CrossEntropyLoss if none is provided self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() @@ -126,12 +169,7 @@ def __init__( self.pred_cm = None # Initialize the confusion matrix self.index_to_label_dict = ["Background", "Infected", "Uninfected"] - if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights + # Define the forward pass def forward(self, x): @@ -193,23 +231,6 @@ def validation_step(self, batch: Sample, batch_idx: int): ) return loss # Return the validation loss - def test_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels - self.pred_cm = confusion_matrix_per_cell( - target, labels_pred, num_classes=3 - ) - - return self.pred_cm - def on_predict_start(self): """Pad the input shape to be divisible by the downsampling factor. The inverse of this transform crops the prediction to original shape. @@ -230,16 +251,55 @@ def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels return labels_pred # log the class predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((3, 3)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self.forward(source) + # prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(logits, dim=1, keepdim=True) # Calculate the predicted labels + + # pred_img = logits.detach().cpu().numpy() + # v = napari.Viewer() + # v.add_image(pred_img) + # napari.run() - # Accumulate the confusion matrix at the end of prediction epoch and log. - def on_test_epoch_end(self): - confusion_matrix = self.pred_cm.compute().cpu().numpy() + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=3 + ) # Calculate the confusion matrix per cell + + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm self.logger.experiment.add_figure( "Confusion Matrix per Cell", - plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), self.current_epoch, ) + # def on_test_batch_end(self): + # # confusion_matrix_sum = torch.zeros((3, 3)) # Initialize the sum of confusion matrices + # # for pred_cm in self.pred_cm: # For each confusion matrix + # # confusion_matrix_sum += pred_cm # Accumulate the sum + # # confusion_matrix_sum = confusion_matrix_sum.cpu().numpy() # Convert to numpy array + # confusion_matrix_sum = torch.sum(torch.stack([tensor.cpu() for tensor in self.pred_cm], dim=0), dim=0) + # self.logger.experiment.add_figure( + # "Confusion Matrix batch-wise", + # plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + # self.current_epoch, + # ) # Define what happens at the end of a training epoch def on_train_epoch_end(self): @@ -254,7 +314,6 @@ def on_validation_epoch_end(self): "val_samples", self.validation_step_outputs ) # Log the validation samples self.validation_step_outputs = [] # Reset the list of validation step outputs - # TODO: Log the confusion matrix # Define a method to detach a sample def _detach_sample(self, imgs: Sequence[Tensor]): @@ -293,40 +352,4 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Log the image grid self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" - ) - - def plot_confusion_matrix(confusion_matrix, index_to_label_dict): - # Create a figure and axis to plot the confusion matrix - fig, ax = plt.subplots() - - # Create a color heatmap for the confusion matrix - cax = ax.matshow(confusion_matrix, cmap="viridis") - - # Create a colorbar and set the label - fig.colorbar(cax, label="Frequency") - - # Set labels for the classes - - ax.set_xticks(np.arange(len(index_to_label_dict))) - ax.set_yticks(np.arange(len(index_to_label_dict))) - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) - ax.set_yticklabels(index_to_label_dict.values()) - - # Set labels for the axes - ax.set_xlabel("Predicted") - ax.set_ylabel("True") - - # Add text annotations to the confusion matrix - for i in range(len(index_to_label_dict)): - for j in range(len(index_to_label_dict)): - ax.text( - j, - i, - str(int(confusion_matrix[i, j])), - ha="center", - va="center", - color="white", - ) - - plt.show(fig) # Show the figure - return fig \ No newline at end of file + ) \ No newline at end of file diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index c12552ca5..fadd3e750 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -4,6 +4,7 @@ from viscy.light.predict_writer import HCSPredictionWriter from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D from pytorch_lightning.loggers import TensorBoardLogger +from viscy.transforms import NormalizeSampled # %% test the model on the test set test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" @@ -17,6 +18,14 @@ architecture="2D", num_workers=0, batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], ) data_module.setup(stage="test") @@ -38,7 +47,7 @@ model = SemanticSegUNet2D( in_channels=2, out_channels=3, - ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_74/checkpoints/epoch=99-step=300.ckpt", ) trainer.test(model=model, datamodule=data_module) From 688336e0d80953a1598cc7b1cb738a93b9817efd Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 3 Apr 2024 11:16:31 -0700 Subject: [PATCH 67/75] corrected test module --- .../classify_infection.py | 35 ++++++------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection.py index 51cac9851..497efe961 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection.py +++ b/viscy/scripts/infection_phenotyping/classify_infection.py @@ -4,9 +4,10 @@ import lightning.pytorch as pl import torch.nn.functional as F from torch import Tensor -import torchmetrics +# from torchmetrics.functional import confusion_matrix from statistics import mode # import napari +from sklearn.metrics import ConfusionMatrixDisplay # import torchview from typing import Literal, Sequence @@ -39,13 +40,17 @@ def confusion_matrix_per_cell( """ # Convert the image class to the nuclei class nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) + + nuclei_true_np = nuclei_true.cpu().numpy() + nuclei_pred_np = nuclei_pred.cpu().numpy() + # Compute the confusion matrix per cell - confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( - nuclei_true[nuclei_true > 0], # indexing just non-background pixels. - nuclei_pred[nuclei_true > 0], - num_classes=num_classes, - task="multiclass", + confusion_matrix_per_cell = ConfusionMatrixDisplay.from_predictions( + nuclei_true_np[nuclei_true_np > 0], # indexing just non-background pixels. + nuclei_pred_np[nuclei_true_np > 0], + labels=range(num_classes), ) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) return confusion_matrix_per_cell @@ -70,7 +75,6 @@ def image_class_to_nuclei_class( for i in range(batch_size): y_true_cpu = y_true[i].cpu().numpy() y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) - print(y_true_reshaped.shape) regions = regionprops(y_true_reshaped.astype(int)) # Find centroids, pixel coordinates from the ground truth. for region in regions: @@ -263,16 +267,10 @@ def test_step(self, batch: Sample): # prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities labels_pred = torch.argmax(logits, dim=1, keepdim=True) # Calculate the predicted labels - # pred_img = logits.detach().cpu().numpy() - # v = napari.Viewer() - # v.add_image(pred_img) - # napari.run() - target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( target, labels_pred, num_classes=3 ) # Calculate the confusion matrix per cell - self.pred_cm += pred_cm # Append the confusion matrix to pred_cm self.logger.experiment.add_figure( @@ -289,17 +287,6 @@ def on_test_end(self): plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), self.current_epoch, ) - # def on_test_batch_end(self): - # # confusion_matrix_sum = torch.zeros((3, 3)) # Initialize the sum of confusion matrices - # # for pred_cm in self.pred_cm: # For each confusion matrix - # # confusion_matrix_sum += pred_cm # Accumulate the sum - # # confusion_matrix_sum = confusion_matrix_sum.cpu().numpy() # Convert to numpy array - # confusion_matrix_sum = torch.sum(torch.stack([tensor.cpu() for tensor in self.pred_cm], dim=0), dim=0) - # self.logger.experiment.add_figure( - # "Confusion Matrix batch-wise", - # plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), - # self.current_epoch, - # ) # Define what happens at the end of a training epoch def on_train_epoch_end(self): From 6b58f34dd951ad7ddccde50b8e3a6ba0b4bb31e9 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 3 Apr 2024 11:18:09 -0700 Subject: [PATCH 68/75] separated test and prediction scripts --- .../predict_infection_classifier.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 viscy/scripts/infection_phenotyping/predict_infection_classifier.py diff --git a/viscy/scripts/infection_phenotyping/predict_infection_classifier.py b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py new file mode 100644 index 000000000..14600325e --- /dev/null +++ b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py @@ -0,0 +1,54 @@ + + +from viscy.light.predict_writer import HCSPredictionWriter +from viscy.data.hcs import HCSDataModule +import lightning.pytorch as pl +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from viscy.transforms import NormalizeSampled + +# %% # %% write the predictions to a zarr file + +pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +data_module = HCSDataModule( + data_path=pred_datapath, + source_channel=['Sensor','Phase'], + target_channel=['Inf_mask'], + split_ratio=0.8, + z_window_size=1, + architecture="2D", + num_workers=0, + batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) + +data_module.setup(stage="predict") + +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_74/checkpoints/epoch=99-step=300.ckpt", +) + +# %% perform prediction + +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SP.zarr" + +trainer = pl.Trainer( + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + callbacks=[HCSPredictionWriter(output_path, write_input=False)], + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +trainer.predict( + model=model, + datamodule=data_module, + return_predictions=True, +) From b6ad254bf5bcb7e6fe78b6cd4d4c38d075c4a0b2 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 4 Apr 2024 21:56:44 -0700 Subject: [PATCH 69/75] changed confusion matrix compute --- .../classify_infection.py | 88 ++++++++--------- .../predict_infection_classifier.py | 12 ++- .../test_infection_classifier.py | 96 ++++++++++++++++--- 3 files changed, 129 insertions(+), 67 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection.py index 497efe961..ee904434a 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection.py +++ b/viscy/scripts/infection_phenotyping/classify_infection.py @@ -4,16 +4,13 @@ import lightning.pytorch as pl import torch.nn.functional as F from torch import Tensor -# from torchmetrics.functional import confusion_matrix -from statistics import mode -# import napari -from sklearn.metrics import ConfusionMatrixDisplay +import cv2 # import torchview from typing import Literal, Sequence from skimage.exposure import rescale_intensity from matplotlib.cm import get_cmap -from skimage.measure import regionprops +from skimage.measure import regionprops, label import numpy as np import matplotlib.pyplot as plt @@ -21,7 +18,7 @@ from viscy.unet.networks.Unet2D import Unet2d from viscy.data.hcs import Sample - +# # %% Methods to compute confusion matrix per cell using torchmetrics # The confusion matrix at the single-cell resolution. @@ -39,23 +36,13 @@ def confusion_matrix_per_cell( torch.Tensor: Confusion matrix per cell (BXCXC). """ # Convert the image class to the nuclei class - nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) - - nuclei_true_np = nuclei_true.cpu().numpy() - nuclei_pred_np = nuclei_pred.cpu().numpy() - - # Compute the confusion matrix per cell - confusion_matrix_per_cell = ConfusionMatrixDisplay.from_predictions( - nuclei_true_np[nuclei_true_np > 0], # indexing just non-background pixels. - nuclei_pred_np[nuclei_true_np > 0], - labels=range(num_classes), - ) + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) return confusion_matrix_per_cell # These images can be logged with prediction. -def image_class_to_nuclei_class( +def compute_confusion_matrix( y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int ): """Convert the class of the image to the class of the nuclei. @@ -67,32 +54,40 @@ def image_class_to_nuclei_class( Returns: torch.Tensor: Label images with a consensus class at the centroid of nuclei. """ - nuclei_true = torch.zeros_like(y_true[:, 0, 0, :, :]) - nuclei_pred = torch.zeros_like(y_pred[:, 0, 0, :, :]) batch_size = y_true.size(0) # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) for i in range(batch_size): y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) - regions = regionprops(y_true_reshaped.astype(int)) - # Find centroids, pixel coordinates from the ground truth. - for region in regions: - centroid = region.centroid - pixel_ids = region.coords - # Find the class of the nuclei in the ground truth and prediction. - pix_labels_true = y_true[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] - consensus_class_true = mode(pix_labels_true[:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) - pix_labels_pred = y_pred[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] - consensus_class_pred = mode(pix_labels_pred[:]) - nuclei_true[i, int(centroid[0]), int(centroid[1])] = torch.FloatTensor([consensus_class_true]).to(y_true.dtype) - nuclei_pred[i, int(centroid[0]), int(centroid[1])] = torch.FloatTensor([consensus_class_pred]).to(y_pred.dtype) + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1,1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0,1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1,0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0,0] += 1 # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - - return nuclei_true, nuclei_pred + return conf_mat def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # Create a figure and axis to plot the confusion matrix @@ -171,7 +166,7 @@ def __init__( ) # Initialize the list of validation step outputs self.pred_cm = None # Initialize the confusion matrix - self.index_to_label_dict = ["Background", "Infected", "Uninfected"] + self.index_to_label_dict = ["Infected", "Uninfected"] @@ -244,32 +239,28 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. + logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels - + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + return labels_pred # log the class predicted image def on_test_start(self): - self.pred_cm = torch.zeros((3, 3)) + self.pred_cm = torch.zeros((2,2)) down_factor = 2**self.unet_model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) def test_step(self, batch: Sample): source = self._predict_pad(batch["source"]) # Pad the source - logits = self.forward(source) - # prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax(logits, dim=1, keepdim=True) # Calculate the predicted labels + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( - target, labels_pred, num_classes=3 + target, labels_pred, num_classes=2 ) # Calculate the confusion matrix per cell self.pred_cm += pred_cm # Append the confusion matrix to pred_cm @@ -339,4 +330,5 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Log the image grid self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" - ) \ No newline at end of file + ) +# %% diff --git a/viscy/scripts/infection_phenotyping/predict_infection_classifier.py b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py index 14600325e..783c13340 100644 --- a/viscy/scripts/infection_phenotyping/predict_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py @@ -1,4 +1,4 @@ - +# %% from viscy.light.predict_writer import HCSPredictionWriter from viscy.data.hcs import HCSDataModule @@ -17,24 +17,24 @@ split_ratio=0.8, z_window_size=1, architecture="2D", - num_workers=0, + num_workers=1, batch_size=1, normalizations=[ NormalizeSampled( - keys=["Sensor", "Phase"], + keys=["Phase", "Sensor"], level="fov_statistics", subtrahend="median", divisor="iqr", ) ], ) - +data_module.prepare_data() data_module.setup(stage="predict") model = SemanticSegUNet2D( in_channels=2, out_channels=3, - ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_74/checkpoints/epoch=99-step=300.ckpt", + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", ) # %% perform prediction @@ -52,3 +52,5 @@ datamodule=data_module, return_predictions=True, ) + +# %% diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index fadd3e750..5ed140946 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -1,7 +1,6 @@ # %% from viscy.data.hcs import HCSDataModule import lightning.pytorch as pl -from viscy.light.predict_writer import HCSPredictionWriter from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D from pytorch_lightning.loggers import TensorBoardLogger from viscy.transforms import NormalizeSampled @@ -47,23 +46,92 @@ model = SemanticSegUNet2D( in_channels=2, out_channels=3, - ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_74/checkpoints/epoch=99-step=300.ckpt", + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", ) trainer.test(model=model, datamodule=data_module) -# %% predict the test set -output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SP.zarr" -trainer = pl.Trainer( - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", - callbacks=[HCSPredictionWriter(output_path, write_input=False)], - devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs -) -trainer.predict( - model=model, - datamodule=data_module, - return_predictions=True, -) +# # %% script to develop confusion matrix for infected cell classifier + +# from iohub.ngff import open_ome_zarr +# import numpy as np +# from skimage.measure import regionprops, label +# import cv2 +# import seaborn as sns +# import matplotlib.pyplot as plt + +# # %% load the predicted zarr and the human-in-loop annotations zarr + +# pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred.zarr" +# test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +# pred_dataset = open_ome_zarr( +# pred_datapath, +# layout="hcs", +# mode="r+", +# ) +# chan_pred = pred_dataset.channel_names + +# test_dataset = open_ome_zarr( +# test_datapath, +# layout="hcs", +# mode="r+", +# ) +# chan_test = test_dataset.channel_names + +# # %% compute confusion matrix for one image +# for well_id, well_data in pred_dataset.wells(): +# well_name, well_no = well_id.split("/") + +# for pos_name, pos_data in well_data.positions(): + +# pred_data = pos_data.data +# pred_pos_data = pred_data.numpy() +# T,C,Z,X,Y = pred_pos_data.shape + +# test_data = test_dataset[well_id + "/" + pos_name + "/0"] +# test_pos_data = test_data.numpy() + +# # compute confusion matrix for each time point and add to total +# conf_mat = np.zeros((2, 2)) +# for time in range(T): +# pred_img = pred_pos_data[time, chan_pred.index("Inf_mask_prediction"), 0, : , :] +# test_img = test_pos_data[time, chan_test.index("Inf_mask"), 0, : , :] + +# test_img_rz = cv2.resize(test_img, dsize=(2016,2048), interpolation=cv2.INTER_NEAREST)# pred_img = +# pred_img = np.where(test_img_rz > 0, pred_img, 0) + +# # find objects in every image +# label_img = label(test_img_rz) +# regions_label = regionprops(label_img) + +# # store pixel id for every label in pred and test imgs +# for region in regions_label: +# if region.area > 0: +# row, col = region.centroid +# pred_id = pred_img[int(row), int(col)] +# test_id = test_img_rz[int(row), int(col)] +# if pred_id == 1 and test_id == 1: +# conf_mat[1,1] += 1 +# if pred_id == 1 and test_id == 2: +# conf_mat[1,0] += 1 +# if pred_id == 2 and test_id == 1: +# conf_mat[0,1] += 1 +# if pred_id == 2 and test_id == 2: +# conf_mat[0,0] += 1 + +# # display the confusion matrix +# ax= plt.subplot() +# sns.heatmap(conf_mat, annot=True, fmt='g', ax=ax); #annot=True to annotate cells, ftm='g' to disable scientific notation + +# # labels, title and ticks +# ax.set_xlabel('annotated labels');ax.set_ylabel('predicted labels'); +# ax.set_title('Confusion Matrix'); +# ax.xaxis.set_ticklabels(['infected', 'uninfected']); ax.yaxis.set_ticklabels(['infected', 'uninfected']); + + +# # %% +# %% From 9c9ce41b27f0ada2e3cc50f255e702f60e955654 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Apr 2024 16:12:12 -0700 Subject: [PATCH 70/75] fix merge error --- viscy/data/hcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 27992bfc1..f33b61217 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -427,7 +427,7 @@ def _setup_test(self, dataset_settings: dict): [p for _, p in plate.positions()], transform=test_transform, ground_truth_masks=self.ground_truth_masks, - norm_meta=plate.zattrs["normalization"] ** dataset_settings, + **dataset_settings, ) else: self.test_dataset = SlidingWindowDataset( From 6b0a42d525c18f740054f7e82897f32a2cf70210 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 23 May 2024 14:01:11 -0700 Subject: [PATCH 71/75] split 2D and 2.5D model scripts --- .../Infection_classification_25DModel.py | 154 ++++++++ ...py => Infection_classification_2Dmodel.py} | 2 +- .../classify_infection_25D.py | 335 ++++++++++++++++++ ..._infection.py => classify_infection_2D.py} | 5 +- 4 files changed, 494 insertions(+), 2 deletions(-) create mode 100644 viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py rename viscy/scripts/infection_phenotyping/{Infection_classification_model.py => Infection_classification_2Dmodel.py} (97%) create mode 100644 viscy/scripts/infection_phenotyping/classify_infection_25D.py rename viscy/scripts/infection_phenotyping/{classify_infection.py => classify_infection_2D.py} (98%) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py new file mode 100644 index 000000000..91702497c --- /dev/null +++ b/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py @@ -0,0 +1,154 @@ +# %% +import torch +import lightning.pytorch as pl +import torch.nn as nn + +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from viscy.transforms import RandWeightedCropd +from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_25D import SemanticSegUNet25D + +from iohub.ngff import open_ome_zarr + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" + +# find ratio of background, uninfected and infected pixels +zarr_input = open_ome_zarr( + dataset_path, + layout="hcs", + mode="r+", +) +in_chan_names = zarr_input.channel_names + +num_pixels_bkg = 0 +num_pixels_uninf = 0 +num_pixels_inf = 0 +num_pixels = 0 +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + data = pos_data.data + T,C,Z,Y,X = data.shape + out_data = data.numpy() + for time in range(T): + Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' + num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() + num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() + num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() + num_pixels = num_pixels + Z*X*Y + +pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_sum = sum(pixel_ratio_1) +pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] + +# %% craete data module + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Phase", "HSP90"], + target_channel=["Inf_mask"], + yx_patch_size=[512, 512], + split_ratio=0.8, + z_window_size=5, + architecture="2.5D", + num_workers=3, + batch_size=32, + normalizations=[ + NormalizeSampled( + keys=["Phase","HSP90"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=4, + spatial_size=[-1, 512, 512], + keys=["Phase","HSP90"], + w_key="Inf_mask", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + +# Visualize the dataset and the batch using napari +# Set the display +# os.environ['DISPLAY'] = ':1' + +# # Create a napari viewer +# viewer = napari.Viewer() + +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# # Start the napari event loop +# napari.run() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=200, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet25D( + in_channels=2, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +print(model) + +# %% +# Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py similarity index 97% rename from viscy/scripts/infection_phenotyping/Infection_classification_model.py rename to viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py index 37f606cb2..52af46732 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py @@ -9,7 +9,7 @@ from viscy.transforms import RandWeightedCropd from viscy.transforms import NormalizeSampled from viscy.data.hcs import HCSDataModule -from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D # %% Create a dataloader and visualize the batches. diff --git a/viscy/scripts/infection_phenotyping/classify_infection_25D.py b/viscy/scripts/infection_phenotyping/classify_infection_25D.py new file mode 100644 index 000000000..c78a7e8f0 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection_25D.py @@ -0,0 +1,335 @@ + +import torch +import torch.nn as nn +import lightning.pytorch as pl +import torch.nn.functional as F +from torch import Tensor +import cv2 + +# import torchview +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap +from skimage.measure import regionprops, label +import numpy as np +import matplotlib.pyplot as plt + +from monai.transforms import DivisiblePad +from viscy.unet.networks.Unet25D import Unet25d +from viscy.data.hcs import Sample + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1,1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0,1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1,0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0,0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig +# Define a 25d unet model for infection classification as a lightning module. + +class SemanticSegUNet25D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = Unet25d(in_channels=in_channels, out_channels=out_channels, num_blocks=4, num_block_layers=4) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2,2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection_2D.py similarity index 98% rename from viscy/scripts/infection_phenotyping/classify_infection.py rename to viscy/scripts/infection_phenotyping/classify_infection_2D.py index ee904434a..b4269c746 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection.py +++ b/viscy/scripts/infection_phenotyping/classify_infection_2D.py @@ -16,6 +16,7 @@ from monai.transforms import DivisiblePad from viscy.unet.networks.Unet2D import Unet2d +# from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample # @@ -244,8 +245,10 @@ def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels - + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image def on_test_start(self): self.pred_cm = torch.zeros((2,2)) From 2ea889227c4370e9a9419e9375986728d2903db6 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 27 May 2024 09:15:09 -0700 Subject: [PATCH 72/75] added covnext script --- .../Infection_classification_covnextModel.py | 154 ++++++++ .../classify_infection_covnext.py | 347 ++++++++++++++++++ 2 files changed, 501 insertions(+) create mode 100644 viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py create mode 100644 viscy/scripts/infection_phenotyping/classify_infection_covnext.py diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py new file mode 100644 index 000000000..2b8a16348 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py @@ -0,0 +1,154 @@ +# %% +import torch +import lightning.pytorch as pl +import torch.nn as nn + +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from viscy.transforms import RandWeightedCropd +from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_covnext import SemanticSegUNet25D + +from iohub.ngff import open_ome_zarr + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" + +# find ratio of background, uninfected and infected pixels +zarr_input = open_ome_zarr( + dataset_path, + layout="hcs", + mode="r+", +) +in_chan_names = zarr_input.channel_names + +num_pixels_bkg = 0 +num_pixels_uninf = 0 +num_pixels_inf = 0 +num_pixels = 0 +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + data = pos_data.data + T,C,Z,Y,X = data.shape + out_data = data.numpy() + for time in range(T): + Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' + num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() + num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() + num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() + num_pixels = num_pixels + Z*X*Y + +pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_sum = sum(pixel_ratio_1) +pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] + +# %% craete data module + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Phase", "HSP90"], + target_channel=["Inf_mask"], + yx_patch_size=[256, 256], + split_ratio=0.8, + z_window_size=5, + architecture="2.2D", + num_workers=3, + batch_size=16, + normalizations=[ + NormalizeSampled( + keys=["Phase","HSP90"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=4, + spatial_size=[-1, 256, 256], + keys=["Phase","HSP90"], + w_key="Inf_mask", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + +# Visualize the dataset and the batch using napari +# Set the display +# os.environ['DISPLAY'] = ':1' + +# # Create a napari viewer +# viewer = napari.Viewer() + +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# # Start the napari event loop +# napari.run() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=200, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet25D( + in_channels=2, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +print(model) + +# %% +# Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection_covnext.py b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py new file mode 100644 index 000000000..edf2feb40 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py @@ -0,0 +1,347 @@ + +import torch +import torch.nn as nn +import lightning.pytorch as pl +import torch.nn.functional as F +from torch import Tensor +import cv2 + +# import torchview +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap +from skimage.measure import regionprops, label +import numpy as np +import matplotlib.pyplot as plt + +from monai.transforms import DivisiblePad +from viscy.unet.networks.Unet25D import Unet25d +from viscy.data.hcs import Sample +from viscy.light.engine import VSUNet + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1,1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0,1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1,0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0,0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig +# Define a 25d unet model for infection classification as a lightning module. + +class SemanticSegUNet25D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = VSUNet( + architecture="2.2D", + model_config={ + "in_channels": 1, + "out_channels": 3, + "in_stack_depth": 5, + "backbone": "convnextv2_tiny", + "stem_kernel_size": (5, 4, 4), + "decoder_mode": "pixelshuffle", + "head_expansion_ratio": 4, + }, + ) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2,2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) +# %% From 220eba131dd8dd20bd536823ea6f238f93a4bc37 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 28 May 2024 13:15:33 -0700 Subject: [PATCH 73/75] fix model input parameter --- .../Infection_classification_covnextModel.py | 4 +++- .../infection_phenotyping/classify_infection_covnext.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py index 2b8a16348..60ce0b5af 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py @@ -1,4 +1,6 @@ # %% +# import sys +# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/") import torch import lightning.pytorch as pl import torch.nn as nn @@ -134,7 +136,7 @@ mode="min", ) -# Add the checkpoint callback to the trainer +# Add the checkpoint callback to the trainer`` trainer.callbacks.append(checkpoint_callback) # Fit the model diff --git a/viscy/scripts/infection_phenotyping/classify_infection_covnext.py b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py index edf2feb40..2ba698eed 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection_covnext.py +++ b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py @@ -147,8 +147,8 @@ def __init__( self.unet_model = VSUNet( architecture="2.2D", model_config={ - "in_channels": 1, - "out_channels": 3, + "in_channels": in_channels, + "out_channels": out_channels, "in_stack_depth": 5, "backbone": "convnextv2_tiny", "stem_kernel_size": (5, 4, 4), From c4839da3248bee2bde61e58c733f86ecbe486b1c Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 29 May 2024 09:33:10 -0700 Subject: [PATCH 74/75] update input file --- .../Infection_classification_covnextModel.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py index 60ce0b5af..0ecd6bdd4 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py @@ -18,7 +18,7 @@ # %% Create a dataloader and visualize the batches. # Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" # find ratio of background, uninfected and infected pixels zarr_input = open_ome_zarr( @@ -56,7 +56,7 @@ # Create an instance of HCSDataModule data_module = HCSDataModule( dataset_path, - source_channel=["Phase", "HSP90"], + source_channel=["Phase", "HSP90", "phase_nucl_iqr","hsp90_skew"], target_channel=["Inf_mask"], yx_patch_size=[256, 256], split_ratio=0.8, @@ -66,7 +66,7 @@ batch_size=16, normalizations=[ NormalizeSampled( - keys=["Phase","HSP90"], + keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], level="fov_statistics", subtrahend="median", divisor="iqr", @@ -76,7 +76,7 @@ RandWeightedCropd( num_samples=4, spatial_size=[-1, 256, 256], - keys=["Phase","HSP90"], + keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], w_key="Inf_mask", ) ], @@ -141,7 +141,7 @@ # Fit the model model = SemanticSegUNet25D( - in_channels=2, + in_channels=4, out_channels=3, loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), ) From 6c1f4d3b6b571209cdfdd71b519215b137ea4aec Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 1 Jul 2024 09:46:28 -0700 Subject: [PATCH 75/75] demo scripts --- examples/compbio/README.md | 2 + examples/compbio/ddp.py | 80 ++++++++++ examples/compbio/ddp.sh | 22 +++ examples/compbio/torch_into.py | 148 ++++++++++++++++++ .../classify_infection_2D.py | 76 +++++---- 5 files changed, 298 insertions(+), 30 deletions(-) create mode 100644 examples/compbio/README.md create mode 100644 examples/compbio/ddp.py create mode 100644 examples/compbio/ddp.sh create mode 100644 examples/compbio/torch_into.py diff --git a/examples/compbio/README.md b/examples/compbio/README.md new file mode 100644 index 000000000..2ec2f662d --- /dev/null +++ b/examples/compbio/README.md @@ -0,0 +1,2 @@ +# Demo scripts for hackathon + diff --git a/examples/compbio/ddp.py b/examples/compbio/ddp.py new file mode 100644 index 000000000..b8669de29 --- /dev/null +++ b/examples/compbio/ddp.py @@ -0,0 +1,80 @@ +""" +# Distributed training + +Demonstrate how to train a model using distributed data parallel (DDP) with PyTorch Lightning. +""" + +import os +from pathlib import Path + +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D +from viscy.transforms import NormalizeSampled, RandWeightedCropd + + +def main(): + dm = HCSDataModule( + data_path="/hpc/mydata/ziwen.liu/demo/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr", + source_channel=["Sensor", "Phase"], + target_channel=["Inf_mask"], + yx_patch_size=(128, 128), + split_ratio=0.5, + z_window_size=1, + architecture="2D", + num_workers=8, + batch_size=128, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=8, + spatial_size=[-1, 128, 128], + keys=["Sensor", "Phase", "Inf_mask"], + w_key="Inf_mask", + ) + ], + ) + dm.prepare_data() + dm.setup(stage="fit") + + model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + loss_function=torch.nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), + ) + log_dir = Path(os.getenv("MYDATA", "")) / "torch_demo" + trainer = Trainer( + accelerator="gpu", + strategy="ddp_find_unused_parameters_true", + precision=32, + num_nodes=1, + devices=2, + fast_dev_run=True, + max_epochs=100, + logger=TensorBoardLogger(save_dir=log_dir, version="interactive_demo"), + log_every_n_steps=10, + callbacks=[ + LearningRateMonitor(logging_interval="step"), + ModelCheckpoint( + monitor="loss/validate", save_top_k=5, every_n_epochs=1, save_last=True + ), + ], + ) + + torch.set_float32_matmul_precision("high") + trainer.fit(model, dm) + + +if __name__ == "__main__": + main() diff --git a/examples/compbio/ddp.sh b/examples/compbio/ddp.sh new file mode 100644 index 000000000..ba4737bb1 --- /dev/null +++ b/examples/compbio/ddp.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +#SBATCH --job-name=ddp_train +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=16 +#SBATCH --mem-per-cpu=7G +#SBATCH --time=0-12:00:00 + + +# debugging flags (optional) +# https://lightning.ai/docs/pytorch/stable/clouds/cluster_advanced.html +export NCCL_DEBUG=INFO +export PYTHONFAULTHANDLER=1 + + +module load anaconda/2022.05 +conda activate viscy + +srun python ddp.py \ No newline at end of file diff --git a/examples/compbio/torch_into.py b/examples/compbio/torch_into.py new file mode 100644 index 000000000..5e8434fd1 --- /dev/null +++ b/examples/compbio/torch_into.py @@ -0,0 +1,148 @@ +# %% [markdown] +""" +# Infected cell segmentation + +Interactive script to demonstrate PyTorch Lightning training +with a semantic segmentation task. +""" + +# %% +import matplotlib.pyplot as plt +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger +from skimage.color import label2rgb +from torchview import draw_graph + +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D +from viscy.transforms import NormalizeSampled, RandWeightedCropd + +# use tf32 for matmul +torch.set_float32_matmul_precision("high") + +# %% [markdown] +""" +## Dataset + +In this dataset, we have images of A549 cells infected with Dengue virus +in two channels: + +- The cells are engineered to express a fluorescent protein (viral sensor) +that translocates from the cytoplasm to the nucleus upon infection. +- Quantitative phase images are reconstructed from brightfield images. + +## Task +The goal is to identify infected and uninfected cells from these images. +For the training target, cell nuclei were segmented from virtual staining, +and manually labelled as infected (1) or uninfected (2), +while background was labelled as 0. +We will train a U-Net to predict these labels from the images. +Is is a semantic segmentation task, +where assign a label (class) to each pixel in the image. +""" + + +# %% +# setup datamodule +data_module = HCSDataModule( + data_path="/hpc/mydata/ziwen.liu/demo/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr", + source_channel=["Sensor", "Phase"], + target_channel=["Inf_mask"], + yx_patch_size=(128, 128), + split_ratio=0.5, + z_window_size=1, + architecture="2D", + num_workers=8, + batch_size=128, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=8, + spatial_size=[-1, 128, 128], + keys=["Sensor", "Phase", "Inf_mask"], + w_key="Inf_mask", + ) + ], +) + +data_module.prepare_data() +data_module.setup(stage="fit") + +# %% +# sample from training data +num_samples = 8 + +for batch in data_module.train_dataloader(): + image = batch["source"][:num_samples].numpy() + label = batch["target"][:num_samples].numpy().astype("uint8") + break + +# %% +# visualize the samples +fig, ax = plt.subplots(num_samples, 3, figsize=(3, 8)) + +for i in range(num_samples): + ax[i, 0].imshow(image[i, 0, 0], cmap="gray") + ax[i, 1].imshow(image[i, 1, 0], cmap="gray") + ax[i, 2].imshow(label2rgb(label[i, 0, 0], bg_label=0)) + +for a in ax.ravel(): + a.axis("off") + +fig.tight_layout() + +# %% +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + loss_function=torch.nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), +) + +# %% +model_graph = draw_graph( + model=model, + input_data=torch.rand(1, 2, 1, 128, 128), + graph_name="2D UNet", + roll=True, + depth=2, + device="cpu", +) + +model_graph.visual_graph + +# %% +trainer = Trainer( + accelerator="gpu", + precision=32, + devices=1, + num_nodes=1, + fast_dev_run=True, + max_epochs=100, + logger=TensorBoardLogger( + save_dir="/hpc/mydata/ziwen.liu/demo/logs", + version="interactive_demo", + log_graph=True, + ), + log_every_n_steps=10, + callbacks=[ + LearningRateMonitor(logging_interval="step"), + ModelCheckpoint( + monitor="loss/validate", save_top_k=5, every_n_epochs=1, save_last=True + ), + ], +) + + +# %% +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection_2D.py b/viscy/scripts/infection_phenotyping/classify_infection_2D.py index b4269c746..74a6038e9 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection_2D.py +++ b/viscy/scripts/infection_phenotyping/classify_infection_2D.py @@ -1,27 +1,27 @@ +# import torchview +from typing import Literal, Sequence +import cv2 +import lightning.pytorch as pl +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn -import lightning.pytorch as pl import torch.nn.functional as F -from torch import Tensor -import cv2 - -# import torchview -from typing import Literal, Sequence +from matplotlib.pyplot import get_cmap +from monai.transforms import DivisiblePad from skimage.exposure import rescale_intensity -from matplotlib.cm import get_cmap -from skimage.measure import regionprops, label -import numpy as np -import matplotlib.pyplot as plt +from skimage.measure import label, regionprops +from torch import Tensor -from monai.transforms import DivisiblePad -from viscy.unet.networks.Unet2D import Unet2d # from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample +from viscy.unet.networks.Unet2D import Unet2d -# +# # %% Methods to compute confusion matrix per cell using torchmetrics + # The confusion matrix at the single-cell resolution. def confusion_matrix_per_cell( y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int @@ -64,7 +64,11 @@ def compute_confusion_matrix( y_pred_cpu = y_pred[i].cpu().numpy() y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) - y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = cv2.resize( + y_pred_reshaped, + dsize=y_true_reshaped.shape[::-1], + interpolation=cv2.INTER_NEAREST, + ) y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) # find objects in every image @@ -79,17 +83,18 @@ def compute_confusion_matrix( test_id = y_true_reshaped[int(row), int(col)] if pred_id == 1 and test_id == 1: - conf_mat[1,1] += 1 + conf_mat[1, 1] += 1 if pred_id == 1 and test_id == 2: - conf_mat[0,1] += 1 + conf_mat[0, 1] += 1 if pred_id == 2 and test_id == 1: - conf_mat[1,0] += 1 + conf_mat[1, 0] += 1 if pred_id == 2 and test_id == 2: - conf_mat[0,0] += 1 + conf_mat[0, 0] += 1 # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. return conf_mat + def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # Create a figure and axis to plot the confusion matrix fig, ax = plt.subplots() @@ -98,7 +103,9 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): cax = ax.matshow(confusion_matrix, cmap="viridis") # Create a colorbar and set the label - index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + index_to_label_dict = dict( + enumerate(index_to_label_dict) + ) # Convert list to dictionary fig.colorbar(cax, label="Frequency") # Set labels for the classes @@ -125,8 +132,11 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # plt.show(fig) # Show the figure return fig + + # Define a 2d unet model for infection classification as a lightning module. + class SemanticSegUNet2D(pl.LightningModule): # Model for semantic segmentation. def __init__( @@ -169,8 +179,6 @@ def __init__( self.pred_cm = None # Initialize the confusion matrix self.index_to_label_dict = ["Infected", "Uninfected"] - - # Define the forward pass def forward(self, x): return self.unet_model(x) # Pass the input through the UNet model @@ -241,39 +249,45 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels # prob_chan = prob_pred[:, 2, :, :] # prob_chan = prob_chan.unsqueeze(1) return labels_pred # log the class predicted image # return prob_chan # log the probability predicted image - + def on_test_start(self): - self.pred_cm = torch.zeros((2,2)) + self.pred_cm = torch.zeros((2, 2)) down_factor = 2**self.unet_model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - + def test_step(self, batch: Sample): source = self._predict_pad(batch["source"]) # Pad the source logits = self._predict_pad.inverse(self.forward(source)) prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels - + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( target, labels_pred, num_classes=2 ) # Calculate the confusion matrix per cell self.pred_cm += pred_cm # Append the confusion matrix to pred_cm - + self.logger.experiment.add_figure( "Confusion Matrix per Cell", plot_confusion_matrix(pred_cm, self.index_to_label_dict), self.current_epoch, ) - # Accumulate the confusion matrix at the end of test epoch and log. + # Accumulate the confusion matrix at the end of test epoch and log. def on_test_end(self): confusion_matrix_sum = self.pred_cm self.logger.experiment.add_figure( @@ -334,4 +348,6 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) + + # %%