Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions docs/agentic_rl.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,53 @@ generating trajectories with stale parameters.
<img src="images/batch_vs_async_rollout.png" alt="Batch vs Async Rollout" width="50%">
</p>

### Mesh Placement and Colocation

Agentic GRPO supports three distinct placement patterns for the actor,
reference, and rollout roles:

1. **Shared mesh**: multiple roles reuse the exact same mesh object. This is
the most tightly colocated setup and may enable model or backbone sharing
in parts of the stack.

Backend status: exact shared-mesh support is currently supported for the
vanilla rollout backend. Exact shared-mesh execution is not supported yet
for `vllm` and `sglang_jax`.

2. **Colocated device set**: a role uses `colocate_with` to reuse another
role's device slice while still keeping its own mesh shape. For example,
the actor can use `(4, 1)` while rollout uses `(1, 4)` on the same four
devices. This is still colocated, but it is different from exact mesh
sharing.

3. **Disaggregated placement**: each role owns a separate device slice.

For CLI-driven GRPO and agentic GRPO runs, `colocate_with` is configured on the
role-specific model config, for example:

```yaml
actor_model_config:
mesh:
shape: "(4,1)"
axis_names: "('fsdp','tp')"

rollout_model_config:
colocate_with: "actor"
mesh:
shape: "(1,4)"
axis_names: "('fsdp','tp')"
```

This tells Tunix to allocate only one four-device owner slice and build two
different meshes on top of it. Exact mesh equality is only required when a
runtime path wants to reuse the same model instance; it is not required for
colocation itself.

For accelerated rollout backends, treat same-device-set colocation and exact
shared-mesh reuse as different features. Today, `vllm` and `sglang_jax`
support colocated device-set placement through `colocate_with`, but exact
shared-mesh execution is not supported yet.

### Trajectory Batching and Grouping

Tunix supports batching of agentic trajectories through the `GroupQueueManager`.
Expand Down
14 changes: 12 additions & 2 deletions docs/launching.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ This section provides a detailed explanation of the configuration parameters ava

#### Model Configuration (`model_config`)

These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration.
These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration.

* **`model_name`**: The unique full name identifier of the model. This
corresponds to the full name and should match exactly with the model name
Expand Down Expand Up @@ -287,6 +287,16 @@ These parameters define the base model, where to download it from, and how to sh
* **`mesh`**: Defines the hardware mesh layout for distributed training.
* `shape`: Tuple string defining mesh dimensions (e.g., `"(2,2)"` for a 2x2 grid).
* `axis_names`: Names for mesh axes, often used for parallelism strategies (e.g., `"('fsdp','tp')"` for Fully Sharded Data Parallelism and Tensor Parallelism).
* **`colocate_with`**: Optional role-local placement override for
`actor_model_config`, `reference_model_config`, `rollout_model_config`, and
other RL roles.
* If unset, a role owns its own device slice when it has an explicit
`mesh`, or shares the actor mesh by default when it does not.
* If set to a role name such as `"actor"`, the role reuses that role's
device slice but may still define its own `mesh.shape` and
`mesh.axis_names`.
* This is different from exact mesh sharing: two roles can be colocated on
the same devices while using different mesh layouts.


#### Tokenizer Configuration (`tokenizer_config`)
Expand Down Expand Up @@ -338,7 +348,7 @@ General settings for the training loop, logging, and checkpointing.

* **`eval_every_n_steps`**: Frequency of running evaluation steps.

* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients
* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients
before performing a parameter update (simulates larger batch sizes).

* **`checkpointing_options`**:
Expand Down
38 changes: 36 additions & 2 deletions docs/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,13 @@ and training. To further maximize the hardware utility, you can consider enablin
non-active models to CPU RAM when a
different component is occupying the TPU.

Enabling collocated mode is straightforward; you simply provide the same mesh to
every component when configuring the `role_to_mesh` mapping for your `rl_cluster`.
Enabling collocated mode is straightforward; the strongest form is to provide
the same mesh to every component when configuring the `role_to_mesh` mapping
for your `rl_cluster`.

Backend status: exact shared-mesh execution is currently supported for the
vanilla rollout backend. Exact shared-mesh execution is not supported yet for
`vllm` and `sglang-jax`.

