Fix issue 389 vmap scalar assertions#434
Open
Sikandar1310291 wants to merge 2 commits intogoogle-deepmind:mainfrom
Open
Fix issue 389 vmap scalar assertions#434Sikandar1310291 wants to merge 2 commits intogoogle-deepmind:mainfrom
Sikandar1310291 wants to merge 2 commits intogoogle-deepmind:mainfrom
Conversation
- Convert assert_scalar_positive, assert_scalar_non_negative, and assert_scalar_negative from static to value assertions - Add jittable predicate functions using checkify.check - Enable usage with jax.vmap, jax.jit, and jax.pmap when wrapped with @chex.chexify - Add comprehensive test suite with 9 tests covering vmap, jit, pmap, and nested vmap - Maintain full backward compatibility for non-jitted usage - Fixes issue google-deepmind#389
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes issue #389 by adding vmap compatibility to assert_scalar_positive, assert_scalar_non_negative, and assert_scalar_negative.
Changes
checkify.checkjax.vmap,jax.jit, andjax.pmapwhen wrapped with@chex.chexifyExample (from issue #389)