Skip to content

Conversation

@Sorata-kanda
Copy link

Fixes #387

Problem

chex.assert_equal fails with TracerBoolConversionError when used inside
@chex.chexify + @jax.jit decorated functions because it attempts to
compare traced values using unittest.TestCase().assertEqual().

Solution

Added a jittable implementation using jnp.array_equal() and converted
assert_equal from a static-only assertion to a value assertion, following
the same pattern as assert_trees_all_equal.

Changes

  • Added _assert_equal_static() - original host implementation
  • Added _assert_equal_jittable() - new jittable implementation using jnp.array_equal()
  • Registered both with _value_assertion() so chexify can use the jittable version during tracing

Testing

All existing assert_equal tests pass. The original failing example now works:

@chex.chexify
@jax.jit
def f(x):
    chex.assert_equal(x, 0)
    return

f(0)  # Now works!

Fixes google-deepmind#387

Previously, assert_equal was a static-only assertion that failed with
TracerBoolConversionError when used inside @chex.chexify + @jax.jit
decorated functions, because it tried to compare traced values using
unittest.TestCase().assertEqual().

This change adds a jittable implementation using jnp.array_equal() and
registers assert_equal as a value assertion, following the same pattern
as assert_trees_all_equal.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jitted, chexified chex.assert_equal fails for value assertions

1 participant