diff --git a/applications/cytoland/README.md b/applications/cytoland/README.md index 04d3a7abc..48894fb09 100644 --- a/applications/cytoland/README.md +++ b/applications/cytoland/README.md @@ -37,6 +37,21 @@ data: class_path: viscy_data.hcs.HCSDataModule ``` +## Tutorials and demos + +Scripts and tutorials live under [`examples/`](./examples/): + +| Folder | What it demonstrates | +|--------|----------------------| +| [`examples/VS_model_inference/`](./examples/VS_model_inference/) | Python API inference demos for VSCyto2D, VSCyto3D, VSNeuromast, and TTA-augmented sliding-window prediction | +| [`examples/vcp_tutorials/`](./examples/vcp_tutorials/) | Virtual Cell Platform quick-start and organism-specific walkthroughs (HEK293T, neuromast) | +| [`examples/dl-course-exercise/`](./examples/dl-course-exercise/) | Image-translation course exercise (training from scratch + evaluation) — used at DL@MBL and DL@Janelia | +| [`examples/configs/`](./examples/configs/) | YAML configs for `viscy fit` / `viscy predict` across models (VSCyto2D/3D, VSNeuromast, FNet3D, dynacell) | + +All demo scripts are written as jupytext-style percent-cell `.py` files. +Regenerate paired `.ipynb` notebooks with `jupytext --to ipynb solution.py` +if you prefer the notebook UI. + ## Models | Model | Input | Output | Architecture | diff --git a/applications/cytoland/examples/VS_model_inference/demo_vscyto2d.py b/applications/cytoland/examples/VS_model_inference/demo_vscyto2d.py new file mode 100644 index 000000000..5054f4e7c --- /dev/null +++ b/applications/cytoland/examples/VS_model_inference/demo_vscyto2d.py @@ -0,0 +1,147 @@ +# %% [markdown] +""" +# 2D Virtual Staining of A549 Cells +--- +## Prediction using the VSCyto2D to predict nuclei and plasma membrane from phase. +This example shows how to virtually stain A549 cells using the _VSCyto2D_ model. +The model is trained to predict the membrane and nuclei channels from the phase channel. +""" + +# %% Imports and paths +from pathlib import Path + +from iohub import open_ome_zarr +from plot import plot_vs_n_fluor + +# Cytoland and VisCy modular classes for the trainer and model +from cytoland.engine import FcmaeUNet +from viscy_data.hcs import HCSDataModule +from viscy_transforms import NormalizeSampled +from viscy_utils.callbacks import HCSPredictionWriter +from viscy_utils.trainer import VisCyTrainer + +# %% [markdown] tags=[] +# +#
+# +# # Download the dataset and checkpoints for the VSCyto2D model +# +# - Download the VSCyto2D test dataset and model checkpoint from here:
+# https://public.czbiohub.org/comp.micro/viscy +# - Update the `input_data_path` and `model_ckpt_path` variables with the path to the downloaded files. +# - Select a FOV (i.e 0/0/0). +# - Set an output path for the predictions. +# +#
+ +# %% +# TODO: Set download paths +root_dir = Path("") +# TODO: modify the path to the downloaded dataset +input_data_path = root_dir / "VSCyto2D/test/a549_hoechst_cellmask_test.zarr" +# TODO: modify the path to the downloaded checkpoint +model_ckpt_path = "/epoch=399-step=23200.ckpt" +# TODO: modify the path +# Zarr store to save the predictions +output_path = root_dir / "./a549_prediction.zarr" +# TODO: Choose an FOV +fov = "0/0/0" + +input_data_path = input_data_path / fov + +# %% +# Create the VSCyto2D network + +# Reduce the batch size if encountering out-of-memory errors +BATCH_SIZE = 8 +# NOTE: Set the number of workers to 0 for Windows and macOS +# since multiprocessing only works with a +# `if __name__ == '__main__':` guard. +# On Linux, set it to the number of CPU cores to maximize performance. +NUM_WORKERS = 0 +phase_channel_name = "Phase3D" + +# %%[markdown] +""" +For this example we will use the following parameters: +For more information on the VSCyto2D model, +see ``viscy.unet.networks.fcmae`` +([source code](https://github.com/mehta-lab/VisCy/blob/main/viscy/unet/networks/fcmae.py)) +for configuration details. +""" +# %% +# Setup the data module. +data_module = HCSDataModule( + data_path=input_data_path, + source_channel=phase_channel_name, + target_channel=["Membrane", "Nuclei"], + z_window_size=1, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + architecture="fcmae", + normalizations=[ + NormalizeSampled( + [phase_channel_name], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) +data_module.prepare_data() +data_module.setup(stage="predict") +# %% +# Setup the model. +# Dictionary that specifies key parameters of the model. +config_VSCyto2D = { + "in_channels": 1, + "out_channels": 2, + "encoder_blocks": [3, 3, 9, 3], + "dims": [96, 192, 384, 768], + "decoder_conv_blocks": 2, + "stem_kernel_size": [1, 2, 2], + "in_stack_depth": 1, + "pretraining": False, +} + +model_VSCyto2D = FcmaeUNet.load_from_checkpoint(model_ckpt_path, model_config=config_VSCyto2D) +model_VSCyto2D.eval() + +# %% +# Setup the Trainer +trainer = VisCyTrainer( + accelerator="gpu", + callbacks=[HCSPredictionWriter(output_path)], +) + +# Start the predictions +trainer.predict( + model=model_VSCyto2D, + datamodule=data_module, + return_predictions=False, +) + +# %% +# Open the output_zarr store and inspect the output +# Show the individual channels and the fused in a 1x3 plot +output_path = Path(output_path) / fov + +# %% +# Open the predicted data +vs_store = open_ome_zarr(output_path, mode="r") +# Get the 2D images +vs_nucleus = vs_store[0][0, 0, 0] # (t,c,z,y,x) +vs_membrane = vs_store[0][0, 1, 0] # (t,c,z,y,x) +# Open the experimental fluorescence +fluor_store = open_ome_zarr(input_data_path, mode="r") +# Get the 2D images +# NOTE: Channel indeces hardcoded for this dataset +fluor_nucleus = fluor_store[0][0, 1, 0] # (t,c,z,y,x) +fluor_membrane = fluor_store[0][0, 2, 0] # (t,c,z,y,x) + +# Plot +plot_vs_n_fluor(vs_nucleus, vs_membrane, fluor_nucleus, fluor_membrane) + +vs_store.close() +fluor_store.close() diff --git a/applications/cytoland/examples/VS_model_inference/demo_vscyto3d.py b/applications/cytoland/examples/VS_model_inference/demo_vscyto3d.py new file mode 100644 index 000000000..e25757e99 --- /dev/null +++ b/applications/cytoland/examples/VS_model_inference/demo_vscyto3d.py @@ -0,0 +1,159 @@ +# %% [markdown] +""" +# 3D Virtual Staining of HEK293T Cells +--- +## Prediction using the VSCyto3D to predict nuclei and membrane from phase. +This example shows how to virtually stain HEK293T cells using the _VSCyto3D_ model. +The model is trained to predict the membrane and nuclei channels from the phase channel. +""" + +# %% Imports and paths +from pathlib import Path + +from iohub import open_ome_zarr +from plot import plot_vs_n_fluor + +# Cytoland and VisCy modular classes for the trainer and model +from cytoland.engine import VSUNet +from viscy_data.hcs import HCSDataModule +from viscy_transforms import NormalizeSampled +from viscy_utils.callbacks import HCSPredictionWriter +from viscy_utils.trainer import VisCyTrainer + +# %% [markdown] +""" +## Data and Model Paths + +The dataset and model checkpoint files need to be downloaded before running this example. +""" + +# %% [markdown] tags=[] +# +#
+# +# # Download the dataset and checkpoints VSCyto3D +# +# - Download the VSCyto3D test dataset and model checkpoint from here:
+# https://public.czbiohub.org/comp.micro/viscy +# - Update the `input_data_path` and `model_ckpt_path` variables with the path to the downloaded files. +# - Select a FOV (i.e plate/0/0). +# - Set an output path for the predictions. +# +#
+# %% +# TODO: modify the path to the downloaded dataset +input_data_path = "/no_pertubation_Phase1e-3_Denconv_Nuc8e-4_Mem8e-4_pad15_bg50.zarr" + +# TODO: modify the path to the downloaded checkpoint +model_ckpt_path = "/epoch=48-step=18130.ckpt" + +# TODO: modify the path +# Zarr store to save the predictions +output_path = "./hek_prediction_3d.zarr" + +# TODO: Choose an FOV +# FOV of interest +fov = "plate/0/0" + +input_data_path = Path(input_data_path) / fov + +# %% +# Create the VSCyto3D model + +# Reduce the batch size if encountering out-of-memory errors +BATCH_SIZE = 2 +# NOTE: Set the number of workers to 0 for Windows and macOS +# since multiprocessing only works with a +# `if __name__ == '__main__':` guard. +# On Linux, set it to the number of CPU cores to maximize performance. +NUM_WORKERS = 0 +phase_channel_name = "Phase3D" + +# %%[markdown] +""" +For this example we will use the following parameters: +### For more information on the VSCyto3D model: +See ``viscy.unet.networks.unext2`` +([source code](https://github.com/mehta-lab/VisCy/blob/main/viscy/unet/networks/unext2.py)) +for configuration details. +""" +# %% +# Setup the data module. +data_module = HCSDataModule( + data_path=input_data_path, + source_channel=phase_channel_name, + target_channel=["Membrane", "Nuclei"], + z_window_size=5, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + architecture="UNeXt2", + normalizations=[ + NormalizeSampled( + [phase_channel_name], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) +data_module.prepare_data() +data_module.setup(stage="predict") +# %% +# Setup the model. +# Dictionary that specifies key parameters of the model. +config_VSCyto3D = { + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": 5, + "backbone": "convnextv2_tiny", + "stem_kernel_size": (5, 4, 4), + "decoder_mode": "pixelshuffle", + "head_expansion_ratio": 4, + "head_pool": True, +} + +model_VSCyto3D = VSUNet.load_from_checkpoint(model_ckpt_path, architecture="UNeXt2", model_config=config_VSCyto3D) +model_VSCyto3D.eval() + +# %% +# Setup the Trainer +trainer = VisCyTrainer( + accelerator="gpu", + callbacks=[HCSPredictionWriter(output_path)], +) + +# Start the predictions +trainer.predict( + model=model_VSCyto3D, + datamodule=data_module, + return_predictions=False, +) + +# %% +# Open the output_zarr store and inspect the output +# Show the individual channels and the fused in a 1x3 plot +output_path = Path(output_path) / fov + +# %% +# Open the predicted data +vs_store = open_ome_zarr(output_path, mode="r") +T, C, Z, Y, X = vs_store.data.shape + +# Get a z-slice +z_slice = Z // 2 # NOTE: using the middle slice of the stack. Change as needed. +vs_nucleus = vs_store[0][0, 0, z_slice] # (t,c,z,y,x) +vs_membrane = vs_store[0][0, 1, z_slice] # (t,c,z,y,x) +# Open the experimental fluorescence +fluor_store = open_ome_zarr(input_data_path, mode="r") +# Get the 2D images +# NOTE: Channel indeces hardcoded for this dataset +fluor_nucleus = fluor_store[0][0, 2, z_slice] # (t,c,z,y,x) +fluor_membrane = fluor_store[0][0, 1, z_slice] # (t,c,z,y,x) + +# Plot +plot_vs_n_fluor(vs_nucleus, vs_membrane, fluor_nucleus, fluor_membrane) + +# Close stores +vs_store.close() +fluor_store.close() diff --git a/applications/cytoland/examples/VS_model_inference/demo_vscyto_w_ttas.py b/applications/cytoland/examples/VS_model_inference/demo_vscyto_w_ttas.py new file mode 100644 index 000000000..9ecd82d63 --- /dev/null +++ b/applications/cytoland/examples/VS_model_inference/demo_vscyto_w_ttas.py @@ -0,0 +1,73 @@ +# %% +""" +Demo: In-memory volume prediction using predict_sliding_windows. + +This API provides the same results as the `viscy predict` CLI (HCSPredictionWriter) +since both use the same linear feathering blending algorithm for overlapping windows. +""" + +from pathlib import Path + +import napari +import numpy as np +import torch +from iohub import open_ome_zarr + +from cytoland.engine import AugmentedPredictionVSUNet, VSUNet + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Instantiate model manually +model = ( + VSUNet( + architecture="fcmae", + model_config={ + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": 21, + "encoder_blocks": [3, 3, 9, 3], + "dims": [96, 192, 384, 768], + "decoder_conv_blocks": 2, + "stem_kernel_size": [7, 4, 4], + "pretraining": False, + "head_conv": True, + "head_conv_expansion_ratio": 4, + "head_conv_pool": False, + }, + ckpt_path="/path/to/checkpoint.ckpt", + ) + .to(DEVICE) + .eval() +) + +vs = ( + AugmentedPredictionVSUNet( + model=model.model, + forward_transforms=[lambda t: t], + inverse_transforms=[lambda t: t], + ) + .to(DEVICE) + .eval() +) + +# Load data +path = Path("/path/to/your.zarr/0/1/000000") +with open_ome_zarr(path) as ds: + vol_np = np.asarray(ds.data[0:1, 0:1]) # (1, 1, Z, Y, X) + +vol = torch.from_numpy(vol_np).float().to(DEVICE) + +# Run inference with sliding windows and linear feathering blending +# step=1 gives maximum overlap; increase step for faster inference +with torch.inference_mode(): + pred = vs.predict_sliding_windows(vol, out_channel=2, step=1) + +# Visualize +pred_np = pred.cpu().numpy() +nuc, mem = pred_np[0, 0], pred_np[0, 1] + +viewer = napari.Viewer() +viewer.add_image(vol_np, name="phase_input", colormap="gray") +viewer.add_image(nuc, name="virt_nuclei", colormap="magenta") +viewer.add_image(mem, name="virt_membrane", colormap="cyan") +napari.run() diff --git a/applications/cytoland/examples/VS_model_inference/demo_vsneuromast.py b/applications/cytoland/examples/VS_model_inference/demo_vsneuromast.py new file mode 100644 index 000000000..61fa76573 --- /dev/null +++ b/applications/cytoland/examples/VS_model_inference/demo_vsneuromast.py @@ -0,0 +1,155 @@ +# %% [markdown] +""" +# 3D Virtual Staining of Neuromast +--- +## Prediction using the VSNeuromast to predict nuclei and membrane from phase. +This example shows how to virtually stain zebrafish neuromast cells using the _VSNeuromast_ model. +The model is trained to predict the membrane and nuclei channels from the phase channel. +""" + +# %% Imports and paths +from pathlib import Path + +from iohub import open_ome_zarr +from plot import plot_vs_n_fluor + +# Cytoland and VisCy modular classes for the trainer and model +from cytoland.engine import VSUNet +from viscy_data.hcs import HCSDataModule +from viscy_transforms import NormalizeSampled +from viscy_utils.callbacks import HCSPredictionWriter +from viscy_utils.trainer import VisCyTrainer + +# %% [markdown] +""" +## Data and Model Paths + +The dataset and model checkpoint files need to be downloaded before running this example. +""" + +# %% [markdown] tags=[] +# +#
+# +# # Download the dataset and checkpoints +# +# - Download the neuromast test dataset and model checkpoint from here:
+# https://public.czbiohub.org/comp.micro/viscy +# - Update the `input_data_path` and `model_ckpt_path` variables with the path to the downloaded files. +# - Select a FOV (i.e 0/3/0). +# - Set an output path for the predictions. +# +#
+# %% +# TODO: modify the path to the downloaded dataset +input_data_path = "/20230803_fish2_60x_1_cropped_zyx_resampled_clipped_2.zarr" + +# TODO: modify the path to the downloaded checkpoint +model_ckpt_path = "/epoch=44-step=1215.ckpt" + +# TODO: modify the path +# Zarr store to save the predictions +output_path = "./test_neuromast_demo.zarr" + +# TODO: Choose an FOV +# FOV of interest +fov = "0/3/0" + +input_data_path = Path(input_data_path) / fov +# %% +# Create the VSNeuromast model + +# Reduce the batch size if encountering out-of-memory errors +BATCH_SIZE = 2 +# NOTE: Set the number of workers to 0 for Windows and macOS +# since multiprocessing only works with a +# `if __name__ == '__main__':` guard. +# On Linux, set it to the number of CPU cores to maximize performance. +NUM_WORKERS = 0 +phase_channel_name = "Phase3D" + +# %%[markdown] +""" +For this example we will use the following parameters: +### For more information on the VSNeuromast model: +See ``viscy.unet.networks.unext2`` ([source code](https://github.com/mehta-lab/VisCy/blob/main/viscy/unet/networks/unext2.py)) for configuration details. +""" +# %% +# Setup the data module. +data_module = HCSDataModule( + data_path=input_data_path, + source_channel=phase_channel_name, + target_channel=["Membrane", "Nuclei"], + z_window_size=21, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + architecture="UNeXt2", + normalizations=[ + NormalizeSampled( + [phase_channel_name], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) +data_module.prepare_data() +data_module.setup(stage="predict") +# %% +# Setup the model. +# Dictionary that specifies key parameters of the model. +config_VSNeuromast = { + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": 21, + "backbone": "convnextv2_tiny", + "stem_kernel_size": (7, 4, 4), + "decoder_mode": "pixelshuffle", + "head_expansion_ratio": 4, + "head_pool": True, +} + +model_VSNeuromast = VSUNet.load_from_checkpoint(model_ckpt_path, architecture="UNeXt2", model_config=config_VSNeuromast) +model_VSNeuromast.eval() + +# %% +# Setup the Trainer +trainer = VisCyTrainer( + accelerator="gpu", + callbacks=[HCSPredictionWriter(output_path)], +) + +# Start the predictions +trainer.predict( + model=model_VSNeuromast, + datamodule=data_module, + return_predictions=False, +) + +# %% +# Open the output_zarr store and inspect the output +# Show the individual channels and the fused in a 1x3 plot +output_path = Path(output_path) / fov + +# %% +# Open the predicted data +vs_store = open_ome_zarr(output_path, mode="r") +T, C, Z, Y, X = vs_store.data.shape +# Get a z-slice +z_slice = Z // 2 # NOTE: using the middle slice of the stack. Change as needed. +vs_nucleus = vs_store[0][0, 0, z_slice] # (t,c,z,y,x) +vs_membrane = vs_store[0][0, 1, z_slice] # (t,c,z,y,x) + +# Open the experimental fluorescence +fluor_store = open_ome_zarr(input_data_path, mode="r") +# Get the 2D images +# NOTE: Channel indeces hardcoded for this dataset +fluor_nucleus = fluor_store[0][0, 1, z_slice] # (t,c,z,y,x) +fluor_membrane = fluor_store[0][0, 2, z_slice] # (t,c,z,y,x) + +# Plot +plot_vs_n_fluor(vs_nucleus, vs_membrane, fluor_nucleus, fluor_membrane) + +vs_store.close() +fluor_store.close() diff --git a/applications/cytoland/examples/VS_model_inference/plot.py b/applications/cytoland/examples/VS_model_inference/plot.py new file mode 100644 index 000000000..2f2bb1ac7 --- /dev/null +++ b/applications/cytoland/examples/VS_model_inference/plot.py @@ -0,0 +1,74 @@ +import matplotlib.pyplot as plt +import numpy as np +from skimage.exposure import rescale_intensity + + +def plot_vs_n_fluor(vs_nucleus, vs_membrane, fluor_nucleus, fluor_membrane): + colormap_1 = [0.1254902, 0.6784314, 0.972549] # bop blue + colormap_2 = [0.972549, 0.6784314, 0.1254902] # bop orange + colormap_3 = [0, 1, 0] # green + colormap_4 = [1, 0, 1] # magenta + + # Rescale the intensity + vs_nucleus = rescale_intensity(vs_nucleus, out_range=(0, 1)) + vs_membrane = rescale_intensity(vs_membrane, out_range=(0, 1)) + # VS Nucleus RGB + vs_nucleus_rgb = np.zeros((*vs_nucleus.shape[-2:], 3)) + vs_nucleus_rgb[:, :, 0] = vs_nucleus * colormap_1[0] + vs_nucleus_rgb[:, :, 1] = vs_nucleus * colormap_1[1] + vs_nucleus_rgb[:, :, 2] = vs_nucleus * colormap_1[2] + # VS Membrane RGB + vs_membrane_rgb = np.zeros((*vs_membrane.data.shape[-2:], 3)) + vs_membrane_rgb[:, :, 0] = vs_membrane * colormap_2[0] + vs_membrane_rgb[:, :, 1] = vs_membrane * colormap_2[1] + vs_membrane_rgb[:, :, 2] = vs_membrane * colormap_2[2] + # Merge the two channels + merged_vs = np.zeros((*vs_nucleus.shape[-2:], 3)) + merged_vs[:, :, 0] = vs_nucleus * colormap_1[0] + vs_membrane * colormap_2[0] + merged_vs[:, :, 1] = vs_nucleus * colormap_1[1] + vs_membrane * colormap_2[1] + merged_vs[:, :, 2] = vs_nucleus * colormap_1[2] + vs_membrane * colormap_2[2] + + # Rescale the intensity + fluor_nucleus = rescale_intensity(fluor_nucleus, out_range=(0, 1)) + fluor_membrane = rescale_intensity(fluor_membrane, out_range=(0, 1)) + # fluor Nucleus RGB + fluor_nucleus_rgb = np.zeros((*fluor_nucleus.shape[-2:], 3)) + fluor_nucleus_rgb[:, :, 0] = fluor_nucleus * colormap_3[0] + fluor_nucleus_rgb[:, :, 1] = fluor_nucleus * colormap_3[1] + fluor_nucleus_rgb[:, :, 2] = fluor_nucleus * colormap_3[2] + # fluor Membrane RGB + fluor_membrane_rgb = np.zeros((*fluor_membrane.shape[-2:], 3)) + fluor_membrane_rgb[:, :, 0] = fluor_membrane * colormap_4[0] + fluor_membrane_rgb[:, :, 1] = fluor_membrane * colormap_4[1] + fluor_membrane_rgb[:, :, 2] = fluor_membrane * colormap_4[2] + # Merge the two channels + merged_fluor = np.zeros((*fluor_nucleus.shape[-2:], 3)) + merged_fluor[:, :, 0] = fluor_nucleus * colormap_3[0] + fluor_membrane * colormap_4[0] + merged_fluor[:, :, 1] = fluor_nucleus * colormap_3[1] + fluor_membrane * colormap_4[1] + merged_fluor[:, :, 2] = fluor_nucleus * colormap_3[2] + fluor_membrane * colormap_4[2] + + # %% + # Plot + fig, ax = plt.subplots(2, 3, figsize=(15, 10)) + + # Virtual staining plots + ax[0, 0].imshow(vs_nucleus_rgb) + ax[0, 0].set_title("VS Nuclei") + ax[0, 1].imshow(vs_membrane_rgb) + ax[0, 1].set_title("VS Membrane") + ax[0, 2].imshow(merged_vs) + ax[0, 2].set_title("VS Nuclei+Membrane") + + # Experimental fluorescence plots + ax[1, 0].imshow(fluor_nucleus_rgb) + ax[1, 0].set_title("Experimental Fluorescence Nuclei") + ax[1, 1].imshow(fluor_membrane_rgb) + ax[1, 1].set_title("Experimental Fluorescence Membrane") + ax[1, 2].imshow(merged_fluor) + ax[1, 2].set_title("Experimental Fluorescence Nuclei+Membrane") + + # turnoff axis + for a in ax.flatten(): + a.axis("off") + plt.margins(0, 0) + plt.show() diff --git a/applications/cytoland/examples/dl-course-exercise/README.md b/applications/cytoland/examples/dl-course-exercise/README.md new file mode 100644 index 000000000..c6c44eeac --- /dev/null +++ b/applications/cytoland/examples/dl-course-exercise/README.md @@ -0,0 +1,130 @@ +# Exercise 6: Image translation - Part 1 + +This demo script was developed for the DL@MBL 2024 course by Eduardo Hirata-Miyasaki, Ziwen Liu and Shalin Mehta, with many inputs and bugfixes by [Morgan Schwartz](https://github.com/msschwartz21), [Caroline Malin-Mayor](https://github.com/cmalinmayor), and [Peter Park](https://github.com/peterhpark). + + +# Image translation (Virtual Staining) + +Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco. + +## Overview + +In this exercise, we will predict fluorescence images of nuclei and plasma membrane markers from quantitative phase images of cells, i.e., we will _virtually stain_ the nuclei and plasma membrane visible in the phase image. +This is an example of an image translation task. We will apply spatial and intensity augmentations to train robust models and evaluate their performance. Finally, we will explore the opposite process of predicting a phase image from a fluorescence membrane label. + +[![HEK293T](https://raw.githubusercontent.com/mehta-lab/VisCy/main/docs/figures/svideo_1.png)](https://github.com/mehta-lab/VisCy/assets/67518483/d53a81eb-eb37-44f3-b522-8bd7bddc7755) +(Click on image to play video) + +## Goals + +### Part 1: Learn to use iohub (I/O library), VisCy dataloaders, and TensorBoard. + + - Use a OME-Zarr dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), + each FOV has 3 channels (phase, nuclei, and cell membrane). + The nuclei were stained with DAPI and the cell membrane with Cellmask. + - Explore OME-Zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html) + and the high-content-screen (HCS) format. + - Use [MONAI](https://monai.io/) to implement data augmentations. + +### Part 2: Train and evaluate the model to translate phase into fluorescence, and vice versa. + - Train a 2D UNeXt2 model to predict nuclei and membrane from phase images. + - Compare the performance of the trained model and a pre-trained model. + - Evaluate the model using pixel-level and instance-level metrics. + + +Checkout [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos), +our deep learning pipeline for training and deploying computer vision models +for image-based phenotyping including the robust virtual staining of landmark organelles. +VisCy exploits recent advances in data and metadata formats +([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, +[PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). + +## Setup + +There are two setup scripts depending on your role: + +- **Students:** run [`setup_student.sh`](setup_student.sh) — creates a per-user + Python venv, registers a Jupyter kernel, and downloads the data only if it + isn't already on disk. +- **TAs / course operators:** run [`setup_TA.sh`](setup_TA.sh) before the + course to pre-stage the ~14 GB of data + checkpoint onto a shared + filesystem so each student doesn't have to re-download it. + +### Student + +From the exercise folder: + +```bash +cd applications/cytoland/examples/dl-course-exercise +bash setup_student.sh +``` + +If your TA pre-staged the data on a shared mount, point `DATA_ROOT` at it to +skip the download: + +```bash +DATA_ROOT=/mnt/shared/image_translation bash setup_student.sh +``` + +The script will: + +- Install [`uv`](https://docs.astral.sh/uv/) if it isn't already on your PATH. +- Create a Python 3.13 virtual environment at `./.venv`. +- Install `cytoland` + `viscy` (`>=0.5.0a0`) plus the tutorial extras: + `cellpose`, `torchview`, `microssim`, `jupyter`, `ipykernel`, + `ipywidgets`, `jupytext`. If you ran the script from inside a clone of + the [VisCy monorepo](https://github.com/mehta-lab/VisCy), it installs + `cytoland` editable from the local workspace; otherwise it installs from + PyPI. +- Register the venv as a Jupyter kernel named **`06_image_translation`** + (display name: *Python (06_image_translation)*). +- Download the training / test OME-Zarr datasets and the VSCyto2D + pretrained checkpoint into `$DATA_ROOT` (default `~/data/06_image_translation/`), + unless the data is already present. + +Everything is self-contained inside this folder — no conda required. + +### TA / course operator + +Run once before the course, ideally targeting a shared mount: + +```bash +cd applications/cytoland/examples/dl-course-exercise +DATA_ROOT=/mnt/shared/image_translation bash setup_TA.sh +``` + +This downloads the OME-Zarr datasets (~14 GB) and the pretrained checkpoint +into `$DATA_ROOT`. Typical runtime is 20–40 min. It does **not** create a +Python environment — students do that themselves with `setup_student.sh`. + +## Use VSCode + +Install VSCode and the Python + Jupyter extensions, then open +[`solution.py`](solution.py) and pick the **Python (06_image_translation)** +kernel from the top-right kernel selector. The script uses +[cell mode](https://code.visualstudio.com/docs/python/jupyter-support-py), so +you can execute each `# %%` block interactively. + +## Use Jupyter Notebook + +Generate a notebook from the solution script and launch Jupyter: + +```bash +./.venv/bin/jupytext --to ipynb solution.py +./.venv/bin/jupyter notebook solution.ipynb +``` + +Pick **Python (06_image_translation)** as the kernel. + +If the kernel is missing (e.g. you reinstalled the venv), re-register it: + +```bash +./.venv/bin/python -m ipykernel install --user \ + --name 06_image_translation \ + --display-name "Python (06_image_translation)" +``` + +### References + +- [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) +- [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502) diff --git a/applications/cytoland/examples/dl-course-exercise/prepare-exercise.sh b/applications/cytoland/examples/dl-course-exercise/prepare-exercise.sh new file mode 100644 index 000000000..630fb201e --- /dev/null +++ b/applications/cytoland/examples/dl-course-exercise/prepare-exercise.sh @@ -0,0 +1,10 @@ +# Run ruff format on .py files +# ruff format viscy + +# Convert .py to ipynb + +# "cell_metadata_filter": "all" preserve cell tags including our solution tags +jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' --update solution.py +jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' --update solution.py --output exercise.ipynb +jupyter nbconvert solution.ipynb --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags task --to notebook --output solution.ipynb +jupyter nbconvert exercise.ipynb --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook --output exercise.ipynb diff --git a/applications/cytoland/examples/dl-course-exercise/setup_TA.sh b/applications/cytoland/examples/dl-course-exercise/setup_TA.sh new file mode 100644 index 000000000..5cffb5f03 --- /dev/null +++ b/applications/cytoland/examples/dl-course-exercise/setup_TA.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env -S bash -i +# +# Image-translation exercise — TA / course-operator setup. +# +# Pre-stage the OME-Zarr datasets and pretrained VSCyto2D checkpoint onto a +# shared filesystem BEFORE the course starts so each student doesn't have to +# re-download ~14 GB. This typically takes 20–40 min depending on link speed +# and storage backend. +# +# Usage: +# +# # Default: stage to ~/data/image_translation/ +# bash setup_TA.sh +# +# # Stage to a shared mount (recommended for courses): +# DATA_ROOT=/mnt/efs/image_translation bash setup_TA.sh +# +# Once this finishes, students point setup_student.sh at the same DATA_ROOT +# and skip the download: +# +# DATA_ROOT=/mnt/efs/image_translation bash setup_student.sh +# +# This script does NOT create a Python environment. Run setup_student.sh for +# that (it can be run before, after, or instead of this script). + +set -euo pipefail + +START_DIR=$(pwd) +KERNEL_NAME="${KERNEL_NAME:-06_image_translation}" +DATA_ROOT="${DATA_ROOT:-$HOME/data/$KERNEL_NAME}" + +mkdir -p "$DATA_ROOT/training" "$DATA_ROOT/test" "$DATA_ROOT/pretrained_models" + +echo "Staging data + checkpoint into $DATA_ROOT ..." +echo "(this typically takes 20-40 min)" + +cd "$DATA_ROOT/training" +wget -m -np -nH --cut-dirs=6 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto2D/training/zarrv3/a549_hoechst_cellmask_train_val.zarr/" + +cd "$DATA_ROOT/test" +wget -m -np -nH --cut-dirs=6 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto2D/test/zarrv3/a549_hoechst_cellmask_test.zarr/" + +cd "$DATA_ROOT/pretrained_models" +wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto2D/VSCyto2D/epoch=399-step=23200.ckpt" +# Second checkpoint used in Task 2.5 (fluorescence -> phase reverse model). +wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto2D/AIMBL_Demo/fluor2phase_step668.ckpt" + +cd "$START_DIR" + +cat <=0.5.0a0) plus the tutorial extras: +# cellpose, torchview, microssim, jupyter, ipywidgets, jupytext. +# If run from inside a checkout of the VisCy monorepo, installs +# the local cytoland workspace package in editable mode (pulls +# viscy-data, viscy-models, viscy-transforms, viscy-utils from +# the workspace). Otherwise installs from PyPI. +# 4. Registers the venv as a Jupyter kernel named "06_image_translation" +# so students see it in VSCode / JupyterLab. +# 5. Downloads the training / test OME-Zarr datasets and the VSCyto2D +# pretrained checkpoint into $DATA_ROOT (default ~/data/06_image_translation), +# ONLY IF the data is not already there. If a TA has pre-staged data +# on a shared filesystem, point DATA_ROOT at it to skip the download: +# +# DATA_ROOT=/mnt/shared/image_translation bash setup_student.sh +# +# Run this from the exercise folder: +# cd applications/cytoland/examples/dl-course-exercise +# bash setup_student.sh + +set -euo pipefail + +START_DIR=$(pwd) +KERNEL_NAME="${KERNEL_NAME:-06_image_translation}" +PYTHON_VERSION="${PYTHON_VERSION:-3.13}" + +# --- Detect optional VisCy monorepo root (four levels up from this script) - +# When this exercise lives inside a viscy clone, install cytoland in editable +# mode against the local workspace. Otherwise fall back to PyPI. +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +MONOREPO_ROOT="$(cd "$SCRIPT_DIR/../../../.." 2>/dev/null && pwd || true)" +if [[ -n "${MONOREPO_ROOT:-}" && -f "$MONOREPO_ROOT/pyproject.toml" ]] \ + && grep -q '^name = "viscy"' "$MONOREPO_ROOT/pyproject.toml"; then + INSTALL_MODE="workspace" +else + INSTALL_MODE="pypi" + MONOREPO_ROOT="" +fi +echo "Install mode: $INSTALL_MODE" + +# --- 1. Install uv if missing ---------------------------------------------- +if ! command -v uv >/dev/null 2>&1; then + echo "uv not found — installing to ~/.local/bin ..." + curl -LsSf https://astral.sh/uv/install.sh | sh + # The installer updates shell profiles but not the current shell + export PATH="$HOME/.local/bin:$PATH" +fi +echo "Using uv: $(uv --version)" + +# --- 2. Create a venv under this exercise folder --------------------------- +VENV_DIR="$SCRIPT_DIR/.venv" +uv venv --python "$PYTHON_VERSION" "$VENV_DIR" +PY="$VENV_DIR/bin/python" + +# --- 3. Install cytoland + viscy + tutorial extras ------------------------- +if [[ "$INSTALL_MODE" == "workspace" ]]; then + echo "Installing cytoland (editable) from $MONOREPO_ROOT ..." + uv pip install --python "$PY" -e "$MONOREPO_ROOT/applications/cytoland[metrics]" +else + echo "Installing cytoland + viscy from PyPI (>=0.5.0a0) ..." + uv pip install --python "$PY" --prerelease=allow \ + "viscy>=0.5.0a0" \ + "cytoland[metrics]>=0.5.0a0" +fi +uv pip install --python "$PY" \ + cellpose \ + torchview \ + microssim \ + jupyter \ + ipykernel \ + ipywidgets \ + jupytext \ + nbformat \ + nbconvert + +# --- 4. Register the venv as a Jupyter kernel ------------------------------ +"$PY" -m ipykernel install --user \ + --name "$KERNEL_NAME" \ + --display-name "Python ($KERNEL_NAME)" +echo "Registered Jupyter kernel: $KERNEL_NAME" + +# --- 5. Download data + pretrained checkpoints (skip if already present) ---- +DATA_ROOT="${DATA_ROOT:-$HOME/data/$KERNEL_NAME}" +TRAINING_ZARR="$DATA_ROOT/training/a549_hoechst_cellmask_train_val.zarr" +TEST_ZARR="$DATA_ROOT/test/a549_hoechst_cellmask_test.zarr" +CHECKPOINT="$DATA_ROOT/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt" +FLUOR2PHASE_CKPT="$DATA_ROOT/pretrained_models/AIMBL_Demo/fluor2phase_step668.ckpt" + +mkdir -p "$DATA_ROOT/training" "$DATA_ROOT/test" "$DATA_ROOT/pretrained_models" + +if [[ -d "$TRAINING_ZARR" && -d "$TEST_ZARR" && -f "$CHECKPOINT" && -f "$FLUOR2PHASE_CKPT" ]]; then + echo "Data already present at $DATA_ROOT — skipping download." +else + echo "Downloading data + checkpoints to $DATA_ROOT ..." + cd "$DATA_ROOT/training" + wget -m -np -nH --cut-dirs=6 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto2D/training/zarrv3/a549_hoechst_cellmask_train_val.zarr/" + + cd "$DATA_ROOT/test" + wget -m -np -nH --cut-dirs=6 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto2D/test/zarrv3/a549_hoechst_cellmask_test.zarr/" + + cd "$DATA_ROOT/pretrained_models" + wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto2D/VSCyto2D/epoch=399-step=23200.ckpt" + # Second checkpoint used in Task 2.5 (fluorescence -> phase reverse model). + wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto2D/AIMBL_Demo/fluor2phase_step668.ckpt" +fi + +cd "$START_DIR" + +cat < +# The exercise is organized in 3 parts: + +# + +# + +# %% [markdown] tags=[] +#
+# Set your python kernel to 06_image_translation +#
+ +# %% [markdown] tags=[] +# ## PyTorch Lightning in one minute +# +# If you've used plain PyTorch you already know the pattern: write a model, write a +# `for batch in dataloader` loop, move tensors to `cuda`, call `loss.backward()`, step +# the optimizer, remember to `zero_grad()`, log every N steps, save a checkpoint, and +# repeat for validation. That boilerplate is the same in every project — so +# [PyTorch Lightning](https://lightning.ai) factors it out into **three objects** and +# owns the training loop for you. +# +# | Lightning object | What it holds | In this exercise | +# | --- | --- | --- | +# | `LightningDataModule` | How to load, split, augment, and batch your data (`train/val/test/predict_dataloader`) | `HCSDataModule` — reads OME-Zarr and yields `{"source": ..., "target": ...}` dicts | +# | `LightningModule` | The network, the loss, and what happens in `training_step` / `validation_step` (one batch at a time) | `VSUNet` — wraps the UNeXt2 architecture and the virtual-staining loss | +# | `Trainer` | The loop: device placement, mixed precision, logging, checkpointing, multi-GPU | `VisCyTrainer` — a thin subclass with VisCy-friendly defaults | +# +# You don't write a `for` loop. You call **`trainer.fit(model, datamodule)`** and +# Lightning drives everything. The trainer handles: +# +# - moving batches to the right device (`accelerator="gpu"`, `devices=[0]`) +# - mixed-precision training (`precision="16-mixed"`) so you use less GPU memory +# - when to log metrics / images (`log_every_n_steps`) and where (`logger=TensorBoardLogger(...)`) +# - saving checkpoints automatically under the logger's directory +# - running a sanity check on a single batch before real training (`fast_dev_run=True`) +# +# VisCy builds on top of Lightning and provides the `HCSDataModule` and `VSUNet` +# classes so you don't have to subclass `LightningDataModule` / `LightningModule` +# yourself — you configure them via constructor arguments and let Lightning run. +# When you see `trainer.fit(...)` below, that single call replaces a ~50-line hand- +# written training loop. + +# %% [markdown] +# # Part 1: Log training data to tensorboard, start training a model. +# --------- +# Learning goals: + +# - Load the OME-zarr dataset and examine the channels (A549). +# - Configure and understand the data loader. +# - Log some patches to tensorboard. +# - Initialize a 2D UNeXt2 model for virtual staining of nuclei and membrane from phase. +# - Start training the model to predict nuclei and membrane from phase. + +# %% Imports +import os +from glob import glob +from pathlib import Path +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torchview +import torchvision +from cellpose import models +from iohub import open_ome_zarr +from iohub.reader import print_info +from lightning.pytorch import seed_everything +from lightning.pytorch.loggers import TensorBoardLogger + +# microSSIM: SSIM variant designed for fluorescence microscopy. +from microssim import micro_structural_similarity +from natsort import natsorted +from numpy.typing import ArrayLike + +# pytorch lightning wrapper for Tensorboard. +from skimage.color import label2rgb +from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard +from torchmetrics.functional import accuracy, jaccard_index +from torchmetrics.functional.segmentation import dice_score +from tqdm import tqdm + +# Trainer class and UNet from the cytoland package. +from cytoland.engine import VSUNet + +# HCSDataModule makes it easy to load data during training. +from viscy_data.hcs import HCSDataModule + +# training augmentations +from viscy_transforms import ( + NormalizeSampled, + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, +) +from viscy_utils.evaluation.metrics import mean_average_precision +from viscy_utils.losses import MixedLoss +from viscy_utils.trainer import VisCyTrainer + +# %% +# seed random number generators for reproducibility. +seed_everything(42, workers=True) + +# Paths to data and log directory +top_dir = Path("~/data").expanduser() # If this fails, point to your data directory (e.g. a shared course mount). + +# Path to the training data +data_path = top_dir / "06_image_translation/training/a549_hoechst_cellmask_train_val.zarr" + +# Path where we will save our training logs +training_top_dir = Path(f"{os.getcwd()}/data/") +# Create top_training_dir directory if needed, and launch tensorboard +training_top_dir.mkdir(parents=True, exist_ok=True) +log_dir = training_top_dir / "06_image_translation/logs/" +# Create log directory if needed, and launch tensorboard +log_dir.mkdir(parents=True, exist_ok=True) + +if not data_path.exists(): + raise FileNotFoundError(f"Data not found at {data_path}. Please check the top_dir and data_path variables.") + +# %% [markdown] tags=[] +# The next cell starts tensorboard. + +#
+# If you launched jupyter lab from ssh terminal, add --host <your-server-name> to the tensorboard command below. <your-server-name> is the address of your compute node that ends in amazonaws.com. + +#
+ + +# %% tags=[] +# Imports and paths +# Function to find an available port +def find_free_port(): + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +# Launch TensorBoard on the browser +def launch_tensorboard(log_dir): + import subprocess + + port = find_free_port() + tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}" + process = subprocess.Popen(tensorboard_cmd, shell=True) + print( + f"TensorBoard started at http://localhost:{port}. \n" + "If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL." + ) + return process + + +# Launch tensorboard and click on the link to view the logs. +tensorboard_process = launch_tensorboard(log_dir) +# %% [markdown] tags = [] +#
+# If you are using VSCode and a remote server, you will need to forward the port to view the tensorboard.
+# Take note of the port number was assigned in the previous cell.(i.e http://localhost:{port_number_assigned})
+ +# Locate the your VSCode terminal and select the Ports tab
+#
    +#
  • Add a new port with the port_number_assigned +#
+# Click on the link to view the tensorboard and it should open in your browser. +#
+ + +# %% [markdown] tags=[] +# ## Load OME-Zarr Dataset +# +# **OME-Zarr** is a chunked, cloud-friendly microscopy format; **HCS layout** +# nests the zarr store like a physical plate — `row/col/field/level/T/C/Z/Y/X` — +# so each FOV is addressable by `dataset[f"{row}/{col}/{field}/{level}"]` and +# returns an `(T, C, Z, Y, X)` array. +# +# This dataset has 34 FOVs of 2048×2048 images across 3 channels (QPI, nuclei +# stained with DAPI, membrane stained with Cellmask), a single pyramid level +# `0`, and a single time point. + +# %% [markdown] tags=[] +#
+# You can inspect the tree structure by using your terminal: +# iohub info -v "path-to-ome-zarr" + +#
+# More info on the CLI: +# iohub info --help to see the help menu. +#
+# %% +# This is the python function called by `iohub info` CLI command +print_info(data_path, verbose=True) + +# Open and inspect the dataset. +dataset = open_ome_zarr(data_path) + +# %% [markdown] tags=[] +#
+# +# ### Task 1.1 +# Look at a couple different fields of view (FOVs) by changing the `field` variable. +# Check the cell density, the cell morphologies, and fluorescence signal. +# HINT: look at the HCS Plate format to see what your options are. +#
+# %% tags=[] +# Use the field and pyramid_level below to visualize data. +row = 0 +col = 0 +field = 9 # TODO: Change this to explore data. + +pyaramid_level = 0 + +# `channel_names` is the metadata that is stored with data according to the OME-NGFF spec. +n_channels = len(dataset.channel_names) + +image = dataset[f"{row}/{col}/{field}/{pyaramid_level}"].numpy() +print(f"data shape: {image.shape}, FOV: {field}, pyramid level: {pyaramid_level}") + +figure, axes = plt.subplots(1, n_channels, figsize=(9, 3)) + +for i in range(n_channels): + channel_image = image[0, i, 0] + # Adjust contrast to 0.5th and 99.5th percentile of pixel values. + p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[i].imshow(channel_image, cmap="gray") + axes[i].axis("off") + axes[i].set_title(dataset.channel_names[i]) +plt.tight_layout() + +# %% [markdown] tags=[] +# ## Explore the effects of augmentation on batch. +# +# Time to meet the first of the three Lightning objects from the primer above: the +# **DataModule**. `HCSDataModule` is VisCy's `LightningDataModule` — it knows how +# to read an OME-Zarr store, split FOVs into train/val, apply normalization and +# augmentations, and hand the Trainer a PyTorch `DataLoader`. You configure it +# once; Lightning calls the right method (`train_dataloader()`, +# `val_dataloader()`, etc.) at the right time. +# +# Every sample `HCSDataModule` yields is a Python `dict` (not a tuple) with: +# +# - `source`: the input image, a tensor of shape `(1, 1, Y, X)` → `(C, Z, Y, X)` +# - `target`: the target image, a tensor of shape `(2, 1, Y, X)` → `(C, Z, Y, X)` +# - `index` : the tuple `(HCS location, time, z-slice)` identifying the sample +# +# A `batch` is a dict of the same keys with an extra leading batch dimension, e.g. +# `batch["source"].shape == (B, 1, 1, Y, X)`. The `training_step` method inside +# `VSUNet` receives this dict directly — no unpacking required. + +# %% [markdown] tags=[] +#
+# +# ### Task 1.2 +# - Run the next cell to setup a logger for your augmentations. +# - Setup the `HCSDataloader()` in for training. +# - Configure the dataloader for the `"UNeXt2_2D"` +# - Configure the dataloader for the phase (source) to fluorescence cell nuclei and membrane (targets) regression task. +# - Configure the dataloader for training. Hint: use the `HCSDataloader.setup()` +# - Open your tensorboard and look at the `IMAGES tab`. +# +# Note: If tensorboard is not showing images or the plots, try refreshing and using the "Images" tab. +#
+ + +# %% +# Define a function to write a batch to tensorboard log. +def log_batch_tensorboard(batch, batchno, writer, card_name): + """ + Logs a batch of images to TensorBoard. + + Args: + batch (dict): A dictionary containing the batch of images to be logged. + writer (SummaryWriter): A TensorBoard SummaryWriter object. + card_name (str): The name of the card to be displayed in TensorBoard. + + Returns: + None + """ + batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor. + batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(1) # batch_size x 1 x Y x X tensor. + batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(1) # batch_size x 1 x Y x X tensor. + + p1, p99 = np.percentile(batch_membrane, (0.1, 99.9)) + batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1) + + p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9)) + batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1) + + p1, p99 = np.percentile(batch_phase, (0.1, 99.9)) + batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1) + + [N, C, H, W] = batch_phase.shape + interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype) + interleaved_images[0::3, :] = batch_phase + interleaved_images[1::3, :] = batch_nuclei + interleaved_images[2::3, :] = batch_membrane + + grid = torchvision.utils.make_grid(interleaved_images, nrow=3) + + # add the grid to tensorboard + writer.add_image(card_name, grid, batchno) + + +# Define a function to visualize a batch on jupyter, in case tensorboard is finicky +def log_batch_jupyter(batch): + """ + Logs a batch of images on jupyter using ipywidget. + + Args: + batch (dict): A dictionary containing the batch of images to be logged. + + Returns: + None + """ + batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor. + batch_size = batch_phase.shape[0] + batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(1) # batch_size x 1 x Y x X tensor. + batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(1) # batch_size x 1 x Y x X tensor. + + p1, p99 = np.percentile(batch_membrane, (0.1, 99.9)) + batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1) + + p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9)) + batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1) + + p1, p99 = np.percentile(batch_phase, (0.1, 99.9)) + batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1) + + n_channels = batch["target"].shape[1] + batch["source"].shape[1] + plt.figure() + fig, axes = plt.subplots(batch_size, n_channels, figsize=(n_channels * 2, batch_size * 2)) + [N, C, H, W] = batch_phase.shape + for sample_id in range(batch_size): + axes[sample_id, 0].imshow(batch_phase[sample_id, 0]) + axes[sample_id, 1].imshow(batch_nuclei[sample_id, 0]) + axes[sample_id, 2].imshow(batch_membrane[sample_id, 0]) + + for i in range(n_channels): + axes[sample_id, i].axis("off") + axes[sample_id, i].set_title(dataset.channel_names[i]) + plt.tight_layout() + plt.show() + + +# %% tags=["task"] +# Initialize the data module. + +BATCH_SIZE = 4 + +# 4 is a perfectly reasonable batch size +# (batch size does not have to be a power of 2) +# See: https://sebastianraschka.com/blog/2022/batch-size-2.html + +# ####################### +# ##### TODO ######## +# ####################### +# HINT: Run dataset.channel_names +source_channel = ["TODO"] +target_channel = ["TODO", "TODO"] + +# ####################### +# ##### TODO ######## +# ####################### +data_module = HCSDataModule( + data_path, + z_window_size=1, + source_channel=source_channel, + target_channel=target_channel, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=8, + yx_patch_size=(256, 256), # larger patch size makes it easy to see augmentations. + augmentations=[], # Turn off augmentation for now. + normalizations=[], # Turn off normalization for now. +) +# ####################### +# ##### TODO ######## +# ####################### +# Setup the data_module to fit. HINT: data_module.setup() + + +# Evaluate the data module +print( + f"Samples in training set: {len(data_module.train_dataset)}, " + f"samples in validation set:{len(data_module.val_dataset)}" +) +train_dataloader = data_module.train_dataloader() +# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard. +writer = SummaryWriter(log_dir=f"{log_dir}/view_batch") +# Draw a batch and write to tensorboard. +batch = next(iter(train_dataloader)) +log_batch_tensorboard(batch, 0, writer, "augmentation/none") +writer.close() +# %% tags=["solution"] +# ####################### +# ##### SOLUTION ######## +# ####################### + +BATCH_SIZE = 4 +# 4 is a perfectly reasonable batch size +# (batch size does not have to be a power of 2) +# See: https://sebastianraschka.com/blog/2022/batch-size-2.html + +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] + +data_module = HCSDataModule( + data_path, + z_window_size=1, + source_channel=source_channel, + target_channel=target_channel, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=8, + yx_patch_size=(256, 256), # larger patch size makes it easy to see augmentations. + augmentations=[], # Turn off augmentation for now. + normalizations=[], # Turn off normalization for now. +) + +# Setup the data_module to fit. HINT: data_module.setup() +data_module.setup("fit") + +# Evaluate the data module +print( + f"Samples in training set: {len(data_module.train_dataset)}, " + f"samples in validation set:{len(data_module.val_dataset)}" +) +train_dataloader = data_module.train_dataloader() +# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard. +writer = SummaryWriter(log_dir=f"{log_dir}/view_batch") +# Draw a batch and write to tensorboard. +batch = next(iter(train_dataloader)) +log_batch_tensorboard(batch, 0, writer, "augmentation/none") +writer.close() +# %% [markdown] tags=[] +#
+# +# ### Questions +# 1. What are the two channels in the target image? +# 2. How many samples are in the training and validation set? What determined that split? +# +# Note: If tensorboard is not showing images, try refreshing and using the "Images" tab. +#
+ +# %% [markdown] tags=[] +# If your tensorboard is causing issues, you can visualize directly on Jupyter /VSCode +# %% +# Visualize in Jupyter +log_batch_jupyter(batch) + +# %% [markdown] tags=[] +#
+#

