Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples/entropy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Entropy dynamics of RL training

This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the arXiv link. The year should be 2402, not 2602.

Suggested change
This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392).
This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2402.03392).


## Data Preparation

We utilize the [DAPO-Math-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset as our training set. We exclude 500 questions from the training set to form the validation set (denoted by dapo-validation-500).
The training set is filtered out samples from the training set with excessively high (≥ 15/16) or low (≤ 1/16) pass rates, as evaluated by Qwen2.5-7B-Instruct.

## Clip_B Experiment

1. Apply the patch to keep entropy information in the trainer batch:

```bash
cd /path/to/Trinity-RFT
git apply examples/entropy/clipb_trainer.patch
```

2. Update the dataset paths in the config file [`clipb.yaml`](clipb.yaml) to point to your local data.

3. Run the experiment:

```bash
trinity run examples/entropy/clipb.yaml
```

## Clip_V Implementation

Coming soon.
100 changes: 100 additions & 0 deletions examples/entropy/clipb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
project: math_dapo
name: clipb_example
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
max_prompt_tokens: 1024
max_response_tokens: 7168
algorithm:
algorithm_type: grpo_verl
advantage_fn: clipb
advantage_fn_args:
mu: 2.5
repeat_times: 16
kl_loss_fn_args:
kl_coef: 0.0
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 20
batch_size: 64
explorer_input:
taskset:
name: dapo_235
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH} # processed DAPO-Math-17k
format:
prompt_key: 'question'
response_key: 'ground_truth'
rollout_args:
temperature: 1.0
logprobs: 20
eval_tasksets:
- name: dapo-validation-500
storage_type: file
path: '/path/to/dapo-validation' # validation samples from DAPO-Math-17k
split: 'test'
repeat_times: 32
format:
prompt_key: 'question'
response_key: 'ground_truth'
rollout_args:
temperature: 0.7
- name: amc23
storage_type: file
path: math-ai/amc23 # Path to the AMC23 dataset
split: 'test'
repeat_times: 32
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 0.7
- name: aime24
storage_type: file
path: HuggingFaceH4/aime_2024 # Path to the AIME2024 dataset
split: 'train'
repeat_times: 32
format:
prompt_key: 'problem'
response_key: 'answer'
rollout_args:
temperature: 0.7
- name : aime25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an extra space before the colon in name : aime25. While many YAML parsers might handle this, it's inconsistent with the rest of the file and can lead to parsing issues with stricter parsers. Please remove the space for consistency.

    - name: aime25

storage_type: file
path: math-ai/aime25 # Path to the AIME2025 dataset
split: 'test'
repeat_times: 32
format:
prompt_key: 'problem'
response_key: 'answer'
rollout_args:
temperature: 0.7
default_workflow_type: 'async_math_workflow'
default_reward_fn_type: 'math_boxed_reward'
trainer_input:
experience_buffer:
name: math_buffer
storage_type: queue
max_read_timeout: 7200
explorer:
eval_interval: 20
eval_on_startup: true
runner_per_model: 8
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
seed: 42
trainer:
trainer_type: 'verl'
save_interval: 200
trainer_config:
algorithm:
rollout_correction:
bypass_mode: false
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 3200
11 changes: 11 additions & 0 deletions examples/entropy/clipb_trainer.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -501,7 +501,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper):
}
metrics.update(old_log_prob_metrics)
- old_log_prob.batch.pop("entropys")
+ # Keep entropys in batch so advantage_fn (e.g. Clip_B) can use it
+ # old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
Comment on lines +1 to +11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Requiring users to manually apply a patch is not a maintainable or user-friendly approach. This change should be integrated directly into the trinity/trainer/verl_trainer.py file within this pull request. Instead of providing a patch, please modify the source code directly.

A better long-term solution would be to make the removal of entropys from the batch configurable. For instance, the advantage_fn could declare which fields it requires, and the trainer could conditionally avoid removing them. This would make the framework more extensible for future algorithms that might have similar requirements.

1 change: 1 addition & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm",
"on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm",
"jsd": "trinity.algorithm.algorithm.JSDAlgorithm",
"grpo_verl": "trinity.algorithm.algorithm.GRPOverlAlgorithm",
},
)

Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage",
"on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage",
"jsd": "trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage",
"clipb": "trinity.algorithm.advantage_fn.clipb_advantage.ClipBAdvantageFn",
},
)

Expand Down
152 changes: 152 additions & 0 deletions trinity/algorithm/advantage_fn/clipb_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# -*- coding: utf-8 -*-
"""Advantage computation for Clip_B
Ref: https://arxiv.org/pdf/2602.03392"""

from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Tuple

import torch

if TYPE_CHECKING:
from verl import DataProto

from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn


class ClipBAdvantageFn(AdvantageFn):
"""Clip_B advantage: keep all positive-advantage tokens,
one-side clip negative-advantage tokens by entropy signal."""

def __init__(
self,
epsilon: float = 1e-6,
mu: float = 2.5,
) -> None:
self.epsilon = epsilon
self.mu = mu

def __call__(
self,
exps: "DataProto",
**kwargs,
) -> Tuple["DataProto", Dict]:
"""
Compute advantage for Clip_B.
exps should contain the following fields:
- token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
- response_mask: `(torch.Tensor)`
shape: (bs, response_length)
- uid: `(torch.Tensor)`
shape: (bs,)
- rollout_log_probs: `(torch.Tensor)`
shape: (bs, response_length)
- entropys: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
exps: DataProto with advantages and returns added
metrics: Dict with clipping metrics
"""
token_level_rewards = exps.batch["token_level_rewards"]
response_mask = exps.batch["response_mask"]
index = exps.non_tensor_batch["uid"]

