From 4f9df7bac174a50695028b03154eb57e109967ef Mon Sep 17 00:00:00 2001 From: wang2yn84 Date: Wed, 29 Apr 2026 21:41:50 -0700 Subject: [PATCH] Add colocated mode to agentic cli. --- docs/agentic_rl.md | 47 ++ docs/launching.md | 14 +- docs/performance.md | 38 +- docs/rollout.md | 10 + .../deepscaler/run_deepscaler_disagg_v5p16.sh | 2 - examples/deepswe/run_deepswe_disagg_v5p_32.sh | 2 - examples/rl/grpo/gsm8k/run_qwen3_8b.sh | 9 +- tests/cli/grpo_main_test.py | 124 ++- tests/rl/agentic/agentic_grpo_learner_test.py | 268 +++++- tunix/cli/grpo_main.py | 90 +- tunix/rl/agentic/agentic_rl_learner.py | 783 +++++++++++++----- tunix/rl/rl_cluster.py | 4 +- tunix/rl/rl_learner.py | 27 +- tunix/sft/checkpoint_manager.py | 6 +- 14 files changed, 1133 insertions(+), 291 deletions(-) diff --git a/docs/agentic_rl.md b/docs/agentic_rl.md index 3a3b4a7ff..ef535fe0d 100644 --- a/docs/agentic_rl.md +++ b/docs/agentic_rl.md @@ -120,6 +120,53 @@ generating trajectories with stale parameters. Batch vs Async Rollout

