From b72729bb152b7b3426299405950b3af300d765a9 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 1 Oct 2025 14:36:04 -0700 Subject: [PATCH] Fixed a couple bugs with `pathwaysutils.jax` PiperOrigin-RevId: 813918268 --- pathwaysutils/jax/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index 6863ccc..75ec4f8 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -77,14 +77,14 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable try: # jax>=0.8.0 - from jax.jaxlib import _pathways # pylint: disable=g-import-not-at-top + from jaxlib import _pathways as jaxlib_pathways # pylint: disable=g-import-not-at-top - jaxlib_pathways = _pathways - del _pathways -except ModuleNotFoundError: +except ImportError: # jax<0.8.0 jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.8.0") +del jax +del Any del _FakeJaxModule