diff --git a/tests/rl/common_test.py b/tests/rl/common_test.py index 77e637b9d..448405c7d 100644 --- a/tests/rl/common_test.py +++ b/tests/rl/common_test.py @@ -453,6 +453,30 @@ def test_pad_to_length(self): kwargs={}, expected_loss=(0.1 + 0.2) / 1.0, ), + dict( + testcase_name="seq_mean_token_sum", + loss_agg_mode="seq-mean-token-sum", + per_token_loss_list=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + completion_mask_list=[[1, 1, 0], [1, 1, 1]], + kwargs={}, + expected_loss=0.9, + ), + dict( + testcase_name="seq_mean_token_sum_zero_mask", + loss_agg_mode="seq-mean-token-sum", + per_token_loss_list=[[0.1, 0.2], [0.3, 0.4]], + completion_mask_list=[[0, 0], [0, 0]], + kwargs={}, + expected_loss=0.0, + ), + dict( + testcase_name="seq_mean_token_sum_partial_zero_mask", + loss_agg_mode="seq-mean-token-sum", + per_token_loss_list=[[0.1, 0.2], [0.3, 0.4]], + completion_mask_list=[[1, 1], [0, 0]], + kwargs={}, + expected_loss=0.3, + ), dict( testcase_name="sequence_mean_token_sum_norm_partial_zero_mask", loss_agg_mode="sequence-mean-token-sum-norm", diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index 1e635d9b9..654f6b4b3 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -69,7 +69,7 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig): policy_loss_fn: Name of the policy loss function. loss_agg_mode: Method for aggregating the loss. Supported values: "token-mean", "sequence-mean-token-mean", "sequence-mean-token-scale", - "sequence-mean-token-sum-norm". + "seq-mean-token-sum", "sequence-mean-token-sum-norm". num_generations: Number of samples per prompt (G in the paper). Must be > 1. num_iterations: Number of GRPO iterations per batch (μ in the paper). beta: KL penalty coefficient. @@ -624,7 +624,6 @@ def grpo_loss_fn( per_token_logps = jnp.astype(per_token_logps, jnp.float32) # TODO(tsbao): We should handle token level advantages. advantages = jnp.astype(train_example.advantages, jnp.float32) - if train_example.old_per_token_logps is None: old_per_token_logps = jax.lax.stop_gradient(per_token_logps) else: diff --git a/tunix/rl/common.py b/tunix/rl/common.py index 09d24650d..a8d507848 100644 --- a/tunix/rl/common.py +++ b/tunix/rl/common.py @@ -526,6 +526,12 @@ def aggregate_loss( norm, min=1e-6 ) loss = seq_loss.sum() / non_zero_rows + elif loss_agg_mode == "seq-mean-token-sum": + # 1) sum token losses within each sequence + # 2) average only across sequences that have at least one valid token + seq_loss = (per_token_loss * completion_mask).sum(axis=-1) + seq_mask = (completion_mask.sum(axis=-1) > 0).astype(jnp.float32) + loss = (seq_loss * seq_mask).sum() / jnp.clip(seq_mask.sum(), min=1e-6) elif loss_agg_mode == "sequence-mean-token-sum-norm": # Get custom normalization factor from kwargs, default to number of # non-empty rows. @@ -538,7 +544,8 @@ def aggregate_loss( raise ValueError( f"Unsupported loss aggregation mode: {loss_agg_mode}. Supported modes:" " 'token-mean', 'sequence-mean-token-mean'," - " 'sequence-mean-token-scale', 'sequence-mean-token-sum-norm'." + " 'sequence-mean-token-scale', 'seq-mean-token-sum'," + " 'sequence-mean-token-sum-norm'." ) return loss diff --git a/tunix/rl/grpo/grpo_learner.py b/tunix/rl/grpo/grpo_learner.py index 5d6b64d0a..fd30a5660 100644 --- a/tunix/rl/grpo/grpo_learner.py +++ b/tunix/rl/grpo/grpo_learner.py @@ -50,8 +50,10 @@ class GRPOConfig(algo_config_lib.AlgorithmConfig): algo_variant: The algorithm variant to use. Default: `grpo`. advantage_estimator: The advantage estimator to use. Default: `grpo`. policy_loss_fn: The policy loss function to use. Default: `grpo`. - loss_agg_mode: The aggregation mode for the loss function. Default: - `sequence-mean-token-mean`. + loss_agg_mode: The aggregation mode for the loss function. Supported values + include `token-mean`, `sequence-mean-token-mean`, + `sequence-mean-token-scale`, `seq-mean-token-sum`, and + `sequence-mean-token-sum-norm`. Default: `sequence-mean-token-mean`. reward_manager: The reward manager to use. Default: `sequence-level`. loss_algo: The loss algorithm to use. To be deprecated. num_generations: The number of times the policy generates multiple responses