Skip to content

Commit 4438b2c

Browse files
lukebaumanncopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 852927594
1 parent 3b65bf4 commit 4438b2c

4 files changed

Lines changed: 8 additions & 42 deletions

File tree

pathwaysutils/jax/__init__.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
`pathwaysutils`'s compatibility window.
1818
"""
1919

20-
import functools
21-
from typing import Any
2220

23-
import jax
21+
import functools
2422

2523

2624
class _FakeJaxFunction:
@@ -47,36 +45,6 @@ def __call__(self, *args, **kwargs):
4745
raise ImportError(self.error_message)
4846

4947

50-
try:
51-
# jax>=0.7.0
52-
from jax.extend import backend # pylint: disable=g-import-not-at-top
53-
54-
register_backend_cache = backend.register_backend_cache
55-
56-
del backend
57-
except AttributeError:
58-
# jax<0.7.0
59-
from jax._src import util # pylint: disable=g-import-not-at-top
60-
61-
def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable=unused-argument
62-
return util.cache_clearing_funs.add(cache.cache_clear)
63-
64-
del util
65-
66-
try:
67-
# jax>=0.7.1
68-
from jax.extend import backend # pylint: disable=g-import-not-at-top
69-
70-
ifrt_proxy = backend.ifrt_proxy
71-
del backend
72-
except AttributeError:
73-
# jax<0.7.1
74-
from jax.lib import xla_extension # pylint: disable=g-import-not-at-top
75-
76-
ifrt_proxy = xla_extension.ifrt_proxy
77-
del xla_extension
78-
79-
8048
try:
8149
# jax>=0.8.0
8250
from jaxlib import _pathways # pylint: disable=g-import-not-at-top
@@ -129,7 +97,5 @@ def ifrt_reshard_available() -> bool:
12997
del jax
13098

13199

132-
del jax
133-
del Any
134100
del _FakeJaxFunction
135101
del functools

pathwaysutils/lru_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import functools
1717
from typing import Any, Callable
1818

19-
from pathwaysutils import jax as pw_jax
19+
from jax.extend import backend
2020

2121

2222
def lru_cache(
@@ -38,7 +38,7 @@ def wrap(f):
3838

3939
wrapper.cache_clear = cached.cache_clear
4040
wrapper.cache_info = cached.cache_info
41-
pw_jax.register_backend_cache(wrapper, "Pathways LRU cache")
41+
backend.register_backend_cache(wrapper, "Pathways LRU cache")
4242
return wrapper
4343

4444
return wrap

pathwaysutils/proxy_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515

1616
import jax
1717
from jax.extend import backend
18-
from pathwaysutils import jax as pw_jax
18+
from jax.extend.backend import ifrt_proxy
1919

2020

2121
def register_backend_factory():
2222
backend.register_backend_factory(
2323
"proxy",
24-
lambda: pw_jax.ifrt_proxy.get_client(
24+
lambda: ifrt_proxy.get_client(
2525
jax.config.read("jax_backend_target"),
26-
pw_jax.ifrt_proxy.ClientConnectionOptions(),
26+
ifrt_proxy.ClientConnectionOptions(),
2727
),
2828
priority=-1,
2929
)

pathwaysutils/test/proxy_backend_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import jax
1919
from jax.extend import backend
20-
from pathwaysutils import jax as pw_jax
20+
from jax.extend.backend import ifrt_proxy
2121
from pathwaysutils import proxy_backend
2222

2323
from absl.testing import absltest
@@ -38,7 +38,7 @@ def test_no_proxy_backend_registration_raises_error(self):
3838
def test_proxy_backend_registration(self):
3939
self.enter_context(
4040
mock.patch.object(
41-
pw_jax.ifrt_proxy,
41+
ifrt_proxy,
4242
"get_client",
4343
return_value=mock.MagicMock(),
4444
)

0 commit comments

Comments
 (0)