```python
import numpy as np
Expand All @@ -179,6 +184,35 @@ ClusterConfig(
)
```

For CLI-driven GRPO and agentic GRPO runs, there is now a second colocated
mode: **same device set, different mesh shape**. This is configured with
`colocate_with`.

```yaml
actor_model_config:
mesh:
shape: "(4,1)"
axis_names: "('fsdp','tp')"

rollout_model_config:
colocate_with: "actor"
mesh:
shape: "(1,4)"
axis_names: "('fsdp','tp')"
```

In this configuration, actor and rollout are still colocated because they run
on the same device slice, but they do not share the exact same mesh object.
That distinction matters:

* Same device set means the roles are colocated.
* Same mesh may additionally allow model or backbone sharing in some runtime
paths.

For `vllm` and `sglang-jax`, this same-device-set colocation mode is the
currently supported colocated placement. Exact shared-mesh reuse is not
supported yet for those backends.

### Disaggregated Execution

Disaggregated mode partitions the TPU cluster into distinct "sub-meshes",
Expand Down
10 changes: 10 additions & 0 deletions docs/rollout.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ Setting `cluster_config.rollout_engine="vllm"` enables the vllm rollout/sampler.
Tunix uses `tunix.rl.rollout.base_rollout.RolloutConfig` for rollout settings.
The fields below are the vLLM-relevant ones.

Exact shared-mesh execution is currently supported only for the vanilla rollout
backend. For `vllm`, exact shared-mesh execution is not supported yet. The
supported colocated configuration today is same-device-set placement with an
independently shaped rollout mesh.

#### vLLM-specific fields

In addition to the common sampling parameters mentioned above, the following
Expand Down Expand Up @@ -277,6 +282,11 @@ Tunix uses `tunix.rl.rollout.base_rollout.RolloutConfig` for rollout settings.
In addition to the common sampling parameters, the following fields are specific
to SGLang-Jax:

Exact shared-mesh execution is currently supported only for the vanilla rollout
backend. For `sglang_jax`, exact shared-mesh execution is not supported yet.
The supported colocated configuration today is same-device-set placement with
an independently shaped rollout mesh.

- `rollout_sglang_jax_model_version`

- Model id or local path used by SGLang-Jax as `model_path`.
Expand Down
2 changes: 0 additions & 2 deletions examples/deepscaler/run_deepscaler_disagg_v5p16.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ python -m tunix.cli.grpo_main \
model_config.remat_config=3 \
actor_model_config.mesh.shape="$trainer_mesh" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
reference_model_config.mesh=null \
reference_model_config.same_mesh_as="actor" \
rollout_model_config.mesh.shape="$rollout_mesh" \
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
\
Expand Down
2 changes: 0 additions & 2 deletions examples/deepswe/run_deepswe_disagg_v5p_32.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ python -m tunix.cli.grpo_main \
model_config.remat_config=3 \
actor_model_config.mesh.shape="$trainer_mesh" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
reference_model_config.mesh=null \
reference_model_config.same_mesh_as="actor" \
rollout_model_config.mesh.shape="$rollout_mesh" \
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
\
Expand Down
9 changes: 6 additions & 3 deletions examples/rl/grpo/gsm8k/run_qwen3_8b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ num_generations="${num_generations:-4}"
train_mesh="${train_mesh:-(8,1)}"
rollout_mesh="${rollout_mesh:-(1,8)}"

checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}"
# Set rollout_colocate to the mesh name (e.g. "actor") to colocate the rollout
# model on the same mesh as the actor model
rollout_colocate="${rollout_colocate:-null}"

checkpoint_dir="${checkpoint_dir-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}"
checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}"
if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then
checkpoint_dir="${checkpoint_dir}_${checkpoint_suffix}"
Expand Down Expand Up @@ -79,8 +83,7 @@ python -m tunix.cli.grpo_main \
model_config.remat_config=3 \
actor_model_config.mesh.shape="$train_mesh" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
reference_model_config.mesh=null \
reference_model_config.same_mesh_as="actor" \
rollout_model_config.colocate_with="$rollout_colocate" \
rollout_model_config.mesh.shape="$rollout_mesh" \
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
\
Expand Down
124 changes: 122 additions & 2 deletions tests/cli/grpo_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,6 @@ def test_cli_empty_system_prompt_stays_empty_string(self):
)
self.assertEqual(p.config["agentic_grpo_config"]["system_prompt"], "")


