Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions monai/networks/nets/dints.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,16 @@ class DiNTS(nn.Module):
The architecture codes will be initialized as one.
- ``TopologyConstruction`` is the parent class which constructs the instance and search space.

To meet the requirements of the structure, the input size for each spatial dimension should be:
divisible by 2 ** (num_depths + 1).
Spatial Shape Constraints:
Each spatial dimension of the input must be divisible by ``2 ** (num_depths + int(use_downsample))``.

- With ``use_downsample=True`` (default) and ``num_depths=3`` (default): divisible by ``2 ** 4 = 16``.
- With ``use_downsample=False`` and ``num_depths=3``: divisible by ``2 ** 3 = 8``.

This requirement arises from the multi-resolution stem downsampling the input ``num_depths`` times
(each by a factor of 2), plus one additional factor of 2 when ``use_downsample=True``.

A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint.

Args:
dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`.
Expand All @@ -346,6 +354,7 @@ class DiNTS(nn.Module):
use_downsample: use downsample in the stem.
If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8],
if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].
Affects the input size divisibility requirement: ``2 ** (num_depths + int(use_downsample))``.
node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`.
+1 for multi-resolution inputs.
In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None.
Expand Down Expand Up @@ -481,13 +490,40 @@ def __init__(
def weight_parameters(self):
return [param for name, param in self.named_parameters()]

@torch.jit.unused
def _check_input_size(self, spatial_shape):
"""
Validate that input spatial dimensions satisfy the divisibility requirement.

Each spatial dimension must be divisible by ``2 ** (num_depths + int(use_downsample))``.

Args:
spatial_shape: spatial dimensions of the input tensor (excluding batch and channel dims).

Raises:
ValueError: if any spatial dimension is not divisible by the required factor.
"""
factor = 2 ** (self.num_depths + int(self.dints_space.use_downsample))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Unify use_downsample source to avoid inconsistent validation.

_check_input_size() uses self.dints_space.use_downsample, but stem construction in DiNTS.__init__ uses the use_downsample argument. If these differ, validation and actual network path can diverge.

Proposed fix
 class DiNTS(nn.Module):
@@
     def __init__(
@@
         use_downsample: bool = True,
@@
     ):
         super().__init__()
+        if hasattr(dints_space, "use_downsample") and dints_space.use_downsample != use_downsample:
+            raise ValueError(
+                f"DiNTS.use_downsample ({use_downsample}) must match dints_space.use_downsample "
+                f"({dints_space.use_downsample})."
+            )
+        self.use_downsample = use_downsample
@@
     def _check_input_size(self, spatial_shape):
@@
-        factor = 2 ** (self.num_depths + int(self.dints_space.use_downsample))
+        factor = 2 ** (self.num_depths + int(self.use_downsample))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/nets/dints.py` at line 505, The validation uses
self.dints_space.use_downsample while the constructor uses the use_downsample
argument, allowing divergence; ensure both use the same source by assigning the
constructor argument into the config before any validation or stem construction
(e.g., set self.dints_space.use_downsample = use_downsample early in
DiNTS.__init__), or alternatively change _check_input_size() to read the
use_downsample argument passed into __init__; update the places computing factor
(currently using self.dints_space.use_downsample) and stem construction to
reference the unified source so _check_input_size, DiNTS.__init__, and factor
calculation are consistent.

wrong_dims = [i + 2 for i, s in enumerate(spatial_shape) if s % factor != 0]
if wrong_dims:
raise ValueError(
f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
f" must be divisible by 2 ** (num_depths + int(use_downsample)) = {factor}."
)

def forward(self, x: torch.Tensor):
"""
Prediction based on dynamic arch_code.

Args:
x: input tensor.

Raises:
ValueError: if any spatial dimension of ``x`` is not divisible by
``2 ** (num_depths + int(use_downsample))``.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
self._check_input_size(x.shape[2:])
inputs = []
for d in range(self.num_depths):
# allow multi-resolution input
Expand Down
22 changes: 20 additions & 2 deletions tests/networks/nets/test_dints_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@
"use_downsample": True,
"spatial_dims": 2,
},
(2, 2, 32, 16),
(2, 2, 32, 16),
(2, 2, 32, 32), # use_downsample=True, num_depths=4 -> factor=32; both dims must be divisible by 32
(2, 2, 32, 32),
]
]
if torch.cuda.is_available():
Expand Down Expand Up @@ -153,6 +153,24 @@ def test_dints_search(self, dints_grid_params, dints_params, input_shape, expect
self.assertTrue(isinstance(net.weight_parameters(), list))


class TestDintsInputShape(unittest.TestCase):
def test_invalid_input_shape_3d(self):
# num_depths=3, use_downsample=True -> factor = 2**(3+1) = 16
# 33 is not divisible by 16
grid = TopologySearch(channel_mul=0.2, num_blocks=6, num_depths=3, use_downsample=True, spatial_dims=3)
net = DiNTS(dints_space=grid, in_channels=1, num_classes=2, use_downsample=True, spatial_dims=3)
with self.assertRaises(ValueError):
net(torch.randn(1, 1, 33, 32, 32))

def test_invalid_input_shape_2d(self):
# num_depths=3, use_downsample=False -> factor = 2**(3+0) = 8
# 33 is not divisible by 8
grid = TopologySearch(channel_mul=0.2, num_blocks=6, num_depths=3, use_downsample=False, spatial_dims=2)
net = DiNTS(dints_space=grid, in_channels=1, num_classes=2, use_downsample=False, spatial_dims=2)
with self.assertRaises(ValueError):
net(torch.randn(1, 1, 33, 32))


class TestDintsTS(unittest.TestCase):
@parameterized.expand(TEST_CASES_3D + TEST_CASES_2D)
def test_script(self, dints_grid_params, dints_params, input_shape, _):
Expand Down