Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/deepswe/train_deepswe_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@
"--loss_agg_mode", type=str, default="sequence-mean-token-scale"
)
parser.add_argument("--advantage_estimator", type=str, default="rloo")
parser.add_argument(
"--use_rollout_logps",
type=bool,
default=False,
help=(
"Whether to use rollout-cached logprobs as old policy logps. "
"Default is False to recompute old logps on the actor side. "
),
)


# Other
Expand Down Expand Up @@ -455,6 +464,7 @@
)
LOSS_AGG_MODE = args.loss_agg_mode
ADVANTAGE_ESTIMATOR = args.advantage_estimator
USE_ROLLOUT_LOGPS = args.use_rollout_logps


# %%
Expand Down Expand Up @@ -673,7 +683,7 @@ def transform(entry):
"top_p": TOP_P,
"top_k": TOP_K,
"eos_tokens": [tokenizer.encode("<|im_end|>")[0]],
"return_logprobs": True,
"return_logprobs": USE_ROLLOUT_LOGPS,
"max_tokens_to_generate": MAX_RESPONSE_LENGTH,
}

Expand Down Expand Up @@ -778,6 +788,7 @@ def transform(entry):
"filter_statuses": FILTER_STATUSES,
"loss_agg_mode": LOSS_AGG_MODE,
"advantage_estimator": ADVANTAGE_ESTIMATOR,
"use_rollout_logps": USE_ROLLOUT_LOGPS,
}

grpo_config = agentic_grpo_learner.GRPOConfig(**config_kwargs)
Expand Down
1 change: 1 addition & 0 deletions tests/rl/agentic/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ def __init__(
loss_algo="grpo",
degenerate_group_masking=masking,
max_response_length=10,
use_rollout_logps=True,
)