class SplitMeshConfigTest(absltest.TestCase):

def test_split_mesh_uses_explicit_role_meshes(self):
Expand Down Expand Up @@ -688,7 +687,6 @@ def test_split_mesh_uses_explicit_role_meshes(self):
"shape": "(2,1)",
"axis_names": "('fsdp','tp')",
}
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
rollout_model_config = pipeline.config["rollout_model_config"]
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
rollout_model_config["mesh"] = {
Expand Down Expand Up @@ -732,6 +730,128 @@ def __init__(self, devices, axis_names, axis_types=None):
role_to_mesh[rl_cluster_lib.Role.ACTOR],
)

def test_colocate_with_reuses_device_slice_with_different_mesh(self):
extra = """
training_mode: "agentic_grpo"
data_module: "tunix.cli.recipes.deepscaler_data"
apply_chat_template_to_dataset: false
data_config:
train_data_path: "gs://fake/train.json"
eval_data_path: "gs://fake/eval.parquet"
prompt_key: "prompts"
reward_functions: []
verl_compatible: false
chat_parser_config:
type: "default"
agent_class_path: null
agent_kwargs: {}
env_class_path: null
env_kwargs: {}
kubernetes_config: null
agentic_grpo_config:
num_generations: 2
num_iterations: 1
beta: 0.0
epsilon: 0.2
epsilon_high: 0.28
system_prompt: ""
max_concurrency: 1
off_policy_steps: 0
max_turns: 1
context_ratio: 1
sglang_jax_config:
mem_fraction_static: 0.8
vllm_config:
hbm_utilization: 0.4
"""
pipeline = _make_pipeline(extra)
actor_model_config = pipeline.config["actor_model_config"]
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
actor_model_config["mesh"] = {
"shape": "(2,1)",
"axis_names": "('fsdp','tp')",
}
rollout_model_config = pipeline.config["rollout_model_config"]
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
rollout_model_config["colocate_with"] = "actor"
rollout_model_config["mesh"] = {
"shape": "(1,2)",
"axis_names": "('fsdp','tp')",
}

fake_devices = list(range(4))

class FakeMesh:

def __init__(self, devices, axis_names, axis_types=None):
self.devices = devices
self.axis_names = axis_names
self.axis_types = axis_types

with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
with mock.patch.object(
grpo_main.jax.sharding, "Mesh", side_effect=FakeMesh
):
role_to_mesh = pipeline.create_role_to_mesh()

self.assertSequenceEqual(
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.flatten().tolist(),
[0, 1],
)
self.assertSequenceEqual(
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.flatten().tolist(),
[0, 1],
)
self.assertEqual(
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.shape,
(2, 1),
)
self.assertEqual(
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape,
(1, 2),
)

def test_empty_string_colocate_with_is_treated_as_unset(self):
extra = """
training_mode: "agentic_grpo"
data_module: "tunix.cli.recipes.deepscaler_data"
apply_chat_template_to_dataset: false
data_config:
train_data_path: "gs://fake/train.json"
eval_data_path: "gs://fake/eval.parquet"
prompt_key: "prompts"
reward_functions: []
verl_compatible: false
chat_parser_config:
type: "default"
agent_class_path: null
agent_kwargs: {}
env_class_path: null
env_kwargs: {}
kubernetes_config: null
agentic_grpo_config:
num_generations: 2
num_iterations: 1
beta: 0.0
epsilon: 0.2
epsilon_high: 0.28
system_prompt: ""
max_concurrency: 1
off_policy_steps: 0
max_turns: 1
context_ratio: 1
sglang_jax_config:
mem_fraction_static: 0.8
vllm_config:
hbm_utilization: 0.4
"""
pipeline = _make_pipeline(extra)
rollout_model_config = pipeline.config["rollout_model_config"]
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
rollout_model_config["colocate_with"] = ""

self.assertEmpty(pipeline._get_colocate_with_map())


if __name__ == "__main__":
absltest.main()
Loading
Loading