diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index 7070a56..5194fe1 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -407,9 +407,9 @@ def unwrap(self): @struct.pytree_dataclass class _SliceThunk(struct.Struct): - start: Any = dataclasses.field(metadata={"pytree_node": False}) - stop: Any = dataclasses.field(metadata={"pytree_node": False}) - step: Any = dataclasses.field(metadata={"pytree_node": False}) + start: int | None = dataclasses.field(metadata={"pytree_node": False}) + stop: int | None = dataclasses.field(metadata={"pytree_node": False}) + step: int | None = dataclasses.field(metadata={"pytree_node": False}) def unwrap(self): return slice(self.start, self.stop, self.step) @@ -567,7 +567,14 @@ def get(self, **kwargs) -> NamedArrayBase: if isinstance(c, jax.Array | np.ndarray | NamedArrayBase | int): index_thunks.append(_DynamicThunk(c)) elif isinstance(c, slice): - index_thunks.append(_SliceThunk(c.start, c.stop, c.step)) + # Slices must be either integers or None. + # We cast to integers immediately here to avoid accidentally passing + # JAX arrays through as static metadata. + # See https://github.com/jax-ml/jax/issues/28311. + start = None if c.start is None else int(c.start) + stop = None if c.stop is None else int(c.stop) + step = None if c.step is None else int(c.step) + index_thunks.append(_SliceThunk(start, stop, step)) else: index_thunks.append(_StaticThunk(c)) diff --git a/tests/core/named_axes_test.py b/tests/core/named_axes_test.py index 3b06b09..3c6da3f 100644 --- a/tests/core/named_axes_test.py +++ b/tests/core/named_axes_test.py @@ -15,6 +15,7 @@ """Tests for penzai.named_axes.""" import collections +import dataclasses import re from absl.testing import absltest @@ -1170,6 +1171,31 @@ def f(carry, x): self.assertEqual(stacked_out["out_narr"].named_shape, {"seq": 7, "c": 4}) self.assertEqual(stacked_out["out_narr"].positional_shape, (5,)) + def test_jax_compilation_with_captured_slice_index(self): + # https://github.com/google-deepmind/penzai/issues/117 + + @pz.pytree_dataclass + class Indexer(pz.Struct): + index: int = dataclasses.field(metadata={"pytree_node": False}) + + def outer_fn(): + indexer = Indexer(jnp.array(3)) + + def inner_fn(x: pz.nx.NamedArray): + return x[{"foo": pz.slice[: indexer.index]}] + + return jax.jit(inner_fn) + + # Test that both calls succeed and produce the expected result, instead of + # causing an infinite hang. + test_fn = outer_fn() + result1 = test_fn(pz.nx.ones({"foo": 10})) + expected = pz.nx.ones({"foo": 3}) + chex.assert_trees_all_equal(result1, expected) + test_fn = outer_fn() + result2 = test_fn(pz.nx.ones({"foo": 10})) + chex.assert_trees_all_equal(result2, expected) + if __name__ == "__main__": absltest.main()