Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions penzai/core/named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def unwrap(self):

@struct.pytree_dataclass
class _SliceThunk(struct.Struct):
start: Any = dataclasses.field(metadata={"pytree_node": False})
stop: Any = dataclasses.field(metadata={"pytree_node": False})
step: Any = dataclasses.field(metadata={"pytree_node": False})
start: int | None = dataclasses.field(metadata={"pytree_node": False})
stop: int | None = dataclasses.field(metadata={"pytree_node": False})
step: int | None = dataclasses.field(metadata={"pytree_node": False})

def unwrap(self):
return slice(self.start, self.stop, self.step)
Expand Down Expand Up @@ -567,7 +567,14 @@ def get(self, **kwargs) -> NamedArrayBase:
if isinstance(c, jax.Array | np.ndarray | NamedArrayBase | int):
index_thunks.append(_DynamicThunk(c))
elif isinstance(c, slice):
index_thunks.append(_SliceThunk(c.start, c.stop, c.step))
# Slices must be either integers or None.
# We cast to integers immediately here to avoid accidentally passing
# JAX arrays through as static metadata.
# See https://github.com/jax-ml/jax/issues/28311.
start = None if c.start is None else int(c.start)
stop = None if c.stop is None else int(c.stop)
step = None if c.step is None else int(c.step)
index_thunks.append(_SliceThunk(start, stop, step))
else:
index_thunks.append(_StaticThunk(c))

Expand Down
26 changes: 26 additions & 0 deletions tests/core/named_axes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for penzai.named_axes."""

import collections
import dataclasses
import re

from absl.testing import absltest
Expand Down Expand Up @@ -1170,6 +1171,31 @@ def f(carry, x):
self.assertEqual(stacked_out["out_narr"].named_shape, {"seq": 7, "c": 4})
self.assertEqual(stacked_out["out_narr"].positional_shape, (5,))

def test_jax_compilation_with_captured_slice_index(self):
# https://github.com/google-deepmind/penzai/issues/117

@pz.pytree_dataclass
class Indexer(pz.Struct):
index: int = dataclasses.field(metadata={"pytree_node": False})

def outer_fn():
indexer = Indexer(jnp.array(3))

def inner_fn(x: pz.nx.NamedArray):
return x[{"foo": pz.slice[: indexer.index]}]

return jax.jit(inner_fn)

# Test that both calls succeed and produce the expected result, instead of
# causing an infinite hang.
test_fn = outer_fn()
result1 = test_fn(pz.nx.ones({"foo": 10}))
expected = pz.nx.ones({"foo": 3})
chex.assert_trees_all_equal(result1, expected)
test_fn = outer_fn()
result2 = test_fn(pz.nx.ones({"foo": 10}))
chex.assert_trees_all_equal(result2, expected)


if __name__ == "__main__":
absltest.main()