Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/rl/agentic/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from tunix.rl import function_registry
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.agentic import agentic_grpo_learner
from tunix.rl.algo_core import grpo_core
from tunix.rl.agentic.agents.agent_types import Action, Step
from tunix.rl.agentic.agents.base_agent import ConversationAgentBase
from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult
Expand Down Expand Up @@ -1899,15 +1900,15 @@ def _patch_process_results(

def test_compute_rloo_advantages(self):
rewards = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
advantages = agentic_grpo_learner.compute_rloo_advantages(
advantages = grpo_core.compute_rloo_advantages(
rewards, num_generations=3
)
expected_value = jnp.array([-1.5, 0.0, 1.5, -1.5, 0.0, 1.5])
np.testing.assert_allclose(advantages, expected_value)

def test_compute_rloo_advantages_low_generations(self):
rewards = jnp.array([1.0, 2.0])
advantages = agentic_grpo_learner.compute_rloo_advantages(
advantages = grpo_core.compute_rloo_advantages(
rewards, num_generations=1
)
np.testing.assert_allclose(advantages, jnp.zeros_like(rewards))
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/grpo/dapo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_diff_loss(self):

self.assertIn("kl", dapo_aux)
self.assertIn("kl", grpo_aux)
self.assertEqual(dapo_aux["kl"], 0.0) # DAPO does not have KL term.



class TestDAPOConfigPostInit(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/grpo/drgrpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_drgrpo_advantage_estimator(self):
)
# Dr. GRPO advantages are not scaled by the standard deviation.
# Std. across groups above is the same by construction.
std_factor = jnp.array([1.0, 2.0]).std(ddof=1) + 1e-4
std_factor = jnp.array([1.0, 2.0]).std(ddof=1) + 1e-6
np.testing.assert_allclose(grpo_advantages * std_factor, drgrpo_advantages)

def test_drgrpo_loss_fn(self):
Expand Down
5 changes: 3 additions & 2 deletions tests/rl/grpo/grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tunix.perf.experimental import tracer as perf_tracer_v2
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo import grpo_learner as grpo_lib
from tunix.rl.algo_core import grpo_core
from tunix.rl.queue import data_queue as queue_lib
from tunix.rl.rollout import base_rollout
from tunix.sft import profiler
Expand Down Expand Up @@ -1234,9 +1235,9 @@ def test_compute_advantages(self):

rng = jax.random.PRNGKey(0)
rewards = jax.random.uniform(rng, shape=(1, 6))
advantages = grpo_lib.compute_advantages(rewards, num_generations=3)
advantages = grpo_core.compute_advantages(rewards, num_generations=3)
expected_value = jnp.array(
[[0.307407, -1.117304, 0.809897, 1.094044, -0.22857, -0.865474]]
[[0.307498, -1.117636, 0.810138, 1.094526, -0.228671, -0.865855]]
)
np.testing.assert_allclose(advantages, expected_value, rtol=1e-5, atol=1e-5)

Expand Down
2 changes: 1 addition & 1 deletion tests/rl/ppo/ppo_helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from tunix.rl.ppo import ppo_helpers
from tunix.rl.algo_core import utils as ppo_helpers


def _ref_compute_gae_advantages(
Expand Down
206 changes: 2 additions & 204 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from tunix.rl import algo_core # pylint: disable=unused-import
from tunix.perf.experimental import constants as perf_constants
from tunix.rl import common
from tunix.rl import function_registry
Expand All @@ -48,7 +49,7 @@
from tunix.rl.agentic.agents import model_agent
from tunix.rl.agentic.environments import base_environment
from tunix.rl.agentic.environments import task_environment
from tunix.rl.ppo import ppo_helpers
from tunix.rl.algo_core import utils as ppo_helpers
from tunix.utils import trajectory_logger


Expand Down Expand Up @@ -553,216 +554,13 @@ def _process_results(
return [combined_batch]


@function_registry.register_policy_loss_fn("agentic_grpo")
def grpo_loss_fn(
model,
train_example,
algo_config,
pad_id,
eos_id,
):
"""GRPO loss function.

The loss aims to maximize the expected advantage of the chosen actions while
constraining the policy updates to stay within a certain range of the
reference policy.

Args:
model: The policy model to be trained.
train_example: A `TrainExample` instance containing the processed input
data, including prompt IDs, completion IDs, masks, advantages, and
per-token log probabilities from the reference and policy models.
algo_config: The algorithm config.
pad_id: The pad ID from tokenizer.
eos_id: The eos ID from.

Returns:
A tuple containing the loss and an aux dictionary.
"""
beta = algo_config.beta
epsilon = algo_config.epsilon
loss_algo = algo_config.loss_algo
epsilon_high = (
algo_config.epsilon_high
if hasattr(algo_config, "epsilon_high")
else epsilon
)
epsilon_c = (
algo_config.epsilon_c
if hasattr(algo_config, "epsilon_c")
else 3.0
)
loss_aggregation_mode = algo_config.loss_agg_mode

completion_ids, completion_mask = (
train_example.completion_ids,
train_example.completion_mask,
)

# TODO(tsbao): split can be avoided with updated peft_trainer model handling.
graphdef, state = nnx.split(model)
per_token_logps, logits = common.compute_per_token_logps(
graphdef,
state,
prompt_tokens=train_example.prompt_ids,
completion_tokens=completion_ids,
pad_id=pad_id,
eos_id=eos_id,
completion_mask=completion_mask,
stop_gradient=False,
return_logits=True,
segment_ids=getattr(train_example, "segment_ids", None),
segment_positions=getattr(train_example, "segment_positions", None),
)
per_token_logps = jnp.astype(per_token_logps, jnp.float32)
# TODO(tsbao): We should handle token level advantages.
advantages = jnp.astype(train_example.advantages, jnp.float32)

if train_example.old_per_token_logps is None:
old_per_token_logps = jax.lax.stop_gradient(per_token_logps)
else:
old_per_token_logps = jnp.astype(
train_example.old_per_token_logps, jnp.float32
)

seq_importance_ratio = per_token_logps - old_per_token_logps
# Record KL divergence before clipping.
ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, completion_mask)

seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0)

# TODO(sizhi): Refactor this to a separate function.
if loss_algo == "gspo-token":
seq_importance_ratio = (seq_importance_ratio * completion_mask).sum(
axis=-1
) / jnp.clip(completion_mask.sum(-1), min=1)
seq_importance_ratio = (
per_token_logps
- jax.lax.stop_gradient(per_token_logps)
+ jnp.expand_dims(jax.lax.stop_gradient(seq_importance_ratio), axis=-1)
)
seq_importance_ratio = jnp.clip(seq_importance_ratio, max=10.0)

is_ratio = jnp.exp(seq_importance_ratio)

# Advantages must be broadcast against seq_length.
# When sequence packing is used, advantages are already 2D [B, seq_length].
# When unpacked, they are 1D [B].
adv = advantages if advantages.ndim == 2 else jnp.expand_dims(advantages, 1)

pg_loss_1 = -adv * is_ratio
pg_loss_2 = -adv * jnp.clip(is_ratio, 1 - epsilon, 1 + epsilon_high)

per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32)

clipped_fraction = ppo_helpers.masked_mean(
jnp.greater(pg_loss_2, pg_loss_1), completion_mask
)

# dual-clip ppo loss
pg_loss_3 = -epsilon_c * adv

# pg_clipfrac_lower measures how often dual-clip ppo kicks in.
# It kicks in when the standard clipped loss is larger than pg_loss_3
# for instances with negative advantages.
unreduced_pg_clipfrac_lower = (
(per_token_loss > pg_loss_3) & (adv < 0.0)
).astype(jnp.float32)
pg_clipfrac_lower = common.aggregate_loss(
unreduced_pg_clipfrac_lower, completion_mask, loss_aggregation_mode
)

pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss)
per_token_loss = jnp.where(adv < 0.0, pg_loss_clipped_dual, per_token_loss)
loss = common.aggregate_loss(
per_token_loss, completion_mask, loss_aggregation_mode
)
aux = {
"kl": 0.0,
"kl_loss": 0.0,
"pg_loss": loss,
"pg_clipfrac": clipped_fraction,
"ppo_kl": ppo_kl,
"pg_clipfrac_lower": pg_clipfrac_lower,
}
# We do not alwayscompute KL divergence (e.g. when beta is 0.0 unless
# force_compute_kl is True).
if train_example.ref_per_token_logps is not None:
kl = common.compute_kl_divergence(
per_token_logps,
train_example.ref_per_token_logps,
algo_config.kl_loss_mode,
)
# Log mean KL.
aux["kl"] = jnp.astype(
(kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1),
jnp.float32,
)
kl_loss = common.aggregate_loss(
kl, completion_mask, loss_aggregation_mode
)
aux["kl_loss"] = kl_loss
if beta is not None and beta != 0.0:
loss = loss + beta * kl_loss

token_entropy = ppo_helpers.compute_entropy_from_logits(logits)
entropy_loss = common.aggregate_loss(
token_entropy, completion_mask, loss_aggregation_mode
)
aux["entropy"] = entropy_loss

return loss, aux


@function_registry.register_advantage_estimator("agentic_grpo")
def compute_advantages(rewards: jax.Array, num_generations: int) -> jax.Array:
"""Compute group relative advantages.

Args:
rewards: reward functions output.
num_generations: Number of generations.

Returns:
Group relative advantages.
"""
rewards = jnp.astype(rewards, jnp.float32)
mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=-1)
std_grouped_rewards = rewards.reshape(-1, num_generations).std(
axis=-1, ddof=1
)

mean_grouped_rewards = mean_grouped_rewards.repeat(num_generations)
std_grouped_rewards = std_grouped_rewards.repeat(num_generations)
return (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-6)


@function_registry.register_advantage_estimator("agentic_rloo")
def compute_rloo_advantages(
rewards: jax.Array, num_generations: int
) -> jax.Array:
"""Compute RLOO (REINFORCE Leave-One-Out) advantages.

RLOO computes a baseline for each completion by averaging the rewards of all
other completions to the same prompt.

Args:
rewards: reward functions output.
num_generations: Number of generations.

Returns:
RLOO advantages.
"""
if num_generations < 2:
# RLOO requires at least 2 samples to calculate a baseline.
return jnp.zeros_like(rewards)

reshaped_rewards = rewards.reshape(-1, num_generations)
loo_mean = (
reshaped_rewards.sum(axis=-1, keepdims=True) - reshaped_rewards
) / (num_generations - 1)
rloo_advantages = reshaped_rewards - loo_mean

return rloo_advantages.flatten()


GrpoConfig = GRPOConfig
Expand Down
5 changes: 5 additions & 0 deletions tunix/rl/algo_core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Algorithm core implementations for RL and Agentic RL learners."""
from tunix.rl.algo_core import utils
from tunix.rl.algo_core import grpo_core
from tunix.rl.algo_core import ppo_core
from tunix.rl.algo_core import drgrpo_core
32 changes: 32 additions & 0 deletions tunix/rl/algo_core/drgrpo_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DrGRPO core algorithm implementations."""

import jax
from tunix.rl import function_registry

@function_registry.register_advantage_estimator("drgrpo")
def compute_advantages(rewards: jax.Array, num_generations: int) -> jax.Array:
"""Group relative advantages -- done right.

Args:
rewards: reward functions output.
num_generations: Number of generations.

Returns:
Group relative advantages.
"""
mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=1)
return rewards - mean_grouped_rewards.repeat(num_generations)
Loading
Loading