From c9fcfc1df3e70cd96f4ee280f01b0db58ceb5c8b Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 28 Apr 2026 05:22:52 +0000 Subject: [PATCH 1/3] feat: add cli example for ppo Co-authored-by: sizhit2 <32147610+sizhit2@users.noreply.github.com> --- examples/rl/README.md | 7 + examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml | 80 +++ examples/rl/ppo/gsm8k/run_gemma2_2b.sh | 54 ++ plan.md | 5 + tests/cli/ppo_main_test.py | 135 +++++ tunix/cli/ppo_main.py | 594 +++++++++++++++++++ 6 files changed, 875 insertions(+) create mode 100644 examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml create mode 100755 examples/rl/ppo/gsm8k/run_gemma2_2b.sh create mode 100644 plan.md create mode 100644 tests/cli/ppo_main_test.py create mode 100644 tunix/cli/ppo_main.py diff --git a/examples/rl/README.md b/examples/rl/README.md index 16a0bda73..9a997dcc2 100644 --- a/examples/rl/README.md +++ b/examples/rl/README.md @@ -29,3 +29,10 @@ does not match the exact configuration listed in the table. | GRPO | **Qwen3 0.6b**| LoRA | v5e-1 | Num of generation = 4, batch_size = 1 | Train: fsdp Rollout: tp | | | GRPO | **Qwen3 14b** | Full | v5p-2 | Num of generation = 4, batch_size = 4 | Train: fsdp Rollout: tp | | | GRPO | **Qwen3 14b** | LoRA | v5p-2 | Num of generation = 4, batch_size = 4 | Train: fsdp Rollout: tp | | + +## PPO + +| Algo | Model | Type | Min Resources | Max Training Micro Batch Size | Sharding | Launch Script | +| :--- | :------------ | :---- | :------------ | :-------------------------------- | :-------------------- | :---------------------------- | +| PPO | **Gemma2-2b** | Full | v5e-4 | Num of generation = 4, batch_size = 1 | Train: fsdp Rollout: tp | *[run_gemma2_2b.sh](examples/rl/ppo/gsm8k/run_gemma2_2b.sh)* | +| PPO | **Gemma2-2b** | LoRA | v5e-4 | Num of generation = 4, batch_size = 1 | Train: fsdp Rollout: tp | *[run_gemma2_2b.sh](examples/rl/ppo/gsm8k/run_gemma2_2b.sh)* | diff --git a/examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml b/examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml new file mode 100644 index 000000000..7708076b4 --- /dev/null +++ b/examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml @@ -0,0 +1,80 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_config: + model_name: "gemma2_2b_it" + model_id: "gemma2_2b_it" + model_path: "google/gemma-2/flax/gemma2-2b-it" + model_source: "kaggle" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" + rng_seed: 42 +actor_model_config: + lora_config: + rank: 64 + alpha: 64.0 + module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" +rollout_model_config: + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" +tokenizer_config: + tokenizer_type: "sentencepiece" + add_bos: False +dataset_name: "gsm8k" +batch_size: 1 +num_batches: 3738 +num_test_batches: 100 +num_train_epochs: 1 +rl_training_config: + actor_optimizer_config: + opt_type: "adamw" + peak_value: 3e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + end_value: 0.0 + warmup_ratio: 0.1 + warmup_steps: 374 + decay_steps: 3738 + b1: 0.9 + b2: 0.99 + weight_decay: 0.1 + max_grad_norm: 0.1 + eval_every_n_steps: 10 + max_steps: 3738 + metrics_logging_options: + flush_every_n_steps: 20 + checkpointing_options: + save_interval_steps: 500 + max_to_keep: 4 + profiler_options: {} +rollout_config: + total_generation_steps: 768 + max_prompt_length: 256 + temperature: 0.9 + top_p: 1.0 + top_k: 50 +rollout_engine: "vanilla" +offload_to_cpu: False +ppo_config: + num_generations: 2 + num_iterations: 1 + beta: 0.08 + epsilon: 0.2 +reward_functions: + - "tunix/cli/reward_fn/gsm8k.py" diff --git a/examples/rl/ppo/gsm8k/run_gemma2_2b.sh b/examples/rl/ppo/gsm8k/run_gemma2_2b.sh new file mode 100755 index 000000000..174c4d8f2 --- /dev/null +++ b/examples/rl/ppo/gsm8k/run_gemma2_2b.sh @@ -0,0 +1,54 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -x # Enable xtrace + +batch_size=${batch_size:-1} +num_batches=${num_batches:-3738} +num_train_epochs=${num_train_epochs:-1} +warmup_ratio=${warmup_ratio:-0.1} +train_fraction=${train_fraction:-1.0} + +echo "Using parameters:" +echo " Batch Size: $batch_size" +echo " Num Batches: $num_batches" +echo " Num Epochs: $num_train_epochs" +echo " Warmup Ratio: $warmup_ratio" +echo " Train Fraction: $train_fraction" + +max_steps_float=$(awk "BEGIN {print $batch_size * $num_batches * $num_train_epochs * $train_fraction}") + +max_steps=$(printf "%.0f" "$max_steps_float") + + +warmup_steps=$(awk "BEGIN {printf \"%.0f\", $warmup_ratio * $max_steps}") + +echo "Max steps: $max_steps" +echo "Rounded warmup steps: $warmup_steps" + +python3 -m tunix.cli.ppo_main \ + tunix/cli/base_config.yaml \ + override_config_file=examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml \ + model_config.model_download_path="/tmp/models/gemma2-2b" \ + model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/1" \ + tokenizer_config.tokenizer_path="/tmp/models/gemma2-2b/models/google/gemma-2/flax/gemma2-2b-it/1/tokenizer.model" \ + batch_size=$batch_size \ + num_batches=$num_batches \ + num_train_epochs=$num_train_epochs \ + rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ + rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \ + rl_training_config.actor_optimizer_config.decay_steps=$max_steps \ + rl_training_config.max_steps=$max_steps \ + rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/ppo" diff --git a/plan.md b/plan.md new file mode 100644 index 000000000..fd6a803b8 --- /dev/null +++ b/plan.md @@ -0,0 +1,5 @@ +1. **Understand PPO vs GRPO for Agentic RL:** The user wants PPO to be used in agentic mode. Currently, the codebase has `agentic_grpo_learner.py` which computes advantages for group-wise generation. For PPO, advantages are usually computed using GAE (Generalized Advantage Estimation) or similar, based on value function estimates. However, the exact implementation details of `ppo_learner.py` should be inspected. +2. **Review `tunix/rl/ppo/ppo_learner.py`:** Understand how PPO computes advantages and updates policies in the standard mode. We need to implement `agentic_ppo_learner.py` to match the API of `agentic_rl_learner.py` and `ppo_learner.py`. +3. **Implement `tunix/cli/ppo_main.py`:** Create the CLI entry point for PPO. It should support both `ppo` and `agentic_ppo` modes. +4. **Add Examples:** Create `examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml` and `examples/rl/ppo/gsm8k/run_gemma2_2b.sh`. +5. **Update README and Notebook:** Add PPO to `examples/rl/README.md` and create `examples/ppo_gemma.ipynb`. diff --git a/tests/cli/ppo_main_test.py b/tests/cli/ppo_main_test.py new file mode 100644 index 000000000..d641eea9e --- /dev/null +++ b/tests/cli/ppo_main_test.py @@ -0,0 +1,135 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that ppo_main dispatches correctly for both training modes + +and that KV cache computation is correct. +""" + +import os +import pathlib +import tempfile +from unittest import mock + +from absl.testing import absltest +import omegaconf +from tunix.cli import ppo_main +from tunix.rl import rl_cluster as rl_cluster_lib + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_REPO_ROOT = pathlib.Path(__file__).resolve().parents[2] + + +def _make_pipeline(extra_yaml: str) -> ppo_main.PPOPipeline: + """Write a minimal valid YAML and instantiate PPOPipeline against it.""" + base = """ +model_config: + model_name: "test_model" + model_id: "test/model" + model_source: "huggingface" + model_display: false + rng_seed: 0 + intermediate_ckpt_dir: "/tmp/ckpt" + +actor_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +reference_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +rollout_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +tokenizer_config: + tokenizer_type: "huggingface" + tokenizer_path: "test/model" + add_bos: false + add_eos: false + +rollout_engine: "vanilla" +offload_to_cpu: false + +rollout_config: + max_prompt_length: 256 + total_generation_steps: 512 + temperature: 1.0 + top_p: null + top_k: null + +rl_training_config: + max_steps: 1 + eval_every_n_steps: 1 + mini_batch_size: 1 + train_micro_batch_size: 1 + actor_optimizer_config: + opt_type: "adamw" + learning_rate: 1.0e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + end_value: 0.0 + warmup_ratio: 0.1 + b1: 0.9 + b2: 0.99 + weight_decay: 0.01 + max_grad_norm: 1.0 + metrics_logging_options: + log_dir: "/tmp/tb_test" + flush_every_n_steps: 1 + checkpointing_options: + save_interval_steps: 100 + max_to_keep: 1 + checkpoint_root_directory: "/tmp/ckpt_test" + +batch_size: 1 +num_batches: 1 +num_train_epochs: 1 +train_fraction: 1.0 +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(base + extra_yaml) + path = f.name + + # Patch HF_TOKEN so tokenizer validation passes + with mock.patch.dict(os.environ, {"HF_TOKEN": "fake"}): + pipeline = ppo_main.PPOPipeline(["", path]) + os.unlink(path) + return pipeline + +class DispatchTest(absltest.TestCase): + + def test_ppo_dispatches_to_ppo(self): + yaml = """ +training_mode: "ppo" +ppo_config: + num_generations: 2 + num_iterations: 1 +""" + pipeline = _make_pipeline(yaml) + self.assertEqual(pipeline.config["training_mode"], "ppo") + + with mock.patch.object(pipeline, "run_ppo_trainer") as mockrun_ppo_trainer: + pipeline.run_ppo_trainer() + mockrun_ppo_trainer.assert_called_once_with() + +if __name__ == "__main__": + absltest.main() diff --git a/tunix/cli/ppo_main.py b/tunix/cli/ppo_main.py new file mode 100644 index 000000000..4d2a14c4c --- /dev/null +++ b/tunix/cli/ppo_main.py @@ -0,0 +1,594 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main entry point for PPO training (standard and agentic). + +Set ``training_mode: "ppo"`` (default) for standard single-turn PPO, or +``training_mode: "agentic_ppo"`` for agentic multi-turn PPO (DeepScaleR, +DeepSWE, etc.). + +Usage:: + + # Standard PPO + python -m tunix.cli.ppo_main examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml + + # Agentic PPO — DeepScaleR + bash examples/deepscaler/run_deepscaler_disagg.sh + + # Agentic PPO — DeepSWE + python -m tunix.cli.ppo_main examples/deepswe/configs/qwen3_32b.yaml +""" + +import collections +import dataclasses +import importlib +import os +from types import ModuleType +from typing import Any + +from absl import app +from absl import flags +from absl import logging +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +from tunix.cli import config +from tunix.cli.utils import data as data_lib +from tunix.cli.utils import model as model_lib +from tunix.examples.data import math_dataset as example_data +from tunix.perf import export as perf_export +from tunix.perf import metrics as perf_metrics +from tunix.perf.experimental import export as perf_export_v2 +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout + + +_PATHWAYS_BNS = flags.DEFINE_string( + "pathways_bns", None, "BNS address of the Pathways server." +) + + +class PPOPipeline(config.HyperParameters): + """Runs standard PPO or agentic PPO depending on ``training_mode``. + + ``training_mode: "ppo"`` (default) — standard single-turn PPO using + PPOLearner. All existing YAML configs continue to work unchanged. + + ``training_mode: "agentic_ppo"`` — multi-turn agentic PPO using + PPOLearner. Additional config sections are recognised: + + * ``agentic_ppo_config``: PPOConfig fields (num_generations, beta, …) + plus ``max_turns``, ``context_ratio``, ``per_turn_timeout_secs``. + * role-specific ``*_model_config.mesh``: any role with an explicit mesh gets + its own device slice; omitted meshes share the actor mesh by default. + * role-specific ``same_mesh_as``: optional mesh sharing like + ``reference_model_config.same_mesh_as: actor``. + * ``sglang_jax_config`` / ``vllm_config``: engine-specific rollout params. + * ``chat_parser_config.type``: ``"default"`` or ``"qwen"``. + * ``agent_class_path`` / ``env_class_path``: dotted Python paths to load + agent and env classes dynamically. + * ``data_module``: dotted module path; the module must expose + ``create_dataset(**data_config) -> grain.MapDataset`` and optionally a + ``batch_fn`` used as ``custom_batch_fn`` in post_init_dataset. + * ``kubernetes_config``: optional Kubernetes env-var and kube-config setup. + """ + + def __init__(self, argv: list[str], **kwargs): + self.data_module: ModuleType | None = None + super().__init__(argv, **kwargs) + + # ------------------------------------------------------------------ + # Mesh + # ------------------------------------------------------------------ + _ROLE_TO_MODEL_KEY = { + rl_cluster_lib.Role.ACTOR: "actor_model_config", + rl_cluster_lib.Role.CRITIC: "critic_model_config", + rl_cluster_lib.Role.REFERENCE: "reference_model_config", + rl_cluster_lib.Role.REWARD: "reward_model_config", + rl_cluster_lib.Role.ROLLOUT: "rollout_model_config", + } + _SPLIT_ROLE_ALIASES = { + "actor": rl_cluster_lib.Role.ACTOR, + "critic": rl_cluster_lib.Role.CRITIC, + "reference": rl_cluster_lib.Role.REFERENCE, + "reward": rl_cluster_lib.Role.REWARD, + "rollout": rl_cluster_lib.Role.ROLLOUT, + } + + def _resolve_split_role(self, role_name: str) -> rl_cluster_lib.Role: + normalized = role_name.strip().lower() + if normalized not in self._SPLIT_ROLE_ALIASES: + valid_roles = sorted(self._SPLIT_ROLE_ALIASES) + raise ValueError( + f"Unknown role name {role_name!r}. Expected one of {valid_roles}." + ) + return self._SPLIT_ROLE_ALIASES[normalized] + + def _get_same_mesh_as_map( + self, + ) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]: + same_mesh_as = {} + for role, model_key in self._ROLE_TO_MODEL_KEY.items(): + model_cfg = self.config.get(model_key, {}) or {} + target_name = model_cfg.get("same_mesh_as") + if target_name is None: + continue + target_role = self._resolve_split_role(str(target_name)) + if role == rl_cluster_lib.Role.ACTOR: + raise ValueError("Actor must own its mesh.") + same_mesh_as[role] = target_role + + return same_mesh_as + + def _is_role_active(self, role: rl_cluster_lib.Role) -> bool: + if role in ( + rl_cluster_lib.Role.ACTOR, + rl_cluster_lib.Role.REFERENCE, + rl_cluster_lib.Role.ROLLOUT, + ): + return True + model_key = self._ROLE_TO_MODEL_KEY[role] + return model_key in self.config + + def _resolve_mesh_owners( + self, + ) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]: + same_mesh_as = self._get_same_mesh_as_map() + base_owners = {} + for role, model_key in self._ROLE_TO_MODEL_KEY.items(): + if not self._is_role_active(role) and role not in same_mesh_as: + continue + has_mesh = bool(self.config.get(model_key, {}).get("mesh")) + base_owners[role] = ( + role + if role == rl_cluster_lib.Role.ACTOR or has_mesh + else rl_cluster_lib.Role.ACTOR + ) + + def resolve_owner( + role: rl_cluster_lib.Role, + seen: set[rl_cluster_lib.Role], + ) -> rl_cluster_lib.Role: + if role in seen: + raise ValueError("same_mesh_as contains a cycle.") + if role not in same_mesh_as: + return base_owners[role] + seen.add(role) + target_role = same_mesh_as[role] + if target_role not in base_owners: + raise ValueError( + f"Role {target_role.value!r} is not active in this config." + ) + return resolve_owner(target_role, seen) + + role_to_owner = {} + for role, model_key in self._ROLE_TO_MODEL_KEY.items(): + if role not in base_owners: + continue + has_mesh = bool(self.config.get(model_key, {}).get("mesh")) + if role in same_mesh_as: + if has_mesh: + raise ValueError( + f"{model_key}.mesh is specified, so it must own a separate mesh " + "and cannot also use same_mesh_as." + ) + else: + role_to_owner[role] = resolve_owner(role, set()) + continue + role_to_owner[role] = resolve_owner(role, set()) + return role_to_owner + + def _create_role_to_mesh(self): + devices = list(jax.devices()) + role_to_owner = self._resolve_mesh_owners() + owner_order = [] + for role in self._ROLE_TO_MODEL_KEY: + if role not in role_to_owner: + continue + owner = role_to_owner[role] + if owner not in owner_order: + owner_order.append(owner) + + owner_to_mesh = {} + owner_to_device_slice = {} + device_offset = 0 + for owner in owner_order: + model_key = self._ROLE_TO_MODEL_KEY[owner] + axis_shapes, _ = self._parse_mesh_config(model_key) + required_devices = int(np.prod(axis_shapes)) + next_offset = device_offset + required_devices + if next_offset > len(devices): + raise ValueError( + f"Mesh allocation requires {next_offset} devices after allocating" + f" {model_key}, but only {len(devices)} are available." + ) + assigned_devices = devices[device_offset:next_offset] + owner_to_device_slice[owner] = assigned_devices + owner_to_mesh[owner] = self.create_mesh( + model_key, devices=assigned_devices + ) + device_offset = next_offset + + if device_offset < len(devices): + logging.warning( + "Mesh allocation used %d of %d devices; %d devices remain unused.", + device_offset, + len(devices), + len(devices) - device_offset, + ) + logging.info( + "Mesh device allocation: %s", + { + self._ROLE_TO_MODEL_KEY[owner]: len(owner_to_device_slice[owner]) + for owner in owner_order + }, + ) + return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} + + def create_role_to_mesh(self): + """Build role→mesh mapping. + + Any role with an explicit ``*.mesh`` config gets a dedicated device slice. + Roles without a mesh share the actor mesh by default, or can point at + another role via ``same_mesh_as``. + """ + return self._create_role_to_mesh() + + # ------------------------------------------------------------------ + # Rollout config + # ------------------------------------------------------------------ + + def create_rollout_config( + self, + role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None, + ) -> base_rollout.RolloutConfig: + """Build RolloutConfig from YAML. + + Standard mode: pass rollout_config fields through with kv_cache_size = + max_prompt_length + total_generation_steps + 256. + + Agentic mode: same base, but multi-turn KV cache = + max_prompt + total_generation_steps * context_ratio * max_turns. + Engine-specific extras (sglang_jax_config, vllm_config) are also applied. + """ + rollout_cfg = self.config["rollout_config"] + mode = self.config.get("training_mode", "ppo") + engine = self.config.get("rollout_engine", "vanilla") + + valid_fields = { + f.name for f in dataclasses.fields(base_rollout.RolloutConfig) + } + + # Base pass-through (same as original create_rollout_config) + filtered = {k: v for k, v in rollout_cfg.items() if k in valid_fields} + if "total_generation_steps" in rollout_cfg: + filtered["max_tokens_to_generate"] = rollout_cfg["total_generation_steps"] + + max_prompt = rollout_cfg.get("max_prompt_length", 0) + max_response = rollout_cfg.get("total_generation_steps", 0) + + # Standard: kv_cache_size = max_prompt + max_response + 256 + if max_prompt and max_response: + filtered["kv_cache_size"] = max_prompt + max_response + 256 + + return base_rollout.RolloutConfig(**filtered) + + + return {} + + # ------------------------------------------------------------------ + # Standard PPO helpers (unchanged) + # ------------------------------------------------------------------ + + def create_cluster_config( + self, + *, + role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh], + rollout_config: base_rollout.RolloutConfig | None = None, + ): + if rollout_config is None: + rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh) + return rl_cluster_lib.ClusterConfig( + role_to_mesh=role_to_mesh, + rollout_engine=self.config["rollout_engine"], + offload_to_cpu=self.config["offload_to_cpu"], + training_config=self.create_rl_training_config(), + rollout_config=rollout_config, + ) + + def create_rl_training_config(self): + base_key = "rl_training_config" + constructed_rl_training_config = self.obtain_training_config_dict(base_key) + + base_config = self.config[base_key] + if base_config.get("actor_optimizer_config"): + constructed_rl_training_config["actor_optimizer"] = self.create_optimizer( + base_key, "actor_optimizer_config" + ) + if base_config.get("critic_optimizer_config"): + constructed_rl_training_config["critic_optimizer"] = ( + self.create_optimizer(base_key, "critic_optimizer_config") + ) + + return rl_cluster_lib.RLTrainingConfig(**constructed_rl_training_config) + + def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig): + perf_metrics_options = cluster_config.training_config.perf_metrics_options + if not perf_metrics_options: + return None + + perf_config = perf_metrics.PerfMetricsConfig() + + if perf_metrics_options.enable_perf_v1: + custom_export_fn_path = perf_metrics_options.custom_export_fn_path + if custom_export_fn_path: + perf_config.custom_export_fn = self._get_function_from_path( + custom_export_fn_path + ) + if perf_config.custom_export_fn is None: + raise ValueError( + "Could not load custom export function from" + f" {custom_export_fn_path}" + ) + else: + perf_config.custom_export_fn = ( + perf_export.PerfMetricsExport.from_cluster_config(cluster_config) + ) + + if perf_metrics_options.enable_perf_v2: + custom_export_fn_path_v2 = perf_metrics_options.custom_export_fn_path_v2 + if custom_export_fn_path_v2: + perf_config.custom_export_fn_v2 = self._get_function_from_path( + custom_export_fn_path_v2 + ) + if perf_config.custom_export_fn_v2 is None: + raise ValueError( + "Could not load custom export function v2 from" + f" {custom_export_fn_path_v2}" + ) + else: + perf_config.custom_export_fn_v2 = ( + perf_export_v2.PerfMetricsExport.from_cluster_config( + cluster_config=cluster_config, + enable_trace_writer=perf_metrics_options.enable_trace_writer, + trace_dir=perf_metrics_options.trace_dir, + ).export_metrics + ) + return perf_config + + def create_rl_cluster(self, tokenizer): + role_to_mesh = self.create_role_to_mesh() + rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh) + # Should not use LoRA for reference model. + if self.config["reference_model_config"].get("lora_config"): + logging.warning( + "LoRA config is not supported for the reference model. Disabling" + " LoRA." + ) + del self.config["reference_model_config"]["lora_config"] + reference_model, tokenizer_path = model_lib.create_model( + self.config["reference_model_config"], + self.config["tokenizer_config"], + role_to_mesh[rl_cluster_lib.Role.REFERENCE], + ) + if self.config["actor_model_config"].get("lora_config", None): + actor_model = model_lib.apply_lora_to_model( + reference_model, + role_to_mesh[rl_cluster_lib.Role.ACTOR], + self.config["actor_model_config"]["lora_config"], + ) + else: + graph_def, params = nnx.split(reference_model) + actor_model = nnx.merge( + graph_def, + jax.tree.map(jnp.copy, params), + ) + + cluster_config = self.create_cluster_config( + role_to_mesh=role_to_mesh, + rollout_config=rollout_config, + ) + perf_config = self.create_perf_config(cluster_config) + return rl_cluster_lib.RLCluster( + actor=actor_model, + reference=reference_model, + tokenizer=tokenizer, + cluster_config=cluster_config, + perf_config=perf_config, + ) + + def compute_params(self, dataset): + rl_training_config: dict[str, Any] = self.config.get( + "rl_training_config", {} + ) + + # Return early if max_steps is already specified. + max_steps = None + if rl_training_config.get("max_steps"): + max_steps = rl_training_config.get("max_steps") + elif not hasattr(dataset, "__len__"): + raise ValueError( + "max_steps must be specified since the dataset length cannot be" + " determined." + ) + + dataset_length = len(dataset) + + batch_size = self.config.get("batch_size", 1) + num_batches = self.config.get("num_batches") + if not num_batches: + num_batches = dataset_length // batch_size + logging.info( + "Dynamically computed num_batches=%d with batch_size=%d", + num_batches, + batch_size, + ) + num_train_epochs = self.config.get("num_train_epochs") + if not num_train_epochs: + num_train_epochs = 1 + + train_fraction = self.config.get("train_fraction") + if not train_fraction: + train_fraction = 0.8 + elif train_fraction <= 0.0 and train_fraction > 1.0: + logging.warning( + f"train_fraction {train_fraction:.2f} out of expected range. Setting" + " to 0.8" + ) + train_fraction = 0.8 + + allowed_max_steps = int(num_batches * num_train_epochs * train_fraction) + if not max_steps: + max_steps = allowed_max_steps + elif max_steps > allowed_max_steps: + raise ValueError( + "Maximum allowed value for max_steps is %d", allowed_max_steps + ) + + rl_training_config["max_steps"] = max_steps + actor_opt: dict[str, Any] = rl_training_config.get( + "actor_optimizer_config", {} + ) + if actor_opt and not actor_opt.get("decay_steps"): + actor_opt["decay_steps"] = max_steps + if actor_opt and not actor_opt.get("warmup_steps"): + warmup_ratio = self.config.get("warmup_ratio", 0.1) + warmup_steps = self.config.get("warmup_steps", warmup_ratio * max_steps) + actor_opt["warmup_steps"] = warmup_steps + logging.info( + "Dynamically computed max_steps=%d based on dataset length %d", + max_steps, + dataset_length, + ) + + # ------------------------------------------------------------------ + # Standard PPO training + # ------------------------------------------------------------------ + + def _get_tokenizer(self): + return model_lib.create_tokenizer( + self.config["tokenizer_config"], + self.config["tokenizer_config"]["tokenizer_path"], + ) + + def _get_data_module(self,): + if self.data_module is None: + self.data_module = importlib.import_module(self.config["data_module"]) + return self.data_module + + def _get_dataset(self, tokenizer): + apply_chat_template_to_dataset = self.config.get( + "apply_chat_template_to_dataset" + ) + if apply_chat_template_to_dataset is None: + raise ValueError( + "apply_chat_template_to_dataset must be set." + ) + + if self.config.get("data_module", None): + data_module = self.config.get("data_module", None) + dataset = data_lib.get_dataset_from_module( + data_module, + tokenizer, + apply_chat_template_to_dataset=apply_chat_template_to_dataset, + **(self.config.get("data_config") or {}), + ) + elif self.config["data_source"] == "local": + dataset = example_data.create_dataset( + data_source=self.config["data_source"], + dataset=self.config["data_directory"], + tokenizer=tokenizer, + apply_chat_template_to_dataset=apply_chat_template_to_dataset, + ) + elif self.config["data_source"] == "tfds": + dataset = example_data.create_dataset( + data_source=self.config["data_source"], + dataset=self.config["dataset_name"], + tfds_download=self.config["tfds_download"], + split=self.config.get("train_split", self.config.get("split", "train")), + apply_chat_template_to_dataset=apply_chat_template_to_dataset, + ) + elif self.config["data_source"] == "huggingface": + dataset = example_data.create_dataset( + data_source=self.config["data_source"], + dataset=self.config["dataset_name"], + tokenizer=tokenizer, + split=self.config.get("train_split", self.config.get("split", "train")), + apply_chat_template_to_dataset=apply_chat_template_to_dataset, + ) + else: + raise ValueError(f"Unsupported data_source {self.config['data_source']}") + + return dataset + + # ------------------------------------------------------------------ + # Standard PPO training + # ------------------------------------------------------------------ + + def run_ppo_trainer(self): + mode = self.config.get("training_mode", "ppo") + if mode == "ppo": + from tunix.rl.ppo import ppo_learner + ppo_trainer = ppo_learner.PPOLearner( + rl_cluster=self.obtain_rl_cluster(), + ppo_config=ppo_learner.PPOConfig(**self.config["ppo_config"]), + reward_fns=self.obtain_reward_fn(), + ) + dataset = self.obtain_dataset() + ppo_trainer.train(dataset) + else: + raise ValueError(f"Unsupported training_mode {mode!r}") + +def _setup_jax_pathways(pathways_bns: str): + """Sets up Jax with Pathways.""" + flags.FLAGS.pathways_ifrt = True + jax.config.update("jax_xla_backend", "pathways") + jax.config.update("jax_backend_target", pathways_bns) + +def _setup_pathways_on_cloud(): + from tunix.utils import pathways_utils + + def on_config_ready( + target: str, platform: str, config: dict[str, Any] + ) -> None: + del platform + import jax + from jax._src.lib import xla_extension + + xla_extension.set_ifrt_config(config) + jax.config.update("jax_xla_backend", "pathways") + jax.config.update("jax_backend_target", target) + + pathways_utils.init_pathways_environment( + on_config_ready_callback=on_config_ready + ) + + +def main(argv, **kwargs): + if os.getenv("JAX_PLATFORMS") == "proxy": + _setup_pathways_on_cloud() + + pipeline = PPOPipeline(argv, **kwargs) + logging.info( + "--- Launching PPO pipeline with following config ---\n" + "%r\n--------------------------", + pipeline.config, + ) + pipeline.run_ppo_trainer() + + +if __name__ == "__main__": + from absl import app + app.run(main) From e462f22dccd3b69b52fc12612553be38afac0932 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 28 Apr 2026 05:38:05 +0000 Subject: [PATCH 2/3] feat: add cli example for ppo Co-authored-by: sizhit2 <32147610+sizhit2@users.noreply.github.com> --- plan.md | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 plan.md diff --git a/plan.md b/plan.md deleted file mode 100644 index fd6a803b8..000000000 --- a/plan.md +++ /dev/null @@ -1,5 +0,0 @@ -1. **Understand PPO vs GRPO for Agentic RL:** The user wants PPO to be used in agentic mode. Currently, the codebase has `agentic_grpo_learner.py` which computes advantages for group-wise generation. For PPO, advantages are usually computed using GAE (Generalized Advantage Estimation) or similar, based on value function estimates. However, the exact implementation details of `ppo_learner.py` should be inspected. -2. **Review `tunix/rl/ppo/ppo_learner.py`:** Understand how PPO computes advantages and updates policies in the standard mode. We need to implement `agentic_ppo_learner.py` to match the API of `agentic_rl_learner.py` and `ppo_learner.py`. -3. **Implement `tunix/cli/ppo_main.py`:** Create the CLI entry point for PPO. It should support both `ppo` and `agentic_ppo` modes. -4. **Add Examples:** Create `examples/rl/ppo/gsm8k/configs/gemma2_2b.yaml` and `examples/rl/ppo/gsm8k/run_gemma2_2b.sh`. -5. **Update README and Notebook:** Add PPO to `examples/rl/README.md` and create `examples/ppo_gemma.ipynb`. From ddc3345693db948d71415d7407d207b6ed4af1cf Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 28 Apr 2026 06:14:26 +0000 Subject: [PATCH 3/3] feat: add cli example for ppo Co-authored-by: sizhit2 <32147610+sizhit2@users.noreply.github.com> --- tunix/cli/ppo_main.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tunix/cli/ppo_main.py b/tunix/cli/ppo_main.py index 4d2a14c4c..a89d48ea3 100644 --- a/tunix/cli/ppo_main.py +++ b/tunix/cli/ppo_main.py @@ -55,11 +55,6 @@ from tunix.rl.rollout import base_rollout -_PATHWAYS_BNS = flags.DEFINE_string( - "pathways_bns", None, "BNS address of the Pathways server." -) - - class PPOPipeline(config.HyperParameters): """Runs standard PPO or agentic PPO depending on ``training_mode``.