Question for Task 1.3

+# 1. How do they make the model more robust to imaging parameters or conditions +# without having to acquire data for every possible condition?
+#
+# %% [markdown] tags=[] +# Each augmentation simulates a real-world source of microscope-to-microscope +# variation so the model doesn't overfit to the training conditions: +# +# | Transform | Simulates | +# | --- | --- | +# | `RandWeightedCropd` | random crops biased toward signal-dense regions (foreground oversampling) | +# | `RandAffined` | stage rotation, scale drift, slight shear between acquisitions | +# | `RandAdjustContrastd` | illumination / exposure differences | +# | `RandScaleIntensityd` | gain / brightness differences between cameras | +# | `RandGaussianNoised` | shot and read noise at different detector settings | +# | `RandGaussianSmoothd` | small focus drift / defocus | +# %% [markdown] tags=[] +#
+# +# ### Task 1.3 +# Add the following augmentations: +# - Add augmentations to rotate about $\pi$ around z-axis, 30% scale in (y,x), +# shearing of 1% in (y,x), and no padding with zeros with a probablity of 80%. +# - Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%. +# +# HINT: `RandAffined()` and `RandGaussianNoised()` are MONAI dictionary +# transforms re-exported from `viscy_transforms`. See the MONAI docs for +# arguments and probability semantics: +# [RandAffined](https://docs.monai.io/en/stable/transforms.html#randaffined), +# [RandGaussianNoised](https://docs.monai.io/en/stable/transforms.html#randgaussiannoised). +# You can also inspect any transform in a cell with `RandAffined?`.

+# [Compare your choice of augmentations against the pretrained models and config files](https://github.com/mehta-lab/VisCy/releases/download/v0.1.0/VisCy-0.1.0-VS-models.zip). +#
+# %% tags=["task"] +# Here we turn on data augmentation and rerun setup +# ####################### +# ##### TODO ######## +# ####################### +# HINT: Run dataset.channel_names +source_channel = ["TODO"] +target_channel = ["TODO", "TODO"] + +augmentations = [ + RandWeightedCropd( + keys=source_channel + target_channel, + spatial_size=(1, 384, 384), + num_samples=2, + w_key=target_channel[0], + ), + # ####################### + # ##### TODO ######## + # ####################### + ## TODO: Add Random Affine Transorms + ## Write code below + # ####################### + RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5), + # ####################### + # ##### TODO ######## + # ####################### + ## TODO: Add Random Gaussian Noise + ## Write code below + # ####################### + RandGaussianSmoothd( + keys=source_channel, + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), + prob=0.5, + ), +] + +normalizations = [ + NormalizeSampled( + keys=source_channel, + level="fov_statistics", + subtrahend="mean", + divisor="std", + ), + NormalizeSampled( + keys=target_channel, + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ), +] + +data_module.augmentations = augmentations +data_module.normalizations = normalizations + +data_module.setup("fit") + +# get the new data loader with augmentation turned on +augmented_train_dataloader = data_module.train_dataloader() + +# Draw batches and write to tensorboard +writer = SummaryWriter(log_dir=f"{log_dir}/view_batch") +augmented_batch = next(iter(augmented_train_dataloader)) +log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some") +writer.close() + +# %% tags=["solution"] +# ####################### +# ##### SOLUTION ######## +# ####################### +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] + +augmentations = [ + RandWeightedCropd( + keys=source_channel + target_channel, + spatial_size=(1, 384, 384), + num_samples=2, + w_key=target_channel[0], + ), + RandAffined( + keys=source_channel + target_channel, + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, 0.3, 0.3], + prob=0.8, + padding_mode="zeros", + shear_range=[0.0, 0.01, 0.01], + ), + RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5), + RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3), + RandGaussianSmoothd( + keys=source_channel, + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), + prob=0.5, + ), +] + +normalizations = [ + NormalizeSampled( + keys=source_channel, + level="fov_statistics", + subtrahend="mean", + divisor="std", + ), + NormalizeSampled( + keys=target_channel, + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ), +] + +data_module.augmentations = augmentations + +# Setup the data_module to fit. HINT: data_module.setup() +data_module.setup("fit") + +# get the new data loader with augmentation turned on +augmented_train_dataloader = data_module.train_dataloader() + +# Draw batches and write to tensorboard +writer = SummaryWriter(log_dir=f"{log_dir}/view_batch") +augmented_batch = next(iter(augmented_train_dataloader)) +log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some") +writer.close() + +# %% [markdown] tags=[] +#
+#

Question for Task 1.3

+# 1. Look at your tensorboard. Can you tell the agumentations were applied to the sample batch? Compare the batch with and without augmentations.
+# 2. Are these augmentations good enough? What else would you add? +#
+ +# %% [markdown] +# Visualize directly on Jupyter + +# %% +log_batch_jupyter(augmented_batch) + +# %% [markdown] tags=[] +# ## Train a 2D U-Net model to predict nuclei and membrane from phase. +# ### Constructing a 2D UNeXt2 using VisCy +# +# Now we meet the second Lightning object: the **`LightningModule`**. `VSUNet` is +# VisCy's `LightningModule` and it bundles three things that plain PyTorch keeps +# separate: +# +# 1. **The network** — a UNeXt2 architecture, configured through `model_config`. +# 2. **The loss** — passed in as `loss_function=MixedLoss(...)`. +# 3. **The per-batch logic** — `training_step` and `validation_step` methods that +# take one `{"source", "target"}` batch, run the forward pass, compute the +# loss, and return it. You don't see these methods here because they're +# defined once inside `VSUNet`; Lightning calls them for you. +# +# Other constructor arguments you'll recognize from plain PyTorch training: +# `lr` is the learning rate, `schedule="WarmupCosine"` picks the LR schedule, +# and `freeze_encoder=False` lets gradients flow through the whole network. +# `log_batches_per_epoch` is a VisCy extra — it tells the module how many image +# samples to push to TensorBoard each epoch. +# %% [markdown] +# **Architecture config** — UNeXt2 is a U-Net with ConvNeXt-style blocks: +# +# - `encoder_blocks=[3, 3, 9, 3]` and `dims=[96, 192, 384, 768]` — 4 downsampling +# stages with that many blocks and feature channels per stage (last stage is +# the bottleneck). More blocks / dims = more capacity and more compute. +# - `decoder_conv_blocks=2` — conv blocks after each upsampling step. +# - `stem_kernel_size=(1, 2, 2)` and `in_stack_depth=1` — this is a 2D model, +# so we use 1 z-slice and a stem that doesn't convolve across z. +# +# **Loss** — `MixedLoss(l1_alpha=0.5, ms_dssim_alpha=0.5)` combines per-pixel +# L1 (penalizes intensity error) with multi-scale SSIM (penalizes structural +# error — edges, texture, shape). L1 alone produces blurry outputs; MS-SSIM +# alone ignores absolute intensity. The 0.5/0.5 mix balances both. +# +# **Schedule** — `schedule="WarmupCosine"`, `lr=6e-4`: the learning rate ramps +# up from 0 over the first few epochs (warmup), then follows a cosine decay +# toward 0. Warmup avoids early gradient blow-up with AdamW; cosine decay is a +# strong default for vision transformer / ConvNeXt-style encoders. + +# %% [markdown] +#
+# +# ### Task 1.4 +# - Run the next cell to instantiate the `UNeXt2_2D` model +# - Configure the network for the phase (source) to fluorescence cell nuclei and membrane (targets) regression task. +# - Call the VSUNet with the `"UNeXt2_2D"` architecture. +# - Run the next cells to instantiate data module and trainer. +# - Add the source channel name and the target channel names +# - Start the training
+# +# Note
+# See ``viscy.translation.engine.VSUNet`` ([source code](https://github.com/mehta-lab/VisCy/blob/main/viscy/translation/engine.py)) and ``viscy.unet.networks.fcmae`` ([source code](https://github.com/mehta-lab/VisCy/blob/main/viscy/unet/networks/fcmae.py)) to learn more about the configuration parameters and FCMAE architecture. +#
+ +# %% tags=["task"] +# Create a 2D UNet. +GPU_ID = 0 + +BATCH_SIZE = 16 +YX_PATCH_SIZE = (256, 256) + +# ####################### +# ##### TODO ######## +# ####################### +# Dictionary that specifies key parameters of the model. +phase2fluor_config = dict( + in_channels=..., # TODO how many input channels are we feeding Hint: int?, + out_channels=..., # TODO how many output channels are we solving for? Hint: int, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=..., # TODO: was this a 2D or 3D input? HINT: int, + pretraining=False, +) + +# ####################### +# ##### TODO ######## +# ####################### +phase2fluor_model = VSUNet( + architecture=..., # TODO: 2D UNeXt2 architecture + model_config=phase2fluor_config.copy(), + loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), + schedule="WarmupCosine", + lr=6e-4, + log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. + freeze_encoder=False, +) + +# ####################### +# ##### TODO ######## +# ####################### +# HINT: Run dataset.channel_names +source_channel = ["TODO"] +target_channel = ["TODO", "TODO"] + +# Setup the data module. +phase2fluor_2D_data = HCSDataModule( + data_path, + source_channel=source_channel, + target_channel=target_channel, + z_window_size=1, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=8, + yx_patch_size=YX_PATCH_SIZE, + augmentations=augmentations, + normalizations=normalizations, +) +phase2fluor_2D_data.setup("fit") +# fast_dev_run runs a single batch of data through the model to check for errors. +trainer = VisCyTrainer(accelerator="gpu", devices=[GPU_ID], precision="16-mixed", fast_dev_run=True) + +# trainer class takes the model and the data module as inputs. +trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data) + + +# %% tags=["solution"] + +# Here we are creating a 2D UNet. +GPU_ID = 0 + +BATCH_SIZE = 16 +YX_PATCH_SIZE = (256, 256) + +# Dictionary that specifies key parameters of the model. +# ####################### +# ##### SOLUTION ######## +# ####################### +phase2fluor_config = dict( + in_channels=1, + out_channels=2, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + pretraining=False, +) + +phase2fluor_model = VSUNet( + architecture="UNeXt2_2D", # 2D UNeXt2 architecture + model_config=phase2fluor_config.copy(), + loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), + schedule="WarmupCosine", + lr=6e-4, + log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. + freeze_encoder=False, +) + +# ### Instantiate data module and trainer, test that we are setup to launch training. +# ####################### +# ##### SOLUTION ######## +# ####################### +# Selecting the source and target channel names from the dataset. +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] +# Setup the data module. +phase2fluor_2D_data = HCSDataModule( + data_path, + source_channel=source_channel, + target_channel=target_channel, + z_window_size=1, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=8, + yx_patch_size=YX_PATCH_SIZE, + augmentations=augmentations, + normalizations=normalizations, +) +# ####################### +# ##### SOLUTION ######## +# ####################### +phase2fluor_2D_data.setup("fit") + +# --- The third Lightning object: the Trainer --- +# +# This is the object that replaces the hand-written training loop. Each kwarg +# controls one piece of the boilerplate Lightning is handling for you: +# +# - accelerator="gpu", devices=[GPU_ID] +# Pick the device. No more ".to(device)" sprinkled through your code — +# Lightning moves model + every batch for you. +# - precision="16-mixed" +# Automatic mixed-precision training (fp16 activations, fp32 master +# weights). Cuts GPU memory roughly in half and speeds up matmuls on +# modern GPUs — no autocast() context managers needed. +# - fast_dev_run=True +# Sanity check: run ONE training batch + ONE validation batch and exit. +# Use this on every new pipeline to catch shape bugs, NaN losses, or +# bad paths *before* you commit to a multi-hour training job. +# +# trainer.fit(model, datamodule=...) then drives the whole thing: it calls +# datamodule.setup(), pulls batches from train_dataloader(), invokes +# model.training_step(batch), runs loss.backward() + optimizer.step() + +# zero_grad(), runs validation, logs to TensorBoard, and saves checkpoints. +trainer = VisCyTrainer(accelerator="gpu", devices=[GPU_ID], precision="16-mixed", fast_dev_run=True) +trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data) + +# %% [markdown] tags=[] +# ## View model graph. +# +# PyTorch uses dynamic graphs under the hood. +# The graphs are constructed on the fly. +# This is in contrast to TensorFlow, +# where the graph is constructed before the training loop and remains static. +# In other words, the graph of the network can change with every forward pass. +# Therefore, we need to supply an input tensor to construct the graph. +# The input tensor can be a random tensor of the correct shape and type. +# We can also supply a real image from the dataset. +# The latter is more useful for debugging. + +# %% [markdown] +#
+# +# ### Task 1.5 +# Run the next cell to generate a graph representation of the model architecture. +#
+ +# %% +# visualize graph of phase2fluor model as image. +model_graph_phase2fluor = torchview.draw_graph( + phase2fluor_model, + phase2fluor_2D_data.train_dataset[0]["source"][0].unsqueeze(dim=0), + roll=True, + depth=3, # adjust depth to zoom in. + device="cpu", + # expand_nested=True, +) +# Print the image of the model. +model_graph_phase2fluor.visual_graph + +# %% [markdown] tags=[] +#
+# +# ### Question: +# Can you recognize the UNet structure and skip connections in this graph visualization? +#
+ +# %% [markdown] +#
+ +#

Task 1.6

+# Start training by running the following cell. Check the new logs on the tensorboard. +#
+ +# %% [markdown] +#
+# Before re-running training: if a previous training cell is still +# holding the GPU (you'll see CUDA out of memory), restart the +# Jupyter kernel (Kernel → Restart in Jupyter, or Restart in +# VSCode) to release the previous model and optimizer state. The dataset and +# augmentations will rebuild quickly; only the trained weights need to be +# re-loaded via load_from_checkpoint if you want to resume. +#
+ +# %% [markdown] +# Now that `fast_dev_run` confirmed the pipeline works end-to-end, we switch +# to a "real" Trainer configured for an actual multi-epoch run. New Lightning +# knobs appearing here: +# +# - `max_epochs=n_epochs` — run this many passes over the training set, then stop. +# - `log_every_n_steps=steps_per_epoch // 2` — how often Lightning flushes +# scalars (loss, learning rate) to the logger. Setting it to half an epoch +# gives us two data points per epoch without spamming TensorBoard. +# - `logger=TensorBoardLogger(save_dir=log_dir, name="phase2fluor", log_graph=True)` +# — Lightning writes TensorBoard event files *and* model checkpoints under +# `{save_dir}/{name}/version_N/`. You don't call `torch.save` yourself; the +# trainer persists checkpoints automatically, and `log_graph=True` adds the +# network architecture to the Graphs tab. +# +# Calling `trainer.fit` again below runs the full training loop — forward, +# loss, backward, optimizer step, validation every epoch, checkpoint at the +# end — across `max_epochs` epochs. + +# %% +# Check if GPU is available +# You can check by typing `nvidia-smi` +GPU_ID = 0 + +n_samples = len(phase2fluor_2D_data.train_dataset) +steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. +n_epochs = 80 # Set this to 80-100 or the number of epochs you want to train for. + +trainer = VisCyTrainer( + accelerator="gpu", + devices=[GPU_ID], + max_epochs=n_epochs, + precision="16-mixed", + log_every_n_steps=steps_per_epoch // 2, + # log losses and image samples 2 times per epoch. + logger=TensorBoardLogger( + save_dir=log_dir, + # lightning trainer transparently saves logs and model checkpoints in this directory. + name="phase2fluor", + log_graph=True, + ), +) +# Launch training and check that loss and images are being logged on tensorboard. +trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data) + +# Move the model to the GPU. +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +phase2fluor_model.to(device) +# %% [markdown] tags=[] +#
+ +#

Checkpoint 1

+ +# While your model is training, let's think about the following questions:
+#
    +#
  • What is the information content of each channel in the dataset?
  • +#
  • How would you use image translation models?
  • +#
  • What can you try to improve the performance of each model?
  • +#
+ +# Now the training has started, +# we can come back after a while and evaluate the performance! + +#
+# %% [markdown] tags=[] +# # Part 2: Assess your trained model +# +# We evaluate on a held-out test set using two complementary families of metrics: +# +# - **Regression / pixel-level** (Pearson, microSSIM): are predicted +# intensities close to ground truth, per pixel? Cheap, but can hide +# topological errors — a model that merges two nuclei may still score well +# pixel-wise. +# - **Segmentation / instance-level** (Jaccard/IoU, Dice, mAP over IoU +# thresholds): run Cellpose on both predicted and measured fluorescence, +# then compare instance masks. This is what ultimately matters for +# downstream analysis (counting cells, tracking, phenotyping). +# +# Also inspect the validation samples on TensorBoard — the experimental +# nuclei channel is noisy, so "ground truth" is itself imperfect. + +# %% [markdown] +#
+ +#

Task 2.1 Define metrics

+ +# For each of the above metrics, write a brief definition of what they are and what they mean +# for this image translation task. Use your favorite search engine and/or resources. + +#
+ +# %% [markdown] tags=[] +# ``` +# ####################### +# ##### Solution ######## +# ####################### +# ``` +# +# - **Pearson Correlation**: linear correlation between predicted and target +# intensities across all pixels, in `[-1, 1]`. `1` means the prediction is a +# perfect affine rescaling of the target; invariant to mean / contrast +# offsets. Good at flagging "the pattern is right" but blind to structural +# errors that preserve correlation (e.g. a uniformly blurred prediction). +# +# - **microSSIM**: a microscopy-aware variant of +# [Structural Similarity (SSIM)](https://en.wikipedia.org/wiki/Structural_similarity). +# Classic SSIM patch-wise compares local mean, variance, and covariance and +# captures structure Pearson misses (blurring, contrast loss) — but it +# assumes the natural-image dynamic range. Fluorescence microscopy images +# are sparse, dim, and noisy: with the default SSIM parameters the scores +# collapse into a narrow band that barely separates good and bad +# predictions. [microSSIM](https://github.com/juglab/MicroSSIM) +# ([Ashesh et al., 2024](https://arxiv.org/abs/2408.08747)) fixes this by +# subtracting the image background and fitting a per-image rescaling factor +# before computing SSIM, so the metric becomes sensitive over the range of +# intensities microscopy predictions actually live in. We use it as a +# drop-in replacement for `skimage.metrics.structural_similarity`. + +# %% [markdown] tags=[] +# ### Let's compute metrics directly and plot below. +# %% [markdown] tags=[] +#
+# If you weren't able to train or training didn't complete please run the following lines to load the latest checkpoint
+# +# ```python +# phase2fluor_model_ckpt = natsorted(glob( +# str(top_dir / "06_image_translation/logs/phase2fluor/version*/checkpoints/*.ckpt") +# ))[-1] +# ``` +#
+# NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything. +# Run the following: +# +# ```python +# phase2fluor_model_ckpt = natsorted(glob( +# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt") +# ))[-1] +# ``` + +# ```python +# phase2fluor_config = dict( +# in_channels=1, +# out_channels=2, +# encoder_blocks=[3, 3, 9, 3], +# dims=[96, 192, 384, 768], +# decoder_conv_blocks=2, +# stem_kernel_size=(1, 2, 2), +# in_stack_depth=1, +# pretraining=False, +# ) +# Load the model checkpoint +# phase2fluor_model = VSUNet.load_from_checkpoint( +# phase2fluor_model_ckpt, +# architecture="UNeXt2_2D", +# model_config = phase2fluor_config, +# accelerator='gpu' +# ) +# ```` +#
+# %% +# Setup the test data module. +test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr" +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] + +test_data = HCSDataModule( + test_data_path, + source_channel=source_channel, + target_channel=target_channel, + z_window_size=1, + batch_size=1, + num_workers=8, +) +test_data.setup("test") + +test_metrics = pd.DataFrame(columns=["pearson_nuc", "microSSIM_nuc", "pearson_mem", "microSSIM_mem"]) + + +# %% +# Compute metrics directly and plot here. +def normalize_fov(input: ArrayLike): + "Normalizing the fov with zero mean and unit variance" + mean = np.mean(input) + std = np.std(input) + return (input - mean) / std + + +for i, sample in enumerate(tqdm(test_data.test_dataloader(), desc="Computing metrics per sample")): + phase_image = sample["source"].to(phase2fluor_model.device) + with torch.inference_mode(): # turn off gradient computation. + predicted_image = phase2fluor_model(phase_image) + + target_image = sample["target"].cpu().numpy().squeeze(0) # Squeezing batch dimension. + predicted_image = predicted_image.cpu().numpy().squeeze(0) + phase_image = phase_image.cpu().numpy().squeeze(0) + target_mem = normalize_fov(target_image[1, 0, :, :]) + target_nuc = normalize_fov(target_image[0, 0, :, :]) + # slicing channel dimension, squeezing z-dimension. + predicted_mem = normalize_fov(predicted_image[1, :, :, :].squeeze(0)) + predicted_nuc = normalize_fov(predicted_image[0, :, :, :].squeeze(0)) + + # Compute microSSIM and pearson correlation. + ssim_nuc = micro_structural_similarity(target_nuc, predicted_nuc) + ssim_mem = micro_structural_similarity(target_mem, predicted_mem) + pearson_nuc = np.corrcoef(target_nuc.flatten(), predicted_nuc.flatten())[0, 1] + pearson_mem = np.corrcoef(target_mem.flatten(), predicted_mem.flatten())[0, 1] + + test_metrics.loc[i] = { + "pearson_nuc": pearson_nuc, + "microSSIM_nuc": ssim_nuc, + "pearson_mem": pearson_mem, + "microSSIM_mem": ssim_mem, + } + +# Plot the following metrics +test_metrics.boxplot( + column=["pearson_nuc", "microSSIM_nuc", "pearson_mem", "microSSIM_mem"], + rot=30, +) + + +# %% +# Adjust the image to the 0.5-99.5 percentile range. +def process_image(image): + p_low, p_high = np.percentile(image, (0.5, 99.5)) + return np.clip(image, p_low, p_high) + + +# Plot the predicted image vs target image. +channel_titles = [ + "Phase", + "Target Nuclei", + "Target Membrane", + "Predicted Nuclei", + "Predicted Membrane", +] +fig, axes = plt.subplots(5, 1, figsize=(20, 20)) + +# Get a writer to output the images into tensorboard and plot the source, predictions and target images +for i, sample in enumerate(test_data.test_dataloader()): + # Plot the phase image + phase_image = sample["source"] + channel_image = phase_image[0, 0, 0] + p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[0].imshow(channel_image, cmap="gray") + axes[0].axis("off") + axes[0].set_title(channel_titles[0]) + + with torch.inference_mode(): # turn off gradient computation. + predicted_image = phase2fluor_model(phase_image.to(phase2fluor_model.device)).cpu().numpy().squeeze(0) + + target_image = sample["target"].cpu().numpy().squeeze(0) + phase_raw = process_image(phase_image[0, 0, 0]) + predicted_nuclei = process_image(predicted_image[0, 0]) + predicted_membrane = process_image(predicted_image[1, 0]) + target_nuclei = process_image(target_image[0, 0]) + target_membrane = process_image(target_image[1, 0]) + # Concatenate all images side by side + combined_image = np.concatenate( + ( + phase_raw, + predicted_nuclei, + predicted_membrane, + target_nuclei, + target_membrane, + ), + axis=1, + ) + + # Plot the phase,target nuclei, target membrane, predicted nuclei, predicted membrane + axes[1].imshow(target_nuclei, cmap="gray") + axes[2].imshow(target_membrane, cmap="gray") + axes[3].imshow(predicted_nuclei, cmap="gray") + axes[4].imshow(predicted_membrane, cmap="gray") + + for ax in axes: + ax.axis("off") + plt.tight_layout() + plt.show() + break +# %% [markdown] tags=[] +#
+ +#

Task 2.2 Loading the pretrained model VSCyto2D

+# Here we will compare your model with the VSCyto2D pretrained model by computing the pixel-based metrics and segmentation-based metrics. +# +#
    +#
  • The pretrained checkpoint was downloaded by setup.sh to +# ~/data/06_image_translation/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt +# — if missing, download it directly from +# public.czbiohub.org. +# Check with ls ~/data/06_image_translation/pretrained_models/VSCyto2D/.
  • +#
  • Load the VSCyto2D model checkpoint and the configuration file
  • +#
  • Compute the pixel-based metrics and segmentation-based metrics between the model you trained and the pretrained model
  • +#
+#
+ +#
+ + +# %% tags=["task"] +################# +##### TODO ###### +################# +# Let's load the pretrained model checkpoint +pretrained_model_ckpt = top_dir / ... ## Add the path to the "VSCyto2D/epoch=399-step=23200.ckpt" + +# TODO: Load the phase2fluor_config just like the model you trained +phase2fluor_config = dict() ## + +# TODO: Load the checkpoint. Write the architecture name. HINT: look at the previous config. +pretrained_phase2fluor = VSUNet.load_from_checkpoint( + pretrained_model_ckpt, + architecture=..., + model_config=phase2fluor_config, + accelerator="gpu", +) +# TODO: Setup the dataloader in evaluation/predict mode +# + +# %% tags=["solution"] +# ####################### +# ##### SOLUTION ######## +# ####################### + +pretrained_model_ckpt = top_dir / "06_image_translation/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt" + +phase2fluor_config = dict( + in_channels=1, + out_channels=2, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + pretraining=False, +) +# Load the model checkpoint +pretrained_phase2fluor = VSUNet.load_from_checkpoint( + pretrained_model_ckpt, + architecture="UNeXt2_2D", + model_config=phase2fluor_config, + accelerator="gpu", +) +pretrained_phase2fluor.eval() + +### Re-load your trained model +# NOTE: assuming the latest checkpoint it your latest training and model +phase2fluor_model_ckpt = natsorted( + glob(str(training_top_dir / "06_image_translation/logs/phase2fluor/version*/checkpoints/*.ckpt")) +)[-1] + +# NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything. +# Uncomment the next lines +# phase2fluor_model_ckpt = natsorted(glob( +# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt") +# ))[-1] + +phase2fluor_config = dict( + in_channels=1, + out_channels=2, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + pretraining=False, +) +# Load the model checkpoint +phase2fluor_model = VSUNet.load_from_checkpoint( + phase2fluor_model_ckpt, + architecture="UNeXt2_2D", + model_config=phase2fluor_config, + accelerator="gpu", +) +phase2fluor_model.eval() +# %% [markdown] tags=[] +#
+#

Question

+# 1. Can we evaluate a model's performance based on their segmentations?
+# 2. Look up IoU or Jaccard index, dice coefficient, and AP metrics. LINK:https://metrics-reloaded.dkfz.de/metric-library
+# We will evaluate the performance of your trained model with a pre-trained model using pixel based metrics as above and +# segmantation based metrics including (mAP@0.5, dice, accuracy and jaccard index).
+#
+# %% [markdown] tags=["solution"] +# +# - IoU (Intersection over Union): Also referred to as the Jaccard index, is essentially a method to quantify the percent overlap between the target and predicted masks. +# It is calculated as the intersection of the target and predicted masks divided by the union of the target and predicted masks.
+# - Dice Coefficient: Metric used to evaluate the similarity between two sets.
+# It is calculated as twice the intersection of the target and predicted masks divided by the sum of the target and predicted masks.
+# - mAP (mean Average Precision): The mean Average Precision (mAP) is a metric used to evaluate the performance of object detection models. +# It is calculated as the average precision across all classes and is used to measure the accuracy of the model in localizing objects. +# +# %% [markdown] tags=[] +# ### Let's compute the metrics for the test dataset +# Before you run the following code, make sure you have the pretrained model loaded and the test data is ready. + +# The following code will compute the following: +# - the pixel-based metrics (pearson correlation, SSIM) +# - segmentation-based metrics (mAP@0.5, dice, accuracy, jaccard index) + + +# #### Note: +# - The segmentation-based metrics are computed using the cellpose stock `nuclei` model +# - The metrics will be store in the `test_pixel_metrics` and `test_segmentation_metrics` dataframes +# - The segmentations will be stored in the `segmentation_store` zarr file +# - Analyze the code while it runs. +# %% +# Create cellpose model once for reuse +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +cellpose_model = models.CellposeModel(gpu=True if device.type == "cuda" else False, device=device) + + +# Define the function to compute the cellpose segmentation +def cellpose_segmentation(prediction: ArrayLike, target: ArrayLike) -> Tuple[torch.ShortTensor]: + # NOTE these are hardcoded for this notebook and A549 dataset + + # Convert 2D arrays to 3D format expected by cellpose v4.0.1+ + # Add channel dimension and replicate to 3 channels (RGB format) + if prediction.ndim == 2: + prediction = np.tile(prediction, (3, 1, 1)) # Shape: (3, H, W) + if target.ndim == 2: + target = np.tile(target, (3, 1, 1)) # Shape: (3, H, W) + + cp_nuc_kwargs = { + "diameter": 65, + "cellprob_threshold": 0.0, + } + + pred_label, _, _ = cellpose_model.eval(prediction, **cp_nuc_kwargs) + target_label, _, _ = cellpose_model.eval(target, **cp_nuc_kwargs) + + pred_label = pred_label.astype(np.int32) + target_label = target_label.astype(np.int32) + pred_label = torch.ShortTensor(pred_label) + target_label = torch.ShortTensor(target_label) + + return (pred_label, target_label) + + +# %% +# Setting the paths for the test data and the output segmentation +test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr" +output_segmentation_path = training_top_dir / "06_image_translation/pretrained_model_segmentations.zarr" + +# Creating the dataframes to store the pixel and segmentation metrics +test_pixel_metrics = pd.DataFrame( + columns=["model", "fov", "pearson_nuc", "microSSIM_nuc", "pearson_mem", "microSSIM_mem"] +) +test_segmentation_metrics = pd.DataFrame( + columns=[ + "model", + "fov", + "masks_per_fov", + "accuracy", + "dice", + "jaccard", + "mAP", + "mAP_50", + "mAP_75", + "mAR_100", + ] +) +# Opening the test dataset +test_dataset = open_ome_zarr(test_data_path) + +# Creating an output store for the predictions and segmentations +segmentation_store = open_ome_zarr( + output_segmentation_path, + channel_names=["nuc_pred", "mem_pred", "nuc_labels"], + mode="w", + layout="hcs", +) + +# Looking at the test dataset +print("Test dataset:") +test_dataset.print_tree() +channel_names = test_dataset.channel_names +print(f"Channel names: {channel_names}") + +# Finding the channel indices for the corresponding channel names +phase_cidx = channel_names.index("Phase3D") +nuc_cidx = channel_names.index("Nucl") +mem_cidx = channel_names.index("Mem") +nuc_label_cidx = channel_names.index("nuclei_segmentation") + + +# %% +def min_max_scale(image: ArrayLike) -> ArrayLike: + "Normalizing the image using min-max scaling" + min_val = image.min() + max_val = image.max() + return (image - min_val) / (max_val - min_val) + + +# %% [markdown] +# ## Visualize segmentation comparison: Fluorescence vs Virtual Staining vs Pretrained +# Let's compare nucleus and membrane segmentation across all three models + +# %% +# Get a sample FOV for visualization +positions = list(test_dataset.positions()) +sample_fov, sample_pos = positions[0] # Use first FOV as example + +T, C, Z, Y, X = sample_pos.data.shape +Z_slice = slice(Z // 2, Z // 2 + 1) + +# Get the data +sample_phase = sample_pos.data[:, phase_cidx : phase_cidx + 1, Z_slice] +sample_nucleus = sample_pos.data[0, nuc_cidx : nuc_cidx + 1, Z_slice] +sample_membrane = sample_pos.data[0, mem_cidx : mem_cidx + 1, Z_slice] + +# Crop 300x300 pixels from center +center_y, center_x = sample_nucleus.shape[2] // 2, sample_nucleus.shape[3] // 2 +crop_size = 300 +y_start = max(0, center_y - crop_size // 2) +y_end = min(sample_nucleus.shape[2], center_y + crop_size // 2) +x_start = max(0, center_x - crop_size // 2) +x_end = min(sample_nucleus.shape[3], center_x + crop_size // 2) + +# Crop fluorescence data +sample_nucleus_crop = min_max_scale(sample_nucleus[0, 0, y_start:y_end, x_start:x_end]) +sample_membrane_crop = min_max_scale(sample_membrane[0, 0, y_start:y_end, x_start:x_end]) + +# Generate virtual stained data from phase (trained model) +sample_phase_tensor = torch.tensor(sample_phase, dtype=torch.float32).to(device) +with torch.inference_mode(): + predicted_image = phase2fluor_model(sample_phase_tensor) +predicted_nuc_crop = min_max_scale(predicted_image.cpu().numpy()[0, 0, 0, y_start:y_end, x_start:x_end]) +predicted_mem_crop = min_max_scale(predicted_image.cpu().numpy()[0, 1, 0, y_start:y_end, x_start:x_end]) + +# Generate virtual stained data from pretrained model +with torch.inference_mode(): + predicted_image_pretrained = pretrained_phase2fluor(sample_phase_tensor) +predicted_nuc_pretrained_crop = min_max_scale( + predicted_image_pretrained.cpu().numpy()[0, 0, 0, y_start:y_end, x_start:x_end] +) +predicted_mem_pretrained_crop = min_max_scale( + predicted_image_pretrained.cpu().numpy()[0, 1, 0, y_start:y_end, x_start:x_end] +) + +# Run segmentation on all nuclei +fluor_nuc_seg, _ = cellpose_segmentation(sample_nucleus_crop, sample_nucleus_crop) +virtual_nuc_seg, _ = cellpose_segmentation(predicted_nuc_crop, predicted_nuc_crop) +pretrained_nuc_seg, _ = cellpose_segmentation(predicted_nuc_pretrained_crop, predicted_nuc_pretrained_crop) + +# Run segmentation on all membranes (using nucleus parameters for consistency) +fluor_mem_seg, _ = cellpose_segmentation(sample_membrane_crop, sample_membrane_crop) +virtual_mem_seg, _ = cellpose_segmentation(predicted_mem_crop, predicted_mem_crop) +pretrained_mem_seg, _ = cellpose_segmentation(predicted_mem_pretrained_crop, predicted_mem_pretrained_crop) + +# Convert to numpy +fluor_nuc_seg = fluor_nuc_seg.numpy() +virtual_nuc_seg = virtual_nuc_seg.numpy() +pretrained_nuc_seg = pretrained_nuc_seg.numpy() +fluor_mem_seg = fluor_mem_seg.numpy() +virtual_mem_seg = virtual_mem_seg.numpy() +pretrained_mem_seg = pretrained_mem_seg.numpy() + +# Create 3x4 visualization +fig, axes = plt.subplots(3, 4, figsize=(16, 12)) + +# Row 1: Fluorescence data +axes[0, 0].imshow(sample_nucleus_crop, cmap="gray") +axes[0, 0].set_title("Fluorescence Nucleus") +axes[0, 0].axis("off") + +fluor_nuc_overlay = label2rgb(fluor_nuc_seg, sample_nucleus_crop, bg_label=0) +axes[0, 1].imshow(fluor_nuc_overlay) +axes[0, 1].set_title("Nucleus Segmentation") +axes[0, 1].axis("off") + +axes[0, 2].imshow(sample_membrane_crop, cmap="gray") +axes[0, 2].set_title("Fluorescence Membrane") +axes[0, 2].axis("off") + +fluor_mem_overlay = label2rgb(fluor_mem_seg, sample_membrane_crop, bg_label=0) +axes[0, 3].imshow(fluor_mem_overlay) +axes[0, 3].set_title("Membrane Segmentation") +axes[0, 3].axis("off") + +# Row 2: Virtual stained data (trained) +axes[1, 0].imshow(predicted_nuc_crop, cmap="gray") +axes[1, 0].set_title("Virtual Nucleus (Trained)") +axes[1, 0].axis("off") + +virtual_nuc_overlay = label2rgb(virtual_nuc_seg, predicted_nuc_crop, bg_label=0) +axes[1, 1].imshow(virtual_nuc_overlay) +axes[1, 1].set_title("Nucleus Segmentation") +axes[1, 1].axis("off") + +axes[1, 2].imshow(predicted_mem_crop, cmap="gray") +axes[1, 2].set_title("Virtual Membrane (Trained)") +axes[1, 2].axis("off") + +virtual_mem_overlay = label2rgb(virtual_mem_seg, predicted_mem_crop, bg_label=0) +axes[1, 3].imshow(virtual_mem_overlay) +axes[1, 3].set_title("Membrane Segmentation") +axes[1, 3].axis("off") + +# Row 3: Virtual stained data (pretrained) +axes[2, 0].imshow(predicted_nuc_pretrained_crop, cmap="gray") +axes[2, 0].set_title("Virtual Nucleus (Pretrained)") +axes[2, 0].axis("off") + +pretrained_nuc_overlay = label2rgb(pretrained_nuc_seg, predicted_nuc_pretrained_crop, bg_label=0) +axes[2, 1].imshow(pretrained_nuc_overlay) +axes[2, 1].set_title("Nucleus Segmentation") +axes[2, 1].axis("off") + +axes[2, 2].imshow(predicted_mem_pretrained_crop, cmap="gray") +axes[2, 2].set_title("Virtual Membrane (Pretrained)") +axes[2, 2].axis("off") + +pretrained_mem_overlay = label2rgb(pretrained_mem_seg, predicted_mem_pretrained_crop, bg_label=0) +axes[2, 3].imshow(pretrained_mem_overlay) +axes[2, 3].set_title("Membrane Segmentation") +axes[2, 3].axis("off") + +plt.suptitle(f"Complete Segmentation Comparison - FOV: {sample_fov}", fontsize=16) +plt.tight_layout() +plt.show() + +print("Nucleus segmentation counts:") +print(f" Fluorescence: {len(np.unique(fluor_nuc_seg)) - 1} nuclei") +print(f" Virtual (trained): {len(np.unique(virtual_nuc_seg)) - 1} nuclei") +print(f" Virtual (pretrained): {len(np.unique(pretrained_nuc_seg)) - 1} nuclei") + +print("\nMembrane segmentation counts:") +print(f" Fluorescence: {len(np.unique(fluor_mem_seg)) - 1} objects") +print(f" Virtual (trained): {len(np.unique(virtual_mem_seg)) - 1} objects") +print(f" Virtual (pretrained): {len(np.unique(pretrained_mem_seg)) - 1} objects") + +# %% [markdown] +# Now let's compute metrics across all FOVs + +# %% +# Iterating through the test dataset positions to: +total_positions = len(positions) + +# Initializing the progress bar with the total number of positions +with tqdm(total=total_positions, desc="Processing FOVs") as pbar: + # Iterating through the test dataset positions + for fov, pos in positions: + T, C, Z, Y, X = pos.data.shape + Z_slice = slice(Z // 2, Z // 2 + 1) + # Getting the arrays and the center slices + phase_image = pos.data[:, phase_cidx : phase_cidx + 1, Z_slice] + target_nucleus = pos.data[0, nuc_cidx : nuc_cidx + 1, Z_slice] + target_membrane = pos.data[0, mem_cidx : mem_cidx + 1, Z_slice] + target_nuc_label = pos.data[0, nuc_label_cidx : nuc_label_cidx + 1, Z_slice] + + # normalize the phase + phase_image = normalize_fov(phase_image) + + # Running the prediction for both models + phase_image = torch.from_numpy(phase_image).type(torch.float32) + phase_image = phase_image.to(phase2fluor_model.device) + with torch.inference_mode(): # turn off gradient computation. + predicted_image_phase2fluor = phase2fluor_model(phase_image) + predicted_image_pretrained = pretrained_phase2fluor(phase_image) + + # Loading and Normalizing the target and predictions for both models + predicted_image_phase2fluor = predicted_image_phase2fluor.cpu().numpy().squeeze(0) + predicted_image_pretrained = predicted_image_pretrained.cpu().numpy().squeeze(0) + phase_image = phase_image.cpu().numpy().squeeze(0) + + target_mem = min_max_scale(target_membrane[0, 0]) + target_nuc = min_max_scale(target_nucleus[0, 0]) + + # Normalizing the dataset using min-max scaling + predicted_mem_phase2fluor = min_max_scale(predicted_image_phase2fluor[1, :, :, :].squeeze(0)) + predicted_nuc_phase2fluor = min_max_scale(predicted_image_phase2fluor[0, :, :, :].squeeze(0)) + + predicted_mem_pretrained = min_max_scale(predicted_image_pretrained[1, :, :, :].squeeze(0)) + predicted_nuc_pretrained = min_max_scale(predicted_image_pretrained[0, :, :, :].squeeze(0)) + + ####### Pixel-based Metrics ############ + # Compute microSSIM and Pearson correlation for phase2fluor_model + pbar.set_description(f"Processing FOV {fov} - Computing Pixel Metrics") + pbar.refresh() + ssim_nuc_phase2fluor = micro_structural_similarity(target_nuc, predicted_nuc_phase2fluor) + ssim_mem_phase2fluor = micro_structural_similarity(target_mem, predicted_mem_phase2fluor) + pearson_nuc_phase2fluor = np.corrcoef(target_nuc.flatten(), predicted_nuc_phase2fluor.flatten())[0, 1] + pearson_mem_phase2fluor = np.corrcoef(target_mem.flatten(), predicted_mem_phase2fluor.flatten())[0, 1] + + test_pixel_metrics.loc[len(test_pixel_metrics)] = { + "model": "phase2fluor", + "fov": fov, + "pearson_nuc": pearson_nuc_phase2fluor, + "microSSIM_nuc": ssim_nuc_phase2fluor, + "pearson_mem": pearson_mem_phase2fluor, + "microSSIM_mem": ssim_mem_phase2fluor, + } + # Compute microSSIM and Pearson correlation for pretrained_model + ssim_nuc_pretrained = micro_structural_similarity(target_nuc, predicted_nuc_pretrained) + ssim_mem_pretrained = micro_structural_similarity(target_mem, predicted_mem_pretrained) + pearson_nuc_pretrained = np.corrcoef(target_nuc.flatten(), predicted_nuc_pretrained.flatten())[0, 1] + pearson_mem_pretrained = np.corrcoef(target_mem.flatten(), predicted_mem_pretrained.flatten())[0, 1] + + test_pixel_metrics.loc[len(test_pixel_metrics)] = { + "model": "pretrained_phase2fluor", + "fov": fov, + "pearson_nuc": pearson_nuc_pretrained, + "microSSIM_nuc": ssim_nuc_pretrained, + "pearson_mem": pearson_mem_pretrained, + "microSSIM_mem": ssim_mem_pretrained, + } + + ###### Segmentation based metrics ######### + # Load the manually curated nuclei target label + pbar.set_description(f"Processing FOV {fov} - Computing Segmentation Metrics") + pbar.refresh() + pred_label, target_label = cellpose_segmentation(predicted_nuc_phase2fluor, target_nucleus) + # Binary labels + pred_label_binary = pred_label > 0 + target_label_binary = target_label > 0 + + # Use Coco metrics to get mean average precision + coco_metrics = mean_average_precision(pred_label, target_label) + # Find unique number of labels + num_masks_fov = len(np.unique(pred_label)) + + test_segmentation_metrics.loc[len(test_segmentation_metrics)] = { + "model": "phase2fluor", + "fov": fov, + "masks_per_fov": num_masks_fov, + "accuracy": accuracy(pred_label_binary, target_label_binary, task="binary").item(), + "dice": dice_score( + pred_label_binary.long()[None], + target_label_binary.long()[None], + num_classes=2, + input_format="index", + average="micro", + ).item(), + "jaccard": jaccard_index(pred_label_binary, target_label_binary, task="binary").item(), + "mAP": coco_metrics["map"].item(), + "mAP_50": coco_metrics["map_50"].item(), + "mAP_75": coco_metrics["map_75"].item(), + "mAR_100": coco_metrics["mar_100"].item(), + } + + pred_label, target_label = cellpose_segmentation(predicted_nuc_pretrained, target_nucleus) + + # Binary labels + pred_label_binary = pred_label > 0 + target_label_binary = target_label > 0 + + # Use Coco metrics to get mean average precision + coco_metrics = mean_average_precision(pred_label, target_label) + # Find unique number of labels + num_masks_fov = len(np.unique(pred_label)) + + test_segmentation_metrics.loc[len(test_segmentation_metrics)] = { + "model": "phase2fluor_pretrained", + "fov": fov, + "masks_per_fov": num_masks_fov, + "accuracy": accuracy(pred_label_binary, target_label_binary, task="binary").item(), + "dice": dice_score( + pred_label_binary.long()[None], + target_label_binary.long()[None], + num_classes=2, + input_format="index", + average="micro", + ).item(), + "jaccard": jaccard_index(pred_label_binary, target_label_binary, task="binary").item(), + "mAP": coco_metrics["map"].item(), + "mAP_50": coco_metrics["map_50"].item(), + "mAP_75": coco_metrics["map_75"].item(), + "mAR_100": coco_metrics["mar_100"].item(), + } + + # Save the predictions and segmentations + position = segmentation_store.create_position(*Path(fov).parts[-3:]) + output_array = np.zeros((T, 3, 1, Y, X), dtype=np.float32) + output_array[0, 0, 0] = predicted_nuc_pretrained + output_array[0, 1, 0] = predicted_mem_pretrained + output_array[0, 2, 0] = np.array(pred_label) + position.create_image("0", output_array) + + # Update the progress bar + pbar.set_description("Processing FOVs") + pbar.update(1) + +# Close the OME-Zarr files +test_dataset.close() +segmentation_store.close() +# %% +# Save the test metrics into a dataframe +pixel_metrics_path = training_top_dir / "06_image_translation/VS_metrics_pixel.csv" +segmentation_metrics_path = training_top_dir / "06_image_translation/VS_metrics_segments.csv" +test_pixel_metrics.to_csv(pixel_metrics_path) +test_segmentation_metrics.to_csv(segmentation_metrics_path) + +# %% [markdown] tags=[] +#
+ +#