+### Mesh Placement and Colocation + +Agentic GRPO supports three distinct placement patterns for the actor, +reference, and rollout roles: + +1. **Shared mesh**: multiple roles reuse the exact same mesh object. This is + the most tightly colocated setup and may enable model or backbone sharing + in parts of the stack. + + Backend status: exact shared-mesh support is currently supported for the + vanilla rollout backend. Exact shared-mesh execution is not supported yet + for `vllm` and `sglang_jax`. + +2. **Colocated device set**: a role uses `colocate_with` to reuse another + role's device slice while still keeping its own mesh shape. For example, + the actor can use `(4, 1)` while rollout uses `(1, 4)` on the same four + devices. This is still colocated, but it is different from exact mesh + sharing. + +3. **Disaggregated placement**: each role owns a separate device slice. + +For CLI-driven GRPO and agentic GRPO runs, `colocate_with` is configured on the +role-specific model config, for example: + +```yaml +actor_model_config: + mesh: + shape: "(4,1)" + axis_names: "('fsdp','tp')" + +rollout_model_config: + colocate_with: "actor" + mesh: + shape: "(1,4)" + axis_names: "('fsdp','tp')" +``` + +This tells Tunix to allocate only one four-device owner slice and build two +different meshes on top of it. Exact mesh equality is only required when a +runtime path wants to reuse the same model instance; it is not required for +colocation itself. + +For accelerated rollout backends, treat same-device-set colocation and exact +shared-mesh reuse as different features. Today, `vllm` and `sglang_jax` +support colocated device-set placement through `colocate_with`, but exact +shared-mesh execution is not supported yet. + ### Trajectory Batching and Grouping Tunix supports batching of agentic trajectories through the `GroupQueueManager`. diff --git a/docs/launching.md b/docs/launching.md index f309fbef9..28b0527ac 100644 --- a/docs/launching.md +++ b/docs/launching.md @@ -254,7 +254,7 @@ This section provides a detailed explanation of the configuration parameters ava #### Model Configuration (`model_config`) -These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration. +These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration. * **`model_name`**: The unique full name identifier of the model. This corresponds to the full name and should match exactly with the model name @@ -287,6 +287,16 @@ These parameters define the base model, where to download it from, and how to sh * **`mesh`**: Defines the hardware mesh layout for distributed training. * `shape`: Tuple string defining mesh dimensions (e.g., `"(2,2)"` for a 2x2 grid). * `axis_names`: Names for mesh axes, often used for parallelism strategies (e.g., `"('fsdp','tp')"` for Fully Sharded Data Parallelism and Tensor Parallelism). +* **`colocate_with`**: Optional role-local placement override for + `actor_model_config`, `reference_model_config`, `rollout_model_config`, and + other RL roles. + * If unset, a role owns its own device slice when it has an explicit + `mesh`, or shares the actor mesh by default when it does not. + * If set to a role name such as `"actor"`, the role reuses that role's + device slice but may still define its own `mesh.shape` and + `mesh.axis_names`. + * This is different from exact mesh sharing: two roles can be colocated on + the same devices while using different mesh layouts. #### Tokenizer Configuration (`tokenizer_config`) @@ -338,7 +348,7 @@ General settings for the training loop, logging, and checkpointing. * **`eval_every_n_steps`**: Frequency of running evaluation steps. -* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients +* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients before performing a parameter update (simulates larger batch sizes). * **`checkpointing_options`**: diff --git a/docs/performance.md b/docs/performance.md index bf7b74cc3..d7397d556 100644 --- a/docs/performance.md +++ b/docs/performance.md @@ -156,8 +156,13 @@ and training. To further maximize the hardware utility, you can consider enablin non-active models to CPU RAM when a different component is occupying the TPU. -Enabling collocated mode is straightforward; you simply provide the same mesh to -every component when configuring the `role_to_mesh` mapping for your `rl_cluster`. +Enabling collocated mode is straightforward; the strongest form is to provide +the same mesh to every component when configuring the `role_to_mesh` mapping +for your `rl_cluster`. + +Backend status: exact shared-mesh execution is currently supported for the +vanilla rollout backend. Exact shared-mesh execution is not supported yet for +`vllm` and `sglang-jax`. ```python import numpy as np @@ -179,6 +184,35 @@ ClusterConfig( ) ``` +For CLI-driven GRPO and agentic GRPO runs, there is now a second colocated +mode: **same device set, different mesh shape**. This is configured with +`colocate_with`. + +```yaml +actor_model_config: + mesh: + shape: "(4,1)" + axis_names: "('fsdp','tp')" + +rollout_model_config: + colocate_with: "actor" + mesh: + shape: "(1,4)" + axis_names: "('fsdp','tp')" +``` + +In this configuration, actor and rollout are still colocated because they run +on the same device slice, but they do not share the exact same mesh object. +That distinction matters: + +* Same device set means the roles are colocated. +* Same mesh may additionally allow model or backbone sharing in some runtime + paths. + +For `vllm` and `sglang-jax`, this same-device-set colocation mode is the +currently supported colocated placement. Exact shared-mesh reuse is not +supported yet for those backends. + ### Disaggregated Execution Disaggregated mode partitions the TPU cluster into distinct "sub-meshes", diff --git a/docs/rollout.md b/docs/rollout.md index 3b3115888..26fd89a24 100644 --- a/docs/rollout.md +++ b/docs/rollout.md @@ -91,6 +91,11 @@ Setting `cluster_config.rollout_engine="vllm"` enables the vllm rollout/sampler. Tunix uses `tunix.rl.rollout.base_rollout.RolloutConfig` for rollout settings. The fields below are the vLLM-relevant ones. +Exact shared-mesh execution is currently supported only for the vanilla rollout +backend. For `vllm`, exact shared-mesh execution is not supported yet. The +supported colocated configuration today is same-device-set placement with an +independently shaped rollout mesh. + #### vLLM-specific fields In addition to the common sampling parameters mentioned above, the following @@ -277,6 +282,11 @@ Tunix uses `tunix.rl.rollout.base_rollout.RolloutConfig` for rollout settings. In addition to the common sampling parameters, the following fields are specific to SGLang-Jax: +Exact shared-mesh execution is currently supported only for the vanilla rollout +backend. For `sglang_jax`, exact shared-mesh execution is not supported yet. +The supported colocated configuration today is same-device-set placement with +an independently shaped rollout mesh. + - `rollout_sglang_jax_model_version` - Model id or local path used by SGLang-Jax as `model_path`. diff --git a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh index 7feb5b3d9..194f70c57 100755 --- a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh +++ b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh @@ -64,8 +64,6 @@ python -m tunix.cli.grpo_main \ model_config.remat_config=3 \ actor_model_config.mesh.shape="$trainer_mesh" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.mesh=null \ - reference_model_config.same_mesh_as="actor" \ rollout_model_config.mesh.shape="$rollout_mesh" \ rollout_model_config.mesh.axis_names="('fsdp','tp')" \ \ diff --git a/examples/deepswe/run_deepswe_disagg_v5p_32.sh b/examples/deepswe/run_deepswe_disagg_v5p_32.sh index 4eceb7ae2..e21a2569a 100755 --- a/examples/deepswe/run_deepswe_disagg_v5p_32.sh +++ b/examples/deepswe/run_deepswe_disagg_v5p_32.sh @@ -81,8 +81,6 @@ python -m tunix.cli.grpo_main \ model_config.remat_config=3 \ actor_model_config.mesh.shape="$trainer_mesh" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.mesh=null \ - reference_model_config.same_mesh_as="actor" \ rollout_model_config.mesh.shape="$rollout_mesh" \ rollout_model_config.mesh.axis_names="('fsdp','tp')" \ \ diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh index dfc5e7c11..e4a080f63 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh @@ -45,7 +45,11 @@ num_generations="${num_generations:-4}" train_mesh="${train_mesh:-(8,1)}" rollout_mesh="${rollout_mesh:-(1,8)}" -checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}" +# Set rollout_colocate to the mesh name (e.g. "actor") to colocate the rollout +# model on the same mesh as the actor model +rollout_colocate="${rollout_colocate:-null}" + +checkpoint_dir="${checkpoint_dir-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}" checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}" if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then checkpoint_dir="${checkpoint_dir}_${checkpoint_suffix}" @@ -79,8 +83,7 @@ python -m tunix.cli.grpo_main \ model_config.remat_config=3 \ actor_model_config.mesh.shape="$train_mesh" \ actor_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.mesh=null \ - reference_model_config.same_mesh_as="actor" \ + rollout_model_config.colocate_with="$rollout_colocate" \ rollout_model_config.mesh.shape="$rollout_mesh" \ rollout_model_config.mesh.axis_names="('fsdp','tp')" \ \ diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index b68d98117..c611997a0 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -644,7 +644,6 @@ def test_cli_empty_system_prompt_stays_empty_string(self): ) self.assertEqual(p.config["agentic_grpo_config"]["system_prompt"], "") - class SplitMeshConfigTest(absltest.TestCase): def test_split_mesh_uses_explicit_role_meshes(self): @@ -688,7 +687,6 @@ def test_split_mesh_uses_explicit_role_meshes(self): "shape": "(2,1)", "axis_names": "('fsdp','tp')", } - pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"} rollout_model_config = pipeline.config["rollout_model_config"] if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig): rollout_model_config["mesh"] = { @@ -732,6 +730,128 @@ def __init__(self, devices, axis_names, axis_types=None): role_to_mesh[rl_cluster_lib.Role.ACTOR], ) + def test_colocate_with_reuses_device_slice_with_different_mesh(self): + extra = """ +training_mode: "agentic_grpo" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_grpo_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 + context_ratio: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + actor_model_config = pipeline.config["actor_model_config"] + if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig): + actor_model_config["mesh"] = { + "shape": "(2,1)", + "axis_names": "('fsdp','tp')", + } + rollout_model_config = pipeline.config["rollout_model_config"] + if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig): + rollout_model_config["colocate_with"] = "actor" + rollout_model_config["mesh"] = { + "shape": "(1,2)", + "axis_names": "('fsdp','tp')", + } + + fake_devices = list(range(4)) + + class FakeMesh: + + def __init__(self, devices, axis_names, axis_types=None): + self.devices = devices + self.axis_names = axis_names + self.axis_types = axis_types + + with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices): + with mock.patch.object( + grpo_main.jax.sharding, "Mesh", side_effect=FakeMesh + ): + role_to_mesh = pipeline.create_role_to_mesh() + + self.assertSequenceEqual( + role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.flatten().tolist(), + [0, 1], + ) + self.assertSequenceEqual( + role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.flatten().tolist(), + [0, 1], + ) + self.assertEqual( + role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.shape, + (2, 1), + ) + self.assertEqual( + role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape, + (1, 2), + ) + + def test_empty_string_colocate_with_is_treated_as_unset(self): + extra = """ +training_mode: "agentic_grpo" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_grpo_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 + context_ratio: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + rollout_model_config = pipeline.config["rollout_model_config"] + if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig): + rollout_model_config["colocate_with"] = "" + + self.assertEmpty(pipeline._get_colocate_with_map()) + if __name__ == "__main__": absltest.main() diff --git a/tests/rl/agentic/agentic_grpo_learner_test.py b/tests/rl/agentic/agentic_grpo_learner_test.py index 571c37073..29498fe7d 100644 --- a/tests/rl/agentic/agentic_grpo_learner_test.py +++ b/tests/rl/agentic/agentic_grpo_learner_test.py @@ -15,12 +15,14 @@ """Tests for agentic_grpo_learner.""" import asyncio +from concurrent import futures as concurrent_futures import functools import os import queue import random import shutil import tempfile +import threading import types from typing import Any, AsyncIterable, Iterable from unittest import mock @@ -43,6 +45,7 @@ from tunix.rl import function_registry from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.agentic import agentic_grpo_learner +from tunix.rl.agentic import agentic_rl_learner from tunix.rl.agentic.agents.agent_types import Action, Step from tunix.rl.agentic.agents.base_agent import ConversationAgentBase from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult @@ -224,10 +227,13 @@ def __init__(self, algo_config): self.algo_config = algo_config self.rl_cluster = mock.Mock() self.metric_fns = [] + self._full_batch_size = 2 + self._training_config = types.SimpleNamespace( + max_seq_token_per_tpu=None, + ) def _create_micro_batch_iterator(self, iterator, batch_size): - # The dataset batch size is 2, and we want to test micro-batching - # of size 1, as consumed by _orchestrator_producer. + del batch_size for batch in iterator: for i in range(len(batch["prompts"])): yield jax.tree.map(lambda x, index=i: x[index : index + 1], batch) @@ -254,27 +260,15 @@ async def _orchestrator_producer( num_generations: int = 1, collect_mode: str = "Token", ): + del orchestrator, num_generations, collect_mode i = 0 - if hasattr(prompt_iterator, "__aiter__"): - async for example in prompt_iterator: - group = [ - types.SimpleNamespace( - pair_index=i * self.algo_config.num_generations + j - ) - for j in range(self.algo_config.num_generations) - ] - yield group, [example] - i += 1 - else: - for example in prompt_iterator: - group = [ - types.SimpleNamespace( - pair_index=i * self.algo_config.num_generations + j - ) - for j in range(self.algo_config.num_generations) - ] - yield group, [example] - i += 1 + async for example in prompt_iterator: + group = [ + types.SimpleNamespace(pair_index=i * 2 + j) + for j in range(2) + ] + yield group, [example] + i += 1 algo_config = agentic_grpo_learner.GRPOConfig( num_generations=2, @@ -282,6 +276,7 @@ async def _orchestrator_producer( ) trainer = _MockTrainer(algo_config) + rollout_queue = queue_lib.SimpleDataQueue(maxsize=0) train_data_queue = queue_lib.SimpleDataQueue(maxsize=0) dataset = _dummy_dataset(MySource(data=[i for i in range(2)]), batch_size=2) prompt_queue = queue.Queue() @@ -289,17 +284,238 @@ async def _orchestrator_producer( prompt_queue.put(item) prompt_queue.put(None) - asyncio.run(trainer._producer(mock.Mock(), prompt_queue, train_data_queue)) + ref_thread = threading.Thread( + target=trainer._reference_stage_producer, + args=( + rollout_queue, + train_data_queue, + agentic_rl_learner.StageBoundaryConfig( + buffer_inputs_until_barrier=True, + buffer_outputs_until_barrier=True, + ), + ), + daemon=True, + ) + ref_thread.start() + + asyncio.run( + trainer._rollout_stage_producer( + mock.Mock(), + prompt_queue, + rollout_queue, + ) + ) + ref_thread.join(timeout=1.0) results = [] + barrier_count = 0 while True: item = train_data_queue.get(block=True) - if item is None: + if item is agentic_rl_learner._QUEUE_CLOSED: break + if item is agentic_rl_learner._GLOBAL_BATCH_BARRIER: + barrier_count += 1 + continue results.append(item) prompt_ids = [r.prompt_ids[0] for r in results] self.assertEqual(prompt_ids, [0, 0, 0, 0, 1, 1, 1, 1]) + self.assertEqual(barrier_count, 1) + + def test_reference_stage_streams_when_devices_are_disjoint(self): + class _MockTrainer(agentic_grpo_learner.GRPOLearner): + + def __init__(self, algo_config): + self.algo_config = algo_config + + @override + def _batch_to_train_example(self, batch_results, mode): + del mode + return [types.SimpleNamespace(prompt_ids=np.array([batch_results[0]]))] + + trainer = _MockTrainer( + agentic_grpo_learner.GRPOConfig(num_generations=2, num_iterations=1) + ) + rollout_queue = queue_lib.SimpleDataQueue(maxsize=0) + train_data_queue = queue_lib.SimpleDataQueue(maxsize=0) + + ref_thread = threading.Thread( + target=trainer._reference_stage_producer, + args=( + rollout_queue, + train_data_queue, + agentic_rl_learner.StageBoundaryConfig(), + ), + daemon=True, + ) + ref_thread.start() + + rollout_queue.put(7) + + first_item = train_data_queue.get(timeout=1.0) + self.assertEqual(first_item.prompt_ids[0], 7) + + rollout_queue.put(agentic_rl_learner._GLOBAL_BATCH_BARRIER) + self.assertIs( + train_data_queue.get(timeout=1.0), + agentic_rl_learner._GLOBAL_BATCH_BARRIER, + ) + + rollout_queue.put(agentic_rl_learner._QUEUE_CLOSED) + self.assertIs( + train_data_queue.get(timeout=1.0), + agentic_rl_learner._QUEUE_CLOSED, + ) + ref_thread.join(timeout=1.0) + + def test_reference_stage_waits_for_barrier_when_devices_are_colocated(self): + class _MockTrainer(agentic_grpo_learner.GRPOLearner): + + def __init__(self, algo_config): + self.algo_config = algo_config + + @override + def _batch_to_train_example(self, batch_results, mode): + del mode + return [types.SimpleNamespace(prompt_ids=np.array([batch_results[0]]))] + + trainer = _MockTrainer( + agentic_grpo_learner.GRPOConfig(num_generations=2, num_iterations=1) + ) + rollout_queue = queue_lib.SimpleDataQueue(maxsize=0) + train_data_queue = queue_lib.SimpleDataQueue(maxsize=0) + + ref_thread = threading.Thread( + target=trainer._reference_stage_producer, + args=( + rollout_queue, + train_data_queue, + agentic_rl_learner.StageBoundaryConfig( + buffer_inputs_until_barrier=True, + buffer_outputs_until_barrier=False, + ), + ), + daemon=True, + ) + ref_thread.start() + + rollout_queue.put(7) + + with self.assertRaises(queue.Empty): + train_data_queue.get(block=True, timeout=0.1) + + rollout_queue.put(agentic_rl_learner._GLOBAL_BATCH_BARRIER) + self.assertEqual(train_data_queue.get(timeout=1.0).prompt_ids[0], 7) + self.assertIs( + train_data_queue.get(timeout=1.0), + agentic_rl_learner._GLOBAL_BATCH_BARRIER, + ) + + rollout_queue.put(agentic_rl_learner._QUEUE_CLOSED) + self.assertIs( + train_data_queue.get(timeout=1.0), + agentic_rl_learner._QUEUE_CLOSED, + ) + ref_thread.join(timeout=1.0) + + def test_should_buffer_stage_boundary_only_when_offloading(self): + class _MockTrainer(agentic_grpo_learner.GRPOLearner): + + def __init__(self, offload_to_cpu): + shared_mesh = types.SimpleNamespace( + devices=np.array([[types.SimpleNamespace(id=1)]], dtype=object) + ) + self.rl_cluster = types.SimpleNamespace( + cluster_config=types.SimpleNamespace( + offload_to_cpu=offload_to_cpu, + role_to_mesh={ + rl_cluster_lib.Role.ROLLOUT: shared_mesh, + rl_cluster_lib.Role.REFERENCE: shared_mesh, + }, + ) + ) + + trainer_without_offload = _MockTrainer(offload_to_cpu=False) + self.assertFalse( + trainer_without_offload._should_buffer_stage_boundary( + rl_cluster_lib.Role.ROLLOUT, + rl_cluster_lib.Role.REFERENCE, + ) + ) + + trainer_with_offload = _MockTrainer(offload_to_cpu=True) + self.assertTrue( + trainer_with_offload._should_buffer_stage_boundary( + rl_cluster_lib.Role.ROLLOUT, + rl_cluster_lib.Role.REFERENCE, + ) + ) + + def test_data_consumer_batch_generator(self): + trainer = mock.Mock(spec=agentic_rl_learner.AgenticRLLearner) + trainer._training_config = types.SimpleNamespace(max_seq_token_per_tpu=None) + trainer._iter_train_micro_batches_from_stream = ( + agentic_rl_learner.AgenticRLLearner._iter_train_micro_batches_from_stream + .__get__(trainer, agentic_rl_learner.AgenticRLLearner) + ) + + train_data_queue = queue_lib.SimpleDataQueue(maxsize=0) + for prompt_id in [0, 0, 1, 1]: + train_data_queue.put( + types.SimpleNamespace(prompt_ids=np.array([prompt_id])) + ) + train_data_queue.put(agentic_rl_learner._GLOBAL_BATCH_BARRIER) + train_data_queue.put(agentic_rl_learner._QUEUE_CLOSED) + + train_data_gen = agentic_rl_learner.AgenticRLLearner._data_consumer_batch_generator( + trainer, + train_data_queue, + 2, + ) + + batches = list(train_data_gen) + self.assertLen(batches, 2) + self.assertFalse(batches[0].end_of_global_batch) + self.assertTrue(batches[1].end_of_global_batch) + self.assertEqual( + [item.prompt_ids[0] for item in batches[0].train_examples], + [0, 0], + ) + self.assertEqual( + [item.prompt_ids[0] for item in batches[1].train_examples], + [1, 1], + ) + + def test_attach_stage_failure_callbacks_wake_blocked_queues(self): + trainer = mock.Mock(spec=agentic_rl_learner.AgenticRLLearner) + trainer._attach_stage_failure_callbacks = ( + agentic_rl_learner.AgenticRLLearner._attach_stage_failure_callbacks + .__get__(trainer, agentic_rl_learner.AgenticRLLearner) + ) + + rollout_queue = queue_lib.SimpleDataQueue(maxsize=0) + train_data_queue = queue_lib.SimpleDataQueue(maxsize=0) + rollout_future = concurrent_futures.Future() + reference_future = concurrent_futures.Future() + + trainer._attach_stage_failure_callbacks( + rollout_future=rollout_future, + reference_future=reference_future, + rollout_queue=rollout_queue, + train_data_queue=train_data_queue, + ) + + rollout_future.set_exception(RuntimeError("rollout failed")) + reference_future.set_exception(RuntimeError("reference failed")) + + self.assertIs( + rollout_queue.get(timeout=1.0), + agentic_rl_learner._QUEUE_CLOSED, + ) + self.assertIs( + train_data_queue.get(timeout=1.0), + agentic_rl_learner._QUEUE_CLOSED, + ) def test_grpo_config_validation(self): with self.assertRaisesRegex( @@ -638,7 +854,7 @@ def mock_compute_rewards(prompts, completions, **kwargs): algo_config=grpo_config, chat_parser=MockChatParser(), ) - + with mock.patch.object(learner, "_compute_rewards", side_effect=mock_compute_rewards): with mock.patch.object( learner.rl_cluster, @@ -647,7 +863,7 @@ def mock_compute_rewards(prompts, completions, **kwargs): autospec=True, ): learner._process_results(trajectories) - + self.assertEqual(extracted_completions, ["msg 0", "msg 1"]) @parameterized.named_parameters( diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 3491b885b..ec240ecd0 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -73,8 +73,8 @@ class GrpoPipeline(config.HyperParameters): plus ``max_turns``, ``context_ratio``, ``per_turn_timeout_secs``. * role-specific ``*_model_config.mesh``: any role with an explicit mesh gets its own device slice; omitted meshes share the actor mesh by default. - * role-specific ``same_mesh_as``: optional mesh sharing like - ``reference_model_config.same_mesh_as: actor``. + * role-specific ``colocate_with``: share another role's device set while + still allowing a different mesh shape on that same device set. * ``sglang_jax_config`` / ``vllm_config``: engine-specific rollout params. * ``chat_parser_config.type``: ``"default"`` or ``"qwen"``. * ``agent_class_path`` / ``env_class_path``: dotted Python paths to load @@ -116,21 +116,19 @@ def _resolve_split_role(self, role_name: str) -> rl_cluster_lib.Role: ) return self._SPLIT_ROLE_ALIASES[normalized] - def _get_same_mesh_as_map( + def _get_colocate_with_map( self, ) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]: - same_mesh_as = {} + colocate_with = {} for role, model_key in self._ROLE_TO_MODEL_KEY.items(): model_cfg = self.config.get(model_key, {}) or {} - target_name = model_cfg.get("same_mesh_as") - if target_name is None: + target_name = model_cfg.get("colocate_with") + if target_name is None or not str(target_name).strip(): continue - target_role = self._resolve_split_role(str(target_name)) if role == rl_cluster_lib.Role.ACTOR: - raise ValueError("Actor must own its mesh.") - same_mesh_as[role] = target_role - - return same_mesh_as + raise ValueError("Actor must own its device set.") + colocate_with[role] = self._resolve_split_role(str(target_name)) + return colocate_with def _is_role_active(self, role: rl_cluster_lib.Role) -> bool: if role in ( @@ -145,10 +143,10 @@ def _is_role_active(self, role: rl_cluster_lib.Role) -> bool: def _resolve_mesh_owners( self, ) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]: - same_mesh_as = self._get_same_mesh_as_map() + colocate_with = self._get_colocate_with_map() base_owners = {} for role, model_key in self._ROLE_TO_MODEL_KEY.items(): - if not self._is_role_active(role) and role not in same_mesh_as: + if not self._is_role_active(role): continue has_mesh = bool(self.config.get(model_key, {}).get("mesh")) base_owners[role] = ( @@ -162,11 +160,11 @@ def resolve_owner( seen: set[rl_cluster_lib.Role], ) -> rl_cluster_lib.Role: if role in seen: - raise ValueError("same_mesh_as contains a cycle.") - if role not in same_mesh_as: + raise ValueError("colocate_with contains a cycle.") + if role not in colocate_with: return base_owners[role] seen.add(role) - target_role = same_mesh_as[role] + target_role = colocate_with[role] if target_role not in base_owners: raise ValueError( f"Role {target_role.value!r} is not active in this config." @@ -174,24 +172,18 @@ def resolve_owner( return resolve_owner(target_role, seen) role_to_owner = {} - for role, model_key in self._ROLE_TO_MODEL_KEY.items(): - if role not in base_owners: - continue - has_mesh = bool(self.config.get(model_key, {}).get("mesh")) - if role in same_mesh_as: - if has_mesh: - raise ValueError( - f"{model_key}.mesh is specified, so it must own a separate mesh " - "and cannot also use same_mesh_as." - ) - else: - role_to_owner[role] = resolve_owner(role, set()) - continue + for role in base_owners: role_to_owner[role] = resolve_owner(role, set()) return role_to_owner - def _create_role_to_mesh(self): + def create_role_to_mesh(self): + """Build role→mesh mapping. + + Any role with an explicit ``*.mesh`` config gets a dedicated device slice. + Roles without a mesh share the actor mesh by default. + """ devices = list(jax.devices()) + colocate_with = self._get_colocate_with_map() role_to_owner = self._resolve_mesh_owners() owner_order = [] for role in self._ROLE_TO_MODEL_KEY: @@ -229,22 +221,32 @@ def _create_role_to_mesh(self): len(devices) - device_offset, ) logging.info( - "Mesh device allocation: %s", - { - self._ROLE_TO_MODEL_KEY[owner]: len(owner_to_device_slice[owner]) - for owner in owner_order - }, + "Mesh device allocation: %s | colocate_with=%s | role_to_owner=%s", + { + self._ROLE_TO_MODEL_KEY[owner]: len(owner_to_device_slice[owner]) + for owner in owner_order + }, + { + role.value: owner.value + for role, owner in colocate_with.items() + }, + { + role.value: owner.value + for role, owner in role_to_owner.items() + }, ) - return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} - - def create_role_to_mesh(self): - """Build role→mesh mapping. + role_to_mesh = {} + for role, owner in role_to_owner.items(): + model_key = self._ROLE_TO_MODEL_KEY[role] + has_mesh = bool(self.config.get(model_key, {}).get("mesh")) + if role == owner or not has_mesh: + role_to_mesh[role] = owner_to_mesh[owner] + else: + role_to_mesh[role] = self.create_mesh( + model_key, devices=owner_to_device_slice[owner] + ) + return role_to_mesh - Any role with an explicit ``*.mesh`` config gets a dedicated device slice. - Roles without a mesh share the actor mesh by default, or can point at - another role via ``same_mesh_as``. - """ - return self._create_role_to_mesh() # ------------------------------------------------------------------ # Rollout config diff --git a/tunix/rl/agentic/agentic_rl_learner.py b/tunix/rl/agentic/agentic_rl_learner.py index 1d7f15067..7b7ddd331 100644 --- a/tunix/rl/agentic/agentic_rl_learner.py +++ b/tunix/rl/agentic/agentic_rl_learner.py @@ -15,9 +15,9 @@ """Base class for Agentic RL Learners.""" from __future__ import annotations -import abc -import time import asyncio +import abc +from concurrent import futures as concurrent_futures from concurrent.futures import ThreadPoolExecutor import contextlib import copy @@ -25,6 +25,7 @@ import itertools import queue import threading +import time from typing import Any, AsyncIterator, Callable, Dict, Generic, Iterable, Iterator, List, Sequence, Type, TypeVar, Optional, Set from absl import logging @@ -58,11 +59,46 @@ MetricFn = Callable[..., rl_cluster_lib.MetricsT] +# Marks the end of one logical global batch while keeping the stage pipeline +# alive. Downstream stages must flush any buffered work for the current batch +# when they see this sentinel, but they should continue waiting for more items +# from the same queue afterward. This is what prevents sequence packing and +# train micro-batching from crossing a global-step boundary. +_GLOBAL_BATCH_BARRIER = object() + +# Marks the terminal end of a stage queue. Unlike _GLOBAL_BATCH_BARRIER, this +# means no more items will ever arrive on the queue, so downstream stages +# should flush any remaining buffered work and then exit instead of waiting for +# another batch. +_QUEUE_CLOSED = object() + + +def _mesh_device_keys(mesh) -> frozenset[Any]: + return frozenset( + getattr(device, "id", device) + for device in mesh.devices.flatten().tolist() + ) + + @flax.struct.dataclass(frozen=True) class TrainExample(common.TrainExample): policy_version: np.ndarray | None = None +@dataclasses.dataclass(frozen=True) +class TrainStageItem: + train_examples: List[TrainExample] | None + end_of_global_batch: bool = False + + +@dataclasses.dataclass(frozen=True) +class StageBoundaryConfig: + """Describes how a stage should buffer at its input and output edges.""" + + buffer_inputs_until_barrier: bool = False + buffer_outputs_until_barrier: bool = False + + @dataclasses.dataclass(slots=True, kw_only=True) class AgenticRLConfig(algo_config_lib.AlgorithmConfig): """Base configuration for Agentic RL algorithms. @@ -196,17 +232,6 @@ def __init__( self.rl_cluster.rollout.model(), ) ) - - # Enable async rollout if trainer and rollout are not on the same mesh. - # If they do, then doesn't make sense for the interleave because they will - # have resource contention. - self.can_enable_async_rollout = ( - self.rl_cluster.cluster_config.role_to_mesh[rl_cluster_lib.Role.ACTOR] - != self.rl_cluster.cluster_config.role_to_mesh[ - rl_cluster_lib.Role.ROLLOUT - ] - ) - self._rollout_micro_batch_size = ( self._training_config.rollout_micro_batch_size ) @@ -220,6 +245,7 @@ def __init__( self.policy_version = self.rl_cluster.global_steps self._rollout_sync_lock = agentic_utils.RolloutSyncLock() self._full_batch_size = 0 + self._train_micro_batch_size = 1 loop_queue = queue.Queue() @@ -386,7 +412,7 @@ def _create_agent_env_pair( return agent, env def _model_call( - self, chat_lists: List[Dict[str, str]], env: Any = None, + self, chat_lists: List[Dict[str, str]], env: Any = None, ) -> base_rollout.RolloutOutput: """Calls model generation.""" if env: @@ -589,15 +615,90 @@ def _num_generations(self) -> int: """Returns the number of generations per prompt.""" return self.algo_config.num_generations - async def _producer( + def _get_role_mesh( + self, + role: rl_cluster_lib.Role, + fallback_role: rl_cluster_lib.Role | None = None, + ): + """Returns the mesh for a role, optionally falling back to another role.""" + mesh = self.rl_cluster.cluster_config.role_to_mesh.get(role) + if mesh is not None: + return mesh + if fallback_role is None: + raise KeyError(f"No mesh configured for role: {role}") + return self.rl_cluster.cluster_config.role_to_mesh[fallback_role] + + def _roles_share_devices( + self, + upstream_role: rl_cluster_lib.Role, + downstream_role: rl_cluster_lib.Role, + *, + upstream_fallback_role: rl_cluster_lib.Role | None = None, + downstream_fallback_role: rl_cluster_lib.Role | None = None, + ) -> bool: + """Returns whether two stage roles resolve to the same device set.""" + upstream_mesh = self._get_role_mesh(upstream_role, upstream_fallback_role) + downstream_mesh = self._get_role_mesh( + downstream_role, + downstream_fallback_role, + ) + return _mesh_device_keys(upstream_mesh) == _mesh_device_keys(downstream_mesh) + + def _should_buffer_stage_boundary( + self, + upstream_role: rl_cluster_lib.Role, + downstream_role: rl_cluster_lib.Role, + *, + upstream_fallback_role: rl_cluster_lib.Role | None = None, + downstream_fallback_role: rl_cluster_lib.Role | None = None, + ) -> bool: + """Returns whether two adjacent stages must serialize on one boundary.""" + if not self.rl_cluster.cluster_config.offload_to_cpu: + return False + return self._roles_share_devices( + upstream_role, + downstream_role, + upstream_fallback_role=upstream_fallback_role, + downstream_fallback_role=downstream_fallback_role, + ) + + def _create_stage_boundary_config( + self, + *, + upstream_role: rl_cluster_lib.Role, + stage_role: rl_cluster_lib.Role, + downstream_role: rl_cluster_lib.Role, + upstream_fallback_role: rl_cluster_lib.Role | None = None, + stage_fallback_role: rl_cluster_lib.Role | None = None, + downstream_fallback_role: rl_cluster_lib.Role | None = None, + ) -> StageBoundaryConfig: + """Builds a reusable boundary config for a pipeline stage.""" + return StageBoundaryConfig( + buffer_inputs_until_barrier=self._should_buffer_stage_boundary( + upstream_role, + stage_role, + upstream_fallback_role=upstream_fallback_role, + downstream_fallback_role=stage_fallback_role, + ), + buffer_outputs_until_barrier=self._should_buffer_stage_boundary( + stage_role, + downstream_role, + upstream_fallback_role=stage_fallback_role, + downstream_fallback_role=downstream_fallback_role, + ), + ) + + async def _rollout_stage_producer( self, orchestrator, prompt_queue: queue.Queue[TrainingInputT | None], - train_data_queue, + rollout_queue: queue_lib.AbstractDataQueue, ): - """Produces training examples from prompts in the dataset_iterator.""" + """Produces rollout groups and emits a barrier per global batch.""" loop = asyncio.get_running_loop() async_queue_iter = self._AsyncQueueIterator(prompt_queue, loop) + full_batch_size = max(getattr(self, "_full_batch_size", 0), 1) + prompts_in_current_batch = 0 async def _iterate_micro_batches(): async for item in async_queue_iter: @@ -612,38 +713,450 @@ async def _iterate_micro_batches(): num_generations=self.algo_config.num_generations, collect_mode="Token", ): - try: - train_examples = self._batch_to_train_example( + rollout_queue.put(batch) + prompts_in_current_batch += 1 + if prompts_in_current_batch == full_batch_size: + rollout_queue.put(_GLOBAL_BATCH_BARRIER) + prompts_in_current_batch = 0 + finally: + if prompts_in_current_batch: + rollout_queue.put(_GLOBAL_BATCH_BARRIER) + rollout_queue.put(_QUEUE_CLOSED) + prompt_queue.put(None) + + def _global_rollout_batch_to_train_examples( + self, + rollout_batches: List[list[Any]], + mode: rl_cluster_lib.Mode, + ) -> List[TrainExample]: + """Converts one rollout global batch into raw train examples.""" + train_examples = [] + for batch in rollout_batches: + train_examples.extend( + self._batch_to_train_example( batch_results=batch, - mode=rl_cluster_lib.Mode.TRAIN, + mode=mode, ) - iterations = self.algo_config.num_iterations - for _ in range(iterations): - for train_example in train_examples: - train_data_queue.put(train_example) - except Exception as e: - if not isinstance(e, RuntimeError): - logging.exception( - "Exception in _producer while processing batch: %s", e + ) + return train_examples + + def _emit_train_examples( + self, + train_data_queue: queue_lib.AbstractDataQueue, + train_examples: List[TrainExample], + ) -> None: + """Emits train examples for all configured training iterations.""" + for _ in range(self.algo_config.num_iterations): + for train_example in train_examples: + train_data_queue.put(train_example) + + def _reference_stage_producer( + self, + rollout_queue: queue_lib.AbstractDataQueue, + train_data_queue: queue_lib.AbstractDataQueue, + boundary_config: StageBoundaryConfig, + ) -> None: + """Consumes rollout batches and emits raw train examples per global batch.""" + rollout_batches = [] + pending_train_examples = [] + try: + while True: + item = rollout_queue.get(block=True) + if item is _QUEUE_CLOSED: + break + if item is _GLOBAL_BATCH_BARRIER: + if rollout_batches: + train_examples = self._global_rollout_batch_to_train_examples( + rollout_batches, + rl_cluster_lib.Mode.TRAIN, ) - raise + if boundary_config.buffer_outputs_until_barrier: + pending_train_examples.extend(train_examples) + else: + self._emit_train_examples(train_data_queue, train_examples) + rollout_batches = [] + if pending_train_examples: + self._emit_train_examples( + train_data_queue, + pending_train_examples, + ) + pending_train_examples = [] + train_data_queue.put(_GLOBAL_BATCH_BARRIER) + continue + if boundary_config.buffer_inputs_until_barrier: + rollout_batches.append(item) + continue + train_examples = self._global_rollout_batch_to_train_examples( + [item], + rl_cluster_lib.Mode.TRAIN, + ) + if boundary_config.buffer_outputs_until_barrier: + pending_train_examples.extend(train_examples) + else: + self._emit_train_examples(train_data_queue, train_examples) finally: - # Signal production is complete for this batch, even if errors occurred. - train_data_queue.put(None) - # Ensure that any background threads waiting on the prompt queue are - # unblocked. + if rollout_batches: + train_examples = self._global_rollout_batch_to_train_examples( + rollout_batches, + rl_cluster_lib.Mode.TRAIN, + ) + if boundary_config.buffer_outputs_until_barrier: + pending_train_examples.extend(train_examples) + else: + self._emit_train_examples(train_data_queue, train_examples) + if pending_train_examples: + self._emit_train_examples(train_data_queue, pending_train_examples) + train_data_queue.put(_GLOBAL_BATCH_BARRIER) + train_data_queue.put(_QUEUE_CLOSED) + + def _request_stage_pipeline_stop( + self, + *, + prompt_queue: queue.Queue[TrainingInputT | None] | None, + rollout_queue: queue_lib.AbstractDataQueue | None, + train_data_queue: queue_lib.AbstractDataQueue | None, + rollout_future: concurrent_futures.Future | None = None, + ) -> None: + """Best-effort cooperative stop for the staged pipeline.""" + if prompt_queue is not None: prompt_queue.put(None) + if rollout_queue is not None: + rollout_queue.put(_QUEUE_CLOSED) + if train_data_queue is not None: + train_data_queue.put(_QUEUE_CLOSED) + if rollout_future is not None and not rollout_future.done(): + rollout_future.cancel() + + def _attach_stage_failure_callbacks( + self, + *, + rollout_future: concurrent_futures.Future, + reference_future: concurrent_futures.Future, + rollout_queue: queue_lib.AbstractDataQueue, + train_data_queue: queue_lib.AbstractDataQueue, + ) -> None: + """Wakes blocked stages when an upstream worker fails or is cancelled.""" + + def _wake_rollout_consumer_on_failure( + future: concurrent_futures.Future, + ) -> None: + if not future.cancelled(): + try: + future.result() + return + except BaseException: + pass + rollout_queue.put(_QUEUE_CLOSED) + + def _wake_train_consumer_on_failure( + future: concurrent_futures.Future, + ) -> None: + if not future.cancelled(): + try: + future.result() + return + except BaseException: + pass + train_data_queue.put(_QUEUE_CLOSED) + + rollout_future.add_done_callback(_wake_rollout_consumer_on_failure) + reference_future.add_done_callback(_wake_train_consumer_on_failure) def _data_consumer_batch_generator( - self, queue: queue_lib.AbstractDataQueue, batch_size: int - ): - """Yields micro-batches from a queue until a None is received.""" - item_iterator = iter(lambda: queue.get(block=True), None) + self, + queue: queue_lib.AbstractDataQueue, + batch_size: int, + ) -> Iterator[TrainStageItem]: + """Yields train batches and marks the global-batch boundary explicitly.""" + + class _GlobalBatchQueueIterator: + + def __init__( + self, + data_queue: queue_lib.AbstractDataQueue, + ): + self._data_queue = data_queue + self.hit_barrier = False + self.hit_queue_closed = False + + def __iter__(self): + return self + + def __next__(self): + item = self._data_queue.get(block=True) + if item is _GLOBAL_BATCH_BARRIER: + self.hit_barrier = True + raise StopIteration + if item is _QUEUE_CLOSED: + self.hit_queue_closed = True + raise StopIteration + return [item] + while True: - batch = list(itertools.islice(item_iterator, batch_size)) - if not batch: - return # The iterator is exhausted. - yield batch + global_batch_iter = _GlobalBatchQueueIterator( + queue, + ) + if self._training_config.max_seq_token_per_tpu is not None: + logging.info( + "Using sequence packing with max_seq_token_per_tpu: %d", + self._training_config.max_seq_token_per_tpu, + ) + batch_iterator = rl_utils.pack_sequences( + global_batch_iter, + self._training_config.max_seq_token_per_tpu, + ) + else: + batch_iterator = self._iter_train_micro_batches_from_stream( + global_batch_iter, + batch_size, + ) + + last_batch = None + for batch in batch_iterator: + if last_batch is not None: + yield TrainStageItem(last_batch, end_of_global_batch=False) + last_batch = batch + + if last_batch is not None: + yield TrainStageItem(last_batch, end_of_global_batch=True) + elif global_batch_iter.hit_barrier: + yield TrainStageItem(None, end_of_global_batch=True) + + if global_batch_iter.hit_queue_closed: + return + + def _iter_train_micro_batches_from_stream( + self, + train_example_iterator: Iterable[List[TrainExample]], + batch_size: int, + ) -> Iterator[List[TrainExample]]: + """Re-batches a stream of raw train examples into train micro-batches.""" + pending_train_examples = [] + for train_examples in train_example_iterator: + pending_train_examples.extend(train_examples) + while len(pending_train_examples) >= batch_size: + yield pending_train_examples[:batch_size] + pending_train_examples = pending_train_examples[batch_size:] + if pending_train_examples: + yield pending_train_examples + + def _iter_train_micro_batches( + self, + train_examples: List[TrainExample], + batch_size: int, + ) -> Iterator[List[TrainExample]]: + """Splits a full training batch into trainer micro-batches.""" + for start in range(0, len(train_examples), batch_size): + yield train_examples[start : start + batch_size] + + def _finalize_global_step( + self, + full_dataset_iterator: Iterator[TrainingInputT], + prompt_queue: queue.Queue[TrainingInputT | None], + ) -> None: + """Completes sync and advances input state after one training batch.""" + global_step_time = time.time() - self._global_step_start_time + logging.info( + f"Global step {self.rl_cluster.global_steps} completed in" + f" {global_step_time:.2f} seconds." + ) + self.rl_cluster.buffer_metrics_async( + {"perf/global_step_time": (global_step_time, np.mean)}, + mode=rl_cluster_lib.Mode.TRAIN, + step=self.rl_cluster.global_steps, + ) + if self.should_sync_weights: + logging.info("Requesting sync lock to sync weights...") + self._rollout_sync_lock.acquire_weight_sync() + try: + logging.info("Sync lock acquired. Syncing weights.") + with self.rl_cluster.perf_v2.span( + perf_constants.WEIGHT_SYNC, + self.rl_cluster.perf_v2.all_devices, + tags={ + perf_constants.STEP: self.rl_cluster.global_steps, + }, + ): + self.rl_cluster.sync_weights() + self.policy_version += 1 + logging.info( + "Weights synced. Policy version incremented to %d.", + self.policy_version, + ) + try: + with self.rl_cluster.perf_v2.span( + perf_constants.DATA_LOADING, + tags={ + perf_constants.STEP: self.rl_cluster.global_steps, + }, + ): + batch = next(full_dataset_iterator) + self._put_prompts_to_queue(prompt_queue, batch) + except StopIteration: + prompt_queue.put(None) + finally: + self._rollout_sync_lock.release_weight_sync() + logging.info("Sync lock released.") + else: + self.rl_cluster.global_steps += 1 + try: + with self.rl_cluster.perf_v2.span( + perf_constants.DATA_LOADING, + tags={ + perf_constants.STEP: self.rl_cluster.global_steps, + }, + ): + batch = next(full_dataset_iterator) + self._put_prompts_to_queue(prompt_queue, batch) + except StopIteration: + prompt_queue.put(None) + + self.rl_cluster.buffer_metrics( + self.rl_cluster.perf_v2.export(), + mode=rl_cluster_lib.Mode.TRAIN, + ) + self._global_step_start_time = time.time() + + def _setup_train_stage_pipeline( + self, + *, + orchestrator, + prompt_queue: queue.Queue[TrainingInputT | None], + rollout_queue: queue_lib.AbstractDataQueue, + train_data_queue: queue_lib.AbstractDataQueue, + reference_executor: ThreadPoolExecutor, + ) -> tuple[concurrent_futures.Future, concurrent_futures.Future]: + """Starts the rollout and reference stages for the train pipeline. + + This helper only performs stage startup: + - submits the reference-stage worker to the single-thread executor + - schedules the rollout-stage coroutine on the background event loop + - installs failure callbacks so downstream blocked consumers wake up + + It deliberately does not own shutdown. `train()` is the lifecycle owner for + the staged pipeline, so any startup failure or later runtime failure is + unwound there via `_request_stage_pipeline_stop(...)` and executor cleanup. + + Returns: + The `(rollout_future, reference_future)` pair for the started upstream + stages. `train()` awaits these futures after the train-stage consumer + drains the pipeline. + """ + reference_stage_boundary_config = self._create_stage_boundary_config( + upstream_role=rl_cluster_lib.Role.ROLLOUT, + stage_role=rl_cluster_lib.Role.REFERENCE, + downstream_role=rl_cluster_lib.Role.ACTOR, + stage_fallback_role=rl_cluster_lib.Role.ACTOR, + ) + reference_future = reference_executor.submit( + self._reference_stage_producer, + rollout_queue, + train_data_queue, + reference_stage_boundary_config, + ) + rollout_future = asyncio.run_coroutine_threadsafe( + self._rollout_stage_producer( + orchestrator, + prompt_queue, + rollout_queue, + ), + self.loop, + ) + self._attach_stage_failure_callbacks( + rollout_future=rollout_future, + reference_future=reference_future, + rollout_queue=rollout_queue, + train_data_queue=train_data_queue, + ) + return rollout_future, reference_future + + def _run_train_stage_pipeline( + self, + *, + prompt_queue: queue.Queue[TrainingInputT | None], + train_data_queue: queue_lib.AbstractDataQueue, + full_dataset_iterator: Iterator[TrainingInputT], + all_eval_prompts: List[TrainingInputT], + train_micro_batch_size: int, + training_config, + skip_jit: bool, + ) -> None: + """Consumes train-stage items until the staged pipeline reaches completion.""" + train_data_gen = self._data_consumer_batch_generator( + train_data_queue, + train_micro_batch_size, + ) + for train_stage_item in train_data_gen: + if ( + self._training_config.max_steps + and self.rl_cluster.global_steps >= self._training_config.max_steps + ): + logging.info( + "Reached max_steps: %d >= %d", + self.rl_cluster.global_steps, + self._training_config.max_steps, + ) + prompt_queue.put(None) + break + + train_micro_batch = train_stage_item.train_examples + if train_micro_batch is not None: + self._iter_steps += 1 + + # TODO(tsbao): Re-enable this once off-policy filtering is needed. + # Filter out examples that are too old (off-policy). + # filtered_train_micro_batch = self._filter_outdated_offpolicy_examples( + # train_micro_batch + # ) + # if not filtered_train_micro_batch: + # continue + # train_micro_batch = filtered_train_micro_batch + + merged_train_micro_batch = jax.tree.map( + lambda *xs: jnp.concatenate(xs, axis=0), *train_micro_batch + ) + + current_eval_dataset = None + if ( + all_eval_prompts + and self.rl_cluster.actor_trainer.train_steps + % training_config.eval_every_n_steps + == 0 + ): + self._eval_iter_steps = 0 + eval_orchestrator = self._build_orchestrator() + + async def _eval_runner_async(current_eval_orchestrator): + eval_examples = [] + async for batch in self._orchestrator_producer( + current_eval_orchestrator, + all_eval_prompts, + num_generations=self._num_generations(), + ): + eval_example = self._batch_to_train_example( + batch, + rl_cluster_lib.Mode.EVAL, + ) + eval_examples.extend(eval_example) + return eval_examples + + eval_future = asyncio.run_coroutine_threadsafe( + _eval_runner_async(eval_orchestrator), self.loop + ) + eval_examples = eval_future.result() + self._eval_iter_steps += 1 + current_eval_dataset = eval_examples + + self.rl_cluster.update_actor( + [merged_train_micro_batch], current_eval_dataset, skip_jit + ) + if hasattr(self.rl_cluster, "critic_trainer"): + self.rl_cluster.update_critic( + [merged_train_micro_batch], current_eval_dataset, skip_jit + ) + + if train_stage_item.end_of_global_batch: + self._finalize_global_step(full_dataset_iterator, prompt_queue) def train( self, @@ -684,6 +1197,7 @@ def train( train_micro_batch_size = ( self._training_config.train_micro_batch_size or mini_batch_size ) + self._train_micro_batch_size = train_micro_batch_size # Rollout and compute_logps micro batch sizes have to be 1 since we only # process inidividual prompts. self._rollout_micro_batch_size = 1 @@ -718,9 +1232,13 @@ def train( training_config = self.rl_cluster.cluster_config.training_config + rollout_queue = queue_lib.SimpleDataQueue(maxsize=0) train_data_queue = queue_lib.SimpleDataQueue(maxsize=0) + reference_executor = None + rollout_future = None + reference_future = None - # 1. Start producer thread to generate rollouts and training examples. + # 1. Start the rollout and reference stages explicitly. orchestrator = self._build_orchestrator() prompt_queue = queue.Queue() @@ -735,161 +1253,38 @@ def train( prompt_queue.put(None) break - producer_future = asyncio.run_coroutine_threadsafe( - self._producer(orchestrator, prompt_queue, train_data_queue), - self.loop, - ) - - # 2. Consume training examples and train. - train_data_gen = self._data_consumer_batch_generator( - train_data_queue, train_micro_batch_size - ) - if self._training_config.max_seq_token_per_tpu is not None: - logging.info( - "Using sequence packing with max_seq_token_per_tpu: %d", - self._training_config.max_seq_token_per_tpu, - ) - train_data_gen = rl_utils.pack_sequences( - train_data_gen, self._training_config.max_seq_token_per_tpu + try: + reference_executor = ThreadPoolExecutor(max_workers=1) + rollout_future, reference_future = self._setup_train_stage_pipeline( + orchestrator=orchestrator, + prompt_queue=prompt_queue, + rollout_queue=rollout_queue, + train_data_queue=train_data_queue, + reference_executor=reference_executor, ) - micro_batches_since_last_sync = 0 - micro_batches_per_full_batch = full_batch_size // train_micro_batch_size - for train_micro_batch in train_data_gen: - if ( - self._training_config.max_steps - and self.rl_cluster.global_steps >= self._training_config.max_steps - ): - logging.info( - "Reached max_steps: %d >= %d", - self.rl_cluster.global_steps, - self._training_config.max_steps, - ) - prompt_queue.put(None) - break - self._iter_steps += 1 - - # TODO(tsbao): Re-enable this once off-policy filtering is needed. - # Filter out examples that are too old (off-policy). - # filtered_train_micro_batch = self._filter_outdated_offpolicy_examples( - # train_micro_batch - # ) - # if not filtered_train_micro_batch: - # continue - # train_micro_batch = filtered_train_micro_batch - - merged_train_micro_batch = jax.tree.map( - lambda *xs: jnp.concatenate(xs, axis=0), *train_micro_batch + self._run_train_stage_pipeline( + prompt_queue=prompt_queue, + train_data_queue=train_data_queue, + full_dataset_iterator=full_dataset_iterator, + all_eval_prompts=all_eval_prompts, + train_micro_batch_size=train_micro_batch_size, + training_config=training_config, + skip_jit=skip_jit, ) - - # --- Evaluation Logic --- - current_eval_dataset = None - if ( - all_eval_prompts - and self.rl_cluster.actor_trainer.train_steps - % training_config.eval_every_n_steps - == 0 - ): - self._eval_iter_steps = 0 - eval_orchestrator = self._build_orchestrator() - - async def _eval_runner_async(current_eval_orchestrator): - eval_examples = [] - async for batch in self._orchestrator_producer( - current_eval_orchestrator, - all_eval_prompts, - num_generations=self._num_generations(), - ): - eval_example = self._batch_to_train_example( - batch, - rl_cluster_lib.Mode.EVAL, - ) - eval_examples.extend(eval_example) - return eval_examples - - eval_future = asyncio.run_coroutine_threadsafe( - _eval_runner_async(eval_orchestrator), self.loop - ) - eval_examples = eval_future.result() - self._eval_iter_steps += 1 - current_eval_dataset = eval_examples - - # --- Training Step --- - self.rl_cluster.update_actor( - [merged_train_micro_batch], current_eval_dataset, skip_jit + rollout_future.result() + reference_future.result() + except BaseException: + self._request_stage_pipeline_stop( + prompt_queue=prompt_queue, + rollout_queue=rollout_queue, + train_data_queue=train_data_queue, + rollout_future=rollout_future, ) - if hasattr(self.rl_cluster, "critic_trainer"): - self.rl_cluster.update_critic( - [merged_train_micro_batch], current_eval_dataset, skip_jit - ) - - # --- Weight Sync Logic --- - micro_batches_since_last_sync += 1 - if micro_batches_since_last_sync == micro_batches_per_full_batch: - global_step_time = time.time() - self._global_step_start_time - logging.info( - f"Global step {self.rl_cluster.global_steps} completed in" - f" {global_step_time:.2f} seconds." - ) - self.rl_cluster.buffer_metrics_async( - {"perf/global_step_time": (global_step_time, np.mean)}, - mode=rl_cluster_lib.Mode.TRAIN, - step=self.rl_cluster.global_steps, - ) - if self.should_sync_weights: - logging.info("Requesting sync lock to sync weights...") - self._rollout_sync_lock.acquire_weight_sync() - try: - logging.info("Sync lock acquired. Syncing weights.") - with self.rl_cluster.perf_v2.span( - perf_constants.WEIGHT_SYNC, - self.rl_cluster.perf_v2.all_devices, - tags={ - perf_constants.STEP: self.rl_cluster.global_steps, - }, - ): - self.rl_cluster.sync_weights() - self.policy_version += 1 - logging.info( - "Weights synced. Policy version incremented to %d.", - self.policy_version, - ) - try: - with self.rl_cluster.perf_v2.span( - perf_constants.DATA_LOADING, - tags={ - perf_constants.STEP: self.rl_cluster.global_steps, - }, - ): - batch = next(full_dataset_iterator) - self._put_prompts_to_queue(prompt_queue, batch) - except StopIteration: - prompt_queue.put(None) - finally: - self._rollout_sync_lock.release_weight_sync() - logging.info("Sync lock released.") - else: - self.rl_cluster.global_steps += 1 - try: - with self.rl_cluster.perf_v2.span( - perf_constants.DATA_LOADING, - tags={ - perf_constants.STEP: self.rl_cluster.global_steps, - }, - ): - batch = next(full_dataset_iterator) - self._put_prompts_to_queue(prompt_queue, batch) - except StopIteration: - prompt_queue.put(None) - - self.rl_cluster.buffer_metrics( - self.rl_cluster.perf_v2.export(), - mode=rl_cluster_lib.Mode.TRAIN, - ) - micro_batches_since_last_sync = 0 - self._global_step_start_time = time.time() - - _ = producer_future.result() - self.rl_cluster.close() + raise + finally: + if reference_executor is not None: + reference_executor.shutdown(wait=False, cancel_futures=True) + self.rl_cluster.close() def _put_prompts_to_queue( self, diff --git a/tunix/rl/rl_cluster.py b/tunix/rl/rl_cluster.py index ee31eb9fd..d786f94dc 100644 --- a/tunix/rl/rl_cluster.py +++ b/tunix/rl/rl_cluster.py @@ -538,7 +538,7 @@ def _init_cluster(self): critic_config = copy.deepcopy(self.cluster_config.training_config) critic_config.metrics_prefix = "critic" critic_config.pbar_description = "Critic Training" - if critic_config.checkpoint_root_directory is not None: + if critic_config.checkpoint_root_directory: critic_config.checkpoint_root_directory = os.path.join( critic_config.checkpoint_root_directory, "critic" ) @@ -562,7 +562,7 @@ def _init_cluster(self): actor_config = copy.deepcopy(self.cluster_config.training_config) actor_config.metrics_prefix = "actor" actor_config.pbar_description = "Actor Training" - if actor_config.checkpoint_root_directory is not None: + if actor_config.checkpoint_root_directory: actor_config.checkpoint_root_directory = os.path.join( actor_config.checkpoint_root_directory, "actor" ) diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py index 0c0bfc7f8..2747fbf3a 100644 --- a/tunix/rl/rl_learner.py +++ b/tunix/rl/rl_learner.py @@ -47,6 +47,13 @@ MetricFn = Callable[..., rl_cluster_lib.MetricsT] + +def _mesh_device_keys(mesh) -> frozenset[Any]: + return frozenset( + getattr(device, "id", device) + for device in mesh.devices.flatten().tolist() + ) + TConfig = TypeVar("TConfig", bound=algo_config_lib.AlgorithmConfig) @@ -118,16 +125,18 @@ def __init__( self.rl_cluster.rollout.model(), ) ) - - # Enable async rollout if trainer and rollout are not on the same mesh. - # If they do, then doesn't make sense for the interleave because they will - # have resource contention. - self.can_enable_async_rollout = ( - self.rl_cluster.cluster_config.role_to_mesh[rl_cluster_lib.Role.ACTOR] - != self.rl_cluster.cluster_config.role_to_mesh[ - rl_cluster_lib.Role.ROLLOUT - ] + actor_mesh = self.rl_cluster.cluster_config.role_to_mesh[ + rl_cluster_lib.Role.ACTOR + ] + rollout_mesh = self.rl_cluster.cluster_config.role_to_mesh[ + rl_cluster_lib.Role.ROLLOUT + ] + self._share_actor_rollout_devices = ( + _mesh_device_keys(actor_mesh) == _mesh_device_keys(rollout_mesh) ) + + # Overlap is only safe when actor and rollout use different device sets. + self.can_enable_async_rollout = not self._share_actor_rollout_devices self.executor = futures.ThreadPoolExecutor(max_workers=1) self._last_iter_step = self.rl_cluster.actor_trainer.iter_steps diff --git a/tunix/sft/checkpoint_manager.py b/tunix/sft/checkpoint_manager.py index d069a7706..50a3e6136 100644 --- a/tunix/sft/checkpoint_manager.py +++ b/tunix/sft/checkpoint_manager.py @@ -42,12 +42,12 @@ def __init__( """Initializes the checkpoint manager. Args: - root_directory: The root directory for the checkpoint manager. If None, - the checkpoint manager will be disabled. + root_directory: The root directory for the checkpoint manager. If None + or empty, the checkpoint manager will be disabled. options: The options for the checkpoint manager. """ self._checkpoint_manager: ocp.CheckpointManager | None = None - if root_directory is not None: + if root_directory: # When using Pathways, the checkpoint manager only supports persistence # APIs now. if 'proxy' in os.getenv('JAX_PLATFORMS', ''):