-
Notifications
You must be signed in to change notification settings - Fork 55
[Example] Clip_B and Clip_V from entropy dynamics #509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # Entropy dynamics of RL training | ||
|
|
||
| This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392). | ||
|
|
||
| ## Data Preparation | ||
|
|
||
| We utilize the [DAPO-Math-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset as our training set. We exclude 500 questions from the training set to form the validation set (denoted by dapo-validation-500). | ||
| The training set is filtered out samples from the training set with excessively high (≥ 15/16) or low (≤ 1/16) pass rates, as evaluated by Qwen2.5-7B-Instruct. | ||
|
|
||
| ## Clip_B Experiment | ||
|
|
||
| 1. Apply the patch to keep entropy information in the trainer batch: | ||
|
|
||
| ```bash | ||
| cd /path/to/Trinity-RFT | ||
| git apply examples/entropy/clipb_trainer.patch | ||
| ``` | ||
|
|
||
| 2. Update the dataset paths in the config file [`clipb.yaml`](clipb.yaml) to point to your local data. | ||
|
|
||
| 3. Run the experiment: | ||
|
|
||
| ```bash | ||
| trinity run examples/entropy/clipb.yaml | ||
| ``` | ||
|
|
||
| ## Clip_V Implementation | ||
|
|
||
| Coming soon. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| project: math_dapo | ||
| name: clipb_example | ||
| checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} | ||
| model: | ||
| model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} | ||
| max_prompt_tokens: 1024 | ||
| max_response_tokens: 7168 | ||
| algorithm: | ||
| algorithm_type: grpo_verl | ||
| advantage_fn: clipb | ||
| advantage_fn_args: | ||
| mu: 2.5 | ||
| repeat_times: 16 | ||
| kl_loss_fn_args: | ||
| kl_coef: 0.0 | ||
| cluster: | ||
| node_num: 1 | ||
| gpu_per_node: 8 | ||
| buffer: | ||
| total_epochs: 20 | ||
| batch_size: 64 | ||
| explorer_input: | ||
| taskset: | ||
| name: dapo_235 | ||
| storage_type: file | ||
| path: ${oc.env:TRINITY_TASKSET_PATH} # processed DAPO-Math-17k | ||
| format: | ||
| prompt_key: 'question' | ||
| response_key: 'ground_truth' | ||
| rollout_args: | ||
| temperature: 1.0 | ||
| logprobs: 20 | ||
| eval_tasksets: | ||
| - name: dapo-validation-500 | ||
| storage_type: file | ||
| path: '/path/to/dapo-validation' # validation samples from DAPO-Math-17k | ||
| split: 'test' | ||
| repeat_times: 32 | ||
| format: | ||
| prompt_key: 'question' | ||
| response_key: 'ground_truth' | ||
| rollout_args: | ||
| temperature: 0.7 | ||
| - name: amc23 | ||
| storage_type: file | ||
| path: math-ai/amc23 # Path to the AMC23 dataset | ||
| split: 'test' | ||
| repeat_times: 32 | ||
| format: | ||
| prompt_key: 'question' | ||
| response_key: 'answer' | ||
| rollout_args: | ||
| temperature: 0.7 | ||
| - name: aime24 | ||
| storage_type: file | ||
| path: HuggingFaceH4/aime_2024 # Path to the AIME2024 dataset | ||
| split: 'train' | ||
| repeat_times: 32 | ||
| format: | ||
| prompt_key: 'problem' | ||
| response_key: 'answer' | ||
| rollout_args: | ||
| temperature: 0.7 | ||
| - name : aime25 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| storage_type: file | ||
| path: math-ai/aime25 # Path to the AIME2025 dataset | ||
| split: 'test' | ||
| repeat_times: 32 | ||
| format: | ||
| prompt_key: 'problem' | ||
| response_key: 'answer' | ||
| rollout_args: | ||
| temperature: 0.7 | ||
| default_workflow_type: 'async_math_workflow' | ||
| default_reward_fn_type: 'math_boxed_reward' | ||
| trainer_input: | ||
| experience_buffer: | ||
| name: math_buffer | ||
| storage_type: queue | ||
| max_read_timeout: 7200 | ||
| explorer: | ||
| eval_interval: 20 | ||
| eval_on_startup: true | ||
| runner_per_model: 8 | ||
| rollout_model: | ||
| engine_type: vllm_async | ||
| engine_num: 4 | ||
| tensor_parallel_size: 1 | ||
| seed: 42 | ||
| trainer: | ||
| trainer_type: 'verl' | ||
| save_interval: 200 | ||
| trainer_config: | ||
| algorithm: | ||
| rollout_correction: | ||
| bypass_mode: false | ||
| synchronizer: | ||
| sync_method: 'nccl' | ||
| sync_interval: 1 | ||
| sync_timeout: 3200 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| --- a/trinity/trainer/verl_trainer.py | ||
| +++ b/trinity/trainer/verl_trainer.py | ||
| @@ -501,7 +501,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): | ||
| } | ||
| metrics.update(old_log_prob_metrics) | ||
| - old_log_prob.batch.pop("entropys") | ||
| + # Keep entropys in batch so advantage_fn (e.g. Clip_B) can use it | ||
| + # old_log_prob.batch.pop("entropys") | ||
| batch = batch.union(old_log_prob) | ||
| if "rollout_log_probs" in batch.batch.keys(): | ||
| # TODO: we may want to add diff of probs too. | ||
|
Comment on lines
+1
to
+11
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Requiring users to manually apply a patch is not a maintainable or user-friendly approach. This change should be integrated directly into the A better long-term solution would be to make the removal of |
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,152 @@ | ||||||||||||||||||||||||||
| # -*- coding: utf-8 -*- | ||||||||||||||||||||||||||
| """Advantage computation for Clip_B | ||||||||||||||||||||||||||
| Ref: https://arxiv.org/pdf/2602.03392""" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Dict, Tuple | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||
| from verl import DataProto | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class ClipBAdvantageFn(AdvantageFn): | ||||||||||||||||||||||||||
| """Clip_B advantage: keep all positive-advantage tokens, | ||||||||||||||||||||||||||
| one-side clip negative-advantage tokens by entropy signal.""" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||
| epsilon: float = 1e-6, | ||||||||||||||||||||||||||
| mu: float = 2.5, | ||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||
| self.epsilon = epsilon | ||||||||||||||||||||||||||
| self.mu = mu | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __call__( | ||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||
| exps: "DataProto", | ||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||
| ) -> Tuple["DataProto", Dict]: | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Compute advantage for Clip_B. | ||||||||||||||||||||||||||
| exps should contain the following fields: | ||||||||||||||||||||||||||
| - token_level_rewards: `(torch.Tensor)` | ||||||||||||||||||||||||||
| shape: (bs, response_length) | ||||||||||||||||||||||||||
| - response_mask: `(torch.Tensor)` | ||||||||||||||||||||||||||
| shape: (bs, response_length) | ||||||||||||||||||||||||||
| - uid: `(torch.Tensor)` | ||||||||||||||||||||||||||
| shape: (bs,) | ||||||||||||||||||||||||||
| - rollout_log_probs: `(torch.Tensor)` | ||||||||||||||||||||||||||
| shape: (bs, response_length) | ||||||||||||||||||||||||||
| - entropys: `(torch.Tensor)` | ||||||||||||||||||||||||||
| shape: (bs, response_length) | ||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| exps: DataProto with advantages and returns added | ||||||||||||||||||||||||||
| metrics: Dict with clipping metrics | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| token_level_rewards = exps.batch["token_level_rewards"] | ||||||||||||||||||||||||||
| response_mask = exps.batch["response_mask"] | ||||||||||||||||||||||||||
| index = exps.non_tensor_batch["uid"] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| response_length = token_level_rewards.shape[-1] | ||||||||||||||||||||||||||
| scores = token_level_rewards.sum(dim=-1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| id2score = defaultdict(list) | ||||||||||||||||||||||||||
| id2mean = {} | ||||||||||||||||||||||||||
| id2std = {} | ||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||
| bsz = scores.shape[0] | ||||||||||||||||||||||||||
| for i in range(bsz): | ||||||||||||||||||||||||||
| id2score[index[i]].append(scores[i]) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| for idx in id2score: | ||||||||||||||||||||||||||
| if len(id2score[idx]) == 1: | ||||||||||||||||||||||||||
| id2mean[idx] = torch.tensor(0.0, dtype=scores.dtype, device=scores.device) | ||||||||||||||||||||||||||
| id2std[idx] = torch.tensor(1.0, dtype=scores.dtype, device=scores.device) | ||||||||||||||||||||||||||
| elif len(id2score[idx]) > 1: | ||||||||||||||||||||||||||
| group_scores = torch.stack(id2score[idx]).to( | ||||||||||||||||||||||||||
| dtype=scores.dtype, device=scores.device | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| id2mean[idx] = torch.mean(group_scores) | ||||||||||||||||||||||||||
| id2std[idx] = torch.std(group_scores) | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's an inconsistency in how standard deviation is calculated. Here,
Suggested change
|
||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| raise ValueError(f"no score in prompt index: {idx}") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| for i in range(bsz): | ||||||||||||||||||||||||||
| scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + self.epsilon) | ||||||||||||||||||||||||||
| scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| exps.batch["advantages"] = scores | ||||||||||||||||||||||||||
| exps.batch["returns"] = scores.clone() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # --- BEGIN: token filtering logic --- | ||||||||||||||||||||||||||
| # Use recomputed logprobs & entropy from current model (not rollout) | ||||||||||||||||||||||||||
| LP = exps.batch["rollout_log_probs"] # [B, T], recomputed logprobs | ||||||||||||||||||||||||||
| H = exps.batch["entropys"] # [B, T], recomputed entropy | ||||||||||||||||||||||||||
| M = response_mask # [B, T], mask of valid tokens | ||||||||||||||||||||||||||
| p = LP.exp() # [B, T], probability of valid tokens | ||||||||||||||||||||||||||
| S = p * (H + LP) # [B, T], indicator | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Detach for constructing clip mask (no gradient needed) | ||||||||||||||||||||||||||
| xS = S.detach().to(torch.float32) # [B, T] | ||||||||||||||||||||||||||
| m = M.to(torch.float32) # [B, T] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Masked global mean & variance (population variance, denominator = n) | ||||||||||||||||||||||||||
| n = m.sum().clamp_min(1.0) | ||||||||||||||||||||||||||
| ES = (xS * m).sum() / n # scalar | ||||||||||||||||||||||||||
| varS = ((xS - ES) ** 2 * m).sum() / n # scalar | ||||||||||||||||||||||||||
| stdS = varS.sqrt() # scalar | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Centered signal | ||||||||||||||||||||||||||
| z = xS - ES # [B, T] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # if stdS is too small, keep all tokens; otherwise | ||||||||||||||||||||||||||
| # keep all positive-advantage tokens; one-side clip negative-advantage tokens | ||||||||||||||||||||||||||
| if stdS.item() < 1e-12: | ||||||||||||||||||||||||||
| keep = torch.ones_like(M, dtype=M.dtype) # all kept | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| A = exps.batch["advantages"].detach().to(torch.float32) # [B, T] | ||||||||||||||||||||||||||
| pos_mask = A > 0 | ||||||||||||||||||||||||||
| neg_mask = A < 0 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| keep_pos = torch.ones_like(pos_mask, dtype=torch.bool) # positive: all kept | ||||||||||||||||||||||||||
| keep_neg = z >= -(self.mu * stdS) # negative: lower-side clip | ||||||||||||||||||||||||||
| keep_zero = torch.ones_like(pos_mask, dtype=torch.bool) # zero: all kept | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| keep_bool = torch.where(pos_mask, keep_pos, torch.where(neg_mask, keep_neg, keep_zero)) | ||||||||||||||||||||||||||
|
Comment on lines
+111
to
+119
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to determine which tokens to keep can be simplified. The nested
Suggested change
|
||||||||||||||||||||||||||
| keep = keep_bool.to(M.dtype) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| M_clipped = M * keep | ||||||||||||||||||||||||||
| exps.batch["response_mask"] = M_clipped | ||||||||||||||||||||||||||
| # --- END: token filtering logic --- | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Monitoring metrics | ||||||||||||||||||||||||||
| total_tokens = m.sum().clamp_min(1.0) | ||||||||||||||||||||||||||
| frac_clipped = 1.0 - (M_clipped.to(torch.float32).sum() / total_tokens).item() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| A = exps.batch["advantages"].detach().to(torch.float32) | ||||||||||||||||||||||||||
| pos_mask = (A > 0).to(M.dtype) | ||||||||||||||||||||||||||
| neg_mask = (A < 0).to(M.dtype) | ||||||||||||||||||||||||||
| total_pos = (M * pos_mask).to(torch.float32).sum().clamp_min(1.0) | ||||||||||||||||||||||||||
| total_neg = (M * neg_mask).to(torch.float32).sum().clamp_min(1.0) | ||||||||||||||||||||||||||
| frac_clipped_pos = 1.0 - ((M_clipped * pos_mask).to(torch.float32).sum() / total_pos).item() | ||||||||||||||||||||||||||
| frac_clipped_neg = 1.0 - ((M_clipped * neg_mask).to(torch.float32).sum() / total_neg).item() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| metrics = { | ||||||||||||||||||||||||||
| "frac_clipped": frac_clipped, | ||||||||||||||||||||||||||
| "frac_clipped_pos": frac_clipped_pos, | ||||||||||||||||||||||||||
| "frac_clipped_neg": frac_clipped_neg, | ||||||||||||||||||||||||||
| "ES": ES.item(), | ||||||||||||||||||||||||||
| "varS": varS.item(), | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| return exps, metrics | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||
| def default_args(cls) -> Dict: | ||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||
| "epsilon": 1e-6, | ||||||||||||||||||||||||||
| "mu": 2.5, | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a typo in the arXiv link. The year should be
2402, not2602.