learner = agentic_grpo_learner.GRPOLearner(
Expand Down
5 changes: 5 additions & 0 deletions tests/rl/agentic/agentic_rl_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DummyLearner(agentic_rl_learner.AgenticRLLearner):
def _process_results(self, **kwargs):
return []


class AgenticRLLearnerTest(parameterized.TestCase):

def test_validate_rollout_config_mismatch_max_tokens(self):
Expand All @@ -41,6 +42,7 @@ def test_validate_rollout_config_mismatch_max_tokens(self):

algo_config = agentic_rl_learner.AgenticRLConfig(
max_response_length=20, # Mismatch: 10 != 20
use_rollout_logps=True,
)

with self.assertRaisesRegex(
Expand All @@ -65,6 +67,7 @@ def test_validate_rollout_config_missing_logprobs(self):

algo_config = agentic_rl_learner.AgenticRLConfig(
max_response_length=10,
use_rollout_logps=True,
)

with self.assertRaisesRegex(
Expand Down Expand Up @@ -97,6 +100,7 @@ def test_validate_rollout_config_dict_mode(self):

algo_config = agentic_rl_learner.AgenticRLConfig(
max_response_length=10,
use_rollout_logps=True,
)

with self.assertRaisesRegex(
Expand All @@ -122,6 +126,7 @@ def test_validate_rollout_config_vllm_missing_server_mode(self):

algo_config = agentic_rl_learner.AgenticRLConfig(
max_response_length=10,
use_rollout_logps=True,
)

with self.assertRaisesRegex(
Expand Down
43 changes: 26 additions & 17 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
degenerate_group_masking: bool = (
True # Whether to mask out degenerate groups with all-0 advantages.
)
use_rollout_logps: bool = False

def __post_init__(self):
if self.num_generations <= 1:
Expand Down Expand Up @@ -370,19 +371,20 @@ def _process_results(
:max_response_length
]
)
if old_logprobs is not None:
padded_old_logprobs.append(
agentic_utils.right_pad(
old_logprobs,
length=max_response_length,
pad=0.0,
dtype=old_logprobs.dtype,
)[:max_response_length]
)
else:
padded_old_logprobs.append(
np.zeros(max_response_length, dtype=np.float32)
)
if self.algo_config.use_rollout_logps:
if old_logprobs is not None:
padded_old_logprobs.append(
agentic_utils.right_pad(
old_logprobs,
length=max_response_length,
pad=0.0,
dtype=old_logprobs.dtype,
)[:max_response_length]
)
else:
padded_old_logprobs.append(
np.zeros(max_response_length, dtype=np.float32)
)

prompt_ids = jnp.asarray(padded_prompt_ids)
prompt_mask = prompt_ids != pad_value
Expand All @@ -394,12 +396,19 @@ def _process_results(
completion_ids.shape,
)

if padded_old_logprobs and len(padded_old_logprobs) == len(
completion_tokens_list
):
if self.algo_config.use_rollout_logps and padded_old_logprobs:
old_per_token_logps = jnp.asarray(padded_old_logprobs)
else:
elif self.algo_config.use_rollout_logps:
old_per_token_logps = None
else:
old_per_token_logps = self.rl_cluster.get_actor_per_token_logps(
prompt_tokens=prompt_ids,
completion_tokens=completion_ids,
pad_id=pad_value,
eos_id=eos_value,
micro_batch_size=None,
completion_mask=completion_mask,
)

if self.algo_config.num_iterations > 1 and old_per_token_logps is None:
raise RuntimeError(
Expand Down
8 changes: 4 additions & 4 deletions tunix/rl/agentic/agentic_rl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class AgenticRLConfig(algo_config_lib.AlgorithmConfig):
episode_timeout: float = 1800.0
filter_statuses: Optional[Set] = None
overlong_filter: bool = False
use_rollout_logps: bool = False


TConfig = TypeVar("TConfig", bound=AgenticRLConfig)
Expand Down Expand Up @@ -252,10 +253,11 @@ def _validate_rollout_config(self):
f"max_response_length ({self.algo_config.max_response_length}). "
"Please align these configurations before initializing RLCluster."
)
if not config.return_logprobs:
if self.algo_config.use_rollout_logps and not config.return_logprobs:
raise ValueError(
f"RolloutConfig ({mode}) must have return_logprobs=True for "
"AgenticRLLearner. Please set this before initializing RLCluster."
"AgenticRLLearner when use_rollout_logps=True. Please set this "
"before initializing RLCluster."
)
if (
self.rl_cluster.cluster_config.rollout_engine == "vllm"
Expand Down Expand Up @@ -412,8 +414,6 @@ def _model_call(
if "pair_index" in env.extra_kwargs:
tags[perf_constants.PAIR_INDEX] = env.extra_kwargs["pair_index"]



result = self.rl_cluster.generate(
prompts=chat_lists,
apply_chat_template=False if self.chat_parser else True,
Expand Down
69 changes: 67 additions & 2 deletions tunix/rl/rl_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from tunix.perf import trace as perf_trace
from tunix.perf.experimental import constants as perf_constants
from tunix.perf.experimental import tracer as perf_tracer_v2
from tunix.rl import common
from tunix.rl import reshard
from tunix.rl import trainer as rl_trainer
from tunix.rl import utils as rl_utils
Expand Down Expand Up @@ -981,6 +982,13 @@ def get_ref_per_token_logps(
completion_tokens,
self.cluster_config.training_config.data_sharding_axis,
)
if completion_mask is not None:
dest_completion_mask = sharding_utils.shard_input(
completion_mask,
self.cluster_config.training_config.data_sharding_axis,
)
else:
dest_completion_mask = None
self._maybe_load_model_from_cpu(
self.inference_worker.get_model("reference"), Role.REFERENCE
)
Expand All @@ -996,8 +1004,8 @@ def get_ref_per_token_logps(
pad_id,
eos_id,
completion_mask=None
if completion_mask is None
else completion_mask[batch_slice],
if dest_completion_mask is None
else dest_completion_mask[batch_slice],
temperature=temperature,
)
)
Expand Down Expand Up @@ -1045,6 +1053,63 @@ def get_old_per_token_logps(
self.rollout.update_params(nnx.state(model))
return per_token_logps

def get_actor_per_token_logps(
self,
prompt_tokens: jax.Array,
completion_tokens: jax.Array,
pad_id: int,
eos_id: int,
micro_batch_size: int | None = None,
completion_mask: jax.Array | None = None,
) -> jax.Array:
"""Gets per-token logps from the actor model on the trainer side."""
batch_size = prompt_tokens.shape[0]
if batch_size == 0:
raise ValueError(
"Cannot get actor log probabilities from an empty batch."
)
micro_batch_size = micro_batch_size or batch_size
with self._get_mesh_and_logical_axis_rules_cm(Role.ACTOR):
dest_prompt_tokens = sharding_utils.shard_input(
prompt_tokens,
self.cluster_config.training_config.data_sharding_axis,
)
dest_completion_tokens = sharding_utils.shard_input(
completion_tokens,
self.cluster_config.training_config.data_sharding_axis,
)
if completion_mask is not None:
dest_completion_mask = sharding_utils.shard_input(
completion_mask,
self.cluster_config.training_config.data_sharding_axis,
)
else:
dest_completion_mask = None
self._maybe_load_model_from_cpu(self.actor_trainer.model, Role.ACTOR)
graphdef, state = nnx.split(self.actor_trainer.model)
outs = []
for batch_slice in rl_utils.chunk_slices_by_size(
stop=batch_size, step=micro_batch_size
):
outs.append(
common.compute_per_token_logps(
graphdef,
state,
prompt_tokens=dest_prompt_tokens[batch_slice],
completion_tokens=dest_completion_tokens[batch_slice],
pad_id=pad_id,
eos_id=eos_id,
completion_mask=None
if dest_completion_mask is None
else dest_completion_mask[batch_slice],
stop_gradient=True,
return_logits=False,
)
)
actor_per_token_logps = jnp.concatenate(outs, axis=0)
self._maybe_offload_model_to_cpu(self.actor_trainer.model, Role.ACTOR)
return actor_per_token_logps

def sync_weights(self):
"""Syncs the weights of between the sampler model and trainer model."""
if jax.devices() and jax.default_backend() not in ["tpu", "gpu"]:
Expand Down
Loading