Task 2.3 Compare the model's metrics

+# In the previous section, we computed the pixel-based metrics and segmentation-based metrics. +# Now we will compare the performance of the model you trained with the pretrained model by plotting the boxplots. + +# After you plot the metrics answer the following: +#
    +#
  • What do these metrics tells us about the performance of the model?
  • +#
  • How do you interpret the differences in the metrics between the models?
  • +#
  • How is your model compared to the pretrained model? How can you improve it?
  • +#
+#
+ +# %% +# Show boxplot of the metrics +# Boxplot of the metrics +test_pixel_metrics.boxplot( + by="model", + column=["pearson_nuc", "microSSIM_nuc", "pearson_mem", "microSSIM_mem"], + rot=30, + figsize=(8, 8), +) +plt.suptitle("Model Pixel Metrics") +plt.show() +# Show boxplot of the metrics +# Boxplot of the metrics +test_segmentation_metrics.boxplot( + by="model", + column=["jaccard", "accuracy", "mAP_75", "mAP_50"], + rot=30, + figsize=(8, 8), +) +plt.suptitle("Model Segmentation Metrics") +plt.show() + +# %% [markdown] tags=["task"] +#
+#

Questions

+#
    +#
  • What do these metrics tells us about the performance of the model?
  • +#
  • How do you interpret the differences in the metrics between the models?
  • +#
  • How is your model compared to the pretrained model? How can you improve it?
  • +#
+#
+ +# %% [markdown] +# ### Plotting the predictions and segmentations +#
+# +#

Task 2.4: Visualize the predictions and segmentations

+# Here we will plot the predictions and segmentations side by side for the pretrained and trained models.
+#
    +#
  • How does your model, the pretrained model and the ground truth compare?
  • +#
  • How do the segmentations compare?
  • +#
+# Feel free to modify the crop size and Y,X slicing to view different areas of the FOV +#
+# %% tags=["task"] + +# Get the shape of the 2D image +Y, X = phase_image.shape[-2:] +######## TODO ########## +# Modify the crop size and Y,X slicing to view different areas of the FOV + +crop = 256 +y_slice = slice(Y // 2 - crop // 2, Y // 2 + crop // 2) +x_slice = slice(X // 2 - crop // 2, X // 2 + crop // 2) +####################### +# Plotting side by side comparisons +fig, axs = plt.subplots(4, 3, figsize=(15, 20)) + +# First row: phase_image, target_nuc, target_mem +axs[0, 0].imshow(phase_image[0, 0, y_slice, x_slice], cmap="gray") +axs[0, 0].set_title("Phase Image") +axs[0, 1].imshow(target_nuc[y_slice, x_slice], cmap="gray") +axs[0, 1].set_title("Target Nucleus") +axs[0, 2].imshow(target_mem[y_slice, x_slice], cmap="gray") +axs[0, 2].set_title("Target Membrane") + +# Second row: target_nuc, pred_nuc_phase2fluor, pred_nuc_pretrained +axs[1, 0].imshow(target_nuc[y_slice, x_slice], cmap="gray") +axs[1, 0].set_title("Target Nucleus") +axs[1, 1].imshow(predicted_nuc_phase2fluor[y_slice, x_slice], cmap="gray") +axs[1, 1].set_title("Pred Nucleus Phase2Fluor") +axs[1, 2].imshow(predicted_nuc_pretrained[y_slice, x_slice], cmap="gray") +axs[1, 2].set_title("Pred Nucleus Pretrained") + +# Third row: target_mem, pred_mem_phase2fluor, pred_mem_pretrained +axs[2, 0].imshow(target_mem[y_slice, x_slice], cmap="gray") +axs[2, 0].set_title("Target Membrane") +axs[2, 1].imshow(predicted_mem_phase2fluor[y_slice, x_slice], cmap="gray") +axs[2, 1].set_title("Pred Membrane Phase2Fluor") +axs[2, 2].imshow(predicted_mem_pretrained[y_slice, x_slice], cmap="gray") +axs[2, 2].set_title("Pred Membrane Pretrained") + +# Fourth row: target_nuc, segment_nuc, segment_nuc2 +axs[3, 0].imshow(target_nuc[y_slice, x_slice], cmap="gray") +axs[3, 0].set_title("Target Nucleus") +axs[3, 1].imshow(label2rgb(np.array(target_label[y_slice, x_slice], dtype="int")), cmap="gray") +axs[3, 1].set_title("Segmented Nucleus (Target)") +axs[3, 2].imshow(label2rgb(np.array(pred_label[y_slice, x_slice], dtype="int")), cmap="gray") +axs[3, 2].set_title("Segmented Nucleus") + +# Hide axes ticks +for ax in axs.flat: + ax.set_xticks([]) + ax.set_yticks([]) + +plt.tight_layout() +plt.show() + + +# %% [markdown] tags=[] +#
+ +#

Checkpoint 2

+# +# Congratulations! You have completed the second checkpoint. You have: +# - Visualized the predictions and segmentations of the model.
+# - Evaluated the performance of the model using pixel-based metrics and segmentation-based metrics.
+# - Compared the performance of the model you trained with the pretrained model.
+# +#
+ +# %% [markdown] tags=[] +#
+# +# ### Task 2.5: Evaluate a fluorescence to phase model +# In this section, we will explore the inverse transformation using fluorescence images +# (nuclei + membrane) to predict the phase image. +# +#

Learning Goals:

+#
    +#
  • Understand the concept of fluorescence to phase transformations in image translation
  • +#
  • Load a pretrained model for the reverse task (fluor → phase)
  • +#
  • Compare input fluorescence channels with predicted phase
  • +#
  • Analyze why the phase prediction is not perfect
  • +#
+# We'll use a pretrained model that was trained to predict phase from fluorescence channels. +#
+ +# %% [markdown] tags=[] +#
+# +#

Questions

+#
    +#
  • How much information is lost in the phase to fluorescence transformation?
  • +#
  • Why might perfect reconstruction not be possible?
  • +#
  • Can multiple phase patterns produce similar fluorescence signals?
  • +#
+#
+ +# %% +# Path to the pretrained fluorescence to phase model checkpoint +fluor2phase_model_path = top_dir / "06_image_translation/pretrained_models/AIMBL_Demo/fluor2phase_step668.ckpt" + + +# %% tags=["task"] +# Load a pretrained model for fluorescence to phase translation +from pathlib import Path + +import torch + +# ####################### +# ##### TODO ######## +# ####################### +# TODO: Load the pretrained fluorescence to phase model +# HINT: Look for pretrained models in the VisCy repository or use a model checkpoint +# HINT: The model should take 2 input channels (nuclei + membrane) and output 1 channel (phase) +# HINT: Use similar architecture as before but with different input/output channels + +# For now, we'll create a placeholder - replace with actual model loading +print("Loading pretrained fluorescence-to-phase model...") + +# TODO: Replace this with actual model loading code +fluor2phase_config = dict( + in_channels=..., # Nuclei + Membrane channels + out_channels=..., # Phase channel + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, +) +fluor2phase_model = VSUNet.load_from_checkpoint( + fluor2phase_model_path, model_config=fluor2phase_config, architecture="fcmae" +) +assert fluor2phase_model is not None, ( + "Fluorescence to phase model not loaded. Check the model config,and the path to the model checkpoint." +) +fluor2phase_model.eval() + +# %% tags=["task"] +# Test the fluorescence to phase model on our test data + +source_channel_fluor = ["TODO", "TODO"] +target_channel_labelfree = ["TODO"] + +test_data_fluor2phase = HCSDataModule( + test_data_path, + source_channel=source_channel_fluor, + target_channel=target_channel_labelfree, + z_window_size=1, + batch_size=1, + num_workers=8, +) +test_data_fluor2phase.setup("test") + + +# Get a test sample +sample = next(iter(test_data_fluor2phase.test_dataloader())) + +# ####################### +# ##### TODO ######## +# ####################### +# TODO: Extract the input channels (fluorescence) and target (phase) +# HINT: Print the keys of the `sample` dictionary +# HINT: Input should be nuclei and membrane channels concatenated +# HINT: Target should be the original phase image + +fluor_input = ... # TODO: Source +target_phase = ... # TODO: Target + +# TODO: Make prediction with the fluorescence to phase model +# NOTE: The `fluor2phase_model`, returns a tuple. Select the first item with `[0]` +with torch.inference_mode(): + predicted_phase = ... + +# ####################### +# ##### TODO ######## +# ####################### +# Calculate metrics between predicted and target phase +# HINT: Use SSIM and Pearson correlation as before + +# TODO: Normalize data range to 0-1 +###### YOUR CODE HERE ###### + +# TODO: Calculate SSIM and Pearson correlation +###### YOUR CODE HERE ###### + +# TODO: Print metrics +print("Phase Reconstruction Metrics:") +print(f"SSIM: {ssim_phase:.3f}") +print(f"Pearson Correlation: {pearson_phase:.3f}") + + +# %% tags=["solution"] +# Load a pretrained model for fluorescence to phase translation +from pathlib import Path + +import torch + +# Load the pretrained fluorescence to phase model +print("Loading pretrained fluorescence-to-phase model...") + +# Note: This assumes a pretrained model is available. In practice, you would: +# 1. Download from VisCy releases or train your own +# 2. Adjust the path accordingly + +# For demonstration, we'll create a model with the correct architecture +fluor2phase_config = dict( + in_channels=2, # Nuclei + Membrane channels + out_channels=1, # Phase channel + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, +) + +# Create the fluorescence to phase model architecture +print("Fluorescence-to-phase model created (note: using untrained model for demonstration)") +print("In practice, load a pretrained checkpoint for meaningful results") + +print("\nLoading pretrained fluorescence-to-phase model...") +fluor2phase_model_path = top_dir / "06_image_translation/pretrained_models/AIMBL_Demo/fluor2phase_step668.ckpt" +assert fluor2phase_model_path.exists(), "Fluorescence-to-phase model checkpoint not found. Please check the path." +fluor2phase_model = VSUNet.load_from_checkpoint( + fluor2phase_model_path, model_config=fluor2phase_config, architecture="fcmae" +) +fluor2phase_model.eval() + +# %% tags=["solution"] +# Test the fluorescence to phase model on our test data + +source_channel_fluor = ["Nucl", "Mem"] +target_channel_labelfree = ["Phase3D"] + +test_data_fluor2phase = HCSDataModule( + test_data_path, + source_channel=source_channel_fluor, + target_channel=target_channel_labelfree, + z_window_size=1, + batch_size=1, + num_workers=8, +) +test_data_fluor2phase.setup("test") + +# Get a test sample +sample = next(iter(test_data_fluor2phase.test_dataloader())) + +# Extract input channels (fluorescence nuclei and membrane) and target (phase) +fluor_input = sample["source"].to(fluor2phase_model.device) +target_image = sample["target"].cpu().numpy().squeeze(0) + +# Run inference +with torch.inference_mode(): + predicted_phase = fluor2phase_model(fluor_input)[0] + +fluor_input = fluor_input.cpu().numpy() +predicted_image = predicted_phase.cpu().numpy().squeeze(0) +target_phase = rescale_intensity(target_image[0, 0], out_range=(0, 1)) +predicted_phase = rescale_intensity(predicted_image[0, 0], out_range=(0, 1)) +ssim_phase = metrics.structural_similarity(target_phase, predicted_phase, data_range=1) +pearson_phase = np.corrcoef(target_phase.flatten(), predicted_phase.flatten())[0, 1] + +print("Phase Reconstruction Metrics:") +print(f"SSIM: {ssim_phase:.3f}") +print(f"Pearson Correlation: {pearson_phase:.3f}") + +# %% +# Visualize the fluorescence to phase transformation results +# TODO: Visualize the fluorescence to phase transformation results. Modify is as you see fit. + +fig, axs = plt.subplots(2, 3, figsize=(15, 10)) + +axs[0, 0].imshow(fluor_input[0, 0, 0], cmap="gray") +axs[0, 0].set_title("Input: Nuclei Channel") +axs[0, 1].imshow(fluor_input[0, 1, 0], cmap="gray") +axs[0, 1].set_title("Input: Membrane Channel") +axs[0, 2].imshow(fluor_input[0, 0, 0] + fluor_input[0, 1, 0], cmap="gray") +axs[0, 2].set_title("Combined Fluorescence\n(Nuclei + Membrane)") + +axs[1, 0].imshow(target_phase, cmap="gray") +axs[1, 0].set_title("Target Phase Image") +axs[1, 1].imshow(predicted_phase, cmap="gray") +axs[1, 1].set_title(f"Predicted Phase\nSSIM: {ssim_phase:.3f}") +axs[1, 2].imshow(np.abs(target_phase - predicted_phase), cmap="magma") +axs[1, 2].set_title("Absolute Difference\n|Target - Predicted|") + +for ax in axs.flat: + ax.set_xticks([]) + ax.set_yticks([]) + +plt.tight_layout() +plt.show() + +# %% [markdown] tags=[] +#
+#

Analysis Questions: Why is Phase Reconstruction Imperfect?

+# +# Looking at your results, consider these questions: +# +#
    +#
  • Does the fluorescence image contain all the information needed to reconstruct the phase?
  • +#
  • What structures are visible in phase but not in fluorescence channels?
  • +#
  • Which has higher information content: phase or fluorescence images?
  • +#
  • What does the reconstruction error map tell you about what's difficult to predict?
  • +#
+#
+ +# %% [markdown] tags=[] +#
+#

Key Insights from Fluorescence to Phase Model

+# +# This exploration reveals fundamental limitations in image-to-image translation: +#
    +#
  • Phase images contain rich structural information about unlabeled cellular components
  • +#
  • Fluorescence only captures specific labeled structures (nuclei, membranes,etc.)
  • +#
  • The fluorescence to phase model is an ill-posed problem - multiple phase images could produce similar fluorescence patterns
  • +#
  • Models can only predict based on correlations learned during training
  • +#
  • Structural details not correlated with fluorescence signals cannot be recovered
  • +#
+# +# #### Now, let's return to the `phase2fluor` model! +# +#
+ +# %% [markdown] tags=[] +#
+#

Bonus: Test Time Augmentation (TTA)

+# +# Test Time Augmentation is a technique where you apply multiple augmentations to a single test image, +# make predictions on each augmented version, and then combine the results (usually by averaging). +# +# **In this section we will:** +#
    +#
  • Use `Rotate90d` and `Flipd` for deterministic transformations
  • +#
  • Apply transforms, make predictions, then apply inverse transforms
  • +#
  • Average all predictions to get the final TTA result that is more robust to geometric variations.
  • +#
