+# The exercise is organized in 3 parts:
+
+#
+# - Part 1 - Train a virtual staining model using iohub (I/O library), VisCy dataloaders, and tensorboard
+# - Part 2 - Evaluate the model to translate phase into fluorescence.
+# - Part 3 - Visualize the image transforms learned by the model and explore the model's regime of validity.
+#
+
+#
+
+# %% [markdown] tags=[]
+#
+# Set your python kernel to 06_image_translation
+#
+
+# %% [markdown] tags=[]
+# ## PyTorch Lightning in one minute
+#
+# If you've used plain PyTorch you already know the pattern: write a model, write a
+# `for batch in dataloader` loop, move tensors to `cuda`, call `loss.backward()`, step
+# the optimizer, remember to `zero_grad()`, log every N steps, save a checkpoint, and
+# repeat for validation. That boilerplate is the same in every project — so
+# [PyTorch Lightning](https://lightning.ai) factors it out into **three objects** and
+# owns the training loop for you.
+#
+# | Lightning object | What it holds | In this exercise |
+# | --- | --- | --- |
+# | `LightningDataModule` | How to load, split, augment, and batch your data (`train/val/test/predict_dataloader`) | `HCSDataModule` — reads OME-Zarr and yields `{"source": ..., "target": ...}` dicts |
+# | `LightningModule` | The network, the loss, and what happens in `training_step` / `validation_step` (one batch at a time) | `VSUNet` — wraps the UNeXt2 architecture and the virtual-staining loss |
+# | `Trainer` | The loop: device placement, mixed precision, logging, checkpointing, multi-GPU | `VisCyTrainer` — a thin subclass with VisCy-friendly defaults |
+#
+# You don't write a `for` loop. You call **`trainer.fit(model, datamodule)`** and
+# Lightning drives everything. The trainer handles:
+#
+# - moving batches to the right device (`accelerator="gpu"`, `devices=[0]`)
+# - mixed-precision training (`precision="16-mixed"`) so you use less GPU memory
+# - when to log metrics / images (`log_every_n_steps`) and where (`logger=TensorBoardLogger(...)`)
+# - saving checkpoints automatically under the logger's directory
+# - running a sanity check on a single batch before real training (`fast_dev_run=True`)
+#
+# VisCy builds on top of Lightning and provides the `HCSDataModule` and `VSUNet`
+# classes so you don't have to subclass `LightningDataModule` / `LightningModule`
+# yourself — you configure them via constructor arguments and let Lightning run.
+# When you see `trainer.fit(...)` below, that single call replaces a ~50-line hand-
+# written training loop.
+
+# %% [markdown]
+# # Part 1: Log training data to tensorboard, start training a model.
+# ---------
+# Learning goals:
+
+# - Load the OME-zarr dataset and examine the channels (A549).
+# - Configure and understand the data loader.
+# - Log some patches to tensorboard.
+# - Initialize a 2D UNeXt2 model for virtual staining of nuclei and membrane from phase.
+# - Start training the model to predict nuclei and membrane from phase.
+
+# %% Imports
+import os
+from glob import glob
+from pathlib import Path
+from typing import Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+import torchview
+import torchvision
+from cellpose import models
+from iohub import open_ome_zarr
+from iohub.reader import print_info
+from lightning.pytorch import seed_everything
+from lightning.pytorch.loggers import TensorBoardLogger
+
+# microSSIM: SSIM variant designed for fluorescence microscopy.
+from microssim import micro_structural_similarity
+from natsort import natsorted
+from numpy.typing import ArrayLike
+
+# pytorch lightning wrapper for Tensorboard.
+from skimage.color import label2rgb
+from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard
+from torchmetrics.functional import accuracy, jaccard_index
+from torchmetrics.functional.segmentation import dice_score
+from tqdm import tqdm
+
+# Trainer class and UNet from the cytoland package.
+from cytoland.engine import VSUNet
+
+# HCSDataModule makes it easy to load data during training.
+from viscy_data.hcs import HCSDataModule
+
+# training augmentations
+from viscy_transforms import (
+ NormalizeSampled,
+ RandAdjustContrastd,
+ RandAffined,
+ RandGaussianNoised,
+ RandGaussianSmoothd,
+ RandScaleIntensityd,
+ RandWeightedCropd,
+)
+from viscy_utils.evaluation.metrics import mean_average_precision
+from viscy_utils.losses import MixedLoss
+from viscy_utils.trainer import VisCyTrainer
+
+# %%
+# seed random number generators for reproducibility.
+seed_everything(42, workers=True)
+
+# Paths to data and log directory
+top_dir = Path("~/data").expanduser() # If this fails, point to your data directory (e.g. a shared course mount).
+
+# Path to the training data
+data_path = top_dir / "06_image_translation/training/a549_hoechst_cellmask_train_val.zarr"
+
+# Path where we will save our training logs
+training_top_dir = Path(f"{os.getcwd()}/data/")
+# Create top_training_dir directory if needed, and launch tensorboard
+training_top_dir.mkdir(parents=True, exist_ok=True)
+log_dir = training_top_dir / "06_image_translation/logs/"
+# Create log directory if needed, and launch tensorboard
+log_dir.mkdir(parents=True, exist_ok=True)
+
+if not data_path.exists():
+ raise FileNotFoundError(f"Data not found at {data_path}. Please check the top_dir and data_path variables.")
+
+# %% [markdown] tags=[]
+# The next cell starts tensorboard.
+
+#
+# If you launched jupyter lab from ssh terminal, add --host <your-server-name> to the tensorboard command below. <your-server-name> is the address of your compute node that ends in amazonaws.com.
+
+#
+
+
+# %% tags=[]
+# Imports and paths
+# Function to find an available port
+def find_free_port():
+ import socket
+
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+
+# Launch TensorBoard on the browser
+def launch_tensorboard(log_dir):
+ import subprocess
+
+ port = find_free_port()
+ tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}"
+ process = subprocess.Popen(tensorboard_cmd, shell=True)
+ print(
+ f"TensorBoard started at http://localhost:{port}. \n"
+ "If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL."
+ )
+ return process
+
+
+# Launch tensorboard and click on the link to view the logs.
+tensorboard_process = launch_tensorboard(log_dir)
+# %% [markdown] tags = []
+#
+# If you are using VSCode and a remote server, you will need to forward the port to view the tensorboard.
+# Take note of the port number was assigned in the previous cell.(i.e
http://localhost:{port_number_assigned})
+
+# Locate the your VSCode terminal and select the
Ports tab
+#
+# - Add a new port with the
port_number_assigned
+#
+# Click on the link to view the tensorboard and it should open in your browser.
+#
+
+
+# %% [markdown] tags=[]
+# ## Load OME-Zarr Dataset
+#
+# **OME-Zarr** is a chunked, cloud-friendly microscopy format; **HCS layout**
+# nests the zarr store like a physical plate — `row/col/field/level/T/C/Z/Y/X` —
+# so each FOV is addressable by `dataset[f"{row}/{col}/{field}/{level}"]` and
+# returns an `(T, C, Z, Y, X)` array.
+#
+# This dataset has 34 FOVs of 2048×2048 images across 3 channels (QPI, nuclei
+# stained with DAPI, membrane stained with Cellmask), a single pyramid level
+# `0`, and a single time point.
+
+# %% [markdown] tags=[]
+#
+# You can inspect the tree structure by using your terminal:
+# iohub info -v "path-to-ome-zarr"
+
+#
+# More info on the CLI:
+# iohub info --help to see the help menu.
+#
+# %%
+# This is the python function called by `iohub info` CLI command
+print_info(data_path, verbose=True)
+
+# Open and inspect the dataset.
+dataset = open_ome_zarr(data_path)
+
+# %% [markdown] tags=[]
+#
+#
+# ### Task 1.1
+# Look at a couple different fields of view (FOVs) by changing the `field` variable.
+# Check the cell density, the cell morphologies, and fluorescence signal.
+# HINT: look at the HCS Plate format to see what your options are.
+#
+# %% tags=[]
+# Use the field and pyramid_level below to visualize data.
+row = 0
+col = 0
+field = 9 # TODO: Change this to explore data.
+
+pyaramid_level = 0
+
+# `channel_names` is the metadata that is stored with data according to the OME-NGFF spec.
+n_channels = len(dataset.channel_names)
+
+image = dataset[f"{row}/{col}/{field}/{pyaramid_level}"].numpy()
+print(f"data shape: {image.shape}, FOV: {field}, pyramid level: {pyaramid_level}")
+
+figure, axes = plt.subplots(1, n_channels, figsize=(9, 3))
+
+for i in range(n_channels):
+ channel_image = image[0, i, 0]
+ # Adjust contrast to 0.5th and 99.5th percentile of pixel values.
+ p_low, p_high = np.percentile(channel_image, (0.5, 99.5))
+ channel_image = np.clip(channel_image, p_low, p_high)
+ axes[i].imshow(channel_image, cmap="gray")
+ axes[i].axis("off")
+ axes[i].set_title(dataset.channel_names[i])
+plt.tight_layout()
+
+# %% [markdown] tags=[]
+# ## Explore the effects of augmentation on batch.
+#
+# Time to meet the first of the three Lightning objects from the primer above: the
+# **DataModule**. `HCSDataModule` is VisCy's `LightningDataModule` — it knows how
+# to read an OME-Zarr store, split FOVs into train/val, apply normalization and
+# augmentations, and hand the Trainer a PyTorch `DataLoader`. You configure it
+# once; Lightning calls the right method (`train_dataloader()`,
+# `val_dataloader()`, etc.) at the right time.
+#
+# Every sample `HCSDataModule` yields is a Python `dict` (not a tuple) with:
+#
+# - `source`: the input image, a tensor of shape `(1, 1, Y, X)` → `(C, Z, Y, X)`
+# - `target`: the target image, a tensor of shape `(2, 1, Y, X)` → `(C, Z, Y, X)`
+# - `index` : the tuple `(HCS location, time, z-slice)` identifying the sample
+#
+# A `batch` is a dict of the same keys with an extra leading batch dimension, e.g.
+# `batch["source"].shape == (B, 1, 1, Y, X)`. The `training_step` method inside
+# `VSUNet` receives this dict directly — no unpacking required.
+
+# %% [markdown] tags=[]
+#
+#
+# ### Task 1.2
+# - Run the next cell to setup a logger for your augmentations.
+# - Setup the `HCSDataloader()` in for training.
+# - Configure the dataloader for the `"UNeXt2_2D"`
+# - Configure the dataloader for the phase (source) to fluorescence cell nuclei and membrane (targets) regression task.
+# - Configure the dataloader for training. Hint: use the `HCSDataloader.setup()`
+# - Open your tensorboard and look at the `IMAGES tab`.
+#
+# Note: If tensorboard is not showing images or the plots, try refreshing and using the "Images" tab.
+#
+
+
+# %%
+# Define a function to write a batch to tensorboard log.
+def log_batch_tensorboard(batch, batchno, writer, card_name):
+ """
+ Logs a batch of images to TensorBoard.
+
+ Args:
+ batch (dict): A dictionary containing the batch of images to be logged.
+ writer (SummaryWriter): A TensorBoard SummaryWriter object.
+ card_name (str): The name of the card to be displayed in TensorBoard.
+
+ Returns:
+ None
+ """
+ batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.
+ batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(1) # batch_size x 1 x Y x X tensor.
+ batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(1) # batch_size x 1 x Y x X tensor.
+
+ p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
+ batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)
+
+ p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
+ batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)
+
+ p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
+ batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)
+
+ [N, C, H, W] = batch_phase.shape
+ interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)
+ interleaved_images[0::3, :] = batch_phase
+ interleaved_images[1::3, :] = batch_nuclei
+ interleaved_images[2::3, :] = batch_membrane
+
+ grid = torchvision.utils.make_grid(interleaved_images, nrow=3)
+
+ # add the grid to tensorboard
+ writer.add_image(card_name, grid, batchno)
+
+
+# Define a function to visualize a batch on jupyter, in case tensorboard is finicky
+def log_batch_jupyter(batch):
+ """
+ Logs a batch of images on jupyter using ipywidget.
+
+ Args:
+ batch (dict): A dictionary containing the batch of images to be logged.
+
+ Returns:
+ None
+ """
+ batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.
+ batch_size = batch_phase.shape[0]
+ batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(1) # batch_size x 1 x Y x X tensor.
+ batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(1) # batch_size x 1 x Y x X tensor.
+
+ p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
+ batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)
+
+ p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
+ batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)
+
+ p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
+ batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)
+
+ n_channels = batch["target"].shape[1] + batch["source"].shape[1]
+ plt.figure()
+ fig, axes = plt.subplots(batch_size, n_channels, figsize=(n_channels * 2, batch_size * 2))
+ [N, C, H, W] = batch_phase.shape
+ for sample_id in range(batch_size):
+ axes[sample_id, 0].imshow(batch_phase[sample_id, 0])
+ axes[sample_id, 1].imshow(batch_nuclei[sample_id, 0])
+ axes[sample_id, 2].imshow(batch_membrane[sample_id, 0])
+
+ for i in range(n_channels):
+ axes[sample_id, i].axis("off")
+ axes[sample_id, i].set_title(dataset.channel_names[i])
+ plt.tight_layout()
+ plt.show()
+
+
+# %% tags=["task"]
+# Initialize the data module.
+
+BATCH_SIZE = 4
+
+# 4 is a perfectly reasonable batch size
+# (batch size does not have to be a power of 2)
+# See: https://sebastianraschka.com/blog/2022/batch-size-2.html
+
+# #######################
+# ##### TODO ########
+# #######################
+# HINT: Run dataset.channel_names
+source_channel = ["TODO"]
+target_channel = ["TODO", "TODO"]
+
+# #######################
+# ##### TODO ########
+# #######################
+data_module = HCSDataModule(
+ data_path,
+ z_window_size=1,
+ source_channel=source_channel,
+ target_channel=target_channel,
+ split_ratio=0.8,
+ batch_size=BATCH_SIZE,
+ num_workers=8,
+ yx_patch_size=(256, 256), # larger patch size makes it easy to see augmentations.
+ augmentations=[], # Turn off augmentation for now.
+ normalizations=[], # Turn off normalization for now.
+)
+# #######################
+# ##### TODO ########
+# #######################
+# Setup the data_module to fit. HINT: data_module.setup()
+
+
+# Evaluate the data module
+print(
+ f"Samples in training set: {len(data_module.train_dataset)}, "
+ f"samples in validation set:{len(data_module.val_dataset)}"
+)
+train_dataloader = data_module.train_dataloader()
+# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.
+writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
+# Draw a batch and write to tensorboard.
+batch = next(iter(train_dataloader))
+log_batch_tensorboard(batch, 0, writer, "augmentation/none")
+writer.close()
+# %% tags=["solution"]
+# #######################
+# ##### SOLUTION ########
+# #######################
+
+BATCH_SIZE = 4
+# 4 is a perfectly reasonable batch size
+# (batch size does not have to be a power of 2)
+# See: https://sebastianraschka.com/blog/2022/batch-size-2.html
+
+source_channel = ["Phase3D"]
+target_channel = ["Nucl", "Mem"]
+
+data_module = HCSDataModule(
+ data_path,
+ z_window_size=1,
+ source_channel=source_channel,
+ target_channel=target_channel,
+ split_ratio=0.8,
+ batch_size=BATCH_SIZE,
+ num_workers=8,
+ yx_patch_size=(256, 256), # larger patch size makes it easy to see augmentations.
+ augmentations=[], # Turn off augmentation for now.
+ normalizations=[], # Turn off normalization for now.
+)
+
+# Setup the data_module to fit. HINT: data_module.setup()
+data_module.setup("fit")
+
+# Evaluate the data module
+print(
+ f"Samples in training set: {len(data_module.train_dataset)}, "
+ f"samples in validation set:{len(data_module.val_dataset)}"
+)
+train_dataloader = data_module.train_dataloader()
+# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.
+writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
+# Draw a batch and write to tensorboard.
+batch = next(iter(train_dataloader))
+log_batch_tensorboard(batch, 0, writer, "augmentation/none")
+writer.close()
+# %% [markdown] tags=[]
+#
+#
+# ### Questions
+# 1. What are the two channels in the target image?
+# 2. How many samples are in the training and validation set? What determined that split?
+#
+# Note: If tensorboard is not showing images, try refreshing and using the "Images" tab.
+#
+
+# %% [markdown] tags=[]
+# If your tensorboard is causing issues, you can visualize directly on Jupyter /VSCode
+# %%
+# Visualize in Jupyter
+log_batch_jupyter(batch)
+
+# %% [markdown] tags=[]
+#
+#
Question for Task 1.3
+# 1. How do they make the model more robust to imaging parameters or conditions
+# without having to acquire data for every possible condition?
+#
+# %% [markdown] tags=[]
+# Each augmentation simulates a real-world source of microscope-to-microscope
+# variation so the model doesn't overfit to the training conditions:
+#
+# | Transform | Simulates |
+# | --- | --- |
+# | `RandWeightedCropd` | random crops biased toward signal-dense regions (foreground oversampling) |
+# | `RandAffined` | stage rotation, scale drift, slight shear between acquisitions |
+# | `RandAdjustContrastd` | illumination / exposure differences |
+# | `RandScaleIntensityd` | gain / brightness differences between cameras |
+# | `RandGaussianNoised` | shot and read noise at different detector settings |
+# | `RandGaussianSmoothd` | small focus drift / defocus |
+# %% [markdown] tags=[]
+#
+#
+# ### Task 1.3
+# Add the following augmentations:
+# - Add augmentations to rotate about $\pi$ around z-axis, 30% scale in (y,x),
+# shearing of 1% in (y,x), and no padding with zeros with a probablity of 80%.
+# - Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%.
+#
+# HINT: `RandAffined()` and `RandGaussianNoised()` are MONAI dictionary
+# transforms re-exported from `viscy_transforms`. See the MONAI docs for
+# arguments and probability semantics:
+# [RandAffined](https://docs.monai.io/en/stable/transforms.html#randaffined),
+# [RandGaussianNoised](https://docs.monai.io/en/stable/transforms.html#randgaussiannoised).
+# You can also inspect any transform in a cell with `RandAffined?`.
+# [Compare your choice of augmentations against the pretrained models and config files](https://github.com/mehta-lab/VisCy/releases/download/v0.1.0/VisCy-0.1.0-VS-models.zip).
+#
+# %% tags=["task"]
+# Here we turn on data augmentation and rerun setup
+# #######################
+# ##### TODO ########
+# #######################
+# HINT: Run dataset.channel_names
+source_channel = ["TODO"]
+target_channel = ["TODO", "TODO"]
+
+augmentations = [
+ RandWeightedCropd(
+ keys=source_channel + target_channel,
+ spatial_size=(1, 384, 384),
+ num_samples=2,
+ w_key=target_channel[0],
+ ),
+ # #######################
+ # ##### TODO ########
+ # #######################
+ ## TODO: Add Random Affine Transorms
+ ## Write code below
+ # #######################
+ RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)),
+ RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5),
+ # #######################
+ # ##### TODO ########
+ # #######################
+ ## TODO: Add Random Gaussian Noise
+ ## Write code below
+ # #######################
+ RandGaussianSmoothd(
+ keys=source_channel,
+ sigma_x=(0.25, 0.75),
+ sigma_y=(0.25, 0.75),
+ sigma_z=(0.0, 0.0),
+ prob=0.5,
+ ),
+]
+
+normalizations = [
+ NormalizeSampled(
+ keys=source_channel,
+ level="fov_statistics",
+ subtrahend="mean",
+ divisor="std",
+ ),
+ NormalizeSampled(
+ keys=target_channel,
+ level="fov_statistics",
+ subtrahend="median",
+ divisor="iqr",
+ ),
+]
+
+data_module.augmentations = augmentations
+data_module.normalizations = normalizations
+
+data_module.setup("fit")
+
+# get the new data loader with augmentation turned on
+augmented_train_dataloader = data_module.train_dataloader()
+
+# Draw batches and write to tensorboard
+writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
+augmented_batch = next(iter(augmented_train_dataloader))
+log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some")
+writer.close()
+
+# %% tags=["solution"]
+# #######################
+# ##### SOLUTION ########
+# #######################
+source_channel = ["Phase3D"]
+target_channel = ["Nucl", "Mem"]
+
+augmentations = [
+ RandWeightedCropd(
+ keys=source_channel + target_channel,
+ spatial_size=(1, 384, 384),
+ num_samples=2,
+ w_key=target_channel[0],
+ ),
+ RandAffined(
+ keys=source_channel + target_channel,
+ rotate_range=[3.14, 0.0, 0.0],
+ scale_range=[0.0, 0.3, 0.3],
+ prob=0.8,
+ padding_mode="zeros",
+ shear_range=[0.0, 0.01, 0.01],
+ ),
+ RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)),
+ RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5),
+ RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3),
+ RandGaussianSmoothd(
+ keys=source_channel,
+ sigma_x=(0.25, 0.75),
+ sigma_y=(0.25, 0.75),
+ sigma_z=(0.0, 0.0),
+ prob=0.5,
+ ),
+]
+
+normalizations = [
+ NormalizeSampled(
+ keys=source_channel,
+ level="fov_statistics",
+ subtrahend="mean",
+ divisor="std",
+ ),
+ NormalizeSampled(
+ keys=target_channel,
+ level="fov_statistics",
+ subtrahend="median",
+ divisor="iqr",
+ ),
+]
+
+data_module.augmentations = augmentations
+
+# Setup the data_module to fit. HINT: data_module.setup()
+data_module.setup("fit")
+
+# get the new data loader with augmentation turned on
+augmented_train_dataloader = data_module.train_dataloader()
+
+# Draw batches and write to tensorboard
+writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
+augmented_batch = next(iter(augmented_train_dataloader))
+log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some")
+writer.close()
+
+# %% [markdown] tags=[]
+#
+#
Question for Task 1.3
+# 1. Look at your tensorboard. Can you tell the agumentations were applied to the sample batch? Compare the batch with and without augmentations.
+# 2. Are these augmentations good enough? What else would you add?
+#
+
+# %% [markdown]
+# Visualize directly on Jupyter
+
+# %%
+log_batch_jupyter(augmented_batch)
+
+# %% [markdown] tags=[]
+# ## Train a 2D U-Net model to predict nuclei and membrane from phase.
+# ### Constructing a 2D UNeXt2 using VisCy
+#
+# Now we meet the second Lightning object: the **`LightningModule`**. `VSUNet` is
+# VisCy's `LightningModule` and it bundles three things that plain PyTorch keeps
+# separate:
+#
+# 1. **The network** — a UNeXt2 architecture, configured through `model_config`.
+# 2. **The loss** — passed in as `loss_function=MixedLoss(...)`.
+# 3. **The per-batch logic** — `training_step` and `validation_step` methods that
+# take one `{"source", "target"}` batch, run the forward pass, compute the
+# loss, and return it. You don't see these methods here because they're
+# defined once inside `VSUNet`; Lightning calls them for you.
+#
+# Other constructor arguments you'll recognize from plain PyTorch training:
+# `lr` is the learning rate, `schedule="WarmupCosine"` picks the LR schedule,
+# and `freeze_encoder=False` lets gradients flow through the whole network.
+# `log_batches_per_epoch` is a VisCy extra — it tells the module how many image
+# samples to push to TensorBoard each epoch.
+# %% [markdown]
+# **Architecture config** — UNeXt2 is a U-Net with ConvNeXt-style blocks:
+#
+# - `encoder_blocks=[3, 3, 9, 3]` and `dims=[96, 192, 384, 768]` — 4 downsampling
+# stages with that many blocks and feature channels per stage (last stage is
+# the bottleneck). More blocks / dims = more capacity and more compute.
+# - `decoder_conv_blocks=2` — conv blocks after each upsampling step.
+# - `stem_kernel_size=(1, 2, 2)` and `in_stack_depth=1` — this is a 2D model,
+# so we use 1 z-slice and a stem that doesn't convolve across z.
+#
+# **Loss** — `MixedLoss(l1_alpha=0.5, ms_dssim_alpha=0.5)` combines per-pixel
+# L1 (penalizes intensity error) with multi-scale SSIM (penalizes structural
+# error — edges, texture, shape). L1 alone produces blurry outputs; MS-SSIM
+# alone ignores absolute intensity. The 0.5/0.5 mix balances both.
+#
+# **Schedule** — `schedule="WarmupCosine"`, `lr=6e-4`: the learning rate ramps
+# up from 0 over the first few epochs (warmup), then follows a cosine decay
+# toward 0. Warmup avoids early gradient blow-up with AdamW; cosine decay is a
+# strong default for vision transformer / ConvNeXt-style encoders.
+
+# %% [markdown]
+#
+#
+# ### Task 1.4
+# - Run the next cell to instantiate the `UNeXt2_2D` model
+# - Configure the network for the phase (source) to fluorescence cell nuclei and membrane (targets) regression task.
+# - Call the VSUNet with the `"UNeXt2_2D"` architecture.
+# - Run the next cells to instantiate data module and trainer.
+# - Add the source channel name and the target channel names
+# - Start the training
+#
+# Note
+# See ``viscy.translation.engine.VSUNet`` ([source code](https://github.com/mehta-lab/VisCy/blob/main/viscy/translation/engine.py)) and ``viscy.unet.networks.fcmae`` ([source code](https://github.com/mehta-lab/VisCy/blob/main/viscy/unet/networks/fcmae.py)) to learn more about the configuration parameters and FCMAE architecture.
+#
+
+# %% tags=["task"]
+# Create a 2D UNet.
+GPU_ID = 0
+
+BATCH_SIZE = 16
+YX_PATCH_SIZE = (256, 256)
+
+# #######################
+# ##### TODO ########
+# #######################
+# Dictionary that specifies key parameters of the model.
+phase2fluor_config = dict(
+ in_channels=..., # TODO how many input channels are we feeding Hint: int?,
+ out_channels=..., # TODO how many output channels are we solving for? Hint: int,
+ encoder_blocks=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ decoder_conv_blocks=2,
+ stem_kernel_size=(1, 2, 2),
+ in_stack_depth=..., # TODO: was this a 2D or 3D input? HINT: int,
+ pretraining=False,
+)
+
+# #######################
+# ##### TODO ########
+# #######################
+phase2fluor_model = VSUNet(
+ architecture=..., # TODO: 2D UNeXt2 architecture
+ model_config=phase2fluor_config.copy(),
+ loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5),
+ schedule="WarmupCosine",
+ lr=6e-4,
+ log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard.
+ freeze_encoder=False,
+)
+
+# #######################
+# ##### TODO ########
+# #######################
+# HINT: Run dataset.channel_names
+source_channel = ["TODO"]
+target_channel = ["TODO", "TODO"]
+
+# Setup the data module.
+phase2fluor_2D_data = HCSDataModule(
+ data_path,
+ source_channel=source_channel,
+ target_channel=target_channel,
+ z_window_size=1,
+ split_ratio=0.8,
+ batch_size=BATCH_SIZE,
+ num_workers=8,
+ yx_patch_size=YX_PATCH_SIZE,
+ augmentations=augmentations,
+ normalizations=normalizations,
+)
+phase2fluor_2D_data.setup("fit")
+# fast_dev_run runs a single batch of data through the model to check for errors.
+trainer = VisCyTrainer(accelerator="gpu", devices=[GPU_ID], precision="16-mixed", fast_dev_run=True)
+
+# trainer class takes the model and the data module as inputs.
+trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)
+
+
+# %% tags=["solution"]
+
+# Here we are creating a 2D UNet.
+GPU_ID = 0
+
+BATCH_SIZE = 16
+YX_PATCH_SIZE = (256, 256)
+
+# Dictionary that specifies key parameters of the model.
+# #######################
+# ##### SOLUTION ########
+# #######################
+phase2fluor_config = dict(
+ in_channels=1,
+ out_channels=2,
+ encoder_blocks=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ decoder_conv_blocks=2,
+ stem_kernel_size=(1, 2, 2),
+ in_stack_depth=1,
+ pretraining=False,
+)
+
+phase2fluor_model = VSUNet(
+ architecture="UNeXt2_2D", # 2D UNeXt2 architecture
+ model_config=phase2fluor_config.copy(),
+ loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5),
+ schedule="WarmupCosine",
+ lr=6e-4,
+ log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard.
+ freeze_encoder=False,
+)
+
+# ### Instantiate data module and trainer, test that we are setup to launch training.
+# #######################
+# ##### SOLUTION ########
+# #######################
+# Selecting the source and target channel names from the dataset.
+source_channel = ["Phase3D"]
+target_channel = ["Nucl", "Mem"]
+# Setup the data module.
+phase2fluor_2D_data = HCSDataModule(
+ data_path,
+ source_channel=source_channel,
+ target_channel=target_channel,
+ z_window_size=1,
+ split_ratio=0.8,
+ batch_size=BATCH_SIZE,
+ num_workers=8,
+ yx_patch_size=YX_PATCH_SIZE,
+ augmentations=augmentations,
+ normalizations=normalizations,
+)
+# #######################
+# ##### SOLUTION ########
+# #######################
+phase2fluor_2D_data.setup("fit")
+
+# --- The third Lightning object: the Trainer ---
+#
+# This is the object that replaces the hand-written training loop. Each kwarg
+# controls one piece of the boilerplate Lightning is handling for you:
+#
+# - accelerator="gpu", devices=[GPU_ID]
+# Pick the device. No more ".to(device)" sprinkled through your code —
+# Lightning moves model + every batch for you.
+# - precision="16-mixed"
+# Automatic mixed-precision training (fp16 activations, fp32 master
+# weights). Cuts GPU memory roughly in half and speeds up matmuls on
+# modern GPUs — no autocast() context managers needed.
+# - fast_dev_run=True
+# Sanity check: run ONE training batch + ONE validation batch and exit.
+# Use this on every new pipeline to catch shape bugs, NaN losses, or
+# bad paths *before* you commit to a multi-hour training job.
+#
+# trainer.fit(model, datamodule=...) then drives the whole thing: it calls
+# datamodule.setup(), pulls batches from train_dataloader(), invokes
+# model.training_step(batch), runs loss.backward() + optimizer.step() +
+# zero_grad(), runs validation, logs to TensorBoard, and saves checkpoints.
+trainer = VisCyTrainer(accelerator="gpu", devices=[GPU_ID], precision="16-mixed", fast_dev_run=True)
+trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)
+
+# %% [markdown] tags=[]
+# ## View model graph.
+#
+# PyTorch uses dynamic graphs under the hood.
+# The graphs are constructed on the fly.
+# This is in contrast to TensorFlow,
+# where the graph is constructed before the training loop and remains static.
+# In other words, the graph of the network can change with every forward pass.
+# Therefore, we need to supply an input tensor to construct the graph.
+# The input tensor can be a random tensor of the correct shape and type.
+# We can also supply a real image from the dataset.
+# The latter is more useful for debugging.
+
+# %% [markdown]
+#
+#
+# ### Task 1.5
+# Run the next cell to generate a graph representation of the model architecture.
+#
+
+# %%
+# visualize graph of phase2fluor model as image.
+model_graph_phase2fluor = torchview.draw_graph(
+ phase2fluor_model,
+ phase2fluor_2D_data.train_dataset[0]["source"][0].unsqueeze(dim=0),
+ roll=True,
+ depth=3, # adjust depth to zoom in.
+ device="cpu",
+ # expand_nested=True,
+)
+# Print the image of the model.
+model_graph_phase2fluor.visual_graph
+
+# %% [markdown] tags=[]
+#
+#
+# ### Question:
+# Can you recognize the UNet structure and skip connections in this graph visualization?
+#
+
+# %% [markdown]
+#
+
+#
Task 1.6
+# Start training by running the following cell. Check the new logs on the tensorboard.
+#
+
+# %% [markdown]
+#
+# Before re-running training: if a previous training cell is still
+# holding the GPU (you'll see CUDA out of memory), restart the
+# Jupyter kernel (Kernel → Restart in Jupyter, or Restart in
+# VSCode) to release the previous model and optimizer state. The dataset and
+# augmentations will rebuild quickly; only the trained weights need to be
+# re-loaded via load_from_checkpoint if you want to resume.
+#
+
+# %% [markdown]
+# Now that `fast_dev_run` confirmed the pipeline works end-to-end, we switch
+# to a "real" Trainer configured for an actual multi-epoch run. New Lightning
+# knobs appearing here:
+#
+# - `max_epochs=n_epochs` — run this many passes over the training set, then stop.
+# - `log_every_n_steps=steps_per_epoch // 2` — how often Lightning flushes
+# scalars (loss, learning rate) to the logger. Setting it to half an epoch
+# gives us two data points per epoch without spamming TensorBoard.
+# - `logger=TensorBoardLogger(save_dir=log_dir, name="phase2fluor", log_graph=True)`
+# — Lightning writes TensorBoard event files *and* model checkpoints under
+# `{save_dir}/{name}/version_N/`. You don't call `torch.save` yourself; the
+# trainer persists checkpoints automatically, and `log_graph=True` adds the
+# network architecture to the Graphs tab.
+#
+# Calling `trainer.fit` again below runs the full training loop — forward,
+# loss, backward, optimizer step, validation every epoch, checkpoint at the
+# end — across `max_epochs` epochs.
+
+# %%
+# Check if GPU is available
+# You can check by typing `nvidia-smi`
+GPU_ID = 0
+
+n_samples = len(phase2fluor_2D_data.train_dataset)
+steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.
+n_epochs = 80 # Set this to 80-100 or the number of epochs you want to train for.
+
+trainer = VisCyTrainer(
+ accelerator="gpu",
+ devices=[GPU_ID],
+ max_epochs=n_epochs,
+ precision="16-mixed",
+ log_every_n_steps=steps_per_epoch // 2,
+ # log losses and image samples 2 times per epoch.
+ logger=TensorBoardLogger(
+ save_dir=log_dir,
+ # lightning trainer transparently saves logs and model checkpoints in this directory.
+ name="phase2fluor",
+ log_graph=True,
+ ),
+)
+# Launch training and check that loss and images are being logged on tensorboard.
+trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)
+
+# Move the model to the GPU.
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+phase2fluor_model.to(device)
+# %% [markdown] tags=[]
+#
+
+#
Checkpoint 1
+
+# While your model is training, let's think about the following questions:
+#
+# - What is the information content of each channel in the dataset?
+# - How would you use image translation models?
+# - What can you try to improve the performance of each model?
+#
+
+# Now the training has started,
+# we can come back after a while and evaluate the performance!
+
+#
+# %% [markdown] tags=[]
+# # Part 2: Assess your trained model
+#
+# We evaluate on a held-out test set using two complementary families of metrics:
+#
+# - **Regression / pixel-level** (Pearson, microSSIM): are predicted
+# intensities close to ground truth, per pixel? Cheap, but can hide
+# topological errors — a model that merges two nuclei may still score well
+# pixel-wise.
+# - **Segmentation / instance-level** (Jaccard/IoU, Dice, mAP over IoU
+# thresholds): run Cellpose on both predicted and measured fluorescence,
+# then compare instance masks. This is what ultimately matters for
+# downstream analysis (counting cells, tracking, phenotyping).
+#
+# Also inspect the validation samples on TensorBoard — the experimental
+# nuclei channel is noisy, so "ground truth" is itself imperfect.
+
+# %% [markdown]
+#
+
+#
Task 2.1 Define metrics
+
+# For each of the above metrics, write a brief definition of what they are and what they mean
+# for this image translation task. Use your favorite search engine and/or resources.
+
+#
+
+# %% [markdown] tags=[]
+# ```
+# #######################
+# ##### Solution ########
+# #######################
+# ```
+#
+# - **Pearson Correlation**: linear correlation between predicted and target
+# intensities across all pixels, in `[-1, 1]`. `1` means the prediction is a
+# perfect affine rescaling of the target; invariant to mean / contrast
+# offsets. Good at flagging "the pattern is right" but blind to structural
+# errors that preserve correlation (e.g. a uniformly blurred prediction).
+#
+# - **microSSIM**: a microscopy-aware variant of
+# [Structural Similarity (SSIM)](https://en.wikipedia.org/wiki/Structural_similarity).
+# Classic SSIM patch-wise compares local mean, variance, and covariance and
+# captures structure Pearson misses (blurring, contrast loss) — but it
+# assumes the natural-image dynamic range. Fluorescence microscopy images
+# are sparse, dim, and noisy: with the default SSIM parameters the scores
+# collapse into a narrow band that barely separates good and bad
+# predictions. [microSSIM](https://github.com/juglab/MicroSSIM)
+# ([Ashesh et al., 2024](https://arxiv.org/abs/2408.08747)) fixes this by
+# subtracting the image background and fitting a per-image rescaling factor
+# before computing SSIM, so the metric becomes sensitive over the range of
+# intensities microscopy predictions actually live in. We use it as a
+# drop-in replacement for `skimage.metrics.structural_similarity`.
+
+# %% [markdown] tags=[]
+# ### Let's compute metrics directly and plot below.
+# %% [markdown] tags=[]
+#
+# If you weren't able to train or training didn't complete please run the following lines to load the latest checkpoint
+#
+# ```python
+# phase2fluor_model_ckpt = natsorted(glob(
+# str(top_dir / "06_image_translation/logs/phase2fluor/version*/checkpoints/*.ckpt")
+# ))[-1]
+# ```
+#
+# NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything.
+# Run the following:
+#
+# ```python
+# phase2fluor_model_ckpt = natsorted(glob(
+# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt")
+# ))[-1]
+# ```
+
+# ```python
+# phase2fluor_config = dict(
+# in_channels=1,
+# out_channels=2,
+# encoder_blocks=[3, 3, 9, 3],
+# dims=[96, 192, 384, 768],
+# decoder_conv_blocks=2,
+# stem_kernel_size=(1, 2, 2),
+# in_stack_depth=1,
+# pretraining=False,
+# )
+# Load the model checkpoint
+# phase2fluor_model = VSUNet.load_from_checkpoint(
+# phase2fluor_model_ckpt,
+# architecture="UNeXt2_2D",
+# model_config = phase2fluor_config,
+# accelerator='gpu'
+# )
+# ````
+#
+# %%
+# Setup the test data module.
+test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr"
+source_channel = ["Phase3D"]
+target_channel = ["Nucl", "Mem"]
+
+test_data = HCSDataModule(
+ test_data_path,
+ source_channel=source_channel,
+ target_channel=target_channel,
+ z_window_size=1,
+ batch_size=1,
+ num_workers=8,
+)
+test_data.setup("test")
+
+test_metrics = pd.DataFrame(columns=["pearson_nuc", "microSSIM_nuc", "pearson_mem", "microSSIM_mem"])
+
+
+# %%
+# Compute metrics directly and plot here.
+def normalize_fov(input: ArrayLike):
+ "Normalizing the fov with zero mean and unit variance"
+ mean = np.mean(input)
+ std = np.std(input)
+ return (input - mean) / std
+
+
+for i, sample in enumerate(tqdm(test_data.test_dataloader(), desc="Computing metrics per sample")):
+ phase_image = sample["source"].to(phase2fluor_model.device)
+ with torch.inference_mode(): # turn off gradient computation.
+ predicted_image = phase2fluor_model(phase_image)
+
+ target_image = sample["target"].cpu().numpy().squeeze(0) # Squeezing batch dimension.
+ predicted_image = predicted_image.cpu().numpy().squeeze(0)
+ phase_image = phase_image.cpu().numpy().squeeze(0)
+ target_mem = normalize_fov(target_image[1, 0, :, :])
+ target_nuc = normalize_fov(target_image[0, 0, :, :])
+ # slicing channel dimension, squeezing z-dimension.
+ predicted_mem = normalize_fov(predicted_image[1, :, :, :].squeeze(0))
+ predicted_nuc = normalize_fov(predicted_image[0, :, :, :].squeeze(0))
+
+ # Compute microSSIM and pearson correlation.
+ ssim_nuc = micro_structural_similarity(target_nuc, predicted_nuc)
+ ssim_mem = micro_structural_similarity(target_mem, predicted_mem)
+ pearson_nuc = np.corrcoef(target_nuc.flatten(), predicted_nuc.flatten())[0, 1]
+ pearson_mem = np.corrcoef(target_mem.flatten(), predicted_mem.flatten())[0, 1]
+
+ test_metrics.loc[i] = {
+ "pearson_nuc": pearson_nuc,
+ "microSSIM_nuc": ssim_nuc,
+ "pearson_mem": pearson_mem,
+ "microSSIM_mem": ssim_mem,
+ }
+
+# Plot the following metrics
+test_metrics.boxplot(
+ column=["pearson_nuc", "microSSIM_nuc", "pearson_mem", "microSSIM_mem"],
+ rot=30,
+)
+
+
+# %%
+# Adjust the image to the 0.5-99.5 percentile range.
+def process_image(image):
+ p_low, p_high = np.percentile(image, (0.5, 99.5))
+ return np.clip(image, p_low, p_high)
+
+
+# Plot the predicted image vs target image.
+channel_titles = [
+ "Phase",
+ "Target Nuclei",
+ "Target Membrane",
+ "Predicted Nuclei",
+ "Predicted Membrane",
+]
+fig, axes = plt.subplots(5, 1, figsize=(20, 20))
+
+# Get a writer to output the images into tensorboard and plot the source, predictions and target images
+for i, sample in enumerate(test_data.test_dataloader()):
+ # Plot the phase image
+ phase_image = sample["source"]
+ channel_image = phase_image[0, 0, 0]
+ p_low, p_high = np.percentile(channel_image, (0.5, 99.5))
+ channel_image = np.clip(channel_image, p_low, p_high)
+ axes[0].imshow(channel_image, cmap="gray")
+ axes[0].axis("off")
+ axes[0].set_title(channel_titles[0])
+
+ with torch.inference_mode(): # turn off gradient computation.
+ predicted_image = phase2fluor_model(phase_image.to(phase2fluor_model.device)).cpu().numpy().squeeze(0)
+
+ target_image = sample["target"].cpu().numpy().squeeze(0)
+ phase_raw = process_image(phase_image[0, 0, 0])
+ predicted_nuclei = process_image(predicted_image[0, 0])
+ predicted_membrane = process_image(predicted_image[1, 0])
+ target_nuclei = process_image(target_image[0, 0])
+ target_membrane = process_image(target_image[1, 0])
+ # Concatenate all images side by side
+ combined_image = np.concatenate(
+ (
+ phase_raw,
+ predicted_nuclei,
+ predicted_membrane,
+ target_nuclei,
+ target_membrane,
+ ),
+ axis=1,
+ )
+
+ # Plot the phase,target nuclei, target membrane, predicted nuclei, predicted membrane
+ axes[1].imshow(target_nuclei, cmap="gray")
+ axes[2].imshow(target_membrane, cmap="gray")
+ axes[3].imshow(predicted_nuclei, cmap="gray")
+ axes[4].imshow(predicted_membrane, cmap="gray")
+
+ for ax in axes:
+ ax.axis("off")
+ plt.tight_layout()
+ plt.show()
+ break
+# %% [markdown] tags=[]
+#
+
+#
Task 2.2 Loading the pretrained model VSCyto2D
+# Here we will compare your model with the VSCyto2D pretrained model by computing the pixel-based metrics and segmentation-based metrics.
+#
+#
+# - The pretrained checkpoint was downloaded by
setup.sh to
+# ~/data/06_image_translation/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt
+# — if missing, download it directly from
+# public.czbiohub.org.
+# Check with ls ~/data/06_image_translation/pretrained_models/VSCyto2D/.
+# - Load the VSCyto2D model checkpoint and the configuration file
+# - Compute the pixel-based metrics and segmentation-based metrics between the model you trained and the pretrained model
+#
+#
+
+#
+
+
+# %% tags=["task"]
+#################
+##### TODO ######
+#################
+# Let's load the pretrained model checkpoint
+pretrained_model_ckpt = top_dir / ... ## Add the path to the "VSCyto2D/epoch=399-step=23200.ckpt"
+
+# TODO: Load the phase2fluor_config just like the model you trained
+phase2fluor_config = dict() ##
+
+# TODO: Load the checkpoint. Write the architecture name. HINT: look at the previous config.
+pretrained_phase2fluor = VSUNet.load_from_checkpoint(
+ pretrained_model_ckpt,
+ architecture=...,
+ model_config=phase2fluor_config,
+ accelerator="gpu",
+)
+# TODO: Setup the dataloader in evaluation/predict mode
+#
+
+# %% tags=["solution"]
+# #######################
+# ##### SOLUTION ########
+# #######################
+
+pretrained_model_ckpt = top_dir / "06_image_translation/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt"
+
+phase2fluor_config = dict(
+ in_channels=1,
+ out_channels=2,
+ encoder_blocks=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ decoder_conv_blocks=2,
+ stem_kernel_size=(1, 2, 2),
+ in_stack_depth=1,
+ pretraining=False,
+)
+# Load the model checkpoint
+pretrained_phase2fluor = VSUNet.load_from_checkpoint(
+ pretrained_model_ckpt,
+ architecture="UNeXt2_2D",
+ model_config=phase2fluor_config,
+ accelerator="gpu",
+)
+pretrained_phase2fluor.eval()
+
+### Re-load your trained model
+# NOTE: assuming the latest checkpoint it your latest training and model
+phase2fluor_model_ckpt = natsorted(
+ glob(str(training_top_dir / "06_image_translation/logs/phase2fluor/version*/checkpoints/*.ckpt"))
+)[-1]
+
+# NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything.
+# Uncomment the next lines
+# phase2fluor_model_ckpt = natsorted(glob(
+# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt")
+# ))[-1]
+
+phase2fluor_config = dict(
+ in_channels=1,
+ out_channels=2,
+ encoder_blocks=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ decoder_conv_blocks=2,
+ stem_kernel_size=(1, 2, 2),
+ in_stack_depth=1,
+ pretraining=False,
+)
+# Load the model checkpoint
+phase2fluor_model = VSUNet.load_from_checkpoint(
+ phase2fluor_model_ckpt,
+ architecture="UNeXt2_2D",
+ model_config=phase2fluor_config,
+ accelerator="gpu",
+)
+phase2fluor_model.eval()
+# %% [markdown] tags=[]
+#
+#
Question
+# 1. Can we evaluate a model's performance based on their segmentations?
+# 2. Look up IoU or Jaccard index, dice coefficient, and AP metrics. LINK:https://metrics-reloaded.dkfz.de/metric-library
+# We will evaluate the performance of your trained model with a pre-trained model using pixel based metrics as above and
+# segmantation based metrics including (mAP@0.5, dice, accuracy and jaccard index).
+#
+# %% [markdown] tags=["solution"]
+#
+# - IoU (Intersection over Union): Also referred to as the Jaccard index, is essentially a method to quantify the percent overlap between the target and predicted masks.
+# It is calculated as the intersection of the target and predicted masks divided by the union of the target and predicted masks.
+# - Dice Coefficient: Metric used to evaluate the similarity between two sets.
+# It is calculated as twice the intersection of the target and predicted masks divided by the sum of the target and predicted masks.
+# - mAP (mean Average Precision): The mean Average Precision (mAP) is a metric used to evaluate the performance of object detection models.
+# It is calculated as the average precision across all classes and is used to measure the accuracy of the model in localizing objects.
+#
+# %% [markdown] tags=[]
+# ### Let's compute the metrics for the test dataset
+# Before you run the following code, make sure you have the pretrained model loaded and the test data is ready.
+
+# The following code will compute the following:
+# - the pixel-based metrics (pearson correlation, SSIM)
+# - segmentation-based metrics (mAP@0.5, dice, accuracy, jaccard index)
+
+
+# #### Note:
+# - The segmentation-based metrics are computed using the cellpose stock `nuclei` model
+# - The metrics will be store in the `test_pixel_metrics` and `test_segmentation_metrics` dataframes
+# - The segmentations will be stored in the `segmentation_store` zarr file
+# - Analyze the code while it runs.
+# %%
+# Create cellpose model once for reuse
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+cellpose_model = models.CellposeModel(gpu=True if device.type == "cuda" else False, device=device)
+
+
+# Define the function to compute the cellpose segmentation
+def cellpose_segmentation(prediction: ArrayLike, target: ArrayLike) -> Tuple[torch.ShortTensor]:
+ # NOTE these are hardcoded for this notebook and A549 dataset
+
+ # Convert 2D arrays to 3D format expected by cellpose v4.0.1+
+ # Add channel dimension and replicate to 3 channels (RGB format)
+ if prediction.ndim == 2:
+ prediction = np.tile(prediction, (3, 1, 1)) # Shape: (3, H, W)
+ if target.ndim == 2:
+ target = np.tile(target, (3, 1, 1)) # Shape: (3, H, W)
+
+ cp_nuc_kwargs = {
+ "diameter": 65,
+ "cellprob_threshold": 0.0,
+ }
+
+ pred_label, _, _ = cellpose_model.eval(prediction, **cp_nuc_kwargs)
+ target_label, _, _ = cellpose_model.eval(target, **cp_nuc_kwargs)
+
+ pred_label = pred_label.astype(np.int32)
+ target_label = target_label.astype(np.int32)
+ pred_label = torch.ShortTensor(pred_label)
+ target_label = torch.ShortTensor(target_label)
+
+ return (pred_label, target_label)
+
+
+# %%
+# Setting the paths for the test data and the output segmentation
+test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr"
+output_segmentation_path = training_top_dir / "06_image_translation/pretrained_model_segmentations.zarr"
+
+# Creating the dataframes to store the pixel and segmentation metrics
+test_pixel_metrics = pd.DataFrame(
+ columns=["model", "fov", "pearson_nuc", "microSSIM_nuc", "pearson_mem", "microSSIM_mem"]
+)
+test_segmentation_metrics = pd.DataFrame(
+ columns=[
+ "model",
+ "fov",
+ "masks_per_fov",
+ "accuracy",
+ "dice",
+ "jaccard",
+ "mAP",
+ "mAP_50",
+ "mAP_75",
+ "mAR_100",
+ ]
+)
+# Opening the test dataset
+test_dataset = open_ome_zarr(test_data_path)
+
+# Creating an output store for the predictions and segmentations
+segmentation_store = open_ome_zarr(
+ output_segmentation_path,
+ channel_names=["nuc_pred", "mem_pred", "nuc_labels"],
+ mode="w",
+ layout="hcs",
+)
+
+# Looking at the test dataset
+print("Test dataset:")
+test_dataset.print_tree()
+channel_names = test_dataset.channel_names
+print(f"Channel names: {channel_names}")
+
+# Finding the channel indices for the corresponding channel names
+phase_cidx = channel_names.index("Phase3D")
+nuc_cidx = channel_names.index("Nucl")
+mem_cidx = channel_names.index("Mem")
+nuc_label_cidx = channel_names.index("nuclei_segmentation")
+
+
+# %%
+def min_max_scale(image: ArrayLike) -> ArrayLike:
+ "Normalizing the image using min-max scaling"
+ min_val = image.min()
+ max_val = image.max()
+ return (image - min_val) / (max_val - min_val)
+
+
+# %% [markdown]
+# ## Visualize segmentation comparison: Fluorescence vs Virtual Staining vs Pretrained
+# Let's compare nucleus and membrane segmentation across all three models
+
+# %%
+# Get a sample FOV for visualization
+positions = list(test_dataset.positions())
+sample_fov, sample_pos = positions[0] # Use first FOV as example
+
+T, C, Z, Y, X = sample_pos.data.shape
+Z_slice = slice(Z // 2, Z // 2 + 1)
+
+# Get the data
+sample_phase = sample_pos.data[:, phase_cidx : phase_cidx + 1, Z_slice]
+sample_nucleus = sample_pos.data[0, nuc_cidx : nuc_cidx + 1, Z_slice]
+sample_membrane = sample_pos.data[0, mem_cidx : mem_cidx + 1, Z_slice]
+
+# Crop 300x300 pixels from center
+center_y, center_x = sample_nucleus.shape[2] // 2, sample_nucleus.shape[3] // 2
+crop_size = 300
+y_start = max(0, center_y - crop_size // 2)
+y_end = min(sample_nucleus.shape[2], center_y + crop_size // 2)
+x_start = max(0, center_x - crop_size // 2)
+x_end = min(sample_nucleus.shape[3], center_x + crop_size // 2)
+
+# Crop fluorescence data
+sample_nucleus_crop = min_max_scale(sample_nucleus[0, 0, y_start:y_end, x_start:x_end])
+sample_membrane_crop = min_max_scale(sample_membrane[0, 0, y_start:y_end, x_start:x_end])
+
+# Generate virtual stained data from phase (trained model)
+sample_phase_tensor = torch.tensor(sample_phase, dtype=torch.float32).to(device)
+with torch.inference_mode():
+ predicted_image = phase2fluor_model(sample_phase_tensor)
+predicted_nuc_crop = min_max_scale(predicted_image.cpu().numpy()[0, 0, 0, y_start:y_end, x_start:x_end])
+predicted_mem_crop = min_max_scale(predicted_image.cpu().numpy()[0, 1, 0, y_start:y_end, x_start:x_end])
+
+# Generate virtual stained data from pretrained model
+with torch.inference_mode():
+ predicted_image_pretrained = pretrained_phase2fluor(sample_phase_tensor)
+predicted_nuc_pretrained_crop = min_max_scale(
+ predicted_image_pretrained.cpu().numpy()[0, 0, 0, y_start:y_end, x_start:x_end]
+)
+predicted_mem_pretrained_crop = min_max_scale(
+ predicted_image_pretrained.cpu().numpy()[0, 1, 0, y_start:y_end, x_start:x_end]
+)
+
+# Run segmentation on all nuclei
+fluor_nuc_seg, _ = cellpose_segmentation(sample_nucleus_crop, sample_nucleus_crop)
+virtual_nuc_seg, _ = cellpose_segmentation(predicted_nuc_crop, predicted_nuc_crop)
+pretrained_nuc_seg, _ = cellpose_segmentation(predicted_nuc_pretrained_crop, predicted_nuc_pretrained_crop)
+
+# Run segmentation on all membranes (using nucleus parameters for consistency)
+fluor_mem_seg, _ = cellpose_segmentation(sample_membrane_crop, sample_membrane_crop)
+virtual_mem_seg, _ = cellpose_segmentation(predicted_mem_crop, predicted_mem_crop)
+pretrained_mem_seg, _ = cellpose_segmentation(predicted_mem_pretrained_crop, predicted_mem_pretrained_crop)
+
+# Convert to numpy
+fluor_nuc_seg = fluor_nuc_seg.numpy()
+virtual_nuc_seg = virtual_nuc_seg.numpy()
+pretrained_nuc_seg = pretrained_nuc_seg.numpy()
+fluor_mem_seg = fluor_mem_seg.numpy()
+virtual_mem_seg = virtual_mem_seg.numpy()
+pretrained_mem_seg = pretrained_mem_seg.numpy()
+
+# Create 3x4 visualization
+fig, axes = plt.subplots(3, 4, figsize=(16, 12))
+
+# Row 1: Fluorescence data
+axes[0, 0].imshow(sample_nucleus_crop, cmap="gray")
+axes[0, 0].set_title("Fluorescence Nucleus")
+axes[0, 0].axis("off")
+
+fluor_nuc_overlay = label2rgb(fluor_nuc_seg, sample_nucleus_crop, bg_label=0)
+axes[0, 1].imshow(fluor_nuc_overlay)
+axes[0, 1].set_title("Nucleus Segmentation")
+axes[0, 1].axis("off")
+
+axes[0, 2].imshow(sample_membrane_crop, cmap="gray")
+axes[0, 2].set_title("Fluorescence Membrane")
+axes[0, 2].axis("off")
+
+fluor_mem_overlay = label2rgb(fluor_mem_seg, sample_membrane_crop, bg_label=0)
+axes[0, 3].imshow(fluor_mem_overlay)
+axes[0, 3].set_title("Membrane Segmentation")
+axes[0, 3].axis("off")
+
+# Row 2: Virtual stained data (trained)
+axes[1, 0].imshow(predicted_nuc_crop, cmap="gray")
+axes[1, 0].set_title("Virtual Nucleus (Trained)")
+axes[1, 0].axis("off")
+
+virtual_nuc_overlay = label2rgb(virtual_nuc_seg, predicted_nuc_crop, bg_label=0)
+axes[1, 1].imshow(virtual_nuc_overlay)
+axes[1, 1].set_title("Nucleus Segmentation")
+axes[1, 1].axis("off")
+
+axes[1, 2].imshow(predicted_mem_crop, cmap="gray")
+axes[1, 2].set_title("Virtual Membrane (Trained)")
+axes[1, 2].axis("off")
+
+virtual_mem_overlay = label2rgb(virtual_mem_seg, predicted_mem_crop, bg_label=0)
+axes[1, 3].imshow(virtual_mem_overlay)
+axes[1, 3].set_title("Membrane Segmentation")
+axes[1, 3].axis("off")
+
+# Row 3: Virtual stained data (pretrained)
+axes[2, 0].imshow(predicted_nuc_pretrained_crop, cmap="gray")
+axes[2, 0].set_title("Virtual Nucleus (Pretrained)")
+axes[2, 0].axis("off")
+
+pretrained_nuc_overlay = label2rgb(pretrained_nuc_seg, predicted_nuc_pretrained_crop, bg_label=0)
+axes[2, 1].imshow(pretrained_nuc_overlay)
+axes[2, 1].set_title("Nucleus Segmentation")
+axes[2, 1].axis("off")
+
+axes[2, 2].imshow(predicted_mem_pretrained_crop, cmap="gray")
+axes[2, 2].set_title("Virtual Membrane (Pretrained)")
+axes[2, 2].axis("off")
+
+pretrained_mem_overlay = label2rgb(pretrained_mem_seg, predicted_mem_pretrained_crop, bg_label=0)
+axes[2, 3].imshow(pretrained_mem_overlay)
+axes[2, 3].set_title("Membrane Segmentation")
+axes[2, 3].axis("off")
+
+plt.suptitle(f"Complete Segmentation Comparison - FOV: {sample_fov}", fontsize=16)
+plt.tight_layout()
+plt.show()
+
+print("Nucleus segmentation counts:")
+print(f" Fluorescence: {len(np.unique(fluor_nuc_seg)) - 1} nuclei")
+print(f" Virtual (trained): {len(np.unique(virtual_nuc_seg)) - 1} nuclei")
+print(f" Virtual (pretrained): {len(np.unique(pretrained_nuc_seg)) - 1} nuclei")
+
+print("\nMembrane segmentation counts:")
+print(f" Fluorescence: {len(np.unique(fluor_mem_seg)) - 1} objects")
+print(f" Virtual (trained): {len(np.unique(virtual_mem_seg)) - 1} objects")
+print(f" Virtual (pretrained): {len(np.unique(pretrained_mem_seg)) - 1} objects")
+
+# %% [markdown]
+# Now let's compute metrics across all FOVs
+
+# %%
+# Iterating through the test dataset positions to:
+total_positions = len(positions)
+
+# Initializing the progress bar with the total number of positions
+with tqdm(total=total_positions, desc="Processing FOVs") as pbar:
+ # Iterating through the test dataset positions
+ for fov, pos in positions:
+ T, C, Z, Y, X = pos.data.shape
+ Z_slice = slice(Z // 2, Z // 2 + 1)
+ # Getting the arrays and the center slices
+ phase_image = pos.data[:, phase_cidx : phase_cidx + 1, Z_slice]
+ target_nucleus = pos.data[0, nuc_cidx : nuc_cidx + 1, Z_slice]
+ target_membrane = pos.data[0, mem_cidx : mem_cidx + 1, Z_slice]
+ target_nuc_label = pos.data[0, nuc_label_cidx : nuc_label_cidx + 1, Z_slice]
+
+ # normalize the phase
+ phase_image = normalize_fov(phase_image)
+
+ # Running the prediction for both models
+ phase_image = torch.from_numpy(phase_image).type(torch.float32)
+ phase_image = phase_image.to(phase2fluor_model.device)
+ with torch.inference_mode(): # turn off gradient computation.
+ predicted_image_phase2fluor = phase2fluor_model(phase_image)
+ predicted_image_pretrained = pretrained_phase2fluor(phase_image)
+
+ # Loading and Normalizing the target and predictions for both models
+ predicted_image_phase2fluor = predicted_image_phase2fluor.cpu().numpy().squeeze(0)
+ predicted_image_pretrained = predicted_image_pretrained.cpu().numpy().squeeze(0)
+ phase_image = phase_image.cpu().numpy().squeeze(0)
+
+ target_mem = min_max_scale(target_membrane[0, 0])
+ target_nuc = min_max_scale(target_nucleus[0, 0])
+
+ # Normalizing the dataset using min-max scaling
+ predicted_mem_phase2fluor = min_max_scale(predicted_image_phase2fluor[1, :, :, :].squeeze(0))
+ predicted_nuc_phase2fluor = min_max_scale(predicted_image_phase2fluor[0, :, :, :].squeeze(0))
+
+ predicted_mem_pretrained = min_max_scale(predicted_image_pretrained[1, :, :, :].squeeze(0))
+ predicted_nuc_pretrained = min_max_scale(predicted_image_pretrained[0, :, :, :].squeeze(0))
+
+ ####### Pixel-based Metrics ############
+ # Compute microSSIM and Pearson correlation for phase2fluor_model
+ pbar.set_description(f"Processing FOV {fov} - Computing Pixel Metrics")
+ pbar.refresh()
+ ssim_nuc_phase2fluor = micro_structural_similarity(target_nuc, predicted_nuc_phase2fluor)
+ ssim_mem_phase2fluor = micro_structural_similarity(target_mem, predicted_mem_phase2fluor)
+ pearson_nuc_phase2fluor = np.corrcoef(target_nuc.flatten(), predicted_nuc_phase2fluor.flatten())[0, 1]
+ pearson_mem_phase2fluor = np.corrcoef(target_mem.flatten(), predicted_mem_phase2fluor.flatten())[0, 1]
+
+ test_pixel_metrics.loc[len(test_pixel_metrics)] = {
+ "model": "phase2fluor",
+ "fov": fov,
+ "pearson_nuc": pearson_nuc_phase2fluor,
+ "microSSIM_nuc": ssim_nuc_phase2fluor,
+ "pearson_mem": pearson_mem_phase2fluor,
+ "microSSIM_mem": ssim_mem_phase2fluor,
+ }
+ # Compute microSSIM and Pearson correlation for pretrained_model
+ ssim_nuc_pretrained = micro_structural_similarity(target_nuc, predicted_nuc_pretrained)
+ ssim_mem_pretrained = micro_structural_similarity(target_mem, predicted_mem_pretrained)
+ pearson_nuc_pretrained = np.corrcoef(target_nuc.flatten(), predicted_nuc_pretrained.flatten())[0, 1]
+ pearson_mem_pretrained = np.corrcoef(target_mem.flatten(), predicted_mem_pretrained.flatten())[0, 1]
+
+ test_pixel_metrics.loc[len(test_pixel_metrics)] = {
+ "model": "pretrained_phase2fluor",
+ "fov": fov,
+ "pearson_nuc": pearson_nuc_pretrained,
+ "microSSIM_nuc": ssim_nuc_pretrained,
+ "pearson_mem": pearson_mem_pretrained,
+ "microSSIM_mem": ssim_mem_pretrained,
+ }
+
+ ###### Segmentation based metrics #########
+ # Load the manually curated nuclei target label
+ pbar.set_description(f"Processing FOV {fov} - Computing Segmentation Metrics")
+ pbar.refresh()
+ pred_label, target_label = cellpose_segmentation(predicted_nuc_phase2fluor, target_nucleus)
+ # Binary labels
+ pred_label_binary = pred_label > 0
+ target_label_binary = target_label > 0
+
+ # Use Coco metrics to get mean average precision
+ coco_metrics = mean_average_precision(pred_label, target_label)
+ # Find unique number of labels
+ num_masks_fov = len(np.unique(pred_label))
+
+ test_segmentation_metrics.loc[len(test_segmentation_metrics)] = {
+ "model": "phase2fluor",
+ "fov": fov,
+ "masks_per_fov": num_masks_fov,
+ "accuracy": accuracy(pred_label_binary, target_label_binary, task="binary").item(),
+ "dice": dice_score(
+ pred_label_binary.long()[None],
+ target_label_binary.long()[None],
+ num_classes=2,
+ input_format="index",
+ average="micro",
+ ).item(),
+ "jaccard": jaccard_index(pred_label_binary, target_label_binary, task="binary").item(),
+ "mAP": coco_metrics["map"].item(),
+ "mAP_50": coco_metrics["map_50"].item(),
+ "mAP_75": coco_metrics["map_75"].item(),
+ "mAR_100": coco_metrics["mar_100"].item(),
+ }
+
+ pred_label, target_label = cellpose_segmentation(predicted_nuc_pretrained, target_nucleus)
+
+ # Binary labels
+ pred_label_binary = pred_label > 0
+ target_label_binary = target_label > 0
+
+ # Use Coco metrics to get mean average precision
+ coco_metrics = mean_average_precision(pred_label, target_label)
+ # Find unique number of labels
+ num_masks_fov = len(np.unique(pred_label))
+
+ test_segmentation_metrics.loc[len(test_segmentation_metrics)] = {
+ "model": "phase2fluor_pretrained",
+ "fov": fov,
+ "masks_per_fov": num_masks_fov,
+ "accuracy": accuracy(pred_label_binary, target_label_binary, task="binary").item(),
+ "dice": dice_score(
+ pred_label_binary.long()[None],
+ target_label_binary.long()[None],
+ num_classes=2,
+ input_format="index",
+ average="micro",
+ ).item(),
+ "jaccard": jaccard_index(pred_label_binary, target_label_binary, task="binary").item(),
+ "mAP": coco_metrics["map"].item(),
+ "mAP_50": coco_metrics["map_50"].item(),
+ "mAP_75": coco_metrics["map_75"].item(),
+ "mAR_100": coco_metrics["mar_100"].item(),
+ }
+
+ # Save the predictions and segmentations
+ position = segmentation_store.create_position(*Path(fov).parts[-3:])
+ output_array = np.zeros((T, 3, 1, Y, X), dtype=np.float32)
+ output_array[0, 0, 0] = predicted_nuc_pretrained
+ output_array[0, 1, 0] = predicted_mem_pretrained
+ output_array[0, 2, 0] = np.array(pred_label)
+ position.create_image("0", output_array)
+
+ # Update the progress bar
+ pbar.set_description("Processing FOVs")
+ pbar.update(1)
+
+# Close the OME-Zarr files
+test_dataset.close()
+segmentation_store.close()
+# %%
+# Save the test metrics into a dataframe
+pixel_metrics_path = training_top_dir / "06_image_translation/VS_metrics_pixel.csv"
+segmentation_metrics_path = training_top_dir / "06_image_translation/VS_metrics_segments.csv"
+test_pixel_metrics.to_csv(pixel_metrics_path)
+test_segmentation_metrics.to_csv(segmentation_metrics_path)
+
+# %% [markdown] tags=[]
+#
+
+#
Task 2.3 Compare the model's metrics
+# In the previous section, we computed the pixel-based metrics and segmentation-based metrics.
+# Now we will compare the performance of the model you trained with the pretrained model by plotting the boxplots.
+
+# After you plot the metrics answer the following:
+#
+# - What do these metrics tells us about the performance of the model?
+# - How do you interpret the differences in the metrics between the models?
+# - How is your model compared to the pretrained model? How can you improve it?
+#
+#
+
+# %%
+# Show boxplot of the metrics
+# Boxplot of the metrics
+test_pixel_metrics.boxplot(
+ by="model",
+ column=["pearson_nuc", "microSSIM_nuc", "pearson_mem", "microSSIM_mem"],
+ rot=30,
+ figsize=(8, 8),
+)
+plt.suptitle("Model Pixel Metrics")
+plt.show()
+# Show boxplot of the metrics
+# Boxplot of the metrics
+test_segmentation_metrics.boxplot(
+ by="model",
+ column=["jaccard", "accuracy", "mAP_75", "mAP_50"],
+ rot=30,
+ figsize=(8, 8),
+)
+plt.suptitle("Model Segmentation Metrics")
+plt.show()
+
+# %% [markdown] tags=["task"]
+#
+#
Questions
+#
+# - What do these metrics tells us about the performance of the model?
+# - How do you interpret the differences in the metrics between the models?
+# - How is your model compared to the pretrained model? How can you improve it?
+#
+#
+
+# %% [markdown]
+# ### Plotting the predictions and segmentations
+#
+#
+#
Task 2.4: Visualize the predictions and segmentations
+# Here we will plot the predictions and segmentations side by side for the pretrained and trained models.
+#
+# - How does your model, the pretrained model and the ground truth compare?
+# - How do the segmentations compare?
+#
+# Feel free to modify the crop size and Y,X slicing to view different areas of the FOV
+#
+# %% tags=["task"]
+
+# Get the shape of the 2D image
+Y, X = phase_image.shape[-2:]
+######## TODO ##########
+# Modify the crop size and Y,X slicing to view different areas of the FOV
+
+crop = 256
+y_slice = slice(Y // 2 - crop // 2, Y // 2 + crop // 2)
+x_slice = slice(X // 2 - crop // 2, X // 2 + crop // 2)
+#######################
+# Plotting side by side comparisons
+fig, axs = plt.subplots(4, 3, figsize=(15, 20))
+
+# First row: phase_image, target_nuc, target_mem
+axs[0, 0].imshow(phase_image[0, 0, y_slice, x_slice], cmap="gray")
+axs[0, 0].set_title("Phase Image")
+axs[0, 1].imshow(target_nuc[y_slice, x_slice], cmap="gray")
+axs[0, 1].set_title("Target Nucleus")
+axs[0, 2].imshow(target_mem[y_slice, x_slice], cmap="gray")
+axs[0, 2].set_title("Target Membrane")
+
+# Second row: target_nuc, pred_nuc_phase2fluor, pred_nuc_pretrained
+axs[1, 0].imshow(target_nuc[y_slice, x_slice], cmap="gray")
+axs[1, 0].set_title("Target Nucleus")
+axs[1, 1].imshow(predicted_nuc_phase2fluor[y_slice, x_slice], cmap="gray")
+axs[1, 1].set_title("Pred Nucleus Phase2Fluor")
+axs[1, 2].imshow(predicted_nuc_pretrained[y_slice, x_slice], cmap="gray")
+axs[1, 2].set_title("Pred Nucleus Pretrained")
+
+# Third row: target_mem, pred_mem_phase2fluor, pred_mem_pretrained
+axs[2, 0].imshow(target_mem[y_slice, x_slice], cmap="gray")
+axs[2, 0].set_title("Target Membrane")
+axs[2, 1].imshow(predicted_mem_phase2fluor[y_slice, x_slice], cmap="gray")
+axs[2, 1].set_title("Pred Membrane Phase2Fluor")
+axs[2, 2].imshow(predicted_mem_pretrained[y_slice, x_slice], cmap="gray")
+axs[2, 2].set_title("Pred Membrane Pretrained")
+
+# Fourth row: target_nuc, segment_nuc, segment_nuc2
+axs[3, 0].imshow(target_nuc[y_slice, x_slice], cmap="gray")
+axs[3, 0].set_title("Target Nucleus")
+axs[3, 1].imshow(label2rgb(np.array(target_label[y_slice, x_slice], dtype="int")), cmap="gray")
+axs[3, 1].set_title("Segmented Nucleus (Target)")
+axs[3, 2].imshow(label2rgb(np.array(pred_label[y_slice, x_slice], dtype="int")), cmap="gray")
+axs[3, 2].set_title("Segmented Nucleus")
+
+# Hide axes ticks
+for ax in axs.flat:
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+plt.tight_layout()
+plt.show()
+
+
+# %% [markdown] tags=[]
+#
+
+#
Checkpoint 2
+#
+# Congratulations! You have completed the second checkpoint. You have:
+# - Visualized the predictions and segmentations of the model.
+# - Evaluated the performance of the model using pixel-based metrics and segmentation-based metrics.
+# - Compared the performance of the model you trained with the pretrained model.
+#
+#
+
+# %% [markdown] tags=[]
+#
+#
+# ### Task 2.5: Evaluate a fluorescence to phase model
+# In this section, we will explore the inverse transformation using fluorescence images
+# (nuclei + membrane) to predict the phase image.
+#
+#
Learning Goals:
+#
+# - Understand the concept of fluorescence to phase transformations in image translation
+# - Load a pretrained model for the reverse task (fluor → phase)
+# - Compare input fluorescence channels with predicted phase
+# - Analyze why the phase prediction is not perfect
+#
+# We'll use a pretrained model that was trained to predict phase from fluorescence channels.
+#
+
+# %% [markdown] tags=[]
+#
+#
+#
Questions
+#
+# - How much information is lost in the phase to fluorescence transformation?
+# - Why might perfect reconstruction not be possible?
+# - Can multiple phase patterns produce similar fluorescence signals?
+#
+#
+
+# %%
+# Path to the pretrained fluorescence to phase model checkpoint
+fluor2phase_model_path = top_dir / "06_image_translation/pretrained_models/AIMBL_Demo/fluor2phase_step668.ckpt"
+
+
+# %% tags=["task"]
+# Load a pretrained model for fluorescence to phase translation
+from pathlib import Path
+
+import torch
+
+# #######################
+# ##### TODO ########
+# #######################
+# TODO: Load the pretrained fluorescence to phase model
+# HINT: Look for pretrained models in the VisCy repository or use a model checkpoint
+# HINT: The model should take 2 input channels (nuclei + membrane) and output 1 channel (phase)
+# HINT: Use similar architecture as before but with different input/output channels
+
+# For now, we'll create a placeholder - replace with actual model loading
+print("Loading pretrained fluorescence-to-phase model...")
+
+# TODO: Replace this with actual model loading code
+fluor2phase_config = dict(
+ in_channels=..., # Nuclei + Membrane channels
+ out_channels=..., # Phase channel
+ encoder_blocks=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ decoder_conv_blocks=2,
+ stem_kernel_size=(1, 2, 2),
+ in_stack_depth=1,
+)
+fluor2phase_model = VSUNet.load_from_checkpoint(
+ fluor2phase_model_path, model_config=fluor2phase_config, architecture="fcmae"
+)
+assert fluor2phase_model is not None, (
+ "Fluorescence to phase model not loaded. Check the model config,and the path to the model checkpoint."
+)
+fluor2phase_model.eval()
+
+# %% tags=["task"]
+# Test the fluorescence to phase model on our test data
+
+source_channel_fluor = ["TODO", "TODO"]
+target_channel_labelfree = ["TODO"]
+
+test_data_fluor2phase = HCSDataModule(
+ test_data_path,
+ source_channel=source_channel_fluor,
+ target_channel=target_channel_labelfree,
+ z_window_size=1,
+ batch_size=1,
+ num_workers=8,
+)
+test_data_fluor2phase.setup("test")
+
+
+# Get a test sample
+sample = next(iter(test_data_fluor2phase.test_dataloader()))
+
+# #######################
+# ##### TODO ########
+# #######################
+# TODO: Extract the input channels (fluorescence) and target (phase)
+# HINT: Print the keys of the `sample` dictionary
+# HINT: Input should be nuclei and membrane channels concatenated
+# HINT: Target should be the original phase image
+
+fluor_input = ... # TODO: Source
+target_phase = ... # TODO: Target
+
+# TODO: Make prediction with the fluorescence to phase model
+# NOTE: The `fluor2phase_model`, returns a tuple. Select the first item with `[0]`
+with torch.inference_mode():
+ predicted_phase = ...
+
+# #######################
+# ##### TODO ########
+# #######################
+# Calculate metrics between predicted and target phase
+# HINT: Use SSIM and Pearson correlation as before
+
+# TODO: Normalize data range to 0-1
+###### YOUR CODE HERE ######
+
+# TODO: Calculate SSIM and Pearson correlation
+###### YOUR CODE HERE ######
+
+# TODO: Print metrics
+print("Phase Reconstruction Metrics:")
+print(f"SSIM: {ssim_phase:.3f}")
+print(f"Pearson Correlation: {pearson_phase:.3f}")
+
+
+# %% tags=["solution"]
+# Load a pretrained model for fluorescence to phase translation
+from pathlib import Path
+
+import torch
+
+# Load the pretrained fluorescence to phase model
+print("Loading pretrained fluorescence-to-phase model...")
+
+# Note: This assumes a pretrained model is available. In practice, you would:
+# 1. Download from VisCy releases or train your own
+# 2. Adjust the path accordingly
+
+# For demonstration, we'll create a model with the correct architecture
+fluor2phase_config = dict(
+ in_channels=2, # Nuclei + Membrane channels
+ out_channels=1, # Phase channel
+ encoder_blocks=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ decoder_conv_blocks=2,
+ stem_kernel_size=(1, 2, 2),
+ in_stack_depth=1,
+)
+
+# Create the fluorescence to phase model architecture
+print("Fluorescence-to-phase model created (note: using untrained model for demonstration)")
+print("In practice, load a pretrained checkpoint for meaningful results")
+
+print("\nLoading pretrained fluorescence-to-phase model...")
+fluor2phase_model_path = top_dir / "06_image_translation/pretrained_models/AIMBL_Demo/fluor2phase_step668.ckpt"
+assert fluor2phase_model_path.exists(), "Fluorescence-to-phase model checkpoint not found. Please check the path."
+fluor2phase_model = VSUNet.load_from_checkpoint(
+ fluor2phase_model_path, model_config=fluor2phase_config, architecture="fcmae"
+)
+fluor2phase_model.eval()
+
+# %% tags=["solution"]
+# Test the fluorescence to phase model on our test data
+
+source_channel_fluor = ["Nucl", "Mem"]
+target_channel_labelfree = ["Phase3D"]
+
+test_data_fluor2phase = HCSDataModule(
+ test_data_path,
+ source_channel=source_channel_fluor,
+ target_channel=target_channel_labelfree,
+ z_window_size=1,
+ batch_size=1,
+ num_workers=8,
+)
+test_data_fluor2phase.setup("test")
+
+# Get a test sample
+sample = next(iter(test_data_fluor2phase.test_dataloader()))
+
+# Extract input channels (fluorescence nuclei and membrane) and target (phase)
+fluor_input = sample["source"].to(fluor2phase_model.device)
+target_image = sample["target"].cpu().numpy().squeeze(0)
+
+# Run inference
+with torch.inference_mode():
+ predicted_phase = fluor2phase_model(fluor_input)[0]
+
+fluor_input = fluor_input.cpu().numpy()
+predicted_image = predicted_phase.cpu().numpy().squeeze(0)
+target_phase = rescale_intensity(target_image[0, 0], out_range=(0, 1))
+predicted_phase = rescale_intensity(predicted_image[0, 0], out_range=(0, 1))
+ssim_phase = metrics.structural_similarity(target_phase, predicted_phase, data_range=1)
+pearson_phase = np.corrcoef(target_phase.flatten(), predicted_phase.flatten())[0, 1]
+
+print("Phase Reconstruction Metrics:")
+print(f"SSIM: {ssim_phase:.3f}")
+print(f"Pearson Correlation: {pearson_phase:.3f}")
+
+# %%
+# Visualize the fluorescence to phase transformation results
+# TODO: Visualize the fluorescence to phase transformation results. Modify is as you see fit.
+
+fig, axs = plt.subplots(2, 3, figsize=(15, 10))
+
+axs[0, 0].imshow(fluor_input[0, 0, 0], cmap="gray")
+axs[0, 0].set_title("Input: Nuclei Channel")
+axs[0, 1].imshow(fluor_input[0, 1, 0], cmap="gray")
+axs[0, 1].set_title("Input: Membrane Channel")
+axs[0, 2].imshow(fluor_input[0, 0, 0] + fluor_input[0, 1, 0], cmap="gray")
+axs[0, 2].set_title("Combined Fluorescence\n(Nuclei + Membrane)")
+
+axs[1, 0].imshow(target_phase, cmap="gray")
+axs[1, 0].set_title("Target Phase Image")
+axs[1, 1].imshow(predicted_phase, cmap="gray")
+axs[1, 1].set_title(f"Predicted Phase\nSSIM: {ssim_phase:.3f}")
+axs[1, 2].imshow(np.abs(target_phase - predicted_phase), cmap="magma")
+axs[1, 2].set_title("Absolute Difference\n|Target - Predicted|")
+
+for ax in axs.flat:
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+plt.tight_layout()
+plt.show()
+
+# %% [markdown] tags=[]
+#
+#
Analysis Questions: Why is Phase Reconstruction Imperfect?
+#
+# Looking at your results, consider these questions:
+#
+#
+# - Does the fluorescence image contain all the information needed to reconstruct the phase?
+# - What structures are visible in phase but not in fluorescence channels?
+# - Which has higher information content: phase or fluorescence images?
+# - What does the reconstruction error map tell you about what's difficult to predict?
+#
+#
+
+# %% [markdown] tags=[]
+#
+#
Key Insights from Fluorescence to Phase Model
+#
+# This exploration reveals fundamental limitations in image-to-image translation:
+#
+# - Phase images contain rich structural information about unlabeled cellular components
+# - Fluorescence only captures specific labeled structures (nuclei, membranes,etc.)
+# - The fluorescence to phase model is an ill-posed problem - multiple phase images could produce similar fluorescence patterns
+# - Models can only predict based on correlations learned during training
+# - Structural details not correlated with fluorescence signals cannot be recovered
+#
+#
+# #### Now, let's return to the `phase2fluor` model!
+#
+#
+
+# %% [markdown] tags=[]
+#
+#
Bonus: Test Time Augmentation (TTA)
+#
+# Test Time Augmentation is a technique where you apply multiple augmentations to a single test image,
+# make predictions on each augmented version, and then combine the results (usually by averaging).
+#
+# **In this section we will:**
+#
+# - Use `Rotate90d` and `Flipd` for deterministic transformations
+# - Apply transforms, make predictions, then apply inverse transforms
+# - Average all predictions to get the final TTA result that is more robust to geometric variations.
+#
+#
+# Reference: N.Moshkov (2020) https://www.nature.com/articles/s41598-020-61808-3
+#
+# Hint: You can use the `Rotate90` and `Flip` transforms from MONAI.
+# Example forward transform: `Rotate90(k=1, spatial_axes=(-1, -2))`
+# Example inverse transform: `Rotate90(k=3, spatial_axes=(-1, -2))`
+#
+#
+
+# %% tags=["task"]
+from monai.transforms import (
+ Flip,
+ Rotate90,
+)
+
+# Get a test sample
+sample = next(iter(test_data.test_dataloader()))
+source_tensor = sample["source"].to(phase2fluor_model.device)
+target_tensor = sample["target"]
+target_nuc = target_tensor[0, 0].cpu().numpy()
+target_mem = target_tensor[0, 1].cpu().numpy()
+
+# Saving the single prediction without TTA for later comparison
+with torch.inference_mode():
+ single_pred = phase2fluor_model(source_tensor)
+ single_pred_nuc = single_pred[0, 0].cpu().numpy()
+ single_pred_mem = single_pred[0, 1].cpu().numpy()
+
+# TODO: Define TTA transforms using MONAI as a list of tuples (forward, inverse)
+###### YOUR CODE HERE ######
+transform_list = [("TODO", "TODO")]
+
+# TODO: Apply test-time augmentation
+# 1. Get original prediction (no augmentation)
+# 2. For each transform:
+# - Apply transform to input
+# - Run inference
+# - De-apply transform to prediction
+# 3. Average all predictions
+
+predictions = []
+
+for forward_transform, inverse_transform in transform_list:
+ # Apply transform to each sample in batch
+ augmented_batch = []
+ for i in range(source_tensor.shape[0]):
+ # Apply the forward and store them
+ ###### YOUR CODE HERE ######
+ aug_img = ...
+ augmented_batch.append(aug_img)
+ augmented_source = torch.stack(augmented_batch).to(source_tensor.device)
+
+ # TODO: Run inference on augmented input
+ with torch.inference_mode():
+ ###### YOUR CODE HERE ######
+ augmented_pred = ...
+
+ # TODO: De-apply transform to prediction
+ deaugmented_batch = []
+ for i in range(augmented_pred.shape[0]):
+ ###### YOUR CODE HERE ######
+ deaug_pred = ...
+ deaugmented_pred = torch.stack(deaugmented_batch)
+
+ predictions.append(deaugmented_pred.cpu().numpy())
+
+# TODO: Average all predictions or take the median
+###### YOUR CODE HERE ######
+averaged_pred = ...
+
+# TODO: Extract nucleus and membrane predictions
+###### YOUR CODE HERE ######
+tta_pred_nuc = ...
+tta_pred_mem = ...
+
+# %% tags=["task"]
+# TODO: Compare TTA results with single prediction
+# Calculate metrics (SSIM, Pearson correlation) for both approaches. Do not forget to normalize the data range to 0-1.
+
+# TODO Normalize data range to 0-1
+###### YOUR CODE HERE ######
+
+# TODO Calculate metrics
+###### YOUR CODE HERE ######
+
+# TODO # TTA prediction metrics
+###### YOUR CODE HERE ######
+
+# Print comparison
+print("\nMetrics Comparison:")
+print(f"{'Metric':<20} {'Single':<10} {'TTA':<10} {'Improvement':<12}")
+print("-" * 55)
+print(f"{'SSIM Nucleus':<20} {ssim_nuc_single:.3f} {ssim_nuc_tta:.3f} {ssim_nuc_tta - ssim_nuc_single:+.3f}")
+print(f"{'SSIM Membrane':<20} {ssim_mem_single:.3f} {ssim_mem_tta:.3f} {ssim_mem_tta - ssim_mem_single:+.3f}")
+print(
+ f"{'Pearson Nucleus':<20} {pearson_nuc_single:.3f} {pearson_nuc_tta:.3f} {pearson_nuc_tta - pearson_nuc_single:+.3f}"
+)
+print(
+ f"{'Pearson Membrane':<20} {pearson_mem_single:.3f} {pearson_mem_tta:.3f} {pearson_mem_tta - pearson_mem_single:+.3f}"
+)
+
+# %% tags=["solution"]
+
+# Normalize data range to 0-1
+target_nuc[0] = rescale_intensity(target_nuc[0], in_range="image", out_range=(0, 1))
+single_pred_nuc[0] = rescale_intensity(single_pred_nuc[0], in_range="image", out_range=(0, 1))
+target_mem[0] = rescale_intensity(target_mem[0], in_range="image", out_range=(0, 1))
+single_pred_mem[0] = rescale_intensity(single_pred_mem[0], in_range="image", out_range=(0, 1))
+target_nuc[0] = rescale_intensity(target_nuc[0], in_range="image", out_range=(0, 1))
+tta_pred_nuc[0] = rescale_intensity(tta_pred_nuc[0], in_range="image", out_range=(0, 1))
+tta_pred_mem[0] = rescale_intensity(tta_pred_mem[0], in_range="image", out_range=(0, 1))
+tta_pred_nuc[0] = rescale_intensity(tta_pred_nuc[0], in_range="image", out_range=(0, 1))
+
+# Calculate metrics
+ssim_nuc_single = metrics.structural_similarity(target_nuc[0], single_pred_nuc[0], data_range=1)
+ssim_mem_single = metrics.structural_similarity(target_mem[0], single_pred_mem[0], data_range=1)
+pearson_nuc_single = np.corrcoef(target_nuc[0].flatten(), single_pred_nuc[0].flatten())[0, 1]
+pearson_mem_single = np.corrcoef(target_mem[0].flatten(), single_pred_mem[0].flatten())[0, 1]
+
+# TTA prediction metrics
+ssim_nuc_tta = metrics.structural_similarity(target_nuc[0], tta_pred_nuc[0], data_range=1)
+ssim_mem_tta = metrics.structural_similarity(target_mem[0], tta_pred_mem[0], data_range=1)
+pearson_nuc_tta = np.corrcoef(target_nuc[0].flatten(), tta_pred_nuc[0].flatten())[0, 1]
+pearson_mem_tta = np.corrcoef(target_mem[0].flatten(), tta_pred_mem[0].flatten())[0, 1]
+
+# Print comparison
+print("\nMetrics Comparison:")
+print(f"{'Metric':<20} {'Single':<10} {'TTA':<10} {'Improvement':<12}")
+print("-" * 55)
+print(f"{'SSIM Nucleus':<20} {ssim_nuc_single:.3f} {ssim_nuc_tta:.3f} {ssim_nuc_tta - ssim_nuc_single:+.3f}")
+print(f"{'SSIM Membrane':<20} {ssim_mem_single:.3f} {ssim_mem_tta:.3f} {ssim_mem_tta - ssim_mem_single:+.3f}")
+print(
+ f"{'Pearson Nucleus':<20} {pearson_nuc_single:.3f} {pearson_nuc_tta:.3f} {pearson_nuc_tta - pearson_nuc_single:+.3f}"
+)
+print(
+ f"{'Pearson Membrane':<20} {pearson_mem_single:.3f} {pearson_mem_tta:.3f} {pearson_mem_tta - pearson_mem_single:+.3f}"
+)
+
+# %%
+# TODO: Modify as you see fit to compute the metrics on the full FOV.
+# Visualize the comparison
+# Modify as you see fit to visualize the results
+
+fig, axs = plt.subplots(3, 3, figsize=(15, 15))
+
+# First row: Input phase and targets
+axs[0, 0].imshow(source_tensor[0, 0, 0].cpu().numpy(), cmap="gray")
+axs[0, 0].set_title("Input Phase")
+axs[0, 1].imshow(target_nuc[0], cmap="gray")
+axs[0, 1].set_title("Target Nucleus")
+axs[0, 2].imshow(target_mem[0], cmap="gray")
+axs[0, 2].set_title("Target Membrane")
+
+# Second row: Single predictions
+axs[1, 0].imshow(source_tensor[0, 0, 0].cpu().numpy(), cmap="gray")
+axs[1, 0].set_title("Input Phase")
+axs[1, 1].imshow(single_pred_nuc[0], cmap="gray")
+axs[1, 1].set_title(f"Single Pred Nucleus\nSSIM: {ssim_nuc_single:.3f}")
+axs[1, 2].imshow(single_pred_mem[0], cmap="gray")
+axs[1, 2].set_title(f"Single Pred Membrane\nSSIM: {ssim_mem_single:.3f}")
+
+# Third row: TTA predictions
+axs[2, 0].imshow(source_tensor[0, 0, 0].cpu().numpy(), cmap="gray")
+axs[2, 0].set_title("Input Phase")
+axs[2, 1].imshow(tta_pred_nuc[0], cmap="gray")
+axs[2, 1].set_title(f"TTA Pred Nucleus\nSSIM: {ssim_nuc_tta:.3f}")
+axs[2, 2].imshow(tta_pred_mem[0], cmap="gray")
+axs[2, 2].set_title(f"TTA Pred Membrane\nSSIM: {ssim_mem_tta:.3f}")
+
+# Remove ticks
+for ax in axs.flat:
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+plt.tight_layout()
+plt.show()
+
+# %% tags=["solution"]
+# Import additional MONAI transforms for TTA
+
+# Get a test sample
+sample = next(iter(test_data.test_dataloader()))
+source_tensor = sample["source"].to(phase2fluor_model.device)
+target_tensor = sample["target"]
+target_nuc = target_tensor[0, 0].cpu().numpy()
+target_mem = target_tensor[0, 1].cpu().numpy()
+
+predictions = []
+
+# Original prediction without augmentation
+with torch.inference_mode():
+ original_pred = phase2fluor_model(source_tensor)
+ predictions.append(original_pred.cpu().numpy())
+
+# Define the TTA transforms and the inverse transforms as a list of tuples (forward, inverse)
+transform_list = [
+ (Rotate90(k=1, spatial_axes=(-1, -2)), Rotate90(k=3, spatial_axes=(-1, -2))),
+ (Rotate90(k=2, spatial_axes=(-1, -2)), Rotate90(k=2, spatial_axes=(-1, -2))),
+ (Rotate90(k=3, spatial_axes=(-1, -2)), Rotate90(k=1, spatial_axes=(-1, -2))),
+ (Flip(spatial_axis=-2), Flip(spatial_axis=-2)),
+ (Flip(spatial_axis=-1), Flip(spatial_axis=-1)),
+]
+
+for forward_transform, inverse_transform in transform_list:
+ # Apply transform to each sample in batch
+ augmented_batch = []
+ for i in range(source_tensor.shape[0]):
+ img = source_tensor[i].cpu().numpy()
+ aug_img = forward_transform(img)
+ augmented_batch.append(aug_img)
+ augmented_source = torch.stack(augmented_batch).to(source_tensor.device)
+
+ # Run inference on augmented input
+ with torch.inference_mode():
+ augmented_pred = phase2fluor_model(augmented_source)
+
+ # De-apply transform to prediction
+ deaugmented_batch = []
+ for i in range(augmented_pred.shape[0]):
+ pred = augmented_pred[i].cpu().numpy()
+ deaug_pred = inverse_transform(pred)
+ deaugmented_batch.append(deaug_pred)
+ deaugmented_pred = torch.stack(deaugmented_batch)
+
+ predictions.append(deaugmented_pred.cpu().numpy())
+
+# Average all predictions
+averaged_pred = np.stack(predictions).mean(axis=0)
+
+# Extract nucleus and membrane predictions
+tta_pred_nuc = averaged_pred[0, 0]
+tta_pred_mem = averaged_pred[0, 1]
+
+# Compare with single prediction (no TTA)
+with torch.inference_mode():
+ single_pred = phase2fluor_model(source_tensor)
+ single_pred_nuc = single_pred[0, 0].cpu().numpy()
+ single_pred_mem = single_pred[0, 1].cpu().numpy()
+
+
+# %% [markdown] tags=[]
+#
+#
+#
Discussion Questions for Test Time Augmentation
+#
+#
+# - Did TTA improve the metrics? By how much?
+# - What are the trade-offs of using TTA? (hint: think about computation time vs. accuracy)
+# - When would TTA be most beneficial in fluorescence microscopy?
+# - How could you modify the TTA strategy to be more effective for this specific virtual staining task?
+# - What other MONAI transforms could be useful for TTA in this context? (e.g., slight rotations, scaling)
+# - Is there any hallucinations that are removed with TTA?
+#
+#
+
+# %% [markdown] tags=[]
+#
+#
Bonus Section Complete!
+#
+# You have successfully implemented Test Time Augmentation using MONAI transforms!
+#
+# Key takeaways:
+#
+# - TTA is particularly useful when prediction quality is critical and computational budget allows
+# - Multiple geometric augmentations can reduce prediction variance and improve robustness
+# - TTA leverages deterministic transforms (`Rotate90d`, `Flipd`) instead of random ones
+# - The computational cost increases linearly with the number of TTA transforms
+#
+#
+
+# %% [markdown] tags=[]
+# # Part 3: Visualizing the encoder and decoder features & exploring the model's range of validity
+#
+# - In this section, we will visualize the encoder and decoder features of the model you trained.
+# - We will also explore the model's range of validity by looking at the feature maps of the encoder and decoder.
+#
+# %% [markdown] tags=[]
+#
+#
Task 3.1: Let's look at what the model is learning
+#
+# - If you are unfamiliar with Principal Component Analysis (PCA), you can read up
here
+# - Run the next cells. We will visualize the encoder feature maps of the trained model.
+# We will use PCA to visualize the feature maps by mapping the first 3 principal components to a colormap `Color`
+#
+#
+#
+
+# %%
+"""
+Script to visualize the encoder feature maps of a trained model.
+Using PCA to visualize feature maps is inspired by
+https://doi.org/10.48550/arXiv.2304.07193 (Oquab et al., 2023).
+"""
+from typing import NamedTuple # noqa: E402
+
+from monai.networks.layers import GaussianFilter # noqa: E402
+from skimage.exposure import rescale_intensity # noqa: E402
+from sklearn.decomposition import PCA # noqa: E402
+
+
+def feature_map_pca(feature_map: np.array, n_components: int = 8) -> PCA:
+ """
+ Compute PCA on a feature map.
+ :param np.array feature_map: (C, H, W) feature map
+ :param int n_components: number of components to keep
+ :return: PCA: fit sklearn PCA object
+ """
+ # (C, H, W) -> (C, H*W)
+ feat = feature_map.reshape(feature_map.shape[0], -1)
+ pca = PCA(n_components=n_components)
+ pca.fit(feat)
+ return pca
+
+
+def pcs_to_rgb(feat: np.ndarray, n_components: int = 8) -> np.ndarray:
+ pca = feature_map_pca(feat[0], n_components=n_components)
+ pc_first_3 = pca.components_[:3].reshape(3, *feat.shape[-2:])
+ return np.stack([rescale_intensity(pc, out_range=(0, 1)) for pc in pc_first_3], axis=-1)
+
+
+# %%
+# Load the test dataset
+test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr"
+test_dataset = open_ome_zarr(test_data_path)
+
+# Looking at the test dataset
+print("Test dataset:")
+test_dataset.print_tree()
+
+# %% [markdown] tags=[]
+#
+#
+# - Change the `fov` and `crop` size to visualize the feature maps of the encoder and decoder
+# Note: the crop should be a multiple of 384
+#
+# %%
+# Load one position
+row = 0
+col = 0
+center_index = 2
+n = 1
+crop = 384 * n
+fov = 10
+
+# normalize phase
+norm_meta = test_dataset.zattrs["normalization"]["Phase3D"]["dataset_statistics"]
+
+# Get the OME-Zarr metadata
+Y, X = test_dataset[f"0/0/{fov}"].data.shape[-2:]
+test_dataset.channel_names
+phase_idx = test_dataset.channel_names.index("Phase3D")
+assert crop // 2 < Y and crop // 2 < Y, "Crop size larger than the image. Check the image shape"
+
+phase_img = test_dataset[f"0/0/{fov}/0"][
+ :,
+ phase_idx : phase_idx + 1,
+ 0:1,
+ Y // 2 - crop // 2 : Y // 2 + crop // 2,
+ X // 2 - crop // 2 : X // 2 + crop // 2,
+]
+fluo = test_dataset[f"0/0/{fov}/0"][
+ 0,
+ 1:3,
+ 0,
+ Y // 2 - crop // 2 : Y // 2 + crop // 2,
+ X // 2 - crop // 2 : X // 2 + crop // 2,
+]
+
+phase_img = (phase_img - norm_meta["median"]) / norm_meta["iqr"]
+plt.imshow(phase_img[0, 0, 0], cmap="gray")
+
+# %% [markdown] tags=[]
+#
+# For the following tasks we will use the pretrained model to extract the encoder and decoder features
+# Extra: If you are done with the whole checkpoint, you can try to look at what your trained model learned.
+#
+# %%
+
+# Loading the pretrained model
+pretrained_model_ckpt = top_dir / "06_image_translation/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt"
+# model config as before
+phase2fluor_config = dict(
+ in_channels=1,
+ out_channels=2,
+ encoder_blocks=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ decoder_conv_blocks=2,
+ stem_kernel_size=(1, 2, 2),
+ in_stack_depth=1,
+ pretraining=False,
+)
+
+# load model
+model = VSUNet.load_from_checkpoint(
+ pretrained_model_ckpt,
+ architecture="UNeXt2_2D",
+ model_config=phase2fluor_config.copy(),
+ accelerator="gpu",
+)
+
+# %% tags=[]
+# Extract features
+with torch.inference_mode():
+ # encoder
+ encoder_features = model.model.encoder(torch.from_numpy(phase_img.astype(np.float32)).to(model.device))[0]
+ encoder_features_np = [f.detach().cpu().numpy() for f in encoder_features]
+
+ # Print the encoder features shapes
+ for f in encoder_features_np:
+ print(f.shape)
+
+ # decoder
+ features = encoder_features.copy()
+ features.reverse()
+ feat = features[0]
+ features.append(None)
+ decoder_features_np = []
+ for skip, stage in zip(features[1:], model.model.decoder.decoder_stages):
+ feat = stage(feat, skip)
+ decoder_features_np.append(feat.detach().cpu().numpy())
+ for f in decoder_features_np:
+ print(f.shape)
+ prediction = model.model.head(feat).detach().cpu().numpy()
+
+
+# Defining the colors for plotting
+class Color(NamedTuple):
+ r: float
+ g: float
+ b: float
+
+
+# Defining the colors for plottting the PCA
+BOP_ORANGE = Color(0.972549, 0.6784314, 0.1254902)
+BOP_BLUE = Color(BOP_ORANGE.b, BOP_ORANGE.g, BOP_ORANGE.r)
+GREEN = Color(0.0, 1.0, 0.0)
+MAGENTA = Color(1.0, 0.0, 1.0)
+
+
+# Defining the functions to rescale the image and composite the nuclear and membrane images
+def rescale_clip(image: torch.Tensor) -> np.ndarray:
+ return rescale_intensity(image, out_range=(0, 1))[..., None].repeat(3, axis=-1)
+
+
+def composite_nuc_mem(image: torch.Tensor, nuc_color: Color, mem_color: Color) -> np.ndarray:
+ c_nuc = rescale_clip(image[0]) * nuc_color
+ c_mem = rescale_clip(image[1]) * mem_color
+ return rescale_intensity(c_nuc + c_mem, out_range=(0, 1))
+
+
+def clip_p(image: np.ndarray) -> np.ndarray:
+ return rescale_intensity(image.clip(*np.percentile(image, [1, 99])))
+
+
+def clip_highlight(image: np.ndarray) -> np.ndarray:
+ return rescale_intensity(image.clip(0, np.percentile(image, 99.5)))
+
+
+# Plot the PCA to RGB of the feature maps
+f, ax = plt.subplots(10, 1, figsize=(5, 25))
+n_components = 4
+ax[0].imshow(phase_img[0, 0, 0], cmap="gray")
+ax[0].set_title(f"Phase {phase_img.shape[1:]}")
+ax[-1].imshow(clip_p(composite_nuc_mem(fluo, GREEN, MAGENTA)))
+ax[-1].set_title("Fluorescence")
+
+for level, feat in enumerate(encoder_features_np):
+ ax[level + 1].imshow(pcs_to_rgb(feat, n_components=n_components))
+ ax[level + 1].set_title(f"Encoder stage {level + 1} {feat.shape[1:]}")
+
+for level, feat in enumerate(decoder_features_np):
+ ax[5 + level].imshow(pcs_to_rgb(feat, n_components=n_components))
+ ax[5 + level].set_title(f"Decoder stage {level + 1} {feat.shape[1:]}")
+
+pred_comp = composite_nuc_mem(prediction[0, :, 0], BOP_BLUE, BOP_ORANGE)
+ax[-2].imshow(clip_p(pred_comp))
+ax[-2].set_title(f"Prediction {prediction.shape[1:]}")
+
+for a in ax.ravel():
+ a.axis("off")
+plt.tight_layout()
+
+# %% [markdown] tags=["task"]
+#
+#
+# ### Task 3.2: Select a sample batch to test the range of validty of the model
+# - Run the next cell to setup the your dataloader for `test`
+# - Select a test batch from the `test_dataloader` by changing the `batch_number`
+# - Examine the plot of the source and target images of the batch
+#
+# Note the 2D images have different focus
+#
+
+# %%
+YX_PATCH_SIZE = (256 * 2, 256 * 2)
+source_channel = ["Phase3D"]
+target_channel = ["Nucl", "Mem"]
+
+normalizations = [
+ NormalizeSampled(
+ keys=source_channel,
+ level="fov_statistics",
+ subtrahend="mean",
+ divisor="std",
+ ),
+ NormalizeSampled(
+ keys=target_channel,
+ level="fov_statistics",
+ subtrahend="median",
+ divisor="iqr",
+ ),
+]
+
+# Re-load the dataloader
+phase2fluor_2D_data = HCSDataModule(
+ data_path,
+ source_channel=source_channel,
+ target_channel=target_channel,
+ z_window_size=1,
+ split_ratio=0.8,
+ batch_size=1,
+ num_workers=8,
+ yx_patch_size=YX_PATCH_SIZE,
+ augmentations=[],
+ normalizations=normalizations,
+)
+phase2fluor_2D_data.setup("test")
+# %% tags=[]
+# ########## TODO ##############
+batch_number = 3 # Change this to see different batches of data
+# #######################
+y_slice = slice(Y // 2 - 256 * n // 2, Y // 2 + 256 * n // 2)
+x_slice = slice(X // 2 - 256 * n // 2, X // 2 + 256 * n // 2)
+
+# Iterate through the test dataloader to get the desired batch
+i = 0
+for batch in phase2fluor_2D_data.test_dataloader():
+ # break if we reach the desired batch
+ if i == batch_number - 1:
+ break
+ i += 1
+
+# Plot the batch source and target images
+f, ax = plt.subplots(1, 2, figsize=(8, 12))
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0].imshow(
+ batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(),
+ cmap="gray",
+ vmin=-15,
+ vmax=15,
+)
+ax[1].imshow(clip_highlight(target_composite[0, y_slice, x_slice]))
+for a in ax.ravel():
+ a.axis("off")
+f.tight_layout()
+plt.show()
+
+# %% [markdown] tags=[]
+#
+#
+# ### Task 3.3: Using the selected batch to test the model's range of validity
+#
+# - Given the selected batch use `monai.networks.layers.GaussianFilter` to blur the images with different sigmas.
+# Check the documentation
here
+# - Plot the source and predicted images comparing the source, target and added perturbations
+# - How is the model's predictions given the perturbations?
+#
+# %% tags=["task"]
+# ########## TODO ##############
+# Try out different multiples of 256 to visualize larger/smaller crops
+n = 3
+# ##############################
+# Center cropping the image
+y_slice = slice(Y // 2 - 256 * n // 2, Y // 2 + 256 * n // 2)
+x_slice = slice(X // 2 - 256 * n // 2, X // 2 + 256 * n // 2)
+
+f, ax = plt.subplots(3, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(),
+ cmap="gray",
+ vmin=-15,
+ vmax=15,
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice]))
+ax[0, 0].set_title("Source and target")
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1, 0].set_title("No perturbation")
+
+# Select a sigma for the Gaussian filtering
+# ########## TODO ##############
+# Tensor dimensions (B, C, Z, Y, X).
+# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigmas
+# Hint: Spatial (Z, Y, X)
+gaussian_blur = GaussianFilter(...)
+# #############################
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ phase = gaussian_blur(phase)
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+
+# %% tags=["solution"]
+# ########## SOLUTION ##############
+# Try out different multiples of 256 to visualize larger/smaller crops
+n = 3
+# ##############################
+# Center cropping the image
+y_slice = slice(Y // 2 - 256 * n // 2, Y // 2 + 256 * n // 2)
+x_slice = slice(X // 2 - 256 * n // 2, X // 2 + 256 * n // 2)
+
+f, ax = plt.subplots(3, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(),
+ cmap="gray",
+ vmin=-15,
+ vmax=15,
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice]))
+ax[0, 0].set_title("Source and target")
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1, 0].set_title("No perturbation")
+
+
+# Select a sigma for the Gaussian filtering
+# ########## SOLUTION ##############
+# Tensor dimensions (B, C, Z, Y, X).
+# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigma
+# Hint: Spatial (Z, Y, X). Apply the same sigma to Y, X
+gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0, 2, 2))
+# #############################
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ phase = gaussian_blur(phase)
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+
+# %% [markdown] tags=[]
+#
+#
+# ### Task 3.3: Using the selected batch to test the model's range of validity
+#
+# - Scale the pixel values up/down of the phase image
+# - Plot the source and predicted images comparing the source, target and added perturbations
+# - How is the model's predictions given the perturbations?
+#
+
+# %% tags=["task"]
+n = 3
+y_slice = slice(Y // 2, Y // 2 + 256 * n)
+x_slice = slice(X // 2, X // 2 + 256 * n)
+f, ax = plt.subplots(3, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(),
+ cmap="gray",
+ vmin=-15,
+ vmax=15,
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice]))
+ax[0, 0].set_title("Source and target")
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1, 0].set_title("No perturbation")
+
+
+# Rescale the pixel value up/down
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ # ########## TODO ##############
+ # Hint: Scale the phase intensity up/down until the model breaks
+ phase = phase * ...
+ # #######################
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+
+# %% tags=["solution"]
+n = 3
+y_slice = slice(Y // 2, Y // 2 + 256 * n)
+x_slice = slice(X // 2, X // 2 + 256 * n)
+f, ax = plt.subplots(3, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(),
+ cmap="gray",
+ vmin=-15,
+ vmax=15,
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice]))
+ax[0, 0].set_title("Source and target")
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1, 0].set_title("No perturbation")
+
+
+# Rescale the pixel value up/down
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ # ########## SOLUTION ##############
+ # Hint: Scale the phase intensity up/down until the model breaks
+ phase = phase * 10
+ # #######################
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+
+# %% [markdown]
+#
+#
Questions
+# How is the model's predictions given the blurring and scaling perturbations?
+#
+
+# %% tags=["solution"]
+# ########## SOLUTIONS FOR ALL POSSIBLE PLOTTINGS ##############
+# This plots all perturbations
+
+n = 3
+y_slice = slice(Y // 2, Y // 2 + 256 * n)
+x_slice = slice(X // 2, X // 2 + 256 * n)
+f, ax = plt.subplots(6, 2, figsize=(8, 12))
+
+target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
+ax[0, 0].imshow(
+ batch["source"][0, 0, 0, y_slice, x_slice].cpu().numpy(),
+ cmap="gray",
+ vmin=-15,
+ vmax=15,
+)
+ax[0, 1].imshow(clip_highlight(target_composite[0, y_slice, x_slice]))
+ax[0, 0].set_title("Source and target")
+
+# no perturbation
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[1, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[1, 1].imshow(pred_composite[0])
+ax[1, 0].set_title("No perturbation")
+
+
+# 2-sigma gaussian blur
+gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0, 2, 2))
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ phase = gaussian_blur(phase)
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[2, 1].imshow(pred_composite[0])
+ax[2, 0].set_title("Gaussian Blur Sigma=2")
+
+
+# 5-sigma gaussian blur
+gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0, 5, 5))
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ phase = gaussian_blur(phase)
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[3, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[3, 1].imshow(pred_composite[0])
+ax[3, 0].set_title("Gaussian Blur Sigma=5")
+
+
+# 0.1x scaling
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ phase = phase * 0.1
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[4, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[4, 1].imshow(pred_composite[0])
+ax[4, 0].set_title("0.1x scaling")
+
+# 10x scaling
+with torch.inference_mode():
+ phase = batch["source"].to(model.device)[:, :, :, y_slice, x_slice]
+ phase = phase * 10
+ pred = model(phase).cpu().numpy()
+pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
+ax[5, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
+ax[5, 1].imshow(pred_composite[0])
+ax[5, 0].set_title("10x scaling")
+
+for a in ax.ravel():
+ a.axis("off")
+
+f.tight_layout()
+# %% [markdown] tags=[]
+#
+
+#
+# 🎉 The end of the notebook 🎉
+#
+
+# Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned.
+
+#
diff --git a/applications/cytoland/examples/phase_contrast/README.md b/applications/cytoland/examples/phase_contrast/README.md
new file mode 100644
index 000000000..0131624a9
--- /dev/null
+++ b/applications/cytoland/examples/phase_contrast/README.md
@@ -0,0 +1,37 @@
+# Demo: Virtual staining of phase contrast data
+
+# Overview:
+
+Generalization to Zernike phase contrast images. This demo showcases the use of VSCyto3D model with and without augmentations on Zernike phase contrast data.
+
+## Setup
+
+Run the setup script to create the environment for this exercise and download the dataset.
+```bash
+source setup.sh
+```
+
+Activate your environment
+```bash
+conda activate vs_Phc
+```
+
+## Use vscode
+
+Install vscode, install jupyter extension inside vscode, and setup [cell mode](https://code.visualstudio.com/docs/python/jupyter-support-py). Open [solution.py](solution.py) and run the script interactively.
+
+## Use Jupyter Notebook
+
+Launch a jupyter environment
+
+```
+jupyter notebook
+```
+
+...and continue with the instructions in the notebook.
+
+If `vs_Phc` is not available as a kernel in jupyter, run:
+
+```
+python -m ipykernel install --user --name=vs_Phc
+```
diff --git a/applications/cytoland/examples/phase_contrast/prepare-exercise.sh b/applications/cytoland/examples/phase_contrast/prepare-exercise.sh
new file mode 100644
index 000000000..526b9fa3c
--- /dev/null
+++ b/applications/cytoland/examples/phase_contrast/prepare-exercise.sh
@@ -0,0 +1,8 @@
+# Run ruff format on .py files
+# ruff format solution.py
+
+# Convert .py to ipynb
+
+# "cell_metadata_filter": "all" preserve cell tags including our solution tags
+jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' --update solution.py
+jupyter nbconvert solution.ipynb --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags task --to notebook --output solution.ipynb
diff --git a/applications/cytoland/examples/phase_contrast/setup.sh b/applications/cytoland/examples/phase_contrast/setup.sh
new file mode 100644
index 000000000..f0bfef845
--- /dev/null
+++ b/applications/cytoland/examples/phase_contrast/setup.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env -S bash -i
+
+START_DIR=$(pwd)
+
+conda deactivate
+# Create conda environment
+conda create -y --name vs_Phc python=3.11
+
+# Install ipykernel in the environment.
+conda install -y ipykernel nbformat nbconvert ruff jupytext ipywidgets --name vs_Phc
+# Specifying the environment explicitly.
+# conda activate sometimes doesn't work from within shell scripts.
+
+# Install cytoland (pulls in viscy-data, viscy-models, viscy-transforms, viscy-utils).
+# Run this from the root of the VisCy monorepo checkout.
+# Find path to the environment - conda activate doesn't work from within shell scripts.
+ENV_PATH=$(conda info --envs | grep vs_Phc | awk '{print $NF}')
+$ENV_PATH/bin/pip install -e "applications/cytoland[metrics]"
+
+# Create the directory structure
+mkdir -p ~/data/vs_PhC/test
+mkdir -p ~/data/vs_PhC/models
+
+# Change to the target directory
+# Download the OME-Zarr dataset recursively
+cd ~/data/vs_PhC/test
+wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto3D/test/HEK_H2B_CAAX_PhC_40x_registered.zarr/"
+
+# Get the models
+cd ~/data/vs_PhC/models
+wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/no_augmentations/best_epoch=30-step=6076.ckpt"
+wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/epoch=48-step=18130.ckpt"
+
+
+# Change back to the starting directory
+cd $START_DIR
diff --git a/applications/cytoland/examples/phase_contrast/solution.py b/applications/cytoland/examples/phase_contrast/solution.py
new file mode 100644
index 000000000..19f5395eb
--- /dev/null
+++ b/applications/cytoland/examples/phase_contrast/solution.py
@@ -0,0 +1,225 @@
+# %% [markdown] tags=[]
+# # Virtual staining of phase contrast images using VSCyto3D with and without augmentations
+#
+# Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco
+#
+# ## Overview
+#
+# This notebook demonstrates how to use the VSCyto3D model to virtually stain phase contrast images. The phase contrast images were not part of the training.
+# We will use the VSCyto3D model to predict the nuclei and cell membrane channels from a phase contrast image with two models:
+# - One model trained without augmentations
+# - One model trained with augmentations
+#
+
+# %% tags=[]
+# Imports
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from iohub import open_ome_zarr
+from lightning.pytorch import seed_everything
+
+# Cytoland and VisCy modular classes for the trainer and model
+from cytoland.engine import VSUNet
+from viscy_data.hcs import HCSDataModule
+from viscy_transforms import NormalizeSampled
+from viscy_utils.trainer import VisCyTrainer
+
+# seed random number generators for reproducibility.
+seed_everything(42, workers=True)
+# %%
+# Paths to data and log directory
+top_dir = (
+ Path("~/data/vs_PhC").expanduser()
+) # If this fails, make sure this to point to your data directory in the shared mounting point inside /dlmbl/data
+
+# Path to the training data
+data_path = top_dir / "test/HEK_H2B_CAAX_PhC_40x_registered.zarr"
+
+# %% [markdown] tags=[]
+# ## Load OME-Zarr Dataset
+
+# There should be 34 FOVs in the dataset.
+#
+# Each FOV consists of 3 channels of 2048x2048 images,
+# saved in the [High-Content Screening (HCS) layout](https://ngff.openmicroscopy.org/latest/#hcs-layout)
+# specified by the Open Microscopy Environment Next Generation File Format
+# (OME-NGFF).
+#
+# The 3 channels correspond to the QPI, nuclei, and cell membrane. The nuclei were stained with DAPI and the cell membrane with Cellmask.
+#
+# - The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.`
+# - These datasets only have 1 level in the pyramid (highest resolution) which is '0'.
+# %%
+# Open dataset and look at it's structure
+dataset = open_ome_zarr(data_path)
+dataset.print_tree()
+# %%
+row = 0
+col = 3
+field = "000000" # TODO Change this for a different FOV
+
+# NOTE: this dataset only has one level
+pyaramid_level = 0
+
+fov_path = f"{row}/{col}/{field}"
+input_data_path = Path(data_path) / fov_path
+image = dataset[fov_path][pyaramid_level].numpy()
+
+n_channels = len(dataset.channel_names)
+Z, Y, X = image.shape[-3:]
+figure, axes = plt.subplots(1, n_channels, figsize=(9, 3))
+title_names = ["PhC", "TXR", "Y5"]
+for i in range(n_channels):
+ for i in range(n_channels):
+ channel_image = image[0, i, Z // 2]
+ # Invert the phase contrast channel
+ if i == 0:
+ channel_image = channel_image * -1
+ # Adjust contrast to 0.5th and 99.5th percentile of pixel values.
+ p_low, p_high = np.percentile(channel_image, (0.5, 99.5))
+ channel_image = np.clip(channel_image, p_low, p_high)
+ axes[i].imshow(channel_image, cmap="gray")
+ axes[i].axis("off")
+ axes[i].set_title(title_names[i])
+plt.tight_layout()
+
+# %% [markdown] tags=[]
+# ## Create the VSCyto3D model
+# Here we will instantiate the `HCSDataModule` that reads the ome-zarr dataset and prepares the data for inference.
+# %%
+# Reduce the batch size if encountering out-of-memory errors
+BATCH_SIZE = 5
+# NOTE: Set the number of workers to 0 for Windows and macOS
+# since multiprocessing only works with a
+# `if __name__ == '__main__':` guard.
+# On Linux, set it to the number of CPU cores to maximize performance.
+NUM_WORKERS = 0
+source_channel_name = "BF"
+
+# %%[markdown]
+"""
+For this example we will use the following parameters:
+### For more information on the VSCyto3D model:
+See ``viscy.unet.networks.fcmae``
+([source code](https://github.com/mehta-lab/VisCy/blob/6a3457ec8f43ecdc51b1760092f1a678ed73244d/viscy/unet/networks/unext2.py#L252))
+for configuration details.
+"""
+# %%
+# Setup the data module.
+data_module = HCSDataModule(
+ data_path=input_data_path,
+ source_channel=source_channel_name,
+ target_channel=["Nuclei", "Membrane"],
+ z_window_size=5,
+ split_ratio=0.8,
+ batch_size=BATCH_SIZE,
+ num_workers=NUM_WORKERS,
+ architecture="UNeXt2",
+ normalizations=[
+ NormalizeSampled(
+ [source_channel_name],
+ level="fov_statistics",
+ subtrahend="median",
+ divisor="iqr",
+ )
+ ],
+)
+data_module.prepare_data()
+data_module.setup(stage="predict")
+
+# %% [markdown] tags=[]
+# ## Setup the _VSCyto3D_ model with and without augmentations
+# We will load the model checkpoints and run inference on the phase contrast image.abs
+# The model that utilizes augmentations shows better performance in the prediction of the nuclei and cell membrane channels.
+# The phase contrast images were not part of the training for the `VSCyto3D`` model.
+# %%
+
+# TODO: change if you want to use a different GPU
+GPU_ID = 0
+
+# TODO: point to the downloaded model checkpoints
+no_augmentation_model_ckpt = top_dir / "models/no_augmentations/best_epoch=30-step=6076.ckpt"
+VSCyto3D_model_ckpt = top_dir / "models/epoch=48-step=18130.ckpt"
+
+# Dictionary that specifies key parameters of the model.
+config_VSCyto3D = {
+ "in_channels": 1,
+ "out_channels": 2,
+ "in_stack_depth": 5,
+ "backbone": "convnextv2_tiny",
+ "stem_kernel_size": (5, 4, 4),
+ "decoder_mode": "pixelshuffle",
+ "head_expansion_ratio": 4,
+ "head_pool": True,
+}
+
+# Model without augmentation
+model_VSCyto3D_no_augmentation = VSUNet.load_from_checkpoint(
+ no_augmentation_model_ckpt, architecture="UNeXt2", model_config=config_VSCyto3D
+)
+model_VSCyto3D_no_augmentation.eval()
+# Model with augmentation
+model_VSCyto3D_w_augmentation = VSUNet.load_from_checkpoint(
+ VSCyto3D_model_ckpt, architecture="UNeXt2", model_config=config_VSCyto3D
+)
+model_VSCyto3D_w_augmentation.eval()
+
+# Setup the Trainer
+trainer = VisCyTrainer(accelerator="gpu", devices=[GPU_ID], precision="16-mixed")
+
+n = 5
+patch_size = 256
+y_slice = slice(Y // 2 - patch_size * n // 2, Y // 2 + patch_size * n // 2)
+x_slice = slice(X // 2 - patch_size * n // 2, X // 2 + patch_size * n // 2)
+
+# Get the Phase Contrast channel
+c_idx = dataset.channel_names.index(source_channel_name)
+phase_image = image[0:1, c_idx : c_idx + 1, Z // 2 - 3 : Z // 2 + 3, y_slice, x_slice]
+# Normalize the image
+median = dataset[fov_path].zattrs["normalization"][source_channel_name]["fov_statistics"]["median"]
+iqr = dataset[fov_path].zattrs["normalization"][source_channel_name]["fov_statistics"]["iqr"]
+phase_image = ((phase_image - median) / iqr) * -1
+
+# Load the image to device
+device = model_VSCyto3D_no_augmentation.device
+phase_image = torch.tensor(phase_image).to(device)
+
+# Run inference on the given volume
+with torch.inference_mode(): # turn off gradient computation.
+ pred_no_augmentation = model_VSCyto3D_no_augmentation(phase_image)
+ pred_w_augmentation = model_VSCyto3D_w_augmentation(phase_image)
+
+pred_no_augmentation = pred_no_augmentation.cpu().detach().numpy()
+pred_w_augmentation = pred_w_augmentation.cpu().detach().numpy()
+phase_image = phase_image.cpu().detach().numpy()
+clim_max = 30
+clim_min = -20
+
+# Plot the predicted images with model without augmentations
+fig, ax = plt.subplots(2, 3, figsize=(12, 12))
+ax[0, 0].imshow(phase_image[0, 0, 2, :, :], cmap="gray", vmin=clim_min, vmax=clim_max)
+ax[0, 0].axis("off")
+ax[0, 0].set_title("Phase Contrast")
+for i in range(2):
+ ax[0, i + 1].imshow(pred_no_augmentation[0, i, 2, :, :], cmap="gray")
+ ax[0, i + 1].axis("off")
+ax[0, 1].set_title("VS_Nuclei without augmentations")
+ax[0, 2].set_title("VS_Membrane without augmentations")
+
+# Plot the predicted images with VSCyto3D with augmentations
+ax[1, 0].imshow(phase_image[0, 0, 2, :, :], cmap="gray", vmin=clim_min, vmax=clim_max)
+ax[1, 0].axis("off")
+ax[1, 0].set_title("Phase Contrast")
+for i in range(2):
+ ax[1, i + 1].imshow(
+ pred_w_augmentation[0, i, 2, :, :],
+ cmap="gray",
+ )
+ ax[1, i + 1].axis("off")
+ax[1, 1].set_title("VS_Nuclei with augmentations")
+ax[1, 2].set_title("VS_Membrane with augmentations")
+
+plt.tight_layout()
diff --git a/applications/cytoland/examples/vcp_tutorials/README.md b/applications/cytoland/examples/vcp_tutorials/README.md
new file mode 100644
index 000000000..c9e2eaf6c
--- /dev/null
+++ b/applications/cytoland/examples/vcp_tutorials/README.md
@@ -0,0 +1,21 @@
+# Virtual Cell Platform Tutorials
+
+This directory contains tutorial notebooks for the Virtual Cell Platform,
+available in both Python scripts and Jupyter notebooks.
+
+- [Quick Start](quick_start.ipynb):
+get started with model inference in Python with a A549 cell dataset.
+- [CLI inference and visualization](hek293t.ipynb):
+run inference from CLI on a HEK293T cell dataset and visualize the results.
+- [Virtual staining _in vivo_](neuromast.ipynb):
+compare virtual staining and fluorescence in a time-lapse dataset of the zebrafish neuromast.
+
+## Development
+
+The development happens on the Python scripts,
+which are converted to Jupyter notebooks with:
+
+```sh
+# TODO: change the file name at the end to be the script to convert
+jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' --update quick_start.py
+```
diff --git a/applications/cytoland/examples/vcp_tutorials/hek293t.py b/applications/cytoland/examples/vcp_tutorials/hek293t.py
new file mode 100644
index 000000000..41decac02
--- /dev/null
+++ b/applications/cytoland/examples/vcp_tutorials/hek293t.py
@@ -0,0 +1,226 @@
+# %% [markdown]
+"""
+# Cytoland Tutorial: Virtual Staining of HEK293T Cells with VSCyto3D
+
+**Estimated time to complete:** 15 minutes
+"""
+
+# %% [markdown]
+"""
+# Learning Goals
+
+* Download the VSCyto3D model and an example dataset containing HEK293T cell images.
+* Pre-compute normalization statistics for the images using the `viscy preprocess` command line interface (CLI).
+* Run inference for joint virtual staining of cell nuclei and plasma membrane via the `viscy predict` CLI.
+* Compare virtually and experimentally stained cells and see how virtual staining can rescue missing labels.
+"""
+
+# %% [markdown]
+"""
+# Prerequisites
+
+Python>=3.11
+"""
+
+# %% [markdown]
+"""
+# Introduction
+
+See the [model card](https://virtualcellmodels.cziscience.com/paper/cytoland2025)
+for more details about the Cytoland models.
+
+VSCyto3D is a 3D UNeXt2 model that has been trained on A549, HEK293T, and hiPSC cells using the Cytoland approach.
+This model enables users to jointly stain cell nuclei and plasma membranes from 3D label-free images
+for downstream analysis such as cell segmentation and tracking without the need for human annotation of volumetric data.
+"""
+
+# %% [markdown]
+"""
+# Setup
+
+The commands below will install the required packages and download the example dataset and model checkpoint.
+It may take a **few minutes** to download all the files.
+
+## Setup Google Colab
+
+To run this quick-start guide using Google Colab,
+choose the 'T4' GPU runtime from the "Connect" dropdown menu
+in the upper-right corner of this notebook for faster execution.
+Using a GPU significantly speeds up running model inference, but CPU compute can also be used.
+
+## Setup Local Environment
+
+The commands below assume a Unix-like shell with `wget` installed.
+On Windows, the files can be downloaded manually from the URLs.
+"""
+
+# %%
+# Install VisCy with the optional dependencies for this example
+# See the [repository](https://github.com/mehta-lab/VisCy) for more details
+# Here stackview and ipycanvas are installed for visualization
+# !pip install -U "viscy[metrics,visual]==0.4.0a3" stackview ipycanvas==0.11
+
+# %%
+# Restart kernel if running in Google Colab
+# This is required to use the packages installed above
+# The 'kernel crashed' message is expected here
+if "get_ipython" in globals():
+ session = get_ipython() # noqa: F821
+ if "google.colab" in str(session):
+ print("Shutting down colab session.")
+ session.kernel.do_shutdown(restart=True)
+
+# %%
+# Download the example dataset
+# !wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto3D/test/HEK293T-Phase3D-H2B-CAAX-example.zarr/"
+
+# %%
+# Rename the downloaded dataset to what the example prediction config expects (`input.ome.zarr`)
+# And validate the OME-Zarr metadata with iohub
+# !mv HEK293T-Phase3D-H2B-CAAX-example.zarr input.ome.zarr
+# !iohub info -v input.ome.zarr
+
+# %%
+# Download the VSCyto3D model checkpoint and prediction config
+# !wget "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/epoch=83-step=14532-loss=0.492.ckpt"
+# !wget "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D/predict.yml"
+
+# %% [markdown]
+"""
+# Use Case
+
+## Example Dataset
+
+The HEK293T example dataset used in this quick-start guide contains
+quantitative phase and paired fluorescence images of cell nuclei and plasma membrane.
+It is a subset (one cropped region of interest) from a test set used to evaluate the VSCyto3D model.
+The full dataset can be downloaded from the
+[BioImage Archive](https://www.ebi.ac.uk/biostudies/BioImages/studies/S-BIAD1702).
+
+Refer to our [preprint](https://doi.org/10.1101/2024.05.31.596901) for more details
+about how the dataset and model were generated.
+
+## Using Custom Data
+
+The model only requires label-free images for inference.
+To run inference on your own data,
+convert them into the OME-Zarr data format using iohub or other
+[tools](https://ngff.openmicroscopy.org/tools/index.html#file-conversion),
+and edit the `predict.yml` file to specify the input data path.
+Specifically, the `data.init_args.data_path` field should be updated:
+
+```diff
+- data_path: input.ome.zarr
++ data_path: /path/to/your.ome.zarr
+```
+
+The image may need to be resampled to roughly match the voxel size of the example dataset
+(0.2x0.1x0.1 µm, ZYX).
+"""
+
+# %% [markdown]
+"""
+# Run Model Inference
+
+On Google Colab, the preprocessing step takes about **1 minute**,
+and the inference step takes about **2 minutes** (T4 GPU).
+"""
+
+# %%
+# Run the CLI command to pre-compute normalization statistics
+# This includes the median and interquartile range (IQR)
+# Used to shift and scale the intensity distribution of the input images
+# !viscy preprocess --data_path=input.ome.zarr
+
+# %%
+# Run the CLI command to run inference
+# !viscy predict -c predict.yml
+
+# %% [markdown]
+"""
+# Analysis of Model Outputs
+
+Visualize the experimental and virtually stained images using the `stackview` package.
+"""
+
+# %% [markdown]
+"""
+Visualizing large 3D multichannel images in a Jupyter notebook
+**is prone to performance issues and may crash the notebook** if the images are too large
+(the free Colab instances have limited CPU cores and memory).
+The visualization code below is only intended for demonstration.
+We strongly recommend downloading the images (from the 'files' bar in Colab)
+and using a standalone viewer such as [napari](https://napari.org/).
+"""
+
+# %%
+
+import numpy as np # noqa: E402
+import stackview # noqa: E402
+from iohub import open_ome_zarr # noqa: E402
+from skimage.exposure import rescale_intensity # noqa: E402
+
+try:
+ from google.colab import output
+
+ output.enable_custom_widget_manager()
+except ImportError:
+ pass
+
+
+# %%
+# open the images
+def split_and_rescale_channels(timepoint: np.ndarray) -> tuple[np.ndarray, ...]:
+ return (rescale_intensity(channel, out_range=(0, 1)) for channel in timepoint)
+
+
+fov_name = "plate/0/11"
+input_image = open_ome_zarr("input.ome.zarr")[fov_name]["0"]
+prediction_image = open_ome_zarr("prediction.ome.zarr")[fov_name]["0"]
+
+phase, fluor_nucleus, fluor_membrane = split_and_rescale_channels(input_image[0])
+vs_nucleus, vs_membrane = split_and_rescale_channels(prediction_image[0])
+
+# %%
+# Drag the slider to start rendering
+# Click on the numbered buttons to toggle the channels
+stackview.switch(
+ # the 0, 1, 2, 3, 4 buttons will correspond to these 5 channels
+ # We apply a gamma adjustment to the phase channel to improve visibility in the overlay
+ images=[phase**2.5, fluor_nucleus, fluor_membrane, vs_nucleus, vs_membrane],
+ colormap=["gray", "pure_green", "pure_magenta", "pure_blue", "pure_yellow"],
+ toggleable=True,
+ zoom_factor=0.5,
+ display_min=0.0,
+ display_max=0.9,
+)
+
+# %% [markdown]
+"""
+Note how the experimental fluorescence is missing for a subset of cells.
+This is due to loss of genetic labeling.
+The virtually stained images is not affected by this issue and can robustly label all cells.
+"""
+
+# %% [markdown]
+"""
+# Summary
+
+In the above example, we demonstrated how to use the VSCyto3D model
+for virtual staining of cell nuclei and plasma membranes, which can rescue missing labels.
+"""
+
+# %% [markdown]
+"""
+## Contact & Feedback
+
+For issues or feedback about this tutorial please contact Ziwen Liu at [ziwen.liu@czbiohub.org](mailto:ziwen.liu@czbiohub.org).
+
+## Responsible Use
+
+We are committed to advancing the responsible development and use of artificial intelligence.
+Please follow our [Acceptable Use Policy](https://virtualcellmodels.cziscience.com/acceptable-use-policy) when engaging with our services.
+
+Should you have any security or privacy issues or questions related to the services,
+please reach out to our team at [security@chanzuckerberg.com](mailto:security@chanzuckerberg.com) or [privacy@chanzuckerberg.com](mailto:privacy@chanzuckerberg.com) respectively.
+"""
diff --git a/applications/cytoland/examples/vcp_tutorials/neuromast.py b/applications/cytoland/examples/vcp_tutorials/neuromast.py
new file mode 100644
index 000000000..4ac547b83
--- /dev/null
+++ b/applications/cytoland/examples/vcp_tutorials/neuromast.py
@@ -0,0 +1,317 @@
+# %% [markdown]
+"""
+# Cytoland Tutorial: Virtual Staining of Zebrafish Neuromasts with VSNeuromast
+
+**Estimated time to complete:** 15 minutes
+"""
+
+# %% [markdown]
+"""
+# Learning Goals
+
+* Download the VSNeuromast model and an example dataset containing time-lapse images of zebrafish neuromasts.
+* Pre-compute normalization statistics for the images using the `viscy preprocess` command-line interface (CLI).
+* Run inference for joint virtual staining of cell nuclei and plasma membrane via the `viscy predict` CLI.
+* Visualize the effect of photobleaching in fluorescence imaging and how virtual staining can mitigate this issue.
+"""
+
+# %% [markdown]
+"""
+# Prerequisites
+
+Python>=3.11
+"""
+
+# %% [markdown]
+"""
+# Introduction
+
+The zebrafish neuromasts are sensory organs on the lateral lines.
+Given their relatively simple structure and high accessibility to live imaging,
+they are used as a model system to study organogenesis _in vivo_.
+However, multiplexed long-term fluorescence imaging at high spatial-temporal resolution
+is often limited by photobleaching and phototoxicity.
+Also, engineering fish lines with a combination of landmark fluorescent labels
+(e.g. nuclei and plasma membrane) and functional reporters increases experimental complexity.
+\
+VSNeuromast is a 3D UNeXt2 model that has been trained on images of
+zebrafish neuromasts using the Cytoland approach.
+(See the [model card](https://virtualcellmodels.cziscience.com/paper/cytoland2025)
+for more details about the Cytoland models.)
+This model enables users to jointly stain cell nuclei and plasma membranes from 3D label-free images
+for downstream analysis such as cell segmentation and tracking.
+"""
+
+# %% [markdown]
+"""
+# Setup
+
+The commands below will install the required packages and download the example dataset and model checkpoint.
+It may take a **few minutes** to download all the files.
+
+## Setup Google Colab
+
+To run this quick-start guide using Google Colab,
+choose the 'T4' GPU runtime from the "Connect" dropdown menu
+in the upper-right corner of this notebook for faster execution.
+Using a GPU significantly speeds up running model inference, but CPU compute can also be used.
+
+## Setup Local Environment
+
+The commands below assume a Unix-like shell with `wget` installed.
+On Windows, the files can be downloaded manually from the URLs.
+"""
+
+# %%
+# Install VisCy with the optional dependencies for this example
+# See the [repository](https://github.com/mehta-lab/VisCy) for more details
+# Here stackview and ipycanvas are installed for visualization
+# !pip install -U "viscy[metrics,visual]==0.4.0a3"
+
+# %%
+# Restart kernel if running in Google Colab
+# This is required to use the packages installed above
+# The 'kernel crashed' message is expected here
+if "get_ipython" in globals():
+ session = get_ipython()
+ if "google.colab" in str(session):
+ print("Shutting down colab session.")
+ session.kernel.do_shutdown(restart=True)
+
+# %%
+# Download the example dataset
+# !wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSNeuromast/test/isim-bleaching-example.zarr/"
+
+# %%
+# Rename the downloaded dataset to what the example prediction config expects (`input.ome.zarr`)
+# And validate the OME-Zarr metadata with iohub
+# !mv isim-bleaching-example.zarr input.ome.zarr
+# !iohub info -v input.ome.zarr
+
+# %%
+# Download the VSNeuromast model checkpoint and prediction config
+# !wget "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSNeuromast/epoch=64-step=24960.ckpt"
+# !wget "https://public.czbiohub.org/comp.micro/viscy/VS_models/VSNeuromast/predict.yml"
+
+# %% [markdown]
+"""
+# Use Case
+
+## Example Dataset
+
+The neuromast example dataset used in this tutorial contains
+quantitative phase and paired fluorescence images of the cell nuclei and the plasma membrane.
+\
+**It is a subsampled time-lapse from a test set used to evaluate the VSNeuromast model.**
+\
+The full dataset can be downloaded from the
+[BioImage Archive](https://www.ebi.ac.uk/biostudies/BioImages/studies/S-BIAD1702).
+
+Refer to our [preprint](https://doi.org/10.1101/2024.05.31.596901) for more details
+about how the dataset and model were generated.
+
+## Using Custom Data
+
+The model only requires label-free images for inference.
+To run inference on your own data,
+convert them into the [OME-Zarr](https://ngff.openmicroscopy.org/)
+data format using iohub or other
+[tools](https://ngff.openmicroscopy.org/tools/index.html#file-conversion),
+and edit the `predict.yml` file to specify the input data path.
+Specifically, the `data.init_args.data_path` field should be updated:
+
+```diff
+- data_path: input.ome.zarr
++ data_path: /path/to/your.ome.zarr
+```
+
+The image may need to be resampled to roughly match the voxel size of the example dataset
+(0.25x0.108x0.108 µm, ZYX).
+"""
+
+# %% [markdown]
+"""
+# Run Model Inference
+
+On Google Colab, the preprocessing step takes about **1 minute**,
+and the inference step takes about **2 minutes** (T4 GPU).
+"""
+
+# %%
+# Run the CLI command to pre-compute normalization statistics
+# This includes the median and interquartile range (IQR)
+# Used to shift and scale the intensity distribution of the input images
+# !viscy preprocess --data_path=input.ome.zarr
+
+# %%
+# Run the CLI command to run inference
+# !viscy predict -c predict.yml
+
+# %% [markdown]
+"""
+# Analysis of Model Outputs
+
+1. Visualize predicted images over time and compare with the fluorescence images.
+2. Measure photobleaching in the fluorescence images
+and how virtual staining can mitigate this issue.
+Since most pixels in the images are background,
+we will use the 99th percentile (brightest 1%)
+of the intensity distribution as a proxy for foreground signal.
+"""
+
+# %%
+# imports
+import matplotlib.pyplot as plt
+import numpy as np
+from cmap import Colormap
+from iohub import open_ome_zarr
+from numpy.typing import NDArray
+from skimage.exposure import rescale_intensity
+
+
+def render_rgb(image: np.ndarray, colormap: Colormap) -> tuple[NDArray, plt.cm.ScalarMappable]:
+ """Render a 2D grayscale image as RGB using a colormap.
+
+ Parameters
+ ----------
+ image : np.ndarray
+ intensity image
+ colormap : Colormap
+ colormap
+
+ Returns
+ -------
+ tuple[NDArray, plt.cm.ScalarMappable]
+ rendered RGB image and the color mapping
+ """
+ image = rescale_intensity(image, out_range=(0, 1))
+ image = colormap(image)
+ mappable = plt.cm.ScalarMappable(norm=plt.Normalize(0, 1), cmap=colormap.to_matplotlib())
+ return image, mappable
+
+
+# %%
+# read a single Z-slice for visualization
+z_slice = 30
+
+with open_ome_zarr("input.ome.zarr/0/3/0") as fluor_store:
+ fluor_nucleus = fluor_store[0][:, 1, z_slice]
+ fluor_membrane = fluor_store[0][:, 0, z_slice]
+
+with open_ome_zarr("prediction.ome.zarr/0/3/0") as vs_store:
+ vs_nucleus = vs_store[0][:, 0, z_slice]
+ vs_membrane = vs_store[0][:, 1, z_slice]
+
+
+# Render the images as RGB in false colors
+vs_nucleus_rgb, vs_nucleus_mappable = render_rgb(vs_nucleus, Colormap("bop_blue"))
+vs_membrane_rgb, vs_membrane_mappable = render_rgb(vs_membrane, Colormap("bop_orange"))
+merged_vs = (vs_nucleus_rgb + vs_membrane_rgb).clip(0, 1)
+
+fluor_nucleus_rgb, fluor_nucleus_mappable = render_rgb(fluor_nucleus, Colormap("green"))
+fluor_membrane_rgb, fluor_membrane_mappable = render_rgb(fluor_membrane, Colormap("magenta"))
+merged_fluor = (fluor_nucleus_rgb + fluor_membrane_rgb).clip(0, 1)
+
+# Plot
+fig = plt.figure(figsize=(12, 7), layout="constrained")
+
+images = {"fluorescence": merged_fluor, "virtual staining": merged_vs}
+
+for row, (subfig, (name, img)) in enumerate(zip(fig.subfigures(nrows=2, ncols=1), images.items())):
+ subfig.suptitle(name)
+ cax_nuc = subfig.add_axes([1, 0.55, 0.02, 0.3])
+ cax_mem = subfig.add_axes([1, 0.15, 0.02, 0.3])
+ axes = subfig.subplots(ncols=len(merged_vs))
+ for t, ax in enumerate(axes):
+ if row == 1:
+ ax.set_title(f"{t * 30} min", y=-0.1)
+ ax.imshow(img[t])
+ ax.axis("off")
+ if row == 0:
+ subfig.colorbar(fluor_nucleus_mappable, cax=cax_nuc, label="Nuclei (GFP)")
+ subfig.colorbar(fluor_membrane_mappable, cax=cax_mem, label="Membrane (mScarlett)")
+ elif row == 1:
+ subfig.colorbar(vs_nucleus_mappable, cax=cax_nuc, label="Nuclei (VS)")
+ subfig.colorbar(vs_membrane_mappable, cax=cax_mem, label="Membrane (VS)")
+
+plt.show()
+
+# %% [markdown]
+"""
+The plasma membrane fluorescence decreases over time,
+while the virtual staining remains stable.
+How significant is this effect? Is it consistent with photobleaching?
+Analysis below will answer these questions.
+"""
+
+
+# %%
+def highlight_intensity_normalized(fov_path: str, channel_name: str) -> list[float]:
+ """
+ Compute highlight (99th percentile) intensity of each timepoint,
+ normalized to the first timepoint.
+
+ Parameters
+ ----------
+ fov_path : str
+ Path to the field of view (FOV).
+ channel_name : str
+ Name of the channel to compute highlight intensity for.
+
+ Returns
+ -------
+ NDArray
+ List of intensity values.
+ """
+ with open_ome_zarr(fov_path) as fov:
+ channel_index = fov.get_channel_index(channel_name)
+ channel = fov["0"].dask_array()[:, channel_index]
+ highlights = []
+ for t, volume in enumerate(channel):
+ highlights.append(np.percentile(volume.compute(), 99))
+ return [h / highlights[0] for h in highlights]
+
+
+# %%
+# Plot intensity over time
+mean_fl = highlight_intensity_normalized("input.ome.zarr/0/3/0", "mScarlett")
+mean_vs = highlight_intensity_normalized("prediction.ome.zarr/0/3/0", "membrane_prediction")
+time = np.arange(0, 100, 30)
+
+plt.plot(time, mean_fl, label="membrane fluorescence")
+plt.plot(time, mean_vs, label="membrane virtual staining")
+plt.xlabel("time / min")
+plt.ylabel("normalized highlight intensity")
+plt.legend()
+
+# %% [markdown]
+"""
+Here the highlight intensity of the fluorescence images decreases over time,
+following a exponential decay pattern, indicating photobleaching.
+The virtual staining is not affected by this issue.
+(The object drifts slightly over time, so some inherent noise is expected.)
+"""
+
+# %% [markdown]
+"""
+# Summary
+
+In the above example, we demonstrated how to use the VSNeuromast model
+for virtual staining of cell nuclei and plasma membranes of the zebrafish neuromast _in vivo_,
+which can avoid photobleaching in long-term live imaging.
+"""
+
+# %% [markdown]
+"""
+## Contact & Feedback
+
+For issues or feedback about this tutorial please contact Ziwen Liu at [ziwen.liu@czbiohub.org](mailto:ziwen.liu@czbiohub.org).
+
+## Responsible Use
+
+We are committed to advancing the responsible development and use of artificial intelligence.
+Please follow our [Acceptable Use Policy](https://virtualcellmodels.cziscience.com/acceptable-use-policy) when engaging with our services.
+
+Should you have any security or privacy issues or questions related to the services,
+please reach out to our team at [security@chanzuckerberg.com](mailto:security@chanzuckerberg.com) or [privacy@chanzuckerberg.com](mailto:privacy@chanzuckerberg.com) respectively.
+"""
diff --git a/applications/cytoland/examples/vcp_tutorials/quick_start.py b/applications/cytoland/examples/vcp_tutorials/quick_start.py
new file mode 100644
index 000000000..58d48c548
--- /dev/null
+++ b/applications/cytoland/examples/vcp_tutorials/quick_start.py
@@ -0,0 +1,332 @@
+# %% [markdown]
+"""
+# Quick Start: Cytoland
+
+**Estimated time to complete:** 15 minutes
+"""
+
+# %% [markdown]
+"""
+# Learning Goals
+
+* Download the VSCyto2D model and an example dataset containing A549 cell images.
+* Run VSCyto2D model inference for joint virtual staining of cell nuclei and plasma membrane.
+* Visualize and compare virtually and experimentally stained cells.
+"""
+
+# %% [markdown]
+"""
+# Prerequisites
+Python>=3.11
+
+"""
+
+# %% [markdown]
+"""
+# Introduction
+
+## Model
+
+The Cytoland virtual staining models are a collection of models (VSCyto2D, VSCyto3D, and VSNeuromast)
+used to predict cellular landmarks (e.g., nuclei and plasma membranes)
+from label-free images (e.g. quantitative phase, Zernike phase contrast, and brightfield).
+This quick-start guide focuses on the VSCyto2D model.
+
+VSCyto2D is a 2D UNeXt2 model that has been trained on A549, HEK293T, and BJ-5ta cells.
+This model enables users to jointly stain cell nuclei and plasma membranes from 2D label-free images
+that are commonly generated for image-based screens.
+
+Alternative models are optimized for different sample types and imaging conditions:
+
+* [VSCyto3D](https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto3D):
+3D UNeXt2 model for joint virtual staining of cell nuclei and plasma membrane
+from high-resolution volumetric images.
+* [VSNeuromast](https://public.czbiohub.org/comp.micro/viscy/VS_models/VSNeuromast):
+3D UNeXt2 model for joint virtual staining of nuclei and plasma membrane in zebrafish neuromasts.
+
+## Example Dataset
+
+The A549 example dataset used in this quick-start guide contains
+quantitative phase and paired fluorescence images of cell nuclei and plasma membrane.
+It is stored in OME-Zarr format and can be downloaded from
+[here](https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto2D/test/a549_hoechst_cellmask_test.zarr).
+It has pre-computed statistics for normalization, generated using the `viscy preprocess` CLI.
+
+Refer to our [preprint](https://doi.org/10.1101/2024.05.31.596901) for more details
+about how the dataset and model were generated.
+
+## User Data
+
+The VSCyto2D model only requires label-free images for inference.
+To run inference on your own data,
+convert them into the OME-Zarr data format using iohub or other
+[tools](https://ngff.openmicroscopy.org/tools/index.html#file-conversion),
+and run [pre-processing](https://github.com/mehta-lab/VisCy/blob/main/docs/usage.md#preprocessing)
+with the `viscy preprocess` CLI.
+"""
+
+# %% [markdown]
+"""
+# Setup
+
+The commands below will install the required packages and download the example dataset and model checkpoint.
+It may take a few minutes to download all the files.
+
+## Setup Google Colab
+
+To run this quick-start guide using Google Colab,
+choose the 'T4' GPU runtime from the "Connect" dropdown menu
+in the upper-right corner of this notebook for faster execution.
+Using a GPU significantly speeds up running model inference, but CPU compute can also be used.
+
+## Setup Local Environment
+
+The commands below assume a Unix-like shell with `wget` installed.
+On Windows, the files can be downloaded manually from the URLs.
+"""
+
+# %%
+# Install VisCy with the optional dependencies for this example
+# See the [repository](https://github.com/mehta-lab/VisCy) for more details
+# !pip install "viscy[metrics,visual]==0.4.0a3"
+
+# %%
+# restart kernel if running in Google Colab
+if "get_ipython" in globals():
+ session = get_ipython() # noqa: F821
+ if "google.colab" in str(session):
+ print("Shutting down colab session.")
+ session.kernel.do_shutdown(restart=True)
+
+# %%
+# Validate installation
+# !viscy --help
+
+# %%
+# Download the example dataset
+# !wget -m -np -nH --cut-dirs=5 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VS_datasets/VSCyto2D/test/a549_hoechst_cellmask_test.zarr/"
+# Download the model checkpoint
+# !wget https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto2D/VSCyto2D/epoch=399-step=23200.ckpt
+
+# %% [markdown]
+"""
+# Run Model Inference
+
+The following code will run inference on a single field of view (FOV) of the example dataset.
+This can also be achieved by using the VisCy CLI.
+"""
+
+# %%
+from pathlib import Path # noqa: E402
+
+from iohub import open_ome_zarr # noqa: E402
+from torchview import draw_graph # noqa: E402
+
+from cytoland.engine import FcmaeUNet # noqa: E402
+from viscy_data.hcs import HCSDataModule # noqa: E402
+from viscy_transforms import NormalizeSampled # noqa: E402
+from viscy_utils.callbacks import HCSPredictionWriter # noqa: E402
+from viscy_utils.trainer import VisCyTrainer # noqa: E402
+
+# %%
+# NOTE: Nothing needs to be changed in this code block for the example to work.
+# If using your own data, please modify the paths below.
+
+# TODO: Set download paths, by default the working directory is used
+root_dir = Path()
+# TODO: modify the path to the input dataset
+input_data_path = root_dir / "a549_hoechst_cellmask_test.zarr"
+# TODO: modify the path to the model checkpoint
+model_ckpt_path = root_dir / "epoch=399-step=23200.ckpt"
+# TODO: modify the path to save the predictions
+output_path = root_dir / "a549_prediction.zarr"
+# TODO: Choose an FOV
+fov = "0/0/0"
+
+
+# %%
+# Configure the data module for loading example images in prediction mode.
+# See API documentation for how to use it with a different dataset.
+# For example, View the documentation for the HCSDataModule class by running:
+# ?HCSDataModule
+
+# %%
+# Setup the data module to use the example dataset
+data_module = HCSDataModule(
+ # Path to HCS or Single-FOV OME-Zarr dataset
+ data_path=input_data_path / fov,
+ # Name of the input phase channel
+ source_channel="Phase3D",
+ # Desired name of the output channels
+ target_channel=["Membrane", "Nuclei"],
+ # Axial input size, 1 for 2D models
+ z_window_size=1,
+ # Batch size
+ # Adjust based on available memory (reduce if seeing OOM errors)
+ batch_size=8,
+ # Number of workers for data loading
+ # Set to 0 for Windows and macOS if running in a notebook,
+ # since multiprocessing only works with a `if __name__ == '__main__':` guard.
+ # On Linux, set it based on available CPU cores to maximize performance.
+ num_workers=4,
+ # Normalization strategy
+ # This one uses pre-computed statistics from `viscy preprocess`
+ # to subtract the median and divide by the interquartile range (IQR).
+ # It can also be replaced by other MONAI transforms.
+ normalizations=[
+ NormalizeSampled(
+ ["Phase3D"],
+ level="fov_statistics",
+ subtrahend="median",
+ divisor="iqr",
+ )
+ ],
+)
+
+# %%
+# Load the VSCyto2D model from the downloaded checkpoint
+# VSCyto2D is fine-tuned from a FCMAE-pretrained UNeXt2 model.
+# See this module for options to configure the model:
+
+# ?FullyConvolutionalMAE
+
+# %%
+vs_cyto_2d = FcmaeUNet.load_from_checkpoint(
+ # checkpoint path
+ model_ckpt_path,
+ model_config={
+ # number of input channels
+ # must match the number of channels in the input data
+ "in_channels": 1,
+ # number of output channels
+ # must match the number of target channels in the data module
+ "out_channels": 2,
+ # number of ConvNeXt v2 blocks in each stage of the encoder
+ "encoder_blocks": [3, 3, 9, 3],
+ # feature map channels in each stage of the encoder
+ "dims": [96, 192, 384, 768],
+ # number of ConvNeXt v2 blocks in each stage of the decoder
+ "decoder_conv_blocks": 2,
+ # kernel size in the stem layer
+ "stem_kernel_size": [1, 2, 2],
+ # axial size of the input image
+ # must match the Z-window size in the data module
+ "in_stack_depth": 1,
+ # whether to perform masking (for FCMAE pre-training)
+ "pretraining": False,
+ },
+)
+
+# %%
+# Visualize the model graph
+model_graph = draw_graph(
+ vs_cyto_2d,
+ (vs_cyto_2d.example_input_array),
+ graph_name="VSCyto2D",
+ roll=True,
+ depth=3,
+ expand_nested=True,
+)
+
+model_graph.visual_graph
+
+# %%
+# Setup the trainer for prediction
+# The trainer can be further configured to better utilize the available hardware,
+# For example using GPUs and half precision.
+# Callbacks can also be used to customize logging and prediction writing.
+# See the API documentation for more details:
+# ?VisCyTrainer
+
+# %%
+# Initialize the trainer
+# The prediction writer callback will save the predictions to an OME-Zarr store
+trainer = VisCyTrainer(callbacks=[HCSPredictionWriter(output_path)])
+
+# Run prediction
+trainer.predict(model=vs_cyto_2d, datamodule=data_module, return_predictions=False)
+
+# %% [markdown]
+"""
+# Model Outputs
+
+The model outputs are also stored in an OME-Zarr store.
+It can be visualized in an image viewer such as [napari](https://napari.org/).
+Below we show a snapshot in the notebook.
+"""
+
+# %%
+# Read images from Zarr stores
+# Choose the ROI for better visualization
+y_slice = slice(0, 512)
+x_slice = slice(0, 512)
+
+# Open the prediction store and get the 2D images from 5D arrays (t,c,z,y,x)
+with open_ome_zarr(output_path / fov) as vs_store:
+ vs_nucleus = vs_store[0][0, 0, 0, y_slice, x_slice]
+ vs_membrane = vs_store[0][0, 1, 0, y_slice, x_slice]
+
+# Open the experimental fluorescence dataset
+with open_ome_zarr(input_data_path / fov) as fluor_store:
+ fluor_nucleus = fluor_store[0][0, 1, 0, y_slice, x_slice]
+ fluor_membrane = fluor_store[0][0, 2, 0, y_slice, x_slice]
+
+# %%
+# Plot
+import matplotlib.pyplot as plt # noqa: E402
+import numpy as np # noqa: E402
+from cmap import Colormap # noqa: E402
+from skimage.exposure import rescale_intensity # noqa: E402
+
+
+def render_rgb(image: np.ndarray, colormap: Colormap):
+ image = rescale_intensity(image, out_range=(0, 1))
+ image = colormap(image)
+ return image
+
+
+# Render the images as RGB in false colors
+vs_nucleus_rgb = render_rgb(vs_nucleus, Colormap("bop_blue"))
+vs_membrane_rgb = render_rgb(vs_membrane, Colormap("bop_orange"))
+merged_vs = (vs_nucleus_rgb + vs_membrane_rgb).clip(0, 1)
+
+fluor_nucleus_rgb = render_rgb(fluor_nucleus, Colormap("green"))
+fluor_membrane_rgb = render_rgb(fluor_membrane, Colormap("magenta"))
+merged_fluor = (fluor_nucleus_rgb + fluor_membrane_rgb).clip(0, 1)
+
+# Plot
+# Show the individual channels and then fused in a grid
+fig, ax = plt.subplots(2, 3, figsize=(15, 10))
+
+# Virtual staining plots
+ax[0, 0].imshow(vs_nucleus_rgb)
+ax[0, 0].set_title("VS Nuclei")
+ax[0, 1].imshow(vs_membrane_rgb)
+ax[0, 1].set_title("VS Membrane")
+ax[0, 2].imshow(merged_vs)
+ax[0, 2].set_title("VS Nuclei+Membrane")
+
+# Experimental fluorescence plots
+ax[1, 0].imshow(fluor_nucleus_rgb)
+ax[1, 0].set_title("Experimental Fluorescence Nuclei")
+ax[1, 1].imshow(fluor_membrane_rgb)
+ax[1, 1].set_title("Experimental Fluorescence Membrane")
+ax[1, 2].imshow(merged_fluor)
+ax[1, 2].set_title("Experimental Fluorescence Nuclei+Membrane")
+
+# turnoff axis
+for a in ax.flatten():
+ a.axis("off")
+plt.tight_layout()
+plt.show()
+
+# %% [markdown]
+"""
+## Responsible Use
+
+We are committed to advancing the responsible development and use of artificial intelligence.
+Please follow our [Acceptable Use Policy](https://virtualcellmodels.cziscience.com/acceptable-use-policy) when engaging with our services.
+
+Should you have any security or privacy issues or questions related to the services,
+please reach out to our team at [security@chanzuckerberg.com](mailto:security@chanzuckerberg.com) or [privacy@chanzuckerberg.com](mailto:privacy@chanzuckerberg.com) respectively.
+"""
diff --git a/pyproject.toml b/pyproject.toml
index 5caca7b62..1ec272574 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -76,6 +76,7 @@ lint.per-file-ignores."**/*.ipynb" = [ "D", "E402", "E501", "PD" ]
lint.per-file-ignores."**/__init__.py" = [ "D104", "F401" ]
lint.per-file-ignores."**/docs/**" = [ "I" ]
lint.per-file-ignores."**/evaluation/**" = [ "D", "E501", "NPY002", "PD011" ]
+lint.per-file-ignores."**/examples/**" = [ "D", "E402", "E501", "F821" ]
lint.per-file-ignores."**/tests/**" = [ "D" ]
lint.pydocstyle.convention = "numpy"