Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions monai/apps/vista3d/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ def point_based_window_inferer(
patch inference and average output stitching, and finally returns the segmented mask.

Args:
inputs: [1CHWD], input image to be processed.
inputs: [1CWHD], input image to be processed (spatial axes are in
Width, Height, Depth order; i.e., for [N, C, W, H, D], Width is
axis 2, Height is axis 3, Depth is axis 4, matching arrays
returned by MONAI's NIfTI/ITK readers).
roi_size: the spatial window size for inferences.
When its components have None or non-positives, the corresponding inputs dimension will be used.
if the components of the `roi_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
sw_batch_size: the batch size to run window slices.
predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
predictor: the model. For vista3D, the output is [B, 1, W, H, D] which needs to be transposed to [1, B, W, H, D].
Add transpose=True in kwargs for vista3d.
point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
Expand All @@ -61,13 +64,13 @@ def point_based_window_inferer(
prompt_class: [B]. The same as class_vector representing the point class and inform point head about
supported class or zeroshot, not used for automatic segmentation. If None, point head is default
to supported class segmentation.
prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
prev_mask: [1, B, W, H, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
point_start: only use points starting from this number. All points before this number is used to generate
prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
margin: if center_only is false, this value is the distance between point to the patch boundary.
Returns:
stitched_output: [1, B, H, W, D]. The value is before sigmoid.
stitched_output: [1, B, W, H, D]. The value is before sigmoid.
Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
"""
if not point_coords.shape[0] == 1:
Expand Down
11 changes: 7 additions & 4 deletions monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,11 @@ def resample_if_needed(
transformation computed from ``affine`` and ``target_affine``.

This function assumes the NIfTI dimension notations. Spatially it
supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D
respectively. When saving multiple time steps or multiple channels,
supports up to three dimensions, that is, ``W``, ``WH``, ``WHD`` for
1D, 2D, 3D respectively (equivalently ``X``, ``XY``, ``XYZ``; axis 0
is columns/Width, axis 1 is rows/Height, axis 2 is Depth/slices,
matching the array order returned by NIfTI/ITK readers). When saving
multiple time steps or multiple channels,
time and/or modality axes should be appended after the first three
dimensions. For example, shape of 2D eight-class segmentation
probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in
Expand Down Expand Up @@ -303,8 +306,8 @@ def convert_to_channel_last(
``None`` indicates no channel dimension, a new axis will be appended as the channel dimension.
a sequence of integers indicates multiple non-spatial dimensions.
squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel
has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`.
If D is also 1, it will be saved as `(H,W)`. If ``False``, image will always be saved as `(H,W,D,C)`.
has been moved to the end). So if input is `(W,H,D,C)` and C==1, then it will be saved as `(W,H,D)`.
If D is also 1, it will be saved as `(W,H)`. If ``False``, image will always be saved as `(W,H,D,C)`.
spatial_ndim: modifying the spatial dims if needed, so that output to have at least
this number of spatial dims. If ``None``, the output will have the same number of
spatial dimensions as the input.
Expand Down
6 changes: 3 additions & 3 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter | Summ
class TensorBoardImageHandler(TensorBoardHandler):
"""
TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images.
2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch,
for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images'
2D output (shape in Batch, channel, W, H) will be shown as simple image using the first element in the batch,
for 3D to ND output (shape in Batch, channel, W, H, D) input, each of ``self.max_channels`` number of images'
last three dimensions will be shown as animated GIF along the last axis (typically Depth).
And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.

Expand Down Expand Up @@ -350,7 +350,7 @@ def __init__(
index: plot which element in a data batch, default is the first element.
max_channels: number of channels to plot.
frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,
expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)
expect input data shape as `NCWHD`, default to `-3` (the first spatial dim)
max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
"""
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
Expand Down
25 changes: 15 additions & 10 deletions monai/visualize/img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@ def _image3_animated_gif(

Args:
tag: Data identifier
image: 3D image tensors expected to be in `HWD` format
image: 3D image tensors expected to be in `WHD` format (axis 0 is
columns/Width, axis 1 is rows/Height, axis 2 is Depth, matching
arrays returned by MONAI's NIfTI/ITK readers).
writer: the tensorboard writer to plot image
frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`.
frame_dim: the dimension used as frames for GIF image, expect data shape as `WHD`, default to `0`.
scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will
scale it to displayable range
"""
if len(image.shape) != 3:
raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3")
raise AssertionError("3D image tensors expected to be in `WHD` format, len(image.shape) != 3")

image_np, *_ = convert_data_type(image, output_type=np.ndarray)
ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)]
Expand Down Expand Up @@ -85,14 +87,15 @@ def make_animated_gif_summary(
frame_dim: int = -3,
scale_factor: float = 1.0,
) -> Summary:
"""Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary.
"""Creates an animated gif out of an image tensor in 'CWHD' format and returns Summary.

Args:
tag: Data identifier
image: The image, expected to be in `CHWD` format
image: The image, expected to be in `CWHD` format (channel-first; spatial axes are
Width, Height, Depth, matching arrays returned by MONAI's NIfTI/ITK readers).
writer: the tensorboard writer to plot image
max_out: maximum number of image channels to animate through
frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,
frame_dim: the dimension used as frames for GIF image, expect input data shape as `CWHD`,
default to `-3` (the first spatial dim)
scale_factor: amount to multiply values by.
if the image data is between 0 and 1, using 255 for this value will scale it to displayable range
Expand Down Expand Up @@ -122,14 +125,16 @@ def add_animated_gif(
scale_factor: float = 1.0,
global_step: int | None = None,
) -> None:
"""Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter.
"""Creates an animated gif out of an image tensor in 'CWHD' format and writes it with SummaryWriter.

Args:
writer: Tensorboard SummaryWriter to write to
tag: Data identifier
image_tensor: tensor for the image to add, expected to be in `CHWD` format
image_tensor: tensor for the image to add, expected to be in `CWHD` format (channel-first;
spatial axes are Width, Height, Depth, matching arrays returned by MONAI's
NIfTI/ITK readers).
max_out: maximum number of image channels to animate through
frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,
frame_dim: the dimension used as frames for GIF image, expect input data shape as `CWHD`,
default to `-3` (the first spatial dim)
scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will
scale it to displayable range
Expand Down Expand Up @@ -168,7 +173,7 @@ def plot_2d_or_3d_image(
index: plot which element in the input data batch, default is the first element.
max_channels: number of channels to plot.
frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,
expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)
expect input data shape as `NCWHD`, default to `-3` (the first spatial dim)
max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
tag: tag of the plotted image on TensorBoard.
"""
Expand Down
17 changes: 11 additions & 6 deletions monai/visualize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,24 @@ def matshow3d(
Create a 3D volume figure as a grid of images.

Args:
volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`.
Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg.
A list of channel-first (C, H[, W, D]) arrays can also be passed in,
volume: 3D volume to display. data shape can be `BCWHD`, `CWHD` or `WHD`
(axis 0 is columns/Width, axis 1 is rows/Height, axis 2 is Depth, matching
arrays returned by MONAI's NIfTI/ITK readers).
Higher dimensional arrays will be reshaped into (-1, spatial0, spatial1, [C]),
`C` depends on `channel_dim` arg.
A list of channel-first (C, W[, H, D]) arrays can also be passed in,
in which case they will be displayed as a padded and stacked volume.
fig: matplotlib figure or Axes to use. If None, a new figure will be created.
title: title of the figure.
figsize: size of the figure.
frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used.
frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to
the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`.
the `-3` dimension, then reshaped to (-1, spatial0, spatial1) to construct frames,
default to `-3`.
channel_dim: if not None, explicitly specify the channel dimension to be transposed to the
last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image.
if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W).
last dimension as shape (-1, spatial0, spatial1, C). this can be used to plot RGB color image.
if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as
shape (-1, spatial0, spatial1).
note that it can only support 3D input image. default is None.
vmin: `vmin` for the matplotlib `imshow`.
vmax: `vmax` for the matplotlib `imshow`.
Expand Down
Loading