+# +# Reference: N.Moshkov (2020) https://www.nature.com/articles/s41598-020-61808-3 +# +# Hint: You can use the `Rotate90` and `Flip` transforms from MONAI. +# Example forward transform: `Rotate90(k=1, spatial_axes=(-1, -2))` +# Example inverse transform: `Rotate90(k=3, spatial_axes=(-1, -2))` +# +#
+ +# %% tags=["task"] +from monai.transforms import ( + Flip, + Rotate90, +) + +# Get a test sample +sample = next(iter(test_data.test_dataloader())) +source_tensor = sample["source"].to(phase2fluor_model.device) +target_tensor = sample["target"] +target_nuc = target_tensor[0, 0].cpu().numpy() +target_mem = target_tensor[0, 1].cpu().numpy() + +# Saving the single prediction without TTA for later comparison +with torch.inference_mode(): + single_pred = phase2fluor_model(source_tensor) + single_pred_nuc = single_pred[0, 0].cpu().numpy() + single_pred_mem = single_pred[0, 1].cpu().numpy() + +# TODO: Define TTA transforms using MONAI as a list of tuples (forward, inverse) +###### YOUR CODE HERE ###### +transform_list = [("TODO", "TODO")] + +# TODO: Apply test-time augmentation +# 1. Get original prediction (no augmentation) +# 2. For each transform: +# - Apply transform to input +# - Run inference +# - De-apply transform to prediction +# 3. Average all predictions + +predictions = [] + +for forward_transform, inverse_transform in transform_list: + # Apply transform to each sample in batch + augmented_batch = [] + for i in range(source_tensor.shape[0]): + # Apply the forward and store them + ###### YOUR CODE HERE ###### + aug_img = ... + augmented_batch.append(aug_img) + augmented_source = torch.stack(augmented_batch).to(source_tensor.device) + + # TODO: Run inference on augmented input + with torch.inference_mode(): + ###### YOUR CODE HERE ###### + augmented_pred = ... + + # TODO: De-apply transform to prediction + deaugmented_batch = [] + for i in range(augmented_pred.shape[0]): + ###### YOUR CODE HERE ###### + deaug_pred = ... + deaugmented_pred = torch.stack(deaugmented_batch) + + predictions.append(deaugmented_pred.cpu().numpy()) + +# TODO: Average all predictions or take the median +###### YOUR CODE HERE ###### +averaged_pred = ... + +# TODO: Extract nucleus and membrane predictions +###### YOUR CODE HERE ###### +tta_pred_nuc = ... +tta_pred_mem = ... + +# %% tags=["task"] +# TODO: Compare TTA results with single prediction +# Calculate metrics (SSIM, Pearson correlation) for both approaches. Do not forget to normalize the data range to 0-1. + +# TODO Normalize data range to 0-1 +###### YOUR CODE HERE ###### + +# TODO Calculate metrics +###### YOUR CODE HERE ###### + +# TODO # TTA prediction metrics +###### YOUR CODE HERE ###### + +# Print comparison +print("\nMetrics Comparison:") +print(f"{'Metric':<20} {'Single':<10} {'TTA':<10} {'Improvement':<12}") +print("-" * 55) +print(f"{'SSIM Nucleus':<20} {ssim_nuc_single:.3f} {ssim_nuc_tta:.3f} {ssim_nuc_tta - ssim_nuc_single:+.3f}") +print(f"{'SSIM Membrane':<20} {ssim_mem_single:.3f} {ssim_mem_tta:.3f} {ssim_mem_tta - ssim_mem_single:+.3f}") +print( + f"{'Pearson Nucleus':<20} {pearson_nuc_single:.3f} {pearson_nuc_tta:.3f} {pearson_nuc_tta - pearson_nuc_single:+.3f}" +) +print( + f"{'Pearson Membrane':<20} {pearson_mem_single:.3f} {pearson_mem_tta:.3f} {pearson_mem_tta - pearson_mem_single:+.3f}" +) + +# %% tags=["solution"] + +# Normalize data range to 0-1 +target_nuc[0] = rescale_intensity(target_nuc[0], in_range="image", out_range=(0, 1)) +single_pred_nuc[0] = rescale_intensity(single_pred_nuc[0], in_range="image", out_range=(0, 1)) +target_mem[0] = rescale_intensity(target_mem[0], in_range="image", out_range=(0, 1)) +single_pred_mem[0] = rescale_intensity(single_pred_mem[0], in_range="image", out_range=(0, 1)) +target_nuc[0] = rescale_intensity(target_nuc[0], in_range="image", out_range=(0, 1)) +tta_pred_nuc[0] = rescale_intensity(tta_pred_nuc[0], in_range="image", out_range=(0, 1)) +tta_pred_mem[0] = rescale_intensity(tta_pred_mem[0], in_range="image", out_range=(0, 1)) +tta_pred_nuc[0] = rescale_intensity(tta_pred_nuc[0], in_range="image", out_range=(0, 1)) + +# Calculate metrics +ssim_nuc_single = metrics.structural_similarity(target_nuc[0], single_pred_nuc[0], data_range=1) +ssim_mem_single = metrics.structural_similarity(target_mem[0], single_pred_mem[0], data_range=1) +pearson_nuc_single = np.corrcoef(target_nuc[0].flatten(), single_pred_nuc[0].flatten())[0, 1] +pearson_mem_single = np.corrcoef(target_mem[0].flatten(), single_pred_mem[0].flatten())[0, 1] + +# TTA prediction metrics +ssim_nuc_tta = metrics.structural_similarity(target_nuc[0], tta_pred_nuc[0], data_range=1) +ssim_mem_tta = metrics.structural_similarity(target_mem[0], tta_pred_mem[0], data_range=1) +pearson_nuc_tta = np.corrcoef(target_nuc[0].flatten(), tta_pred_nuc[0].flatten())[0, 1] +pearson_mem_tta = np.corrcoef(target_mem[0].flatten(), tta_pred_mem[0].flatten())[0, 1] + +# Print comparison +print("\nMetrics Comparison:") +print(f"{'Metric':<20} {'Single':<10} {'TTA':<10} {'Improvement':<12}") +print("-" * 55) +print(f"{'SSIM Nucleus':<20} {ssim_nuc_single:.3f} {ssim_nuc_tta:.3f} {ssim_nuc_tta - ssim_nuc_single:+.3f}") +print(f"{'SSIM Membrane':<20} {ssim_mem_single:.3f} {ssim_mem_tta:.3f} {ssim_mem_tta - ssim_mem_single:+.3f}") +print( + f"{'Pearson Nucleus':<20} {pearson_nuc_single:.3f} {pearson_nuc_tta:.3f} {pearson_nuc_tta - pearson_nuc_single:+.3f}" +) +print( + f"{'Pearson Membrane':<20} {pearson_mem_single:.3f} {pearson_mem_tta:.3f} {pearson_mem_tta - pearson_mem_single:+.3f}" +) + +# %% +# TODO: Modify as you see fit to compute the metrics on the full FOV. +# Visualize the comparison +# Modify as you see fit to visualize the results + +fig, axs = plt.subplots(3, 3, figsize=(15, 15)) + +# First row: Input phase and targets +axs[0, 0].imshow(source_tensor[0, 0, 0].cpu().numpy(), cmap="gray") +axs[0, 0].set_title("Input Phase") +axs[0, 1].imshow(target_nuc[0], cmap="gray") +axs[0, 1].set_title("Target Nucleus") +axs[0, 2].imshow(target_mem[0], cmap="gray") +axs[0, 2].set_title("Target Membrane") + +# Second row: Single predictions +axs[1, 0].imshow(source_tensor[0, 0, 0].cpu().numpy(), cmap="gray") +axs[1, 0].set_title("Input Phase") +axs[1, 1].imshow(single_pred_nuc[0], cmap="gray") +axs[1, 1].set_title(f"Single Pred Nucleus\nSSIM: {ssim_nuc_single:.3f}") +axs[1, 2].imshow(single_pred_mem[0], cmap="gray") +axs[1, 2].set_title(f"Single Pred Membrane\nSSIM: {ssim_mem_single:.3f}") + +# Third row: TTA predictions +axs[2, 0].imshow(source_tensor[0, 0, 0].cpu().numpy(), cmap="gray") +axs[2, 0].set_title("Input Phase") +axs[2, 1].imshow(tta_pred_nuc[0], cmap="gray") +axs[2, 1].set_title(f"TTA Pred Nucleus\nSSIM: {ssim_nuc_tta:.3f}") +axs[2, 2].imshow(tta_pred_mem[0], cmap="gray") +axs[2, 2].set_title(f"TTA Pred Membrane\nSSIM: {ssim_mem_tta:.3f}") + +# Remove ticks +for ax in axs.flat: + ax.set_xticks([]) + ax.set_yticks([]) + +plt.tight_layout() +plt.show() + +# %% tags=["solution"] +# Import additional MONAI transforms for TTA + +# Get a test sample +sample = next(iter(test_data.test_dataloader())) +source_tensor = sample["source"].to(phase2fluor_model.device) +target_tensor = sample["target"] +target_nuc = target_tensor[0, 0].cpu().numpy() +target_mem = target_tensor[0, 1].cpu().numpy() + +predictions = [] + +# Original prediction without augmentation +with torch.inference_mode(): + original_pred = phase2fluor_model(source_tensor) + predictions.append(original_pred.cpu().numpy()) + +# Define the TTA transforms and the inverse transforms as a list of tuples (forward, inverse) +transform_list = [ + (Rotate90(k=1, spatial_axes=(-1, -2)), Rotate90(k=3, spatial_axes=(-1, -2))), + (Rotate90(k=2, spatial_axes=(-1, -2)), Rotate90(k=2, spatial_axes=(-1, -2))), + (Rotate90(k=3, spatial_axes=(-1, -2)), Rotate90(k=1, spatial_axes=(-1, -2))), + (Flip(spatial_axis=-2), Flip(spatial_axis=-2)), + (Flip(spatial_axis=-1), Flip(spatial_axis=-1)), +] + +for forward_transform, inverse_transform in transform_list: + # Apply transform to each sample in batch + augmented_batch = [] + for i in range(source_tensor.shape[0]): + img = source_tensor[i].cpu().numpy() + aug_img = forward_transform(img) + augmented_batch.append(aug_img) + augmented_source = torch.stack(augmented_batch).to(source_tensor.device) + + # Run inference on augmented input + with torch.inference_mode(): + augmented_pred = phase2fluor_model(augmented_source) + + # De-apply transform to prediction + deaugmented_batch = [] + for i in range(augmented_pred.shape[0]): + pred = augmented_pred[i].cpu().numpy() + deaug_pred = inverse_transform(pred) + deaugmented_batch.append(deaug_pred) + deaugmented_pred = torch.stack(deaugmented_batch) + + predictions.append(deaugmented_pred.cpu().numpy()) + +# Average all predictions +averaged_pred = np.stack(predictions).mean(axis=0) + +# Extract nucleus and membrane predictions +tta_pred_nuc = averaged_pred[0, 0] +tta_pred_mem = averaged_pred[0, 1] + +# Compare with single prediction (no TTA) +with torch.inference_mode(): + single_pred = phase2fluor_model(source_tensor) + single_pred_nuc = single_pred[0, 0].cpu().numpy() + single_pred_mem = single_pred[0, 1].cpu().numpy() + + +# %% [markdown] tags=[] +#
+# +#

Discussion Questions for Test Time Augmentation

+# +#
    +#
  • Did TTA improve the metrics? By how much?
  • +#
  • What are the trade-offs of using TTA? (hint: think about computation time vs. accuracy)
  • +#
  • When would TTA be most beneficial in fluorescence microscopy?
  • +#
  • How could you modify the TTA strategy to be more effective for this specific virtual staining task?
  • +#
  • What other MONAI transforms could be useful for TTA in this context? (e.g., slight rotations, scaling)
  • +#
  • Is there any hallucinations that are removed with TTA?
  • +#
+#
+ +# %% [markdown] tags=[] +#
+#

Bonus Section Complete!

+# +# You have successfully implemented Test Time Augmentation using MONAI transforms! +# +# Key takeaways: +#
    +#
  • TTA is particularly useful when prediction quality is critical and computational budget allows
  • +#
  • Multiple geometric augmentations can reduce prediction variance and improve robustness
  • +#
  • TTA leverages deterministic transforms (`Rotate90d`, `Flipd`) instead of random ones
  • +#
  • The computational cost increases linearly with the number of TTA transforms
  • +#
+#
+ +# %% [markdown] tags=[] +# # Part 3: Visualizing the encoder and decoder features & exploring the model's range of validity +# +# - In this section, we will visualize the encoder and decoder features of the model you trained. +# - We will also explore the model's range of validity by looking at the feature maps of the encoder and decoder. +# +# %% [markdown] tags=[] +#
+#

Task 3.1: Let's look at what the model is learning

+# +# - If you are unfamiliar with Principal Component Analysis (PCA), you can read up here
+# - Run the next cells. We will visualize the encoder feature maps of the trained model. +# We will use PCA to visualize the feature maps by mapping the first 3 principal components to a colormap `Color`
+# +# +#
+ +# %% +""" +Script to visualize the encoder feature maps of a trained model. +Using PCA to visualize feature maps is inspired by +https://doi.org/10.48550/arXiv.2304.07193 (Oquab et al., 2023). +""" +from typing import NamedTuple # noqa: E402 + +from monai.networks.layers import GaussianFilter # noqa: E402 +from skimage.exposure import rescale_intensity # noqa: E402 +from sklearn.decomposition import PCA # noqa: E402 + + +def feature_map_pca(feature_map: np.array, n_components: int = 8) -> PCA: + """ + Compute PCA on a feature map. + :param np.array feature_map: (C, H, W) feature map + :param int n_components: number of components to keep + :return: PCA: fit sklearn PCA object + """ + # (C, H, W) -> (C, H*W) + feat = feature_map.reshape(feature_map.shape[0], -1) + pca = PCA(n_components=n_components) + pca.fit(feat) + return pca + + +def pcs_to_rgb(feat: np.ndarray, n_components: int = 8) -> np.ndarray: + pca = feature_map_pca(feat[0], n_components=n_components) + pc_first_3 = pca.components_[:3].reshape(3, *feat.shape[-2:]) + return np.stack([rescale_intensity(pc, out_range=(0, 1)) for pc in pc_first_3], axis=-1) + + +# %% +# Load the test dataset +test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr" +test_dataset = open_ome_zarr(test_data_path) + +# Looking at the test dataset +print("Test dataset:") +test_dataset.print_tree() + +# %% [markdown] tags=[] +#
+# +# - Change the `fov` and `crop` size to visualize the feature maps of the encoder and decoder
+# Note: the crop should be a multiple of 384 +#
+# %% +# Load one position +row = 0 +col = 0 +center_index = 2 +n = 1 +crop = 384 * n +fov = 10 + +# normalize phase +norm_meta = test_dataset.zattrs["normalization"]["Phase3D"]["dataset_statistics"] + +# Get the OME-Zarr metadata +Y, X = test_dataset[f"0/0/{fov}"].data.shape[-2:] +test_dataset.channel_names +phase_idx = test_dataset.channel_names.index("Phase3D") +assert crop // 2 < Y and crop // 2 < Y, "Crop size larger than the image. Check the image shape" + +phase_img = test_dataset[f"0/0/{fov}/0"][ + :, + phase_idx : phase_idx + 1, + 0:1, + Y // 2 - crop // 2 : Y // 2 + crop // 2, + X // 2 - crop // 2 : X // 2 + crop // 2, +] +fluo = test_dataset[f"0/0/{fov}/0"][ + 0, + 1:3, + 0, + Y // 2 - crop // 2 : Y // 2 + crop // 2, + X // 2 - crop // 2 : X // 2 + crop // 2, +] + +phase_img = (phase_img - norm_meta["median"]) / norm_meta["iqr"] +plt.imshow(phase_img[0, 0, 0], cmap="gray") + +# %% [markdown] tags=[] +#
+# For the following tasks we will use the pretrained model to extract the encoder and decoder features
+# Extra: If you are done with the whole checkpoint, you can try to look at what your trained model learned. +#
+# %% + +# Loading the pretrained model +pretrained_model_ckpt = top_dir / "06_image_translation/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt" +# model config as before +phase2fluor_config = dict( + in_channels=1, + out_channels=2, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + pretraining=False, +) + +# load model +model = VSUNet.load_from_checkpoint( + pretrained_model_ckpt, + architecture="UNeXt2_2D", + model_config=phase2fluor_config.copy(), + accelerator="gpu", +) + +# %% tags=[] +# Extract features +with torch.inference_mode(): + # encoder + encoder_features = model.model.encoder(torch.from_numpy(phase_img.astype(np.float32)).to(model.device))[0] + encoder_features_np = [f.detach().cpu().numpy() for f in encoder_features] + + # Print the encoder features shapes + for f in encoder_features_np: + print(f.shape) + + # decoder + features = encoder_features.copy() + features.reverse() + feat = features[0] + features.append(None) + decoder_features_np = [] + for skip, stage in zip(features[1:], model.model.decoder.decoder_stages): + feat = stage(feat, skip) + decoder_features_np.append(feat.detach().cpu().numpy()) + for f in decoder_features_np: + print(f.shape) + prediction = model.model.head(feat).detach().cpu().numpy() + + +# Defining the colors for plotting +class Color(NamedTuple): + r: float + g: float + b: float + + +# Defining the colors for plottting the PCA +BOP_ORANGE = Color(0.972549, 0.6784314, 0.1254902) +BOP_BLUE = Color(BOP_ORANGE.b, BOP_ORANGE.g, BOP_ORANGE.r) +GREEN = Color(0.0, 1.0, 0.0) +MAGENTA = Color(1.0, 0.0, 1.0) + + +# Defining the functions to rescale the image and composite the nuclear and membrane images +def rescale_clip(image: torch.Tensor) -> np.ndarray: + return rescale_intensity(image, out_range=(0, 1))[..., None].repeat(3, axis=-1) + + +def composite_nuc_mem(image: torch.Tensor, nuc_color: Color, mem_color: Color) -> np.ndarray: + c_nuc = rescale_clip(image[0]) * nuc_color + c_mem = rescale_clip(image[1]) * mem_color + return rescale_intensity(c_nuc + c_mem, out_range=(0, 1)) + + +def clip_p(image: np.ndarray) -> np.ndarray: + return rescale_intensity(image.clip(*np.percentile(image, [1, 99]))) + + +def clip_highlight(image: np.ndarray) -> np.ndarray: + return rescale_intensity(image.clip(0, np.percentile(image, 99.5))) + + +# Plot the PCA to RGB of the feature maps +f, ax = plt.subplots(10, 1, figsize=(5, 25)) +n_components = 4 +ax[0].imshow(phase_img[0, 0, 0], cmap="gray") +ax[0].set_title(f"Phase {phase_img.shape[1:]}") +ax[-1].imshow(clip_p(composite_nuc_mem(fluo, GREEN, MAGENTA))) +ax[-1].set_title("Fluorescence") + +for level, feat in enumerate(encoder_features_np): + ax[level + 1].imshow(pcs_to_rgb(feat, n_components=n_components)) + ax[level + 1].set_title(f"Encoder stage {level + 1} {feat.shape[1:]}") + +for level, feat in enumerate(decoder_features_np): + ax[5 + level].imshow(pcs_to_rgb(feat, n_components=n_components)) + ax[5 + level].set_title(f"Decoder stage {level + 1} {feat.shape[1:]}") + +pred_comp = composite_nuc_mem(prediction[0, :, 0], BOP_BLUE, BOP_ORANGE) +ax[-2].imshow(clip_p(pred_comp)) +ax[-2].set_title(f"Prediction {prediction.shape[1:]}") + +for a in ax.ravel(): + a.axis("off") +plt.tight_layout() + +# %% [markdown] tags=["task"] +#
+# +# ### Task 3.2: Select a sample batch to test the range of validty of the model +# - Run the next cell to setup the your dataloader for `test`
+# - Select a test batch from the `test_dataloader` by changing the `batch_number`
+# - Examine the plot of the source and target images of the batch
+# +# Note the 2D images have different focus
+#
+ +# %% +YX_PATCH_SIZE = (256 * 2, 256 * 2) +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] + +normalizations = [ + NormalizeSampled( + keys=source_channel, + level="fov_statistics", + subtrahend="mean", + divisor="std", + ), + NormalizeSampled( + keys=target_channel, + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ), +] + +# Re-load the dataloader +phase2fluor_2D_data = HCSDataModule( + data_path, + source_channel=source_channel, + target_channel=target_channel, + z_window_size=1, + split_ratio=0.8, + batch_size=1, + num_workers=8, + yx_patch_size=YX_PATCH_SIZE, + augmentations=[], + normalizations=normalizations, +) +phase2fluor_2D_data.setup("test") +# %% tags=[] +# ########## TODO ############## +batch_number = 3 # Change this to see different batches of data +# ####################### +y_slice = slice(Y // 2 - 256 * n // 2, Y // 2 + 256 * n // 2) +x_slice = slice(X // 2 - 256 * n // 2, X // 2 + 256 * n // 2) + +# Iterate through the test dataloader to get the desired batch +i = 0 +for batch in phase2fluor_2D_data.test_dataloader(): + # break if we reach the desired batch + if i == batch_number - 1: + break + i += 1 + +# Plot the batch source and target images +f, ax = plt.subplots(1, 2, figsize=(8, 12)) +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0].imshow( + batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(), + cmap="gray", + vmin=-15, + vmax=15, +) +ax[1].imshow(clip_highlight(target_composite[0, y_slice, x_slice])) +for a in ax.ravel(): + a.axis("off") +f.tight_layout() +plt.show() + +# %% [markdown] tags=[] +#
+# +# ### Task 3.3: Using the selected batch to test the model's range of validity +# +# - Given the selected batch use `monai.networks.layers.GaussianFilter` to blur the images with different sigmas. +# Check the documentation here
+# - Plot the source and predicted images comparing the source, target and added perturbations
+# - How is the model's predictions given the perturbations?
+#
+# %% tags=["task"] +# ########## TODO ############## +# Try out different multiples of 256 to visualize larger/smaller crops +n = 3 +# ############################## +# Center cropping the image +y_slice = slice(Y // 2 - 256 * n // 2, Y // 2 + 256 * n // 2) +x_slice = slice(X // 2 - 256 * n // 2, X // 2 + 256 * n // 2) + +f, ax = plt.subplots(3, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(), + cmap="gray", + vmin=-15, + vmax=15, +) +ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice])) +ax[0, 0].set_title("Source and target") + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1, 0].set_title("No perturbation") + +# Select a sigma for the Gaussian filtering +# ########## TODO ############## +# Tensor dimensions (B, C, Z, Y, X). +# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigmas +# Hint: Spatial (Z, Y, X) +gaussian_blur = GaussianFilter(...) +# ############################# +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + phase = gaussian_blur(phase) + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) + +# %% tags=["solution"] +# ########## SOLUTION ############## +# Try out different multiples of 256 to visualize larger/smaller crops +n = 3 +# ############################## +# Center cropping the image +y_slice = slice(Y // 2 - 256 * n // 2, Y // 2 + 256 * n // 2) +x_slice = slice(X // 2 - 256 * n // 2, X // 2 + 256 * n // 2) + +f, ax = plt.subplots(3, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(), + cmap="gray", + vmin=-15, + vmax=15, +) +ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice])) +ax[0, 0].set_title("Source and target") + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1, 0].set_title("No perturbation") + + +# Select a sigma for the Gaussian filtering +# ########## SOLUTION ############## +# Tensor dimensions (B, C, Z, Y, X). +# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigma +# Hint: Spatial (Z, Y, X). Apply the same sigma to Y, X +gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0, 2, 2)) +# ############################# +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + phase = gaussian_blur(phase) + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) + +# %% [markdown] tags=[] +#
+# +# ### Task 3.3: Using the selected batch to test the model's range of validity +# +# - Scale the pixel values up/down of the phase image
+# - Plot the source and predicted images comparing the source, target and added perturbations
+# - How is the model's predictions given the perturbations?
+#
+ +# %% tags=["task"] +n = 3 +y_slice = slice(Y // 2, Y // 2 + 256 * n) +x_slice = slice(X // 2, X // 2 + 256 * n) +f, ax = plt.subplots(3, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(), + cmap="gray", + vmin=-15, + vmax=15, +) +ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice])) +ax[0, 0].set_title("Source and target") + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1, 0].set_title("No perturbation") + + +# Rescale the pixel value up/down +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + # ########## TODO ############## + # Hint: Scale the phase intensity up/down until the model breaks + phase = phase * ... + # ####################### + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) + +# %% tags=["solution"] +n = 3 +y_slice = slice(Y // 2, Y // 2 + 256 * n) +x_slice = slice(X // 2, X // 2 + 256 * n) +f, ax = plt.subplots(3, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(), + cmap="gray", + vmin=-15, + vmax=15, +) +ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice])) +ax[0, 0].set_title("Source and target") + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1, 0].set_title("No perturbation") + + +# Rescale the pixel value up/down +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + # ########## SOLUTION ############## + # Hint: Scale the phase intensity up/down until the model breaks + phase = phase * 10 + # ####################### + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) + +# %% [markdown] +#
+#

Questions

