diff --git a/tests/rl/common_test.py b/tests/rl/common_test.py index 77e637b9d..d78f0ec62 100644 --- a/tests/rl/common_test.py +++ b/tests/rl/common_test.py @@ -82,6 +82,70 @@ def test_compute_kl_divergence(self, method, expected_value): kl_divergence, expected_value, atol=1e-5, rtol=1e-2 ) + @parameterized.named_parameters( + dict( + testcase_name="no_seq_packing", + prompt_tokens=np.array([[1, 2], [3, 4]]), + completion_tokens=np.array([[5, 6, 7], [8, 9, 10]]), + segment_ids=None, + segment_positions=None, + expected_shape=(2, 5), + expected_scores=np.array( + [[1, 2, 5, 6, 7], [3, 4, 8, 9, 10]], dtype=np.float32 + ), + ), + dict( + testcase_name="with_seq_packing", + prompt_tokens=np.zeros((2, 0), dtype=np.int32), + completion_tokens=np.array([[1, 2, 5, 6, 7], [3, 4, 8, 9, 10]]), + segment_ids=np.array( + [[1, 1, 2, 2, 2], [1, 1, 2, 2, 2]], dtype=np.int32 + ), + segment_positions=np.array( + [[0, 1, 0, 1, 2], [0, 1, 0, 1, 2]], dtype=np.int32 + ), + expected_shape=(2, 5), + expected_scores=np.array( + [[1, 2, 5, 6, 7], [3, 4, 8, 9, 10]], dtype=np.float32 + ), + ), + ) + def test_compute_score( + self, + prompt_tokens, + completion_tokens, + segment_ids, + segment_positions, + expected_shape, + expected_scores, + ): + class MockModel(nnx.Module): + + def __call__( + self, + input_ids, + positions, + attention_mask=None, + segment_ids=None, + cache=None, + ): + return input_ids.astype(jnp.float32)[..., None], None + + model = MockModel() + + scores = common.compute_score( + model, + prompt_tokens, + completion_tokens, + pad_id=0, + eos_id=-1, + segment_ids=segment_ids, + segment_positions=segment_positions, + ) + + self.assertEqual(scores.shape, expected_shape) + np.testing.assert_allclose(scores, expected_scores) + def test_selective_log_softmax(self): rng = jax.random.PRNGKey(0) logits = jax.random.uniform(rng, shape=(2, 4, 8))