response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])

for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0, dtype=scores.dtype, device=scores.device)
id2std[idx] = torch.tensor(1.0, dtype=scores.dtype, device=scores.device)
elif len(id2score[idx]) > 1:
group_scores = torch.stack(id2score[idx]).to(
dtype=scores.dtype, device=scores.device
)
id2mean[idx] = torch.mean(group_scores)
id2std[idx] = torch.std(group_scores)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an inconsistency in how standard deviation is calculated. Here, torch.std is used with its default unbiased=True, which calculates the sample standard deviation (using N-1 in the denominator). However, on line 100, the comment and implementation for varS indicate population variance (using N in the denominator). For consistency, if population statistics are intended throughout, you should use unbiased=False.

Suggested change
id2std[idx] = torch.std(group_scores)
id2std[idx] = torch.std(group_scores, unbiased=False)

else:
raise ValueError(f"no score in prompt index: {idx}")

for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + self.epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask

exps.batch["advantages"] = scores
exps.batch["returns"] = scores.clone()

# --- BEGIN: token filtering logic ---
# Use recomputed logprobs & entropy from current model (not rollout)
LP = exps.batch["rollout_log_probs"] # [B, T], recomputed logprobs
H = exps.batch["entropys"] # [B, T], recomputed entropy
M = response_mask # [B, T], mask of valid tokens
p = LP.exp() # [B, T], probability of valid tokens
S = p * (H + LP) # [B, T], indicator

# Detach for constructing clip mask (no gradient needed)
xS = S.detach().to(torch.float32) # [B, T]
m = M.to(torch.float32) # [B, T]

# Masked global mean & variance (population variance, denominator = n)
n = m.sum().clamp_min(1.0)
ES = (xS * m).sum() / n # scalar
varS = ((xS - ES) ** 2 * m).sum() / n # scalar
stdS = varS.sqrt() # scalar

# Centered signal
z = xS - ES # [B, T]

# if stdS is too small, keep all tokens; otherwise
# keep all positive-advantage tokens; one-side clip negative-advantage tokens
if stdS.item() < 1e-12:
keep = torch.ones_like(M, dtype=M.dtype) # all kept
else:
A = exps.batch["advantages"].detach().to(torch.float32) # [B, T]
pos_mask = A > 0
neg_mask = A < 0

keep_pos = torch.ones_like(pos_mask, dtype=torch.bool) # positive: all kept
keep_neg = z >= -(self.mu * stdS) # negative: lower-side clip
keep_zero = torch.ones_like(pos_mask, dtype=torch.bool) # zero: all kept

keep_bool = torch.where(pos_mask, keep_pos, torch.where(neg_mask, keep_neg, keep_zero))
Comment on lines +111 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine which tokens to keep can be simplified. The nested torch.where calls are equivalent to a more concise and readable boolean expression.

Suggested change
A = exps.batch["advantages"].detach().to(torch.float32) # [B, T]
pos_mask = A > 0
neg_mask = A < 0
keep_pos = torch.ones_like(pos_mask, dtype=torch.bool) # positive: all kept
keep_neg = z >= -(self.mu * stdS) # negative: lower-side clip
keep_zero = torch.ones_like(pos_mask, dtype=torch.bool) # zero: all kept
keep_bool = torch.where(pos_mask, keep_pos, torch.where(neg_mask, keep_neg, keep_zero))
A = exps.batch["advantages"].detach().to(torch.float32) # [B, T]
# Keep tokens with non-negative advantage, or tokens with negative advantage that satisfy the entropy-based condition.
keep_bool = (A >= 0) | (z >= -(self.mu * stdS))

keep = keep_bool.to(M.dtype)

M_clipped = M * keep
exps.batch["response_mask"] = M_clipped
# --- END: token filtering logic ---

# Monitoring metrics
total_tokens = m.sum().clamp_min(1.0)
frac_clipped = 1.0 - (M_clipped.to(torch.float32).sum() / total_tokens).item()

A = exps.batch["advantages"].detach().to(torch.float32)
pos_mask = (A > 0).to(M.dtype)
neg_mask = (A < 0).to(M.dtype)
total_pos = (M * pos_mask).to(torch.float32).sum().clamp_min(1.0)
total_neg = (M * neg_mask).to(torch.float32).sum().clamp_min(1.0)
frac_clipped_pos = 1.0 - ((M_clipped * pos_mask).to(torch.float32).sum() / total_pos).item()
frac_clipped_neg = 1.0 - ((M_clipped * neg_mask).to(torch.float32).sum() / total_neg).item()

metrics = {
"frac_clipped": frac_clipped,
"frac_clipped_pos": frac_clipped_pos,
"frac_clipped_neg": frac_clipped_neg,
"ES": ES.item(),
"varS": varS.item(),
}
return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"epsilon": 1e-6,
"mu": 2.5,
}
22 changes: 22 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,25 @@ def default_config(cls) -> Dict:
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}


class GRPOverlAlgorithm(AlgorithmType):
"""GRPO algorithm, but advantage computation is done in trainer."""

use_critic: bool = False
use_reference: bool = True
compute_advantage_in_trainer: bool = True
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "grpo",
"sample_strategy": "default",
"policy_loss_fn": "ppo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}
1 change: 1 addition & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class Actor:
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
# do not set
loss_agg_mode: str = "token-mean"
loss_scale_factor: Optional[float] = None
clip_ratio: float = 0.2
clip_ratio_low: Optional[float] = None
clip_ratio_high: Optional[float] = None
Expand Down