Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ on:

permissions:
contents: read
jobs:
jobs:
run:
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -63,6 +63,10 @@ jobs:
run: |
python -m pytest tests/cli/utils/ -v --tb=short

- name: Run shared mesh and topology tests
run: |
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short

- name: Run perf tests
run: |
python -m pytest tests/perf/ -v --tb=short
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ jobs:
- name: Run tunix tests not covered by the above categories
run: |
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
if [ "${code:-0}" = "5" ]; then
echo "No tests collected (expected)."
exit 0
Expand Down
91 changes: 91 additions & 0 deletions tests/cli/grpo_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,97 @@ def __init__(self, devices, axis_names, axis_types=None):
role_to_mesh[rl_cluster_lib.Role.ACTOR],
)

def test_split_mesh_delegates_device_allocation_to_mesh_utils(self):
extra = """
training_mode: "agentic_grpo"
data_module: "tunix.cli.recipes.deepscaler_data"
apply_chat_template_to_dataset: false
data_config:
train_data_path: "gs://fake/train.json"
eval_data_path: "gs://fake/eval.parquet"
prompt_key: "prompts"
reward_functions: []
verl_compatible: false
chat_parser_config:
type: "default"
agent_class_path: null
agent_kwargs: {}
env_class_path: null
env_kwargs: {}
kubernetes_config: null
agentic_grpo_config:
num_generations: 2
num_iterations: 1
beta: 0.0
epsilon: 0.2
epsilon_high: 0.28
system_prompt: ""
max_concurrency: 1
off_policy_steps: 0
max_turns: 1
context_ratio: 1
sglang_jax_config:
mem_fraction_static: 0.8
vllm_config:
hbm_utilization: 0.4
"""
pipeline = _make_pipeline(extra)
actor_model_config = pipeline.config["actor_model_config"]
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
actor_model_config["mesh"] = {
"shape": "(1,2)",
"axis_names": "('fsdp','tp')",
}
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
rollout_model_config = pipeline.config["rollout_model_config"]
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
rollout_model_config["mesh"] = {
"shape": "(1,2)",
"axis_names": "('fsdp','tp')",
}

fake_devices = ["a0", "a1", "r0", "r1"]
allocated_devices = {
"actor_model_config": ["a0", "a1"],
"rollout_model_config": ["r0", "r1"],
}
created_mesh_devices = {}

def fake_create_mesh(model_key, devices=None):
created_mesh_devices[model_key] = list(devices)
return (model_key, tuple(devices))

with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
with mock.patch.object(
grpo_main.mesh_lib,
"allocate_named_mesh_device_slices",
return_value=allocated_devices,
) as allocate_mock:
with mock.patch.object(pipeline, "create_mesh", side_effect=fake_create_mesh):
role_to_mesh = pipeline.create_role_to_mesh()

allocate_mock.assert_called_once_with(
[
("actor_model_config", 2),
("rollout_model_config", 2),
],
devices=fake_devices,
)
self.assertEqual(created_mesh_devices["actor_model_config"], ["a0", "a1"])
self.assertEqual(created_mesh_devices["rollout_model_config"], ["r0", "r1"])
self.assertEqual(
role_to_mesh[rl_cluster_lib.Role.ACTOR],
("actor_model_config", ("a0", "a1")),
)
self.assertIs(
role_to_mesh[rl_cluster_lib.Role.REFERENCE],
role_to_mesh[rl_cluster_lib.Role.ACTOR],
)
self.assertEqual(
role_to_mesh[rl_cluster_lib.Role.ROLLOUT],
("rollout_model_config", ("r0", "r1")),
)


if __name__ == "__main__":
absltest.main()
Loading
Loading