diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index 6863ccc..75ec4f8 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -77,14 +77,14 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable try: # jax>=0.8.0 - from jax.jaxlib import _pathways # pylint: disable=g-import-not-at-top + from jaxlib import _pathways as jaxlib_pathways # pylint: disable=g-import-not-at-top - jaxlib_pathways = _pathways - del _pathways -except ModuleNotFoundError: +except ImportError: # jax<0.8.0 jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.8.0") +del jax +del Any del _FakeJaxModule