+# How is the model's predictions given the blurring and scaling perturbations?
+#
+ +# %% tags=["solution"] +# ########## SOLUTIONS FOR ALL POSSIBLE PLOTTINGS ############## +# This plots all perturbations + +n = 3 +y_slice = slice(Y // 2, Y // 2 + 256 * n) +x_slice = slice(X // 2, X // 2 + 256 * n) +f, ax = plt.subplots(6, 2, figsize=(8, 12)) + +target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA) +ax[0, 0].imshow( + batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(), + cmap="gray", + vmin=-15, + vmax=15, +) +ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice])) +ax[0, 0].set_title("Source and target") + +# no perturbation +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[1, 1].imshow(pred_composite[0]) +ax[1, 0].set_title("No perturbation") + + +# 2-sigma gaussian blur +gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0, 2, 2)) +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + phase = gaussian_blur(phase) + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[2, 1].imshow(pred_composite[0]) +ax[2, 0].set_title("Gaussian Blur Sigma=2") + + +# 5-sigma gaussian blur +gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0, 5, 5)) +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + phase = gaussian_blur(phase) + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[3, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[3, 1].imshow(pred_composite[0]) +ax[3, 0].set_title("Gaussian Blur Sigma=5") + + +# 0.1x scaling +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + phase = phase * 0.1 + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[4, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[4, 1].imshow(pred_composite[0]) +ax[4, 0].set_title("0.1x scaling") + +# 10x scaling +with torch.inference_mode(): + phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice] + phase = phase * 10 + pred = model(phase).cpu().numpy() +pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE) +ax[5, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15) +ax[5, 1].imshow(pred_composite[0]) +ax[5, 0].set_title("10x scaling") + +for a in ax.ravel(): + a.axis("off") + +f.tight_layout() +# %% [markdown] tags=[] +#
+ +#

+# 🎉 The end of the notebook 🎉 +#

