Skip to content

Commit 827a4a7

Browse files
authored
Add spatial shape constraint docs and test for SwinUNETR (#6771) (#8817)
Fixes #6771 . ### Description Adds documentation of spatial shape constraints for `SwinUNETR`. Each input spatial dimension must be divisible by `patch_size ** 5` (32 by default with `patch_size=2`). The runtime validation logic already existed in `_check_input_size()` but was undocumented. This PR adds a `Spatial Shape Constraints` section to the class docstring, updates the `patch_size` arg description in `__init__`, and adds a test to verify that `forward()` raises `ValueError` for invalid spatial shapes. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Adrian Caderno <adriancaderno@gmail.com>
1 parent 5a2d0a7 commit 827a4a7

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

monai/networks/nets/swin_unetr.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ class SwinUNETR(nn.Module):
4747
Swin UNETR based on: "Hatamizadeh et al.,
4848
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
4949
<https://arxiv.org/abs/2201.01266>"
50+
51+
Spatial Shape Constraints:
52+
Each spatial dimension of the input must be divisible by ``patch_size ** 5``.
53+
With the default ``patch_size=2``, this means each spatial dimension must be divisible by **32**
54+
(i.e., 2^5 = 32). This requirement comes from the patch embedding step followed by 4 stages
55+
of PatchMerging downsampling, each halving the spatial resolution.
56+
57+
For a custom ``patch_size``, the divisibility requirement is ``patch_size ** 5``.
58+
59+
Examples of valid 3D input sizes (with default ``patch_size=2``):
60+
``(32, 32, 32)``, ``(64, 64, 64)``, ``(96, 96, 96)``, ``(128, 128, 128)``, ``(64, 32, 192)``.
61+
62+
A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint.
5063
"""
5164

5265
def __init__(
@@ -76,7 +89,8 @@ def __init__(
7689
Args:
7790
in_channels: dimension of input channels.
7891
out_channels: dimension of output channels.
79-
patch_size: size of the patch token.
92+
patch_size: size of the patch token. Input spatial dimensions must be divisible by
93+
``patch_size ** 5`` (e.g., divisible by 32 when ``patch_size=2``).
8094
feature_size: dimension of network feature size.
8195
depths: number of layers in each stage.
8296
num_heads: number of attention heads.
@@ -108,6 +122,10 @@ def __init__(
108122
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
109123
>>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
110124
125+
Raises:
126+
ValueError: When a spatial dimension of the input is not divisible by ``patch_size ** 5``.
127+
Use ``net._check_input_size(spatial_shape)`` to validate a shape before inference.
128+
111129
"""
112130

113131
super().__init__()

tests/networks/nets/test_swin_unetr.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,17 @@ def test_ill_arg(self):
9090
with self.assertRaises(ValueError):
9191
SwinUNETR(in_channels=1, out_channels=3, feature_size=24, norm_name="instance", drop_rate=-1)
9292

93+
@skipUnless(has_einops, "Requires einops")
94+
def test_invalid_input_shape(self):
95+
# spatial dims not divisible by patch_size**5 (default patch_size=2, so must be divisible by 32)
96+
net = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=3)
97+
with self.assertRaises(ValueError):
98+
net(torch.randn(1, 1, 33, 64, 64)) # 33 is not divisible by 32
99+
100+
net_2d = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=2)
101+
with self.assertRaises(ValueError):
102+
net_2d(torch.randn(1, 1, 48, 33)) # 33 is not divisible by 32
103+
93104
def test_patch_merging(self):
94105
dim = 10
95106
t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim)))

0 commit comments

Comments
 (0)