diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index bdd833af9d..61bfa1ab8a 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -45,14 +45,17 @@ def point_based_window_inferer( patch inference and average output stitching, and finally returns the segmented mask. Args: - inputs: [1CHWD], input image to be processed. + inputs: [1CWHD], input image to be processed (spatial axes are in + Width, Height, Depth order; i.e., for [N, C, W, H, D], Width is + axis 2, Height is axis 3, Depth is axis 4, matching arrays + returned by MONAI's NIfTI/ITK readers). roi_size: the spatial window size for inferences. When its components have None or non-positives, the corresponding inputs dimension will be used. if the components of the `roi_size` are non-positive values, the transform will use the corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. sw_batch_size: the batch size to run window slices. - predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. + predictor: the model. For vista3D, the output is [B, 1, W, H, D] which needs to be transposed to [1, B, W, H, D]. Add transpose=True in kwargs for vista3d. point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points. point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes. @@ -61,13 +64,13 @@ def point_based_window_inferer( prompt_class: [B]. The same as class_vector representing the point class and inform point head about supported class or zeroshot, not used for automatic segmentation. If None, point head is default to supported class segmentation. - prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks. + prev_mask: [1, B, W, H, D]. The value is before sigmoid. An optional tensor of previously segmented masks. point_start: only use points starting from this number. All points before this number is used to generate prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask. center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point. margin: if center_only is false, this value is the distance between point to the patch boundary. Returns: - stitched_output: [1, B, H, W, D]. The value is before sigmoid. + stitched_output: [1, B, W, H, D]. The value is before sigmoid. Notice: The function only supports SINGLE OBJECT INFERENCE with B=1. """ if not point_coords.shape[0] == 1: diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index cc6cdcdead..fa9ab5523c 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -226,8 +226,11 @@ def resample_if_needed( transformation computed from ``affine`` and ``target_affine``. This function assumes the NIfTI dimension notations. Spatially it - supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D - respectively. When saving multiple time steps or multiple channels, + supports up to three dimensions, that is, ``W``, ``WH``, ``WHD`` for + 1D, 2D, 3D respectively (equivalently ``X``, ``XY``, ``XYZ``; axis 0 + is columns/Width, axis 1 is rows/Height, axis 2 is Depth/slices, + matching the array order returned by NIfTI/ITK readers). When saving + multiple time steps or multiple channels, time and/or modality axes should be appended after the first three dimensions. For example, shape of 2D eight-class segmentation probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in @@ -303,8 +306,8 @@ def convert_to_channel_last( ``None`` indicates no channel dimension, a new axis will be appended as the channel dimension. a sequence of integers indicates multiple non-spatial dimensions. squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel - has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`. - If D is also 1, it will be saved as `(H,W)`. If ``False``, image will always be saved as `(H,W,D,C)`. + has been moved to the end). So if input is `(W,H,D,C)` and C==1, then it will be saved as `(W,H,D)`. + If D is also 1, it will be saved as `(W,H)`. If ``False``, image will always be saved as `(W,H,D,C)`. spatial_ndim: modifying the spatial dims if needed, so that output to have at least this number of spatial dims. If ``None``, the output will have the same number of spatial dimensions as the input. diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 20e2d74c8c..a24f47ab71 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -288,8 +288,8 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter | Summ class TensorBoardImageHandler(TensorBoardHandler): """ TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images. - 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, - for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' + 2D output (shape in Batch, channel, W, H) will be shown as simple image using the first element in the batch, + for 3D to ND output (shape in Batch, channel, W, H, D) input, each of ``self.max_channels`` number of images' last three dimensions will be shown as animated GIF along the last axis (typically Depth). And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video. @@ -350,7 +350,7 @@ def __init__( index: plot which element in a data batch, default is the first element. max_channels: number of channels to plot. frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, - expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) + expect input data shape as `NCWHD`, default to `-3` (the first spatial dim) max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. """ super().__init__(summary_writer=summary_writer, log_dir=log_dir) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index df922b1eca..309813d094 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -50,14 +50,16 @@ def _image3_animated_gif( Args: tag: Data identifier - image: 3D image tensors expected to be in `HWD` format + image: 3D image tensors expected to be in `WHD` format (axis 0 is + columns/Width, axis 1 is rows/Height, axis 2 is Depth, matching + arrays returned by MONAI's NIfTI/ITK readers). writer: the tensorboard writer to plot image - frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`. + frame_dim: the dimension used as frames for GIF image, expect data shape as `WHD`, default to `0`. scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ if len(image.shape) != 3: - raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3") + raise AssertionError("3D image tensors expected to be in `WHD` format, len(image.shape) != 3") image_np, *_ = convert_data_type(image, output_type=np.ndarray) ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)] @@ -85,14 +87,15 @@ def make_animated_gif_summary( frame_dim: int = -3, scale_factor: float = 1.0, ) -> Summary: - """Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary. + """Creates an animated gif out of an image tensor in 'CWHD' format and returns Summary. Args: tag: Data identifier - image: The image, expected to be in `CHWD` format + image: The image, expected to be in `CWHD` format (channel-first; spatial axes are + Width, Height, Depth, matching arrays returned by MONAI's NIfTI/ITK readers). writer: the tensorboard writer to plot image max_out: maximum number of image channels to animate through - frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`, + frame_dim: the dimension used as frames for GIF image, expect input data shape as `CWHD`, default to `-3` (the first spatial dim) scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range @@ -122,14 +125,16 @@ def add_animated_gif( scale_factor: float = 1.0, global_step: int | None = None, ) -> None: - """Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter. + """Creates an animated gif out of an image tensor in 'CWHD' format and writes it with SummaryWriter. Args: writer: Tensorboard SummaryWriter to write to tag: Data identifier - image_tensor: tensor for the image to add, expected to be in `CHWD` format + image_tensor: tensor for the image to add, expected to be in `CWHD` format (channel-first; + spatial axes are Width, Height, Depth, matching arrays returned by MONAI's + NIfTI/ITK readers). max_out: maximum number of image channels to animate through - frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`, + frame_dim: the dimension used as frames for GIF image, expect input data shape as `CWHD`, default to `-3` (the first spatial dim) scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will scale it to displayable range @@ -168,7 +173,7 @@ def plot_2d_or_3d_image( index: plot which element in the input data batch, default is the first element. max_channels: number of channels to plot. frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, - expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) + expect input data shape as `NCWHD`, default to `-3` (the first spatial dim) max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. tag: tag of the plotted image on TensorBoard. """ diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py index e79fbba847..01db048c27 100644 --- a/monai/visualize/utils.py +++ b/monai/visualize/utils.py @@ -53,19 +53,24 @@ def matshow3d( Create a 3D volume figure as a grid of images. Args: - volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`. - Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg. - A list of channel-first (C, H[, W, D]) arrays can also be passed in, + volume: 3D volume to display. data shape can be `BCWHD`, `CWHD` or `WHD` + (axis 0 is columns/Width, axis 1 is rows/Height, axis 2 is Depth, matching + arrays returned by MONAI's NIfTI/ITK readers). + Higher dimensional arrays will be reshaped into (-1, spatial0, spatial1, [C]), + `C` depends on `channel_dim` arg. + A list of channel-first (C, W[, H, D]) arrays can also be passed in, in which case they will be displayed as a padded and stacked volume. fig: matplotlib figure or Axes to use. If None, a new figure will be created. title: title of the figure. figsize: size of the figure. frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used. frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to - the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`. + the `-3` dimension, then reshaped to (-1, spatial0, spatial1) to construct frames, + default to `-3`. channel_dim: if not None, explicitly specify the channel dimension to be transposed to the - last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image. - if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W). + last dimension as shape (-1, spatial0, spatial1, C). this can be used to plot RGB color image. + if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as + shape (-1, spatial0, spatial1). note that it can only support 3D input image. default is None. vmin: `vmin` for the matplotlib `imshow`. vmax: `vmax` for the matplotlib `imshow`.