+ +# Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned. + +#
diff --git a/applications/cytoland/examples/phase_contrast/README.md b/applications/cytoland/examples/phase_contrast/README.md new file mode 100644 index 000000000..0131624a9 --- /dev/null +++ b/applications/cytoland/examples/phase_contrast/README.md @@ -0,0 +1,37 @@ +# Demo: Virtual staining of phase contrast data + +# Overview: + +Generalization to Zernike phase contrast images. This demo showcases the use of VSCyto3D model with and without augmentations on Zernike phase contrast data. + +## Setup + +Run the setup script to create the environment for this exercise and download the dataset. +```bash +source setup.sh +``` + +Activate your environment +```bash +conda activate vs_Phc +``` + +## Use vscode + +Install vscode, install jupyter extension inside vscode, and setup [cell mode](https://code.visualstudio.com/docs/python/jupyter-support-py). Open [solution.py](solution.py) and run the script interactively. + +## Use Jupyter Notebook + +Launch a jupyter environment + +``` +jupyter notebook +``` + +...and continue with the instructions in the notebook. + +If `vs_Phc` is not available as a kernel in jupyter, run: + +``` +python -m ipykernel install --user --name=vs_Phc +``` diff --git a/applications/cytoland/examples/phase_contrast/prepare-exercise.sh b/applications/cytoland/examples/phase_contrast/prepare-exercise.sh new file mode 100644 index 000000000..526b9fa3c --- /dev/null +++ b/applications/cytoland/examples/phase_contrast/prepare-exercise.sh @@ -0,0 +1,8 @@ +# Run ruff format on .py files +# ruff format solution.py + +# Convert .py to ipynb + +# "cell_metadata_filter": "all" preserve cell tags including our solution tags +jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' --update solution.py +jupyter nbconvert solution.ipynb --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags task --to notebook --output solution.ipynb diff --git a/applications/cytoland/examples/phase_contrast/setup.sh b/applications/cytoland/examples/phase_contrast/setup.sh new file mode 100644 index 000000000..f0bfef845 --- /dev/null +++ b/applications/cytoland/examples/phase_contrast/setup.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env -S bash -i + +START_DIR=$(pwd) + +conda deactivate +# Create conda environment +conda create -y --name vs_Phc python=3.11 + +# Install ipykernel in the environment. +conda install -y ipykernel nbformat nbconvert ruff jupytext ipywidgets --name vs_Phc +# Specifying the environment explicitly. +# conda activate sometimes doesn't work from within shell scripts. + +# Install cytoland (pulls in viscy-data, viscy-models, viscy-transforms, viscy-utils). +# Run this from the root of the VisCy monorepo checkout. +# Find path to the environment - conda activate doesn't work from within shell scripts. +ENV_PATH=$(conda info --envs | grep vs_Phc | awk '{print $NF}') +$ENV_PATH/bin/pip install -e "applications/cytoland[metrics]" + +# Create the directory structure +mkdir -p ~/data/vs_PhC/test +mkdir -p ~/data/vs_PhC/models + +# Change to the target directory +# Download the OME-Zarr dataset recursively +cd ~/data/vs_PhC/test +wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto3D/test/HEK_H2B_CAAX_PhC_40x_registered.zarr/" + +# Get the models +cd ~/data/vs_PhC/models +wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/no_augmentations/best_epoch=30-step=6076.ckpt" +wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/epoch=48-step=18130.ckpt" + + +# Change back to the starting directory +cd $START_DIR diff --git a/applications/cytoland/examples/phase_contrast/solution.py b/applications/cytoland/examples/phase_contrast/solution.py new file mode 100644 index 000000000..19f5395eb --- /dev/null +++ b/applications/cytoland/examples/phase_contrast/solution.py @@ -0,0 +1,225 @@ +# %% [markdown] tags=[] +# # Virtual staining of phase contrast images using VSCyto3D with and without augmentations +# +# Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco +# +# ## Overview +# +# This notebook demonstrates how to use the VSCyto3D model to virtually stain phase contrast images. The phase contrast images were not part of the training. +# We will use the VSCyto3D model to predict the nuclei and cell membrane channels from a phase contrast image with two models: +# - One model trained without augmentations +# - One model trained with augmentations +# + +# %% tags=[] +# Imports +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +from iohub import open_ome_zarr +from lightning.pytorch import seed_everything + +# Cytoland and VisCy modular classes for the trainer and model +from cytoland.engine import VSUNet +from viscy_data.hcs import HCSDataModule +from viscy_transforms import NormalizeSampled +from viscy_utils.trainer import VisCyTrainer + +# seed random number generators for reproducibility. +seed_everything(42, workers=True) +# %% +# Paths to data and log directory +top_dir = ( + Path("~/data/vs_PhC").expanduser() +) # If this fails, make sure this to point to your data directory in the shared mounting point inside /dlmbl/data + +# Path to the training data +data_path = top_dir / "test/HEK_H2B_CAAX_PhC_40x_registered.zarr" + +# %% [markdown] tags=[] +# ## Load OME-Zarr Dataset + +# There should be 34 FOVs in the dataset. +# +# Each FOV consists of 3 channels of 2048x2048 images, +# saved in the [High-Content Screening (HCS) layout](https://ngff.openmicroscopy.org/latest/#hcs-layout) +# specified by the Open Microscopy Environment Next Generation File Format +# (OME-NGFF). +# +# The 3 channels correspond to the QPI, nuclei, and cell membrane. The nuclei were stained with DAPI and the cell membrane with Cellmask. +# +# - The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.` +# - These datasets only have 1 level in the pyramid (highest resolution) which is '0'. +# %% +# Open dataset and look at it's structure +dataset = open_ome_zarr(data_path) +dataset.print_tree() +# %% +row = 0 +col = 3 +field = "000000" # TODO Change this for a different FOV + +# NOTE: this dataset only has one level +pyaramid_level = 0 + +fov_path = f"{row}/{col}/{field}" +input_data_path = Path(data_path) / fov_path +image = dataset[fov_path][pyaramid_level].numpy() + +n_channels = len(dataset.channel_names) +Z, Y, X = image.shape[-3:] +figure, axes = plt.subplots(1, n_channels, figsize=(9, 3)) +title_names = ["PhC", "TXR", "Y5"] +for i in range(n_channels): + for i in range(n_channels): + channel_image = image[0, i, Z // 2] + # Invert the phase contrast channel + if i == 0: + channel_image = channel_image * -1 + # Adjust contrast to 0.5th and 99.5th percentile of pixel values. + p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[i].imshow(channel_image, cmap="gray") + axes[i].axis("off") + axes[i].set_title(title_names[i]) +plt.tight_layout() + +# %% [markdown] tags=[] +# ## Create the VSCyto3D model +# Here we will instantiate the `HCSDataModule` that reads the ome-zarr dataset and prepares the data for inference. +# %% +# Reduce the batch size if encountering out-of-memory errors +BATCH_SIZE = 5 +# NOTE: Set the number of workers to 0 for Windows and macOS +# since multiprocessing only works with a +# `if __name__ == '__main__':` guard. +# On Linux, set it to the number of CPU cores to maximize performance. +NUM_WORKERS = 0 +source_channel_name = "BF" + +# %%[markdown] +""" +For this example we will use the following parameters: +### For more information on the VSCyto3D model: +See ``viscy.unet.networks.fcmae`` +([source code](https://github.com/mehta-lab/VisCy/blob/6a3457ec8f43ecdc51b1760092f1a678ed73244d/viscy/unet/networks/unext2.py#L252)) +for configuration details. +""" +# %% +# Setup the data module. +data_module = HCSDataModule( + data_path=input_data_path, + source_channel=source_channel_name, + target_channel=["Nuclei", "Membrane"], + z_window_size=5, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + architecture="UNeXt2", + normalizations=[ + NormalizeSampled( + [source_channel_name], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) +data_module.prepare_data() +data_module.setup(stage="predict") + +# %% [markdown] tags=[] +# ## Setup the _VSCyto3D_ model with and without augmentations +# We will load the model checkpoints and run inference on the phase contrast image.abs +# The model that utilizes augmentations shows better performance in the prediction of the nuclei and cell membrane channels. +# The phase contrast images were not part of the training for the `VSCyto3D`` model. +# %% + +# TODO: change if you want to use a different GPU +GPU_ID = 0 + +# TODO: point to the downloaded model checkpoints +no_augmentation_model_ckpt = top_dir / "models/no_augmentations/best_epoch=30-step=6076.ckpt" +VSCyto3D_model_ckpt = top_dir / "models/epoch=48-step=18130.ckpt" + +# Dictionary that specifies key parameters of the model. +config_VSCyto3D = { + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": 5, + "backbone": "convnextv2_tiny", + "stem_kernel_size": (5, 4, 4), + "decoder_mode": "pixelshuffle", + "head_expansion_ratio": 4, + "head_pool": True, +} + +# Model without augmentation +model_VSCyto3D_no_augmentation = VSUNet.load_from_checkpoint( + no_augmentation_model_ckpt, architecture="UNeXt2", model_config=config_VSCyto3D +) +model_VSCyto3D_no_augmentation.eval() +# Model with augmentation +model_VSCyto3D_w_augmentation = VSUNet.load_from_checkpoint( + VSCyto3D_model_ckpt, architecture="UNeXt2", model_config=config_VSCyto3D +) +model_VSCyto3D_w_augmentation.eval() + +# Setup the Trainer +trainer = VisCyTrainer(accelerator="gpu", devices=[GPU_ID], precision="16-mixed") + +n = 5 +patch_size = 256 +y_slice = slice(Y // 2 - patch_size * n // 2, Y // 2 + patch_size * n // 2) +x_slice = slice(X // 2 - patch_size * n // 2, X // 2 + patch_size * n // 2) + +# Get the Phase Contrast channel +c_idx = dataset.channel_names.index(source_channel_name) +phase_image = image[0:1, c_idx : c_idx + 1, Z // 2 - 3 : Z // 2 + 3, y_slice, x_slice] +# Normalize the image +median = dataset[fov_path].zattrs["normalization"][source_channel_name]["fov_statistics"]["median"] +iqr = dataset[fov_path].zattrs["normalization"][source_channel_name]["fov_statistics"]["iqr"] +phase_image = ((phase_image - median) / iqr) * -1 + +# Load the image to device +device = model_VSCyto3D_no_augmentation.device +phase_image = torch.tensor(phase_image).to(device) + +# Run inference on the given volume +with torch.inference_mode(): # turn off gradient computation. + pred_no_augmentation = model_VSCyto3D_no_augmentation(phase_image) + pred_w_augmentation = model_VSCyto3D_w_augmentation(phase_image) + +pred_no_augmentation = pred_no_augmentation.cpu().detach().numpy() +pred_w_augmentation = pred_w_augmentation.cpu().detach().numpy() +phase_image = phase_image.cpu().detach().numpy() +clim_max = 30 +clim_min = -20 + +# Plot the predicted images with model without augmentations +fig, ax = plt.subplots(2, 3, figsize=(12, 12)) +ax[0, 0].imshow(phase_image[0, 0, 2, :, :], cmap="gray", vmin=clim_min, vmax=clim_max) +ax[0, 0].axis("off") +ax[0, 0].set_title("Phase Contrast") +for i in range(2): + ax[0, i + 1].imshow(pred_no_augmentation[0, i, 2, :, :], cmap="gray") + ax[0, i + 1].axis("off") +ax[0, 1].set_title("VS_Nuclei without augmentations") +ax[0, 2].set_title("VS_Membrane without augmentations") + +# Plot the predicted images with VSCyto3D with augmentations +ax[1, 0].imshow(phase_image[0, 0, 2, :, :], cmap="gray", vmin=clim_min, vmax=clim_max) +ax[1, 0].axis("off") +ax[1, 0].set_title("Phase Contrast") +for i in range(2): + ax[1, i + 1].imshow( + pred_w_augmentation[0, i, 2, :, :], + cmap="gray", + ) + ax[1, i + 1].axis("off") +ax[1, 1].set_title("VS_Nuclei with augmentations") +ax[1, 2].set_title("VS_Membrane with augmentations") + +plt.tight_layout() diff --git a/applications/cytoland/examples/vcp_tutorials/README.md b/applications/cytoland/examples/vcp_tutorials/README.md new file mode 100644 index 000000000..c9e2eaf6c --- /dev/null +++ b/applications/cytoland/examples/vcp_tutorials/README.md @@ -0,0 +1,21 @@ +# Virtual Cell Platform Tutorials + +This directory contains tutorial notebooks for the Virtual Cell Platform, +available in both Python scripts and Jupyter notebooks. + +- [Quick Start](quick_start.ipynb): +get started with model inference in Python with a A549 cell dataset. +- [CLI inference and visualization](hek293t.ipynb): +run inference from CLI on a HEK293T cell dataset and visualize the results. +- [Virtual staining _in vivo_](neuromast.ipynb): +compare virtual staining and fluorescence in a time-lapse dataset of the zebrafish neuromast. + +## Development + +The development happens on the Python scripts, +which are converted to Jupyter notebooks with: + +```sh +# TODO: change the file name at the end to be the script to convert +jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' --update quick_start.py +``` diff --git a/applications/cytoland/examples/vcp_tutorials/hek293t.py b/applications/cytoland/examples/vcp_tutorials/hek293t.py new file mode 100644 index 000000000..41decac02 --- /dev/null +++ b/applications/cytoland/examples/vcp_tutorials/hek293t.py @@ -0,0 +1,226 @@ +# %% [markdown] +""" +# Cytoland Tutorial: Virtual Staining of HEK293T Cells with VSCyto3D + +**Estimated time to complete:** 15 minutes +""" + +# %% [markdown] +""" +# Learning Goals + +* Download the VSCyto3D model and an example dataset containing HEK293T cell images. +* Pre-compute normalization statistics for the images using the `viscy preprocess` command line interface (CLI). +* Run inference for joint virtual staining of cell nuclei and plasma membrane via the `viscy predict` CLI. +* Compare virtually and experimentally stained cells and see how virtual staining can rescue missing labels. +""" + +# %% [markdown] +""" +# Prerequisites + +Python>=3.11 +""" + +# %% [markdown] +""" +# Introduction + +See the [model card](https://virtualcellmodels.cziscience.com/paper/cytoland2025) +for more details about the Cytoland models. + +VSCyto3D is a 3D UNeXt2 model that has been trained on A549, HEK293T, and hiPSC cells using the Cytoland approach. +This model enables users to jointly stain cell nuclei and plasma membranes from 3D label-free images +for downstream analysis such as cell segmentation and tracking without the need for human annotation of volumetric data. +""" + +# %% [markdown] +""" +# Setup + +The commands below will install the required packages and download the example dataset and model checkpoint. +It may take a **few minutes** to download all the files. + +## Setup Google Colab + +To run this quick-start guide using Google Colab, +choose the 'T4' GPU runtime from the "Connect" dropdown menu +in the upper-right corner of this notebook for faster execution. +Using a GPU significantly speeds up running model inference, but CPU compute can also be used. + +## Setup Local Environment + +The commands below assume a Unix-like shell with `wget` installed. +On Windows, the files can be downloaded manually from the URLs. +""" + +# %% +# Install VisCy with the optional dependencies for this example +# See the [repository](https://github.com/mehta-lab/VisCy) for more details +# Here stackview and ipycanvas are installed for visualization +# !pip install -U "viscy[metrics,visual]==0.4.0a3" stackview ipycanvas==0.11 + +# %% +# Restart kernel if running in Google Colab +# This is required to use the packages installed above +# The 'kernel crashed' message is expected here +if "get_ipython" in globals(): + session = get_ipython() # noqa: F821 + if "google.colab" in str(session): + print("Shutting down colab session.") + session.kernel.do_shutdown(restart=True) + +# %% +# Download the example dataset +# !wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto3D/test/HEK293T-Phase3D-H2B-CAAX-example.zarr/" + +# %% +# Rename the downloaded dataset to what the example prediction config expects (`input.ome.zarr`) +# And validate the OME-Zarr metadata with iohub +# !mv HEK293T-Phase3D-H2B-CAAX-example.zarr input.ome.zarr +# !iohub info -v input.ome.zarr + +# %% +# Download the VSCyto3D model checkpoint and prediction config +# !wget "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/epoch=83-step=14532-loss=0.492.ckpt" +# !wget "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/predict.yml" + +# %% [markdown] +""" +# Use Case + +## Example Dataset + +The HEK293T example dataset used in this quick-start guide contains +quantitative phase and paired fluorescence images of cell nuclei and plasma membrane. +It is a subset (one cropped region of interest) from a test set used to evaluate the VSCyto3D model. +The full dataset can be downloaded from the +[BioImage Archive](https://www.ebi.ac.uk/biostudies/BioImages/studies/S-BIAD1702). + +Refer to our [preprint](https://doi.org/10.1101/2024.05.31.596901) for more details +about how the dataset and model were generated. + +## Using Custom Data + +The model only requires label-free images for inference. +To run inference on your own data, +convert them into the OME-Zarr data format using iohub or other +[tools](https://ngff.openmicroscopy.org/tools/index.html#file-conversion), +and edit the `predict.yml` file to specify the input data path. +Specifically, the `data.init_args.data_path` field should be updated: + +```diff +- data_path: input.ome.zarr ++ data_path: /path/to/your.ome.zarr +``` + +The image may need to be resampled to roughly match the voxel size of the example dataset +(0.2x0.1x0.1 µm, ZYX). +""" + +# %% [markdown] +""" +# Run Model Inference + +On Google Colab, the preprocessing step takes about **1 minute**, +and the inference step takes about **2 minutes** (T4 GPU). +""" + +# %% +# Run the CLI command to pre-compute normalization statistics +# This includes the median and interquartile range (IQR) +# Used to shift and scale the intensity distribution of the input images +# !viscy preprocess --data_path=input.ome.zarr + +# %% +# Run the CLI command to run inference +# !viscy predict -c predict.yml + +# %% [markdown] +""" +# Analysis of Model Outputs + +Visualize the experimental and virtually stained images using the `stackview` package. +""" + +# %% [markdown] +""" +Visualizing large 3D multichannel images in a Jupyter notebook +**is prone to performance issues and may crash the notebook** if the images are too large +(the free Colab instances have limited CPU cores and memory). +The visualization code below is only intended for demonstration. +We strongly recommend downloading the images (from the 'files' bar in Colab) +and using a standalone viewer such as [napari](https://napari.org/). +""" + +# %% + +import numpy as np # noqa: E402 +import stackview # noqa: E402 +from iohub import open_ome_zarr # noqa: E402 +from skimage.exposure import rescale_intensity # noqa: E402 + +try: + from google.colab import output + + output.enable_custom_widget_manager() +except ImportError: + pass + + +# %% +# open the images +def split_and_rescale_channels(timepoint: np.ndarray) -> tuple[np.ndarray, ...]: + return (rescale_intensity(channel, out_range=(0, 1)) for channel in timepoint) + + +fov_name = "plate/0/11" +input_image = open_ome_zarr("input.ome.zarr")[fov_name]["0"] +prediction_image = open_ome_zarr("prediction.ome.zarr")[fov_name]["0"] + +phase, fluor_nucleus, fluor_membrane = split_and_rescale_channels(input_image[0]) +vs_nucleus, vs_membrane = split_and_rescale_channels(prediction_image[0]) + +# %% +# Drag the slider to start rendering +# Click on the numbered buttons to toggle the channels +stackview.switch( + # the 0, 1, 2, 3, 4 buttons will correspond to these 5 channels + # We apply a gamma adjustment to the phase channel to improve visibility in the overlay + images=[phase**2.5, fluor_nucleus, fluor_membrane, vs_nucleus, vs_membrane], + colormap=["gray", "pure_green", "pure_magenta", "pure_blue", "pure_yellow"], + toggleable=True, + zoom_factor=0.5, + display_min=0.0, + display_max=0.9, +) + +# %% [markdown] +""" +Note how the experimental fluorescence is missing for a subset of cells. +This is due to loss of genetic labeling. +The virtually stained images is not affected by this issue and can robustly label all cells. +""" + +# %% [markdown] +""" +# Summary + +In the above example, we demonstrated how to use the VSCyto3D model +for virtual staining of cell nuclei and plasma membranes, which can rescue missing labels. +""" + +# %% [markdown] +""" +## Contact & Feedback + +For issues or feedback about this tutorial please contact Ziwen Liu at [ziwen.liu@czbiohub.org](mailto:ziwen.liu@czbiohub.org). + +## Responsible Use + +We are committed to advancing the responsible development and use of artificial intelligence. +Please follow our [Acceptable Use Policy](https://virtualcellmodels.cziscience.com/acceptable-use-policy) when engaging with our services. + +Should you have any security or privacy issues or questions related to the services, +please reach out to our team at [security@chanzuckerberg.com](mailto:security@chanzuckerberg.com) or [privacy@chanzuckerberg.com](mailto:privacy@chanzuckerberg.com) respectively. +""" diff --git a/applications/cytoland/examples/vcp_tutorials/neuromast.py b/applications/cytoland/examples/vcp_tutorials/neuromast.py new file mode 100644 index 000000000..4ac547b83 --- /dev/null +++ b/applications/cytoland/examples/vcp_tutorials/neuromast.py @@ -0,0 +1,317 @@ +# %% [markdown] +""" +# Cytoland Tutorial: Virtual Staining of Zebrafish Neuromasts with VSNeuromast + +**Estimated time to complete:** 15 minutes +""" + +# %% [markdown] +""" +# Learning Goals + +* Download the VSNeuromast model and an example dataset containing time-lapse images of zebrafish neuromasts. +* Pre-compute normalization statistics for the images using the `viscy preprocess` command-line interface (CLI). +* Run inference for joint virtual staining of cell nuclei and plasma membrane via the `viscy predict` CLI. +* Visualize the effect of photobleaching in fluorescence imaging and how virtual staining can mitigate this issue. +""" + +# %% [markdown] +""" +# Prerequisites + +Python>=3.11 +""" + +# %% [markdown] +""" +# Introduction + +The zebrafish neuromasts are sensory organs on the lateral lines. +Given their relatively simple structure and high accessibility to live imaging, +they are used as a model system to study organogenesis _in vivo_. +However, multiplexed long-term fluorescence imaging at high spatial-temporal resolution +is often limited by photobleaching and phototoxicity. +Also, engineering fish lines with a combination of landmark fluorescent labels +(e.g. nuclei and plasma membrane) and functional reporters increases experimental complexity. +\ +VSNeuromast is a 3D UNeXt2 model that has been trained on images of +zebrafish neuromasts using the Cytoland approach. +(See the [model card](https://virtualcellmodels.cziscience.com/paper/cytoland2025) +for more details about the Cytoland models.) +This model enables users to jointly stain cell nuclei and plasma membranes from 3D label-free images +for downstream analysis such as cell segmentation and tracking. +""" + +# %% [markdown] +""" +# Setup + +The commands below will install the required packages and download the example dataset and model checkpoint. +It may take a **few minutes** to download all the files. + +## Setup Google Colab + +To run this quick-start guide using Google Colab, +choose the 'T4' GPU runtime from the "Connect" dropdown menu +in the upper-right corner of this notebook for faster execution. +Using a GPU significantly speeds up running model inference, but CPU compute can also be used. + +## Setup Local Environment + +The commands below assume a Unix-like shell with `wget` installed. +On Windows, the files can be downloaded manually from the URLs. +""" + +# %% +# Install VisCy with the optional dependencies for this example +# See the [repository](https://github.com/mehta-lab/VisCy) for more details +# Here stackview and ipycanvas are installed for visualization +# !pip install -U "viscy[metrics,visual]==0.4.0a3" + +# %% +# Restart kernel if running in Google Colab +# This is required to use the packages installed above +# The 'kernel crashed' message is expected here +if "get_ipython" in globals(): + session = get_ipython() + if "google.colab" in str(session): + print("Shutting down colab session.") + session.kernel.do_shutdown(restart=True) + +# %% +# Download the example dataset +# !wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSNeuromast/test/isim-bleaching-example.zarr/" + +# %% +# Rename the downloaded dataset to what the example prediction config expects (`input.ome.zarr`) +# And validate the OME-Zarr metadata with iohub +# !mv isim-bleaching-example.zarr input.ome.zarr +# !iohub info -v input.ome.zarr + +# %% +# Download the VSNeuromast model checkpoint and prediction config +# !wget "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSNeuromast/epoch=64-step=24960.ckpt" +# !wget "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSNeuromast/predict.yml" + +# %% [markdown] +""" +# Use Case + +## Example Dataset + +The neuromast example dataset used in this tutorial contains +quantitative phase and paired fluorescence images of the cell nuclei and the plasma membrane. +\ +**It is a subsampled time-lapse from a test set used to evaluate the VSNeuromast model.** +\ +The full dataset can be downloaded from the +[BioImage Archive](https://www.ebi.ac.uk/biostudies/BioImages/studies/S-BIAD1702). + +Refer to our [preprint](https://doi.org/10.1101/2024.05.31.596901) for more details +about how the dataset and model were generated. + +## Using Custom Data + +The model only requires label-free images for inference. +To run inference on your own data, +convert them into the [OME-Zarr](https://ngff.openmicroscopy.org/) +data format using iohub or other +[tools](https://ngff.openmicroscopy.org/tools/index.html#file-conversion), +and edit the `predict.yml` file to specify the input data path. +Specifically, the `data.init_args.data_path` field should be updated: + +```diff +- data_path: input.ome.zarr ++ data_path: /path/to/your.ome.zarr +``` + +The image may need to be resampled to roughly match the voxel size of the example dataset +(0.25x0.108x0.108 µm, ZYX). +""" + +# %% [markdown] +""" +# Run Model Inference + +On Google Colab, the preprocessing step takes about **1 minute**, +and the inference step takes about **2 minutes** (T4 GPU). +""" + +# %% +# Run the CLI command to pre-compute normalization statistics +# This includes the median and interquartile range (IQR) +# Used to shift and scale the intensity distribution of the input images +# !viscy preprocess --data_path=input.ome.zarr + +# %% +# Run the CLI command to run inference +# !viscy predict -c predict.yml + +# %% [markdown] +""" +# Analysis of Model Outputs + +1. Visualize predicted images over time and compare with the fluorescence images. +2. Measure photobleaching in the fluorescence images +and how virtual staining can mitigate this issue. +Since most pixels in the images are background, +we will use the 99th percentile (brightest 1%) +of the intensity distribution as a proxy for foreground signal. +""" + +# %% +# imports +import matplotlib.pyplot as plt +import numpy as np +from cmap import Colormap +from iohub import open_ome_zarr +from numpy.typing import NDArray +from skimage.exposure import rescale_intensity + + +def render_rgb(image: np.ndarray, colormap: Colormap) -> tuple[NDArray, plt.cm.ScalarMappable]: + """Render a 2D grayscale image as RGB using a colormap. + + Parameters + ---------- + image : np.ndarray + intensity image + colormap : Colormap + colormap + + Returns + ------- + tuple[NDArray, plt.cm.ScalarMappable] + rendered RGB image and the color mapping + """ + image = rescale_intensity(image, out_range=(0, 1)) + image = colormap(image) + mappable = plt.cm.ScalarMappable(norm=plt.Normalize(0, 1), cmap=colormap.to_matplotlib()) + return image, mappable + + +# %% +# read a single Z-slice for visualization +z_slice = 30 + +with open_ome_zarr("input.ome.zarr/0/3/0") as fluor_store: + fluor_nucleus = fluor_store[0][:, 1, z_slice] + fluor_membrane = fluor_store[0][:, 0, z_slice] + +with open_ome_zarr("prediction.ome.zarr/0/3/0") as vs_store: + vs_nucleus = vs_store[0][:, 0, z_slice] + vs_membrane = vs_store[0][:, 1, z_slice] + + +# Render the images as RGB in false colors +vs_nucleus_rgb, vs_nucleus_mappable = render_rgb(vs_nucleus, Colormap("bop_blue")) +vs_membrane_rgb, vs_membrane_mappable = render_rgb(vs_membrane, Colormap("bop_orange")) +merged_vs = (vs_nucleus_rgb + vs_membrane_rgb).clip(0, 1) + +fluor_nucleus_rgb, fluor_nucleus_mappable = render_rgb(fluor_nucleus, Colormap("green")) +fluor_membrane_rgb, fluor_membrane_mappable = render_rgb(fluor_membrane, Colormap("magenta")) +merged_fluor = (fluor_nucleus_rgb + fluor_membrane_rgb).clip(0, 1) + +# Plot +fig = plt.figure(figsize=(12, 7), layout="constrained") + +images = {"fluorescence": merged_fluor, "virtual staining": merged_vs} + +for row, (subfig, (name, img)) in enumerate(zip(fig.subfigures(nrows=2, ncols=1), images.items())): + subfig.suptitle(name) + cax_nuc = subfig.add_axes([1, 0.55, 0.02, 0.3]) + cax_mem = subfig.add_axes([1, 0.15, 0.02, 0.3]) + axes = subfig.subplots(ncols=len(merged_vs)) + for t, ax in enumerate(axes): + if row == 1: + ax.set_title(f"{t * 30} min", y=-0.1) + ax.imshow(img[t]) + ax.axis("off") + if row == 0: + subfig.colorbar(fluor_nucleus_mappable, cax=cax_nuc, label="Nuclei (GFP)") + subfig.colorbar(fluor_membrane_mappable, cax=cax_mem, label="Membrane (mScarlett)") + elif row == 1: + subfig.colorbar(vs_nucleus_mappable, cax=cax_nuc, label="Nuclei (VS)") + subfig.colorbar(vs_membrane_mappable, cax=cax_mem, label="Membrane (VS)") + +plt.show() + +# %% [markdown] +""" +The plasma membrane fluorescence decreases over time, +while the virtual staining remains stable. +How significant is this effect? Is it consistent with photobleaching? +Analysis below will answer these questions. +""" + + +# %% +def highlight_intensity_normalized(fov_path: str, channel_name: str) -> list[float]: + """ + Compute highlight (99th percentile) intensity of each timepoint, + normalized to the first timepoint. + + Parameters + ---------- + fov_path : str + Path to the field of view (FOV). + channel_name : str + Name of the channel to compute highlight intensity for. + + Returns + ------- + NDArray + List of intensity values. + """ + with open_ome_zarr(fov_path) as fov: + channel_index = fov.get_channel_index(channel_name) + channel = fov["0"].dask_array()[:, channel_index] + highlights = [] + for t, volume in enumerate(channel): + highlights.append(np.percentile(volume.compute(), 99)) + return [h / highlights[0] for h in highlights] + + +# %% +# Plot intensity over time +mean_fl = highlight_intensity_normalized("input.ome.zarr/0/3/0", "mScarlett") +mean_vs = highlight_intensity_normalized("prediction.ome.zarr/0/3/0", "membrane_prediction") +time = np.arange(0, 100, 30) + +plt.plot(time, mean_fl, label="membrane fluorescence") +plt.plot(time, mean_vs, label="membrane virtual staining") +plt.xlabel("time / min") +plt.ylabel("normalized highlight intensity") +plt.legend() + +# %% [markdown] +""" +Here the highlight intensity of the fluorescence images decreases over time, +following a exponential decay pattern, indicating photobleaching. +The virtual staining is not affected by this issue. +(The object drifts slightly over time, so some inherent noise is expected.) +""" + +# %% [markdown] +""" +# Summary + +In the above example, we demonstrated how to use the VSNeuromast model +for virtual staining of cell nuclei and plasma membranes of the zebrafish neuromast _in vivo_, +which can avoid photobleaching in long-term live imaging. +""" + +# %% [markdown] +""" +## Contact & Feedback + +For issues or feedback about this tutorial please contact Ziwen Liu at [ziwen.liu@czbiohub.org](mailto:ziwen.liu@czbiohub.org). + +## Responsible Use + +We are committed to advancing the responsible development and use of artificial intelligence. +Please follow our [Acceptable Use Policy](https://virtualcellmodels.cziscience.com/acceptable-use-policy) when engaging with our services. + +Should you have any security or privacy issues or questions related to the services, +please reach out to our team at [security@chanzuckerberg.com](mailto:security@chanzuckerberg.com) or [privacy@chanzuckerberg.com](mailto:privacy@chanzuckerberg.com) respectively. +""" diff --git a/applications/cytoland/examples/vcp_tutorials/quick_start.py b/applications/cytoland/examples/vcp_tutorials/quick_start.py new file mode 100644 index 000000000..58d48c548 --- /dev/null +++ b/applications/cytoland/examples/vcp_tutorials/quick_start.py @@ -0,0 +1,332 @@ +# %% [markdown] +""" +# Quick Start: Cytoland + +**Estimated time to complete:** 15 minutes +""" + +# %% [markdown] +""" +# Learning Goals + +* Download the VSCyto2D model and an example dataset containing A549 cell images. +* Run VSCyto2D model inference for joint virtual staining of cell nuclei and plasma membrane. +* Visualize and compare virtually and experimentally stained cells. +""" + +# %% [markdown] +""" +# Prerequisites +Python>=3.11 + +""" + +# %% [markdown] +""" +# Introduction + +## Model + +The Cytoland virtual staining models are a collection of models (VSCyto2D, VSCyto3D, and VSNeuromast) +used to predict cellular landmarks (e.g., nuclei and plasma membranes) +from label-free images (e.g. quantitative phase, Zernike phase contrast, and brightfield). +This quick-start guide focuses on the VSCyto2D model. + +VSCyto2D is a 2D UNeXt2 model that has been trained on A549, HEK293T, and BJ-5ta cells. +This model enables users to jointly stain cell nuclei and plasma membranes from 2D label-free images +that are commonly generated for image-based screens. + +Alternative models are optimized for different sample types and imaging conditions: + +* [VSCyto3D](https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D): +3D UNeXt2 model for joint virtual staining of cell nuclei and plasma membrane +from high-resolution volumetric images. +* [VSNeuromast](https://public.czbiohub.org/comp.micro/viscy/VS_models/VSNeuromast): +3D UNeXt2 model for joint virtual staining of nuclei and plasma membrane in zebrafish neuromasts. + +## Example Dataset + +The A549 example dataset used in this quick-start guide contains +quantitative phase and paired fluorescence images of cell nuclei and plasma membrane. +It is stored in OME-Zarr format and can be downloaded from +[here](https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto2D/test/a549_hoechst_cellmask_test.zarr). +It has pre-computed statistics for normalization, generated using the `viscy preprocess` CLI. + +Refer to our [preprint](https://doi.org/10.1101/2024.05.31.596901) for more details +about how the dataset and model were generated. + +## User Data + +The VSCyto2D model only requires label-free images for inference. +To run inference on your own data, +convert them into the OME-Zarr data format using iohub or other +[tools](https://ngff.openmicroscopy.org/tools/index.html#file-conversion), +and run [pre-processing](https://github.com/mehta-lab/VisCy/blob/main/docs/usage.md#preprocessing) +with the `viscy preprocess` CLI. +""" + +# %% [markdown] +""" +# Setup + +The commands below will install the required packages and download the example dataset and model checkpoint. +It may take a few minutes to download all the files. + +## Setup Google Colab + +To run this quick-start guide using Google Colab, +choose the 'T4' GPU runtime from the "Connect" dropdown menu +in the upper-right corner of this notebook for faster execution. +Using a GPU significantly speeds up running model inference, but CPU compute can also be used. + +## Setup Local Environment + +The commands below assume a Unix-like shell with `wget` installed. +On Windows, the files can be downloaded manually from the URLs. +""" + +# %% +# Install VisCy with the optional dependencies for this example +# See the [repository](https://github.com/mehta-lab/VisCy) for more details +# !pip install "viscy[metrics,visual]==0.4.0a3" + +# %% +# restart kernel if running in Google Colab +if "get_ipython" in globals(): + session = get_ipython() # noqa: F821 + if "google.colab" in str(session): + print("Shutting down colab session.") + session.kernel.do_shutdown(restart=True) + +# %% +# Validate installation +# !viscy --help + +# %% +# Download the example dataset +# !wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto2D/test/a549_hoechst_cellmask_test.zarr/" +# Download the model checkpoint +# !wget https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto2D/VSCyto2D/epoch=399-step=23200.ckpt + +# %% [markdown] +""" +# Run Model Inference + +The following code will run inference on a single field of view (FOV) of the example dataset. +This can also be achieved by using the VisCy CLI. +""" + +# %% +from pathlib import Path # noqa: E402 + +from iohub import open_ome_zarr # noqa: E402 +from torchview import draw_graph # noqa: E402 + +from cytoland.engine import FcmaeUNet # noqa: E402 +from viscy_data.hcs import HCSDataModule # noqa: E402 +from viscy_transforms import NormalizeSampled # noqa: E402 +from viscy_utils.callbacks import HCSPredictionWriter # noqa: E402 +from viscy_utils.trainer import VisCyTrainer # noqa: E402 + +# %% +# NOTE: Nothing needs to be changed in this code block for the example to work. +# If using your own data, please modify the paths below. + +# TODO: Set download paths, by default the working directory is used +root_dir = Path() +# TODO: modify the path to the input dataset +input_data_path = root_dir / "a549_hoechst_cellmask_test.zarr" +# TODO: modify the path to the model checkpoint +model_ckpt_path = root_dir / "epoch=399-step=23200.ckpt" +# TODO: modify the path to save the predictions +output_path = root_dir / "a549_prediction.zarr" +# TODO: Choose an FOV +fov = "0/0/0" + + +# %% +# Configure the data module for loading example images in prediction mode. +# See API documentation for how to use it with a different dataset. +# For example, View the documentation for the HCSDataModule class by running: +# ?HCSDataModule + +# %% +# Setup the data module to use the example dataset +data_module = HCSDataModule( + # Path to HCS or Single-FOV OME-Zarr dataset + data_path=input_data_path / fov, + # Name of the input phase channel + source_channel="Phase3D", + # Desired name of the output channels + target_channel=["Membrane", "Nuclei"], + # Axial input size, 1 for 2D models + z_window_size=1, + # Batch size + # Adjust based on available memory (reduce if seeing OOM errors) + batch_size=8, + # Number of workers for data loading + # Set to 0 for Windows and macOS if running in a notebook, + # since multiprocessing only works with a `if __name__ == '__main__':` guard. + # On Linux, set it based on available CPU cores to maximize performance. + num_workers=4, + # Normalization strategy + # This one uses pre-computed statistics from `viscy preprocess` + # to subtract the median and divide by the interquartile range (IQR). + # It can also be replaced by other MONAI transforms. + normalizations=[ + NormalizeSampled( + ["Phase3D"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) + +# %% +# Load the VSCyto2D model from the downloaded checkpoint +# VSCyto2D is fine-tuned from a FCMAE-pretrained UNeXt2 model. +# See this module for options to configure the model: + +# ?FullyConvolutionalMAE + +# %% +vs_cyto_2d = FcmaeUNet.load_from_checkpoint( + # checkpoint path + model_ckpt_path, + model_config={ + # number of input channels + # must match the number of channels in the input data + "in_channels": 1, + # number of output channels + # must match the number of target channels in the data module + "out_channels": 2, + # number of ConvNeXt v2 blocks in each stage of the encoder + "encoder_blocks": [3, 3, 9, 3], + # feature map channels in each stage of the encoder + "dims": [96, 192, 384, 768], + # number of ConvNeXt v2 blocks in each stage of the decoder + "decoder_conv_blocks": 2, + # kernel size in the stem layer + "stem_kernel_size": [1, 2, 2], + # axial size of the input image + # must match the Z-window size in the data module + "in_stack_depth": 1, + # whether to perform masking (for FCMAE pre-training) + "pretraining": False, + }, +) + +# %% +# Visualize the model graph +model_graph = draw_graph( + vs_cyto_2d, + (vs_cyto_2d.example_input_array), + graph_name="VSCyto2D", + roll=True, + depth=3, + expand_nested=True, +) + +model_graph.visual_graph + +# %% +# Setup the trainer for prediction +# The trainer can be further configured to better utilize the available hardware, +# For example using GPUs and half precision. +# Callbacks can also be used to customize logging and prediction writing. +# See the API documentation for more details: +# ?VisCyTrainer + +# %% +# Initialize the trainer +# The prediction writer callback will save the predictions to an OME-Zarr store +trainer = VisCyTrainer(callbacks=[HCSPredictionWriter(output_path)]) + +# Run prediction +trainer.predict(model=vs_cyto_2d, datamodule=data_module, return_predictions=False) + +# %% [markdown] +""" +# Model Outputs + +The model outputs are also stored in an OME-Zarr store. +It can be visualized in an image viewer such as [napari](https://napari.org/). +Below we show a snapshot in the notebook. +""" + +# %% +# Read images from Zarr stores +# Choose the ROI for better visualization +y_slice = slice(0, 512) +x_slice = slice(0, 512) + +# Open the prediction store and get the 2D images from 5D arrays (t,c,z,y,x) +with open_ome_zarr(output_path / fov) as vs_store: + vs_nucleus = vs_store[0][0, 0, 0, y_slice, x_slice] + vs_membrane = vs_store[0][0, 1, 0, y_slice, x_slice] + +# Open the experimental fluorescence dataset +with open_ome_zarr(input_data_path / fov) as fluor_store: + fluor_nucleus = fluor_store[0][0, 1, 0, y_slice, x_slice] + fluor_membrane = fluor_store[0][0, 2, 0, y_slice, x_slice] + +# %% +# Plot +import matplotlib.pyplot as plt # noqa: E402 +import numpy as np # noqa: E402 +from cmap import Colormap # noqa: E402 +from skimage.exposure import rescale_intensity # noqa: E402 + + +def render_rgb(image: np.ndarray, colormap: Colormap): + image = rescale_intensity(image, out_range=(0, 1)) + image = colormap(image) + return image + + +# Render the images as RGB in false colors +vs_nucleus_rgb = render_rgb(vs_nucleus, Colormap("bop_blue")) +vs_membrane_rgb = render_rgb(vs_membrane, Colormap("bop_orange")) +merged_vs = (vs_nucleus_rgb + vs_membrane_rgb).clip(0, 1) + +fluor_nucleus_rgb = render_rgb(fluor_nucleus, Colormap("green")) +fluor_membrane_rgb = render_rgb(fluor_membrane, Colormap("magenta")) +merged_fluor = (fluor_nucleus_rgb + fluor_membrane_rgb).clip(0, 1) + +# Plot +# Show the individual channels and then fused in a grid +fig, ax = plt.subplots(2, 3, figsize=(15, 10)) + +# Virtual staining plots +ax[0, 0].imshow(vs_nucleus_rgb) +ax[0, 0].set_title("VS Nuclei") +ax[0, 1].imshow(vs_membrane_rgb) +ax[0, 1].set_title("VS Membrane") +ax[0, 2].imshow(merged_vs) +ax[0, 2].set_title("VS Nuclei+Membrane") + +# Experimental fluorescence plots +ax[1, 0].imshow(fluor_nucleus_rgb) +ax[1, 0].set_title("Experimental Fluorescence Nuclei") +ax[1, 1].imshow(fluor_membrane_rgb) +ax[1, 1].set_title("Experimental Fluorescence Membrane") +ax[1, 2].imshow(merged_fluor) +ax[1, 2].set_title("Experimental Fluorescence Nuclei+Membrane") + +# turnoff axis +for a in ax.flatten(): + a.axis("off") +plt.tight_layout() +plt.show() + +# %% [markdown] +""" +## Responsible Use + +We are committed to advancing the responsible development and use of artificial intelligence. +Please follow our [Acceptable Use Policy](https://virtualcellmodels.cziscience.com/acceptable-use-policy) when engaging with our services. + +Should you have any security or privacy issues or questions related to the services, +please reach out to our team at [security@chanzuckerberg.com](mailto:security@chanzuckerberg.com) or [privacy@chanzuckerberg.com](mailto:privacy@chanzuckerberg.com) respectively. +""" diff --git a/pyproject.toml b/pyproject.toml index 5caca7b62..1ec272574 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ lint.per-file-ignores."**/*.ipynb" = [ "D", "E402", "E501", "PD" ] lint.per-file-ignores."**/__init__.py" = [ "D104", "F401" ] lint.per-file-ignores."**/docs/**" = [ "I" ] lint.per-file-ignores."**/evaluation/**" = [ "D", "E501", "NPY002", "PD011" ] +lint.per-file-ignores."**/examples/**" = [ "D", "E402", "E501", "F821" ] lint.per-file-ignores."**/tests/**" = [ "D" ] lint.pydocstyle.convention = "numpy"