From c7884855193d64b8de29454e7b6f8346935be44e Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 13 Feb 2026 17:31:53 +0800 Subject: [PATCH] Add ClipB example --- examples/entropy/README.md | 29 ++++ examples/entropy/clipb.yaml | 100 ++++++++++++ examples/entropy/clipb_trainer.patch | 11 ++ trinity/algorithm/__init__.py | 1 + trinity/algorithm/advantage_fn/__init__.py | 1 + .../algorithm/advantage_fn/clipb_advantage.py | 152 ++++++++++++++++++ trinity/algorithm/algorithm.py | 22 +++ trinity/common/verl_config.py | 1 + 8 files changed, 317 insertions(+) create mode 100644 examples/entropy/README.md create mode 100644 examples/entropy/clipb.yaml create mode 100644 examples/entropy/clipb_trainer.patch create mode 100644 trinity/algorithm/advantage_fn/clipb_advantage.py diff --git a/examples/entropy/README.md b/examples/entropy/README.md new file mode 100644 index 0000000000..e38144ccd4 --- /dev/null +++ b/examples/entropy/README.md @@ -0,0 +1,29 @@ +# Entropy dynamics of RL training + +This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392). + +## Data Preparation + +We utilize the [DAPO-Math-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset as our training set. We exclude 500 questions from the training set to form the validation set (denoted by dapo-validation-500). +The training set is filtered out samples from the training set with excessively high (≥ 15/16) or low (≤ 1/16) pass rates, as evaluated by Qwen2.5-7B-Instruct. + +## Clip_B Experiment + +1. Apply the patch to keep entropy information in the trainer batch: + +```bash +cd /path/to/Trinity-RFT +git apply examples/entropy/clipb_trainer.patch +``` + +2. Update the dataset paths in the config file [`clipb.yaml`](clipb.yaml) to point to your local data. + +3. Run the experiment: + +```bash +trinity run examples/entropy/clipb.yaml +``` + +## Clip_V Implementation + +Coming soon. diff --git a/examples/entropy/clipb.yaml b/examples/entropy/clipb.yaml new file mode 100644 index 0000000000..d78edda47e --- /dev/null +++ b/examples/entropy/clipb.yaml @@ -0,0 +1,100 @@ +project: math_dapo +name: clipb_example +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} + max_prompt_tokens: 1024 + max_response_tokens: 7168 +algorithm: + algorithm_type: grpo_verl + advantage_fn: clipb + advantage_fn_args: + mu: 2.5 + repeat_times: 16 + kl_loss_fn_args: + kl_coef: 0.0 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 20 + batch_size: 64 + explorer_input: + taskset: + name: dapo_235 + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} # processed DAPO-Math-17k + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 1.0 + logprobs: 20 + eval_tasksets: + - name: dapo-validation-500 + storage_type: file + path: '/path/to/dapo-validation' # validation samples from DAPO-Math-17k + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 0.7 + - name: amc23 + storage_type: file + path: math-ai/amc23 # Path to the AMC23 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name: aime24 + storage_type: file + path: HuggingFaceH4/aime_2024 # Path to the AIME2024 dataset + split: 'train' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name : aime25 + storage_type: file + path: math-ai/aime25 # Path to the AIME2025 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + default_workflow_type: 'async_math_workflow' + default_reward_fn_type: 'math_boxed_reward' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + max_read_timeout: 7200 +explorer: + eval_interval: 20 + eval_on_startup: true + runner_per_model: 8 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + seed: 42 +trainer: + trainer_type: 'verl' + save_interval: 200 + trainer_config: + algorithm: + rollout_correction: + bypass_mode: false +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 3200 diff --git a/examples/entropy/clipb_trainer.patch b/examples/entropy/clipb_trainer.patch new file mode 100644 index 0000000000..03ca08b29c --- /dev/null +++ b/examples/entropy/clipb_trainer.patch @@ -0,0 +1,11 @@ +--- a/trinity/trainer/verl_trainer.py ++++ b/trinity/trainer/verl_trainer.py +@@ -501,7 +501,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + } + metrics.update(old_log_prob_metrics) +- old_log_prob.batch.pop("entropys") ++ # Keep entropys in batch so advantage_fn (e.g. Clip_B) can use it ++ # old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 52bb605bcd..e693684cb9 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -29,6 +29,7 @@ "multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm", "on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm", "jsd": "trinity.algorithm.algorithm.JSDAlgorithm", + "grpo_verl": "trinity.algorithm.algorithm.GRPOverlAlgorithm", }, ) diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 239862ba58..7f59211dbb 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -19,6 +19,7 @@ "rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage", "on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage", "jsd": "trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage", + "clipb": "trinity.algorithm.advantage_fn.clipb_advantage.ClipBAdvantageFn", }, ) diff --git a/trinity/algorithm/advantage_fn/clipb_advantage.py b/trinity/algorithm/advantage_fn/clipb_advantage.py new file mode 100644 index 0000000000..62898ada1a --- /dev/null +++ b/trinity/algorithm/advantage_fn/clipb_advantage.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +"""Advantage computation for Clip_B +Ref: https://arxiv.org/pdf/2602.03392""" + +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Tuple + +import torch + +if TYPE_CHECKING: + from verl import DataProto + +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn + + +class ClipBAdvantageFn(AdvantageFn): + """Clip_B advantage: keep all positive-advantage tokens, + one-side clip negative-advantage tokens by entropy signal.""" + + def __init__( + self, + epsilon: float = 1e-6, + mu: float = 2.5, + ) -> None: + self.epsilon = epsilon + self.mu = mu + + def __call__( + self, + exps: "DataProto", + **kwargs, + ) -> Tuple["DataProto", Dict]: + """ + Compute advantage for Clip_B. + exps should contain the following fields: + - token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + - response_mask: `(torch.Tensor)` + shape: (bs, response_length) + - uid: `(torch.Tensor)` + shape: (bs,) + - rollout_log_probs: `(torch.Tensor)` + shape: (bs, response_length) + - entropys: `(torch.Tensor)` + shape: (bs, response_length) + Returns: + exps: DataProto with advantages and returns added + metrics: Dict with clipping metrics + """ + token_level_rewards = exps.batch["token_level_rewards"] + response_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0, dtype=scores.dtype, device=scores.device) + id2std[idx] = torch.tensor(1.0, dtype=scores.dtype, device=scores.device) + elif len(id2score[idx]) > 1: + group_scores = torch.stack(id2score[idx]).to( + dtype=scores.dtype, device=scores.device + ) + id2mean[idx] = torch.mean(group_scores) + id2std[idx] = torch.std(group_scores) + else: + raise ValueError(f"no score in prompt index: {idx}") + + for i in range(bsz): + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + self.epsilon) + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores.clone() + + # --- BEGIN: token filtering logic --- + # Use recomputed logprobs & entropy from current model (not rollout) + LP = exps.batch["rollout_log_probs"] # [B, T], recomputed logprobs + H = exps.batch["entropys"] # [B, T], recomputed entropy + M = response_mask # [B, T], mask of valid tokens + p = LP.exp() # [B, T], probability of valid tokens + S = p * (H + LP) # [B, T], indicator + + # Detach for constructing clip mask (no gradient needed) + xS = S.detach().to(torch.float32) # [B, T] + m = M.to(torch.float32) # [B, T] + + # Masked global mean & variance (population variance, denominator = n) + n = m.sum().clamp_min(1.0) + ES = (xS * m).sum() / n # scalar + varS = ((xS - ES) ** 2 * m).sum() / n # scalar + stdS = varS.sqrt() # scalar + + # Centered signal + z = xS - ES # [B, T] + + # if stdS is too small, keep all tokens; otherwise + # keep all positive-advantage tokens; one-side clip negative-advantage tokens + if stdS.item() < 1e-12: + keep = torch.ones_like(M, dtype=M.dtype) # all kept + else: + A = exps.batch["advantages"].detach().to(torch.float32) # [B, T] + pos_mask = A > 0 + neg_mask = A < 0 + + keep_pos = torch.ones_like(pos_mask, dtype=torch.bool) # positive: all kept + keep_neg = z >= -(self.mu * stdS) # negative: lower-side clip + keep_zero = torch.ones_like(pos_mask, dtype=torch.bool) # zero: all kept + + keep_bool = torch.where(pos_mask, keep_pos, torch.where(neg_mask, keep_neg, keep_zero)) + keep = keep_bool.to(M.dtype) + + M_clipped = M * keep + exps.batch["response_mask"] = M_clipped + # --- END: token filtering logic --- + + # Monitoring metrics + total_tokens = m.sum().clamp_min(1.0) + frac_clipped = 1.0 - (M_clipped.to(torch.float32).sum() / total_tokens).item() + + A = exps.batch["advantages"].detach().to(torch.float32) + pos_mask = (A > 0).to(M.dtype) + neg_mask = (A < 0).to(M.dtype) + total_pos = (M * pos_mask).to(torch.float32).sum().clamp_min(1.0) + total_neg = (M * neg_mask).to(torch.float32).sum().clamp_min(1.0) + frac_clipped_pos = 1.0 - ((M_clipped * pos_mask).to(torch.float32).sum() / total_pos).item() + frac_clipped_neg = 1.0 - ((M_clipped * neg_mask).to(torch.float32).sum() / total_neg).item() + + metrics = { + "frac_clipped": frac_clipped, + "frac_clipped_pos": frac_clipped_pos, + "frac_clipped_neg": frac_clipped_neg, + "ES": ES.item(), + "varS": varS.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 1e-6, + "mu": 2.5, + } diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 5352bba449..35b7453222 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -540,3 +540,25 @@ def default_config(cls) -> Dict: "kl_loss_fn": "none", "entropy_loss_fn": "none", } + + +class GRPOverlAlgorithm(AlgorithmType): + """GRPO algorithm, but advantage computation is done in trainer.""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = True + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "advantage_fn": "grpo", + "sample_strategy": "default", + "policy_loss_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 1aaa4a3c4a..689b3231b4 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -175,6 +175,7 @@ class Actor: router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig) # do not set loss_agg_mode: str = "token-mean" + loss_scale_factor: Optional[float] = None clip_ratio: float = 0.2 clip_ratio_low: Optional[float] = None clip_ratio_high: Optional[float] = None