From b426c4f15787364663ed9635988cf81210cc3c4c Mon Sep 17 00:00:00 2001 From: Haoyu Gao Date: Tue, 5 May 2026 14:39:49 -0700 Subject: [PATCH] Add option to recompute old policy logprobs on the trainer. PiperOrigin-RevId: 910921777 --- examples/deepswe/train_deepswe_nb.py | 13 +++- tests/rl/agentic/agentic_grpo_learner_test.py | 1 + tests/rl/agentic/agentic_rl_learner_test.py | 5 ++ tunix/rl/agentic/agentic_grpo_learner.py | 43 +++++++----- tunix/rl/agentic/agentic_rl_learner.py | 8 +-- tunix/rl/rl_cluster.py | 69 ++++++++++++++++++- 6 files changed, 115 insertions(+), 24 deletions(-) diff --git a/examples/deepswe/train_deepswe_nb.py b/examples/deepswe/train_deepswe_nb.py index e847910fd..7f582bf69 100644 --- a/examples/deepswe/train_deepswe_nb.py +++ b/examples/deepswe/train_deepswe_nb.py @@ -185,6 +185,15 @@ "--loss_agg_mode", type=str, default="sequence-mean-token-scale" ) parser.add_argument("--advantage_estimator", type=str, default="rloo") +parser.add_argument( + "--use_rollout_logps", + type=bool, + default=False, + help=( + "Whether to use rollout-cached logprobs as old policy logps. " + "Default is False to recompute old logps on the actor side. " + ), +) # Other @@ -455,6 +464,7 @@ ) LOSS_AGG_MODE = args.loss_agg_mode ADVANTAGE_ESTIMATOR = args.advantage_estimator +USE_ROLLOUT_LOGPS = args.use_rollout_logps # %% @@ -673,7 +683,7 @@ def transform(entry): "top_p": TOP_P, "top_k": TOP_K, "eos_tokens": [tokenizer.encode("<|im_end|>")[0]], - "return_logprobs": True, + "return_logprobs": USE_ROLLOUT_LOGPS, "max_tokens_to_generate": MAX_RESPONSE_LENGTH, } @@ -778,6 +788,7 @@ def transform(entry): "filter_statuses": FILTER_STATUSES, "loss_agg_mode": LOSS_AGG_MODE, "advantage_estimator": ADVANTAGE_ESTIMATOR, + "use_rollout_logps": USE_ROLLOUT_LOGPS, } grpo_config = agentic_grpo_learner.GRPOConfig(**config_kwargs) diff --git a/tests/rl/agentic/agentic_grpo_learner_test.py b/tests/rl/agentic/agentic_grpo_learner_test.py index 6dcced541..308465a7a 100644 --- a/tests/rl/agentic/agentic_grpo_learner_test.py +++ b/tests/rl/agentic/agentic_grpo_learner_test.py @@ -736,6 +736,7 @@ def __init__( loss_algo="grpo", degenerate_group_masking=masking, max_response_length=10, + use_rollout_logps=True, ) learner = agentic_grpo_learner.GRPOLearner( diff --git a/tests/rl/agentic/agentic_rl_learner_test.py b/tests/rl/agentic/agentic_rl_learner_test.py index a0aceca07..29fe9dca6 100644 --- a/tests/rl/agentic/agentic_rl_learner_test.py +++ b/tests/rl/agentic/agentic_rl_learner_test.py @@ -26,6 +26,7 @@ class DummyLearner(agentic_rl_learner.AgenticRLLearner): def _process_results(self, **kwargs): return [] + class AgenticRLLearnerTest(parameterized.TestCase): def test_validate_rollout_config_mismatch_max_tokens(self): @@ -41,6 +42,7 @@ def test_validate_rollout_config_mismatch_max_tokens(self): algo_config = agentic_rl_learner.AgenticRLConfig( max_response_length=20, # Mismatch: 10 != 20 + use_rollout_logps=True, ) with self.assertRaisesRegex( @@ -65,6 +67,7 @@ def test_validate_rollout_config_missing_logprobs(self): algo_config = agentic_rl_learner.AgenticRLConfig( max_response_length=10, + use_rollout_logps=True, ) with self.assertRaisesRegex( @@ -97,6 +100,7 @@ def test_validate_rollout_config_dict_mode(self): algo_config = agentic_rl_learner.AgenticRLConfig( max_response_length=10, + use_rollout_logps=True, ) with self.assertRaisesRegex( @@ -122,6 +126,7 @@ def test_validate_rollout_config_vllm_missing_server_mode(self): algo_config = agentic_rl_learner.AgenticRLConfig( max_response_length=10, + use_rollout_logps=True, ) with self.assertRaisesRegex( diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index 1e635d9b9..ff4f3a149 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -110,6 +110,7 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig): degenerate_group_masking: bool = ( True # Whether to mask out degenerate groups with all-0 advantages. ) + use_rollout_logps: bool = False def __post_init__(self): if self.num_generations <= 1: @@ -370,19 +371,20 @@ def _process_results( :max_response_length ] ) - if old_logprobs is not None: - padded_old_logprobs.append( - agentic_utils.right_pad( - old_logprobs, - length=max_response_length, - pad=0.0, - dtype=old_logprobs.dtype, - )[:max_response_length] - ) - else: - padded_old_logprobs.append( - np.zeros(max_response_length, dtype=np.float32) - ) + if self.algo_config.use_rollout_logps: + if old_logprobs is not None: + padded_old_logprobs.append( + agentic_utils.right_pad( + old_logprobs, + length=max_response_length, + pad=0.0, + dtype=old_logprobs.dtype, + )[:max_response_length] + ) + else: + padded_old_logprobs.append( + np.zeros(max_response_length, dtype=np.float32) + ) prompt_ids = jnp.asarray(padded_prompt_ids) prompt_mask = prompt_ids != pad_value @@ -394,12 +396,19 @@ def _process_results( completion_ids.shape, ) - if padded_old_logprobs and len(padded_old_logprobs) == len( - completion_tokens_list - ): + if self.algo_config.use_rollout_logps and padded_old_logprobs: old_per_token_logps = jnp.asarray(padded_old_logprobs) - else: + elif self.algo_config.use_rollout_logps: old_per_token_logps = None + else: + old_per_token_logps = self.rl_cluster.get_actor_per_token_logps( + prompt_tokens=prompt_ids, + completion_tokens=completion_ids, + pad_id=pad_value, + eos_id=eos_value, + micro_batch_size=None, + completion_mask=completion_mask, + ) if self.algo_config.num_iterations > 1 and old_per_token_logps is None: raise RuntimeError( diff --git a/tunix/rl/agentic/agentic_rl_learner.py b/tunix/rl/agentic/agentic_rl_learner.py index 54c8ebfa3..a5f966229 100644 --- a/tunix/rl/agentic/agentic_rl_learner.py +++ b/tunix/rl/agentic/agentic_rl_learner.py @@ -91,6 +91,7 @@ class AgenticRLConfig(algo_config_lib.AlgorithmConfig): episode_timeout: float = 1800.0 filter_statuses: Optional[Set] = None overlong_filter: bool = False + use_rollout_logps: bool = False TConfig = TypeVar("TConfig", bound=AgenticRLConfig) @@ -252,10 +253,11 @@ def _validate_rollout_config(self): f"max_response_length ({self.algo_config.max_response_length}). " "Please align these configurations before initializing RLCluster." ) - if not config.return_logprobs: + if self.algo_config.use_rollout_logps and not config.return_logprobs: raise ValueError( f"RolloutConfig ({mode}) must have return_logprobs=True for " - "AgenticRLLearner. Please set this before initializing RLCluster." + "AgenticRLLearner when use_rollout_logps=True. Please set this " + "before initializing RLCluster." ) if ( self.rl_cluster.cluster_config.rollout_engine == "vllm" @@ -412,8 +414,6 @@ def _model_call( if "pair_index" in env.extra_kwargs: tags[perf_constants.PAIR_INDEX] = env.extra_kwargs["pair_index"] - - result = self.rl_cluster.generate( prompts=chat_lists, apply_chat_template=False if self.chat_parser else True, diff --git a/tunix/rl/rl_cluster.py b/tunix/rl/rl_cluster.py index 511232d1e..ec65050fb 100644 --- a/tunix/rl/rl_cluster.py +++ b/tunix/rl/rl_cluster.py @@ -45,6 +45,7 @@ from tunix.perf import trace as perf_trace from tunix.perf.experimental import constants as perf_constants from tunix.perf.experimental import tracer as perf_tracer_v2 +from tunix.rl import common from tunix.rl import reshard from tunix.rl import trainer as rl_trainer from tunix.rl import utils as rl_utils @@ -981,6 +982,13 @@ def get_ref_per_token_logps( completion_tokens, self.cluster_config.training_config.data_sharding_axis, ) + if completion_mask is not None: + dest_completion_mask = sharding_utils.shard_input( + completion_mask, + self.cluster_config.training_config.data_sharding_axis, + ) + else: + dest_completion_mask = None self._maybe_load_model_from_cpu( self.inference_worker.get_model("reference"), Role.REFERENCE ) @@ -996,8 +1004,8 @@ def get_ref_per_token_logps( pad_id, eos_id, completion_mask=None - if completion_mask is None - else completion_mask[batch_slice], + if dest_completion_mask is None + else dest_completion_mask[batch_slice], temperature=temperature, ) ) @@ -1045,6 +1053,63 @@ def get_old_per_token_logps( self.rollout.update_params(nnx.state(model)) return per_token_logps + def get_actor_per_token_logps( + self, + prompt_tokens: jax.Array, + completion_tokens: jax.Array, + pad_id: int, + eos_id: int, + micro_batch_size: int | None = None, + completion_mask: jax.Array | None = None, + ) -> jax.Array: + """Gets per-token logps from the actor model on the trainer side.""" + batch_size = prompt_tokens.shape[0] + if batch_size == 0: + raise ValueError( + "Cannot get actor log probabilities from an empty batch." + ) + micro_batch_size = micro_batch_size or batch_size + with self._get_mesh_and_logical_axis_rules_cm(Role.ACTOR): + dest_prompt_tokens = sharding_utils.shard_input( + prompt_tokens, + self.cluster_config.training_config.data_sharding_axis, + ) + dest_completion_tokens = sharding_utils.shard_input( + completion_tokens, + self.cluster_config.training_config.data_sharding_axis, + ) + if completion_mask is not None: + dest_completion_mask = sharding_utils.shard_input( + completion_mask, + self.cluster_config.training_config.data_sharding_axis, + ) + else: + dest_completion_mask = None + self._maybe_load_model_from_cpu(self.actor_trainer.model, Role.ACTOR) + graphdef, state = nnx.split(self.actor_trainer.model) + outs = [] + for batch_slice in rl_utils.chunk_slices_by_size( + stop=batch_size, step=micro_batch_size + ): + outs.append( + common.compute_per_token_logps( + graphdef, + state, + prompt_tokens=dest_prompt_tokens[batch_slice], + completion_tokens=dest_completion_tokens[batch_slice], + pad_id=pad_id, + eos_id=eos_id, + completion_mask=None + if dest_completion_mask is None + else dest_completion_mask[batch_slice], + stop_gradient=True, + return_logits=False, + ) + ) + actor_per_token_logps = jnp.concatenate(outs, axis=0) + self._maybe_offload_model_to_cpu(self.actor_trainer.model, Role.ACTOR) + return actor_per_token_logps + def sync_weights(self): """Syncs the weights of between the sampler model and trainer model.""" if jax.devices() and jax.default_backend() not in ["tpu", "gpu"]: