diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index f87fcebea8..93c2820ed5 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -333,8 +333,16 @@ class DiNTS(nn.Module): The architecture codes will be initialized as one. - ``TopologyConstruction`` is the parent class which constructs the instance and search space. - To meet the requirements of the structure, the input size for each spatial dimension should be: - divisible by 2 ** (num_depths + 1). + Spatial Shape Constraints: + Each spatial dimension of the input must be divisible by ``2 ** (num_depths + int(use_downsample))``. + + - With ``use_downsample=True`` (default) and ``num_depths=3`` (default): divisible by ``2 ** 4 = 16``. + - With ``use_downsample=False`` and ``num_depths=3``: divisible by ``2 ** 3 = 8``. + + This requirement arises from the multi-resolution stem downsampling the input ``num_depths`` times + (each by a factor of 2), plus one additional factor of 2 when ``use_downsample=True``. + + A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint. Args: dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`. @@ -346,6 +354,7 @@ class DiNTS(nn.Module): use_downsample: use downsample in the stem. If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8], if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16]. + Affects the input size divisibility requirement: ``2 ** (num_depths + int(use_downsample))``. node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`. +1 for multi-resolution inputs. In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None. @@ -481,13 +490,40 @@ def __init__( def weight_parameters(self): return [param for name, param in self.named_parameters()] + @torch.jit.unused + def _check_input_size(self, spatial_shape): + """ + Validate that input spatial dimensions satisfy the divisibility requirement. + + Each spatial dimension must be divisible by ``2 ** (num_depths + int(use_downsample))``. + + Args: + spatial_shape: spatial dimensions of the input tensor (excluding batch and channel dims). + + Raises: + ValueError: if any spatial dimension is not divisible by the required factor. + """ + factor = 2 ** (self.num_depths + int(self.dints_space.use_downsample)) + wrong_dims = [i + 2 for i, s in enumerate(spatial_shape) if s % factor != 0] + if wrong_dims: + raise ValueError( + f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})" + f" must be divisible by 2 ** (num_depths + int(use_downsample)) = {factor}." + ) + def forward(self, x: torch.Tensor): """ Prediction based on dynamic arch_code. Args: x: input tensor. + + Raises: + ValueError: if any spatial dimension of ``x`` is not divisible by + ``2 ** (num_depths + int(use_downsample))``. """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + self._check_input_size(x.shape[2:]) inputs = [] for d in range(self.num_depths): # allow multi-resolution input diff --git a/tests/networks/nets/test_dints_network.py b/tests/networks/nets/test_dints_network.py index 80ade00db7..2a088dd0f5 100644 --- a/tests/networks/nets/test_dints_network.py +++ b/tests/networks/nets/test_dints_network.py @@ -84,8 +84,8 @@ "use_downsample": True, "spatial_dims": 2, }, - (2, 2, 32, 16), - (2, 2, 32, 16), + (2, 2, 32, 32), # use_downsample=True, num_depths=4 -> factor=32; both dims must be divisible by 32 + (2, 2, 32, 32), ] ] if torch.cuda.is_available(): @@ -153,6 +153,24 @@ def test_dints_search(self, dints_grid_params, dints_params, input_shape, expect self.assertTrue(isinstance(net.weight_parameters(), list)) +class TestDintsInputShape(unittest.TestCase): + def test_invalid_input_shape_3d(self): + # num_depths=3, use_downsample=True -> factor = 2**(3+1) = 16 + # 33 is not divisible by 16 + grid = TopologySearch(channel_mul=0.2, num_blocks=6, num_depths=3, use_downsample=True, spatial_dims=3) + net = DiNTS(dints_space=grid, in_channels=1, num_classes=2, use_downsample=True, spatial_dims=3) + with self.assertRaises(ValueError): + net(torch.randn(1, 1, 33, 32, 32)) + + def test_invalid_input_shape_2d(self): + # num_depths=3, use_downsample=False -> factor = 2**(3+0) = 8 + # 33 is not divisible by 8 + grid = TopologySearch(channel_mul=0.2, num_blocks=6, num_depths=3, use_downsample=False, spatial_dims=2) + net = DiNTS(dints_space=grid, in_channels=1, num_classes=2, use_downsample=False, spatial_dims=2) + with self.assertRaises(ValueError): + net(torch.randn(1, 1, 33, 32)) + + class TestDintsTS(unittest.TestCase): @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) def test_script(self, dints_grid_params, dints_params, input_shape, _):