From a83d7d9164f6bc03d6bc9f79e5230f0e562d0b56 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Sun, 8 Jun 2025 21:53:56 -0700 Subject: [PATCH] Named axes: Cast slice coordinates to integers or None. Only slicing with integers or None are supported. Previously, other values such as scalar JAX arrays were passed through, but this can cause confusing behavior and expose bugs because JAX does not generally allow putting JAX arrays in static PyTree metadata. Casting them immediately forces the values to be concrete at trace time and prevents confusing bugs later. Fixes https://github.com/google-deepmind/penzai/issues/117 --- penzai/core/named_axes.py | 15 +++++++++++---- tests/core/named_axes_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index 7070a56..5194fe1 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -407,9 +407,9 @@ def unwrap(self): @struct.pytree_dataclass class _SliceThunk(struct.Struct): - start: Any = dataclasses.field(metadata={"pytree_node": False}) - stop: Any = dataclasses.field(metadata={"pytree_node": False}) - step: Any = dataclasses.field(metadata={"pytree_node": False}) + start: int | None = dataclasses.field(metadata={"pytree_node": False}) + stop: int | None = dataclasses.field(metadata={"pytree_node": False}) + step: int | None = dataclasses.field(metadata={"pytree_node": False}) def unwrap(self): return slice(self.start, self.stop, self.step) @@ -567,7 +567,14 @@ def get(self, **kwargs) -> NamedArrayBase: if isinstance(c, jax.Array | np.ndarray | NamedArrayBase | int): index_thunks.append(_DynamicThunk(c)) elif isinstance(c, slice): - index_thunks.append(_SliceThunk(c.start, c.stop, c.step)) + # Slices must be either integers or None. + # We cast to integers immediately here to avoid accidentally passing + # JAX arrays through as static metadata. + # See https://github.com/jax-ml/jax/issues/28311. + start = None if c.start is None else int(c.start) + stop = None if c.stop is None else int(c.stop) + step = None if c.step is None else int(c.step) + index_thunks.append(_SliceThunk(start, stop, step)) else: index_thunks.append(_StaticThunk(c)) diff --git a/tests/core/named_axes_test.py b/tests/core/named_axes_test.py index 3b06b09..3c6da3f 100644 --- a/tests/core/named_axes_test.py +++ b/tests/core/named_axes_test.py @@ -15,6 +15,7 @@ """Tests for penzai.named_axes.""" import collections +import dataclasses import re from absl.testing import absltest @@ -1170,6 +1171,31 @@ def f(carry, x): self.assertEqual(stacked_out["out_narr"].named_shape, {"seq": 7, "c": 4}) self.assertEqual(stacked_out["out_narr"].positional_shape, (5,)) + def test_jax_compilation_with_captured_slice_index(self): + # https://github.com/google-deepmind/penzai/issues/117 + + @pz.pytree_dataclass + class Indexer(pz.Struct): + index: int = dataclasses.field(metadata={"pytree_node": False}) + + def outer_fn(): + indexer = Indexer(jnp.array(3)) + + def inner_fn(x: pz.nx.NamedArray): + return x[{"foo": pz.slice[: indexer.index]}] + + return jax.jit(inner_fn) + + # Test that both calls succeed and produce the expected result, instead of + # causing an infinite hang. + test_fn = outer_fn() + result1 = test_fn(pz.nx.ones({"foo": 10})) + expected = pz.nx.ones({"foo": 3}) + chex.assert_trees_all_equal(result1, expected) + test_fn = outer_fn() + result2 = test_fn(pz.nx.ones({"foo": 10})) + chex.assert_trees_all_equal(result2, expected) + if __name__ == "__main__": absltest.main()