diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 4bfb8e3edd..6c40fed3fe 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -85,6 +85,9 @@ cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") exposure, has_skimage = optional_import("skimage.exposure") +# NOTE: cucim is deliberately NOT imported at module level. +# Module-level cucim imports caused very slow import times and other buggy behaviour. +# Keep cucim imports inside the functions that need them. __all__ = [ "allow_missing_keys_mode", diff --git a/monai/utils/module.py b/monai/utils/module.py index a64f73cd6b..c8851714ce 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -17,6 +17,7 @@ import pdb import re import sys +import traceback as traceback_mod import warnings from collections.abc import Callable, Collection, Hashable, Iterable, Mapping from functools import partial, wraps @@ -368,8 +369,9 @@ def optional_import( OptionalImportError: from torch.nn.functional import conv1d (requires version '42' by 'min_version'). """ - tb = None + had_exception = False exception_str = "" + tb_str = "" if name: actual_cmd = f"from {module} import {name}" else: @@ -384,8 +386,12 @@ def optional_import( if name: # user specified to load class/function/... from the module the_module = getattr(the_module, name) except Exception as import_exception: # any exceptions during import - tb = import_exception.__traceback__ + tb_str = "".join( + traceback_mod.format_exception(type(import_exception), import_exception, import_exception.__traceback__) + ) + import_exception.__traceback__ = None exception_str = f"{import_exception}" + had_exception = True else: # found the module if version_args and version_checker(pkg, f"{version}", version_args): return the_module, True @@ -394,7 +400,7 @@ def optional_import( # preparing lazy error message msg = descriptor.format(actual_cmd) - if version and tb is None: # a pure version issue + if version and not had_exception: # a pure version issue msg += f" (requires '{module} {version}' by '{version_checker.__name__}')" if exception_str: msg += f" ({exception_str})" @@ -407,10 +413,9 @@ def __init__(self, *_args, **_kwargs): + "\n\nFor details about installing the optional dependencies, please visit:" + "\n https://monai.readthedocs.io/en/latest/installation.html#installing-the-recommended-dependencies" ) - if tb is None: - self._exception = OptionalImportError(_default_msg) - else: - self._exception = OptionalImportError(_default_msg).with_traceback(tb) + if tb_str: + _default_msg += f"\n\nOriginal traceback:\n{tb_str}" + self._exception = OptionalImportError(_default_msg) def __getattr__(self, name): """ diff --git a/tests/utils/test_optional_import.py b/tests/utils/test_optional_import.py index 2f640f88d0..b5bc914c92 100644 --- a/tests/utils/test_optional_import.py +++ b/tests/utils/test_optional_import.py @@ -11,7 +11,9 @@ from __future__ import annotations +import gc import unittest +import weakref from parameterized import parameterized @@ -75,6 +77,34 @@ def versioning(module, ver, a): nn, flag = optional_import("torch", "1.1", version_checker=versioning, name="nn", version_args=test_args) self.assertTrue(flag) + def test_no_traceback_leak(self): + """Verify optional_import does not retain references to stack frames (issue #7480).""" + + class _Marker: + pass + + def _do_import(): + marker = _Marker() + ref = weakref.ref(marker) + # Call optional_import for a module that does not exist. + # If the traceback is leaked, `marker` stays alive via frame references. + _mod, flag = optional_import("nonexistent_module_for_leak_test") + self.assertFalse(flag) + return ref + + ref = _do_import() + gc.collect() + self.assertIsNone(ref(), "optional_import is leaking frame references via traceback") + + def test_failed_import_shows_traceback_string(self): + """Verify the error message includes the original traceback as a string.""" + mod, flag = optional_import("nonexistent_module_for_tb_test") + self.assertFalse(flag) + with self.assertRaises(OptionalImportError) as ctx: + _ = mod.something + self.assertIn("Original traceback", str(ctx.exception)) + self.assertIn("ModuleNotFoundError", str(ctx.exception)) + if __name__ == "__main__": unittest.main()