Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/rl/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,30 @@ def test_pad_to_length(self):
kwargs={},
expected_loss=(0.1 + 0.2) / 1.0,
),
dict(
testcase_name="seq_mean_token_sum",
loss_agg_mode="seq-mean-token-sum",
per_token_loss_list=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
completion_mask_list=[[1, 1, 0], [1, 1, 1]],
kwargs={},
expected_loss=0.9,
),
dict(
testcase_name="seq_mean_token_sum_zero_mask",
loss_agg_mode="seq-mean-token-sum",
per_token_loss_list=[[0.1, 0.2], [0.3, 0.4]],
completion_mask_list=[[0, 0], [0, 0]],
kwargs={},
expected_loss=0.0,
),
dict(
testcase_name="seq_mean_token_sum_partial_zero_mask",
loss_agg_mode="seq-mean-token-sum",
per_token_loss_list=[[0.1, 0.2], [0.3, 0.4]],
completion_mask_list=[[1, 1], [0, 0]],
kwargs={},
expected_loss=0.3,
),
dict(
testcase_name="sequence_mean_token_sum_norm_partial_zero_mask",
loss_agg_mode="sequence-mean-token-sum-norm",
Expand Down
3 changes: 1 addition & 2 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
policy_loss_fn: Name of the policy loss function.
loss_agg_mode: Method for aggregating the loss. Supported values:
"token-mean", "sequence-mean-token-mean", "sequence-mean-token-scale",
"sequence-mean-token-sum-norm".
"seq-mean-token-sum", "sequence-mean-token-sum-norm".
num_generations: Number of samples per prompt (G in the paper). Must be > 1.
num_iterations: Number of GRPO iterations per batch (μ in the paper).
beta: KL penalty coefficient.
Expand Down Expand Up @@ -624,7 +624,6 @@ def grpo_loss_fn(
per_token_logps = jnp.astype(per_token_logps, jnp.float32)
# TODO(tsbao): We should handle token level advantages.
advantages = jnp.astype(train_example.advantages, jnp.float32)

if train_example.old_per_token_logps is None:
old_per_token_logps = jax.lax.stop_gradient(per_token_logps)
else:
Expand Down
9 changes: 8 additions & 1 deletion tunix/rl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,12 @@ def aggregate_loss(
norm, min=1e-6
)
loss = seq_loss.sum() / non_zero_rows
elif loss_agg_mode == "seq-mean-token-sum":
# 1) sum token losses within each sequence
# 2) average only across sequences that have at least one valid token
seq_loss = (per_token_loss * completion_mask).sum(axis=-1)
seq_mask = (completion_mask.sum(axis=-1) > 0).astype(jnp.float32)
loss = (seq_loss * seq_mask).sum() / jnp.clip(seq_mask.sum(), min=1e-6)
elif loss_agg_mode == "sequence-mean-token-sum-norm":
# Get custom normalization factor from kwargs, default to number of
# non-empty rows.
Expand All @@ -538,7 +544,8 @@ def aggregate_loss(
raise ValueError(
f"Unsupported loss aggregation mode: {loss_agg_mode}. Supported modes:"
" 'token-mean', 'sequence-mean-token-mean',"
" 'sequence-mean-token-scale', 'sequence-mean-token-sum-norm'."
" 'sequence-mean-token-scale', 'seq-mean-token-sum',"
" 'sequence-mean-token-sum-norm'."
)
return loss

Expand Down
6 changes: 4 additions & 2 deletions tunix/rl/grpo/grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ class GRPOConfig(algo_config_lib.AlgorithmConfig):
algo_variant: The algorithm variant to use. Default: `grpo`.
advantage_estimator: The advantage estimator to use. Default: `grpo`.
policy_loss_fn: The policy loss function to use. Default: `grpo`.
loss_agg_mode: The aggregation mode for the loss function. Default:
`sequence-mean-token-mean`.
loss_agg_mode: The aggregation mode for the loss function. Supported values
include `token-mean`, `sequence-mean-token-mean`,
`sequence-mean-token-scale`, `seq-mean-token-sum`, and
`sequence-mean-token-sum-norm`. Default: `sequence-mean-token-mean`.
reward_manager: The reward manager to use. Default: `sequence-level`.
loss_algo: The loss algorithm to use. To be deprecated.
num_generations: The number of times the policy generates multiple responses
Expand Down
Loading