Skip to content

Fix issue 389 vmap scalar assertions#434

Open
Sikandar1310291 wants to merge 2 commits intogoogle-deepmind:mainfrom
Sikandar1310291:fix-issue-389-vmap-scalar-assertions
Open

Fix issue 389 vmap scalar assertions#434
Sikandar1310291 wants to merge 2 commits intogoogle-deepmind:mainfrom
Sikandar1310291:fix-issue-389-vmap-scalar-assertions

Conversation

@Sikandar1310291
Copy link

Summary

This PR fixes issue #389 by adding vmap compatibility to assert_scalar_positive, assert_scalar_non_negative, and assert_scalar_negative.

Changes

  • Converted three scalar assertions from static to value assertions
  • Added jittable predicate functions using checkify.check
  • These assertions can now be used with jax.vmap, jax.jit, and jax.pmap when wrapped with @chex.chexify
  • Added comprehensive test suite with 9 tests

Example (from issue #389)

import jax.numpy as jnp
import jax
from chex import assert_scalar_positive, chexify

x_vector = jnp.array([1., 1.])

@chexify
@jax.jit
def test_vmap():
    jax.vmap(assert_scalar_positive)(x_vector)
    return x_vector

test_vmap()  # Now works!

- Convert assert_scalar_positive, assert_scalar_non_negative, and
  assert_scalar_negative from static to value assertions
- Add jittable predicate functions using checkify.check
- Enable usage with jax.vmap, jax.jit, and jax.pmap when wrapped with @chex.chexify
- Add comprehensive test suite with 9 tests covering vmap, jit, pmap, and nested vmap
- Maintain full backward compatibility for non-jitted usage
- Fixes issue google-deepmind#389
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants