diff --git a/bec_lib/bec_lib/utils/__init__.py b/bec_lib/bec_lib/utils/__init__.py index 1662e9c18..1791ee625 100644 --- a/bec_lib/bec_lib/utils/__init__.py +++ b/bec_lib/bec_lib/utils/__init__.py @@ -1,3 +1,26 @@ -from bec_lib.utils.rpc_utils import user_access -from bec_lib.utils.scan_utils import scan_to_csv, scan_to_dict -from bec_lib.utils.threading_utils import threadlocked +from importlib import import_module +from typing import Any + +_LAZY_EXPORTS = { + "lazy_import": ("bec_lib.utils.import_utils", "lazy_import"), + "lazy_import_from": ("bec_lib.utils.import_utils", "lazy_import_from"), + "scan_to_csv": ("bec_lib.utils.scan_utils", "scan_to_csv"), + "scan_to_dict": ("bec_lib.utils.scan_utils", "scan_to_dict"), + "threadlocked": ("bec_lib.utils.threading_utils", "threadlocked"), + "user_access": ("bec_lib.utils.rpc_utils", "user_access"), +} + +__all__ = sorted(_LAZY_EXPORTS) + + +def __getattr__(name: str) -> Any: + if name not in _LAZY_EXPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + module_name, attr_name = _LAZY_EXPORTS[name] + value = getattr(import_module(module_name), attr_name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/bec_lib/bec_lib/utils/import_utils.py b/bec_lib/bec_lib/utils/import_utils.py index e0f0d8329..4b42e584b 100644 --- a/bec_lib/bec_lib/utils/import_utils.py +++ b/bec_lib/bec_lib/utils/import_utils.py @@ -1,28 +1,40 @@ -import inspect -import sys +from collections.abc import Iterable from importlib import import_module -from typing import Any +from typing import Any, overload from bec_lib.utils.proxy import Proxy -def lazy_import(module_name): +def lazy_import(module_name: str) -> Proxy: return Proxy(lambda: import_module(module_name), init_once=True) -def lazy_import_from(module_name, from_list): - ret = (Proxy(lambda name=name: getattr(import_module(module_name), name)) for name in from_list) - if len(from_list) == 1: - return next(ret) - else: - return ret +@overload +def lazy_import_from(module_name: str, from_list: str) -> Proxy: ... -def isinstance_based_on_class_name(obj: Any, full_class_name: str): +@overload +def lazy_import_from(module_name: str, from_list: Iterable[str]) -> tuple[Proxy, ...] | Proxy: ... + + +def lazy_import_from(module_name: str, from_list: str | Iterable[str]) -> tuple[Proxy, ...] | Proxy: + names = (from_list,) if isinstance(from_list, str) else tuple(from_list) + proxies = tuple( + Proxy(lambda name=name: getattr(import_module(module_name), name), init_once=True) + for name in names + ) + if len(proxies) == 1: + return proxies[0] + return proxies + + +def isinstance_based_on_class_name(obj: Any, full_class_name: str) -> bool: """Return if object 'obj' is an instance of class named 'full_class_name' 'full_class_name' must be a string like 'class_module.class_name', the corresponding class does not need to be imported at the caller module level """ + import inspect + return full_class_name in [ f"{klass.__module__}.{klass.__name__}" for klass in inspect.getmro(type(obj)) ] diff --git a/bec_lib/tests/test_import_utils.py b/bec_lib/tests/test_import_utils.py index 1f0f5cc7d..35bdcd089 100644 --- a/bec_lib/tests/test_import_utils.py +++ b/bec_lib/tests/test_import_utils.py @@ -1,9 +1,106 @@ -from bec_lib.device import DeviceBase -from bec_lib.utils.import_utils import isinstance_based_on_class_name +import os +import subprocess +import sys +from types import SimpleNamespace + +from bec_lib.utils import import_utils + + +class DummyBase: + pass + + +class DummyChild(DummyBase): + pass + + +def _clean_pythonpath() -> str: + return os.pathsep.join(str(path) for path in sys.path if path) def test_isinstance_based_on_class_name(): - obj = DeviceBase(name="test_obj") + obj = DummyChild() + assert import_utils.isinstance_based_on_class_name( + obj, f"{DummyBase.__module__}.{DummyBase.__name__}" + ) + assert import_utils.isinstance_based_on_class_name( + obj, f"{DummyChild.__module__}.{DummyChild.__name__}" + ) + assert not import_utils.isinstance_based_on_class_name(obj, "builtins.dict") + + +def test_lazy_import_from_accepts_string_input(): + json_decoder = import_utils.lazy_import_from("json", "JSONDecoder") + assert json_decoder.__name__ == "JSONDecoder" + + +def test_lazy_import_from_single_tuple_returns_single_proxy(): + json_decoder = import_utils.lazy_import_from("json", ("JSONDecoder",)) + assert json_decoder.__name__ == "JSONDecoder" + + +def test_lazy_import_from_multiple_names_returns_tuple(): + proxies = import_utils.lazy_import_from("json", ("JSONDecoder", "JSONEncoder")) + assert isinstance(proxies, tuple) + assert [proxy.__name__ for proxy in proxies] == ["JSONDecoder", "JSONEncoder"] + + +def test_lazy_import_from_materializes_once(monkeypatch): + calls = [] + + def fake_import(module_name): + calls.append(module_name) + return SimpleNamespace(DemoClass=type("DemoClass", (), {})) + + monkeypatch.setattr(import_utils, "import_module", fake_import) + + demo_class = import_utils.lazy_import_from("demo.module", "DemoClass") + assert demo_class.__name__ == "DemoClass" + assert demo_class.__name__ == "DemoClass" + assert calls == ["demo.module"] + + +def test_lazy_import_does_not_import_module_until_use(tmp_path, monkeypatch): + module_name = "lazy_target_module" + module_path = tmp_path / "lazy_target_module.py" + module_path.write_text("VALUE = 123\n", encoding="utf-8") + monkeypatch.syspath_prepend(str(tmp_path)) + sys.modules.pop(module_name, None) + + mod = import_utils.lazy_import(module_name) + + assert module_name not in sys.modules + assert mod.VALUE == 123 + assert module_name in sys.modules + + +def test_lazy_import_from_does_not_import_module_until_use(tmp_path, monkeypatch): + module_name = "lazy_from_target_module" + module_path = tmp_path / "lazy_from_target_module.py" + module_path.write_text("class DemoClass:\n" " VALUE = 456\n", encoding="utf-8") + monkeypatch.syspath_prepend(str(tmp_path)) + sys.modules.pop(module_name, None) + + demo_cls = import_utils.lazy_import_from(module_name, "DemoClass") + + assert module_name not in sys.modules + assert demo_cls.VALUE == 456 + assert module_name in sys.modules + - assert isinstance_based_on_class_name(obj, "bec_lib.device.DeviceBase") - assert not isinstance_based_on_class_name(obj, "bec_lib.device.Status") +def test_importing_import_utils_does_not_import_scan_utils(): + # This needs a clean interpreter because sys.modules is shared by the test process. + env = os.environ | {"PYTHONPATH": _clean_pythonpath()} + proc = subprocess.run( + [ + sys.executable, + "-c", + "from bec_lib.utils.import_utils import lazy_import_from; import sys; " + "print('bec_lib.utils.scan_utils' in sys.modules)", + ], + check=True, + capture_output=True, + text=True, + env=env, + ) + assert proc.stdout.strip() == "False"