diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index 114aa7e06..89fae1202 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -22,7 +22,7 @@ on: permissions: contents: read -jobs: +jobs: run: runs-on: ubuntu-latest steps: @@ -63,6 +63,10 @@ jobs: run: | python -m pytest tests/cli/utils/ -v --tb=short + - name: Run shared mesh and topology tests + run: | + python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short + - name: Run perf tests run: | python -m pytest tests/perf/ -v --tb=short diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 9881707f0..304ce9d9f 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -120,7 +120,7 @@ jobs: - name: Run tunix tests not covered by the above categories run: | # This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here. - python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$? + python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$? if [ "${code:-0}" = "5" ]; then echo "No tests collected (expected)." exit 0 diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index b68d98117..a5d432598 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -732,6 +732,97 @@ def __init__(self, devices, axis_names, axis_types=None): role_to_mesh[rl_cluster_lib.Role.ACTOR], ) + def test_split_mesh_delegates_device_allocation_to_mesh_utils(self): + extra = """ +training_mode: "agentic_grpo" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_grpo_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 + context_ratio: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + actor_model_config = pipeline.config["actor_model_config"] + if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig): + actor_model_config["mesh"] = { + "shape": "(1,2)", + "axis_names": "('fsdp','tp')", + } + pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"} + rollout_model_config = pipeline.config["rollout_model_config"] + if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig): + rollout_model_config["mesh"] = { + "shape": "(1,2)", + "axis_names": "('fsdp','tp')", + } + + fake_devices = ["a0", "a1", "r0", "r1"] + allocated_devices = { + "actor_model_config": ["a0", "a1"], + "rollout_model_config": ["r0", "r1"], + } + created_mesh_devices = {} + + def fake_create_mesh(model_key, devices=None): + created_mesh_devices[model_key] = list(devices) + return (model_key, tuple(devices)) + + with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices): + with mock.patch.object( + grpo_main.mesh_lib, + "allocate_named_mesh_device_slices", + return_value=allocated_devices, + ) as allocate_mock: + with mock.patch.object(pipeline, "create_mesh", side_effect=fake_create_mesh): + role_to_mesh = pipeline.create_role_to_mesh() + + allocate_mock.assert_called_once_with( + [ + ("actor_model_config", 2), + ("rollout_model_config", 2), + ], + devices=fake_devices, + ) + self.assertEqual(created_mesh_devices["actor_model_config"], ["a0", "a1"]) + self.assertEqual(created_mesh_devices["rollout_model_config"], ["r0", "r1"]) + self.assertEqual( + role_to_mesh[rl_cluster_lib.Role.ACTOR], + ("actor_model_config", ("a0", "a1")), + ) + self.assertIs( + role_to_mesh[rl_cluster_lib.Role.REFERENCE], + role_to_mesh[rl_cluster_lib.Role.ACTOR], + ) + self.assertEqual( + role_to_mesh[rl_cluster_lib.Role.ROLLOUT], + ("rollout_model_config", ("r0", "r1")), + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/utils/mesh_utils_test.py b/tests/utils/mesh_utils_test.py new file mode 100644 index 000000000..42c205d1e --- /dev/null +++ b/tests/utils/mesh_utils_test.py @@ -0,0 +1,879 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from tunix.utils import mesh + + +class MeshUtilsTest(absltest.TestCase): + + def test_device_attr_calls_callable_attributes(self): + class FakeDevice: + + def coords(self): + return (1, 2, 3) + + self.assertEqual(mesh.device_attr(FakeDevice(), "coords"), (1, 2, 3)) + self.assertIsNone(mesh.device_attr(FakeDevice(), "missing")) + + def test_device_host_key_prefers_slice_and_process_metadata(self): + class FakeDevice: + + def __init__(self): + self.slice_index = 4 + self.process_index = 7 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (4, 7)) + + def test_device_host_key_falls_back_to_slice_and_task_id(self): + class FakeDevice: + + def __init__(self): + self.slice = 3 + self.task_id = 9 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (3, 9)) + + def test_device_host_key_prefers_logical_task_over_process_index(self): + class FakeDevice: + + def __init__(self): + self.slice_index = 4 + self.process_index = 0 + self.logical_task = 7 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (4, 7)) + + def test_device_host_key_prefers_task_id_over_process_index(self): + class FakeDevice: + + def __init__(self): + self.slice_index = 4 + self.process_index = 0 + self.task_id = 9 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (4, 9)) + + def test_device_host_key_returns_none_without_task_metadata(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.device_host_key(FakeDevice())) + + def test_find_candidate_coord_boxes_finds_contiguous_boxes(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0)), + FakeDevice(1, (1, 0, 0)), + FakeDevice(2, (0, 1, 0)), + FakeDevice(3, (1, 1, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertEqual( + mesh.find_candidate_coord_boxes(topology, 4), + [ + ( + (0, 0, 0), + (2, 2, 1), + ((0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + ) + ], + ) + + def test_find_candidate_coord_boxes_skips_missing_coords(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0)), + FakeDevice(1, (1, 0, 0)), + FakeDevice(2, (1, 1, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertEqual(mesh.find_candidate_coord_boxes(topology, 4), []) + + def test_find_candidate_coord_boxes_can_return_multiple_candidates(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0)), + FakeDevice(1, (1, 0, 0)), + FakeDevice(2, (2, 0, 0)), + FakeDevice(3, (3, 0, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertEqual( + mesh.find_candidate_coord_boxes(topology, 2), + [ + ((0, 0, 0), (2, 1, 1), ((0, 0, 0), (1, 0, 0))), + ((1, 0, 0), (2, 1, 1), ((1, 0, 0), (2, 0, 0))), + ((2, 0, 0), (2, 1, 1), ((2, 0, 0), (3, 0, 0))), + ], + ) + + def test_find_candidate_coord_boxes_rejects_split_chip_candidates(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (0, 0, 0), 1), + FakeDevice(2, (1, 0, 0), 0), + FakeDevice(3, (1, 0, 0), 1), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertEqual( + mesh.find_candidate_coord_boxes(topology, 1), + [], + ) + + def test_find_host_aligned_candidate_coord_boxes_respects_exact_host_shape(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + + fake_devices = [] + device_id = 0 + for x in range(4): + for y in range(4): + for z in range(2): + for core_on_chip in (0, 1): + fake_devices.append(FakeDevice(device_id, (x, y, z), core_on_chip)) + device_id += 1 + + topology = mesh.get_coord_topology(fake_devices) + + candidate_boxes = mesh.find_host_aligned_candidate_coord_boxes( + topology, 8, (2, 2, 1, 2) + ) + + self.assertLen(candidate_boxes, 8) + self.assertContainsSubset( + [ + ( + (0, 0, 0, 0), + (2, 2, 1, 2), + ( + (0, 0, 0, 0), + (0, 0, 0, 1), + (0, 1, 0, 0), + (0, 1, 0, 1), + (1, 0, 0, 0), + (1, 0, 0, 1), + (1, 1, 0, 0), + (1, 1, 0, 1), + ), + ), + ( + (0, 0, 1, 0), + (2, 2, 1, 2), + ( + (0, 0, 1, 0), + (0, 0, 1, 1), + (0, 1, 1, 0), + (0, 1, 1, 1), + (1, 0, 1, 0), + (1, 0, 1, 1), + (1, 1, 1, 0), + (1, 1, 1, 1), + ), + ), + ], + candidate_boxes, + ) + + def test_candidate_uses_whole_chips_requires_all_cores(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + + topology = mesh.get_coord_topology([ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (0, 0, 0), 1), + FakeDevice(2, (1, 0, 0), 0), + FakeDevice(3, (1, 0, 0), 1), + ]) + + self.assertFalse( + mesh.candidate_uses_whole_chips( + topology, + [(0, 0, 0, 0), (1, 0, 0, 0)], + ) + ) + self.assertTrue( + mesh.candidate_uses_whole_chips( + topology, + [(0, 0, 0, 0), (0, 0, 0, 1), (1, 0, 0, 0), (1, 0, 0, 1)], + ) + ) + + def test_get_coord_topology_builds_bounding_box(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (2, 1, 0)), + FakeDevice(1, (3, 1, 0)), + FakeDevice(2, (2, 2, 0)), + FakeDevice(3, (3, 2, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertIsNotNone(topology) + self.assertEqual(topology.num_dims, 3) + self.assertEqual(topology.max_shape, (2, 2, 1)) + self.assertEqual(topology.all_coords, ((2, 1, 0), (3, 1, 0), (2, 2, 0), (3, 2, 0))) + + def test_get_coord_topology_rejects_duplicate_coords(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0))] + + self.assertIsNone(mesh.get_coord_topology(fake_devices)) + + def test_get_coord_topology_uses_core_on_chip_to_disambiguate_devices(self): + class FakeDevice: + + def __init__(self, coords, core_on_chip): + self.coords = coords + self.core_on_chip = core_on_chip + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((0, 0, 0), 1), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertIsNotNone(topology) + self.assertEqual(topology.all_coords, ((0, 0, 0, 0), (0, 0, 0, 1))) + + def test_get_coord_topology_rejects_empty_device_list(self): + self.assertIsNone(mesh.get_coord_topology([])) + + def test_get_coord_topology_rejects_mismatched_coord_dimensions(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0, 1))] + + self.assertIsNone(mesh.get_coord_topology(fake_devices)) + + def test_summarize_devices_for_logging_includes_id_coords_and_host(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index, slice_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + self.slice_index = slice_index + + self.assertEqual( + mesh.summarize_devices_for_logging([FakeDevice(11, (1, 2, 0), 5, 6)]), + [{"id": 11, "coords": (1, 2, 0), "host": (6, 5)}], + ) + + def test_group_devices_by_host_groups_equal_sized_hosts(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + grouped = mesh.group_devices_by_host([ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + FakeDevice(3, 1), + ]) + + self.assertEqual([[device.id for device in group] for group in grouped], [[0, 1], [2, 3]]) + + def test_allocate_named_mesh_device_slices_uses_logical_task_host_groups(self): + class FakeDevice: + + def __init__(self, device_id, logical_task): + self.id = device_id + self.process_index = 0 + self.logical_task = logical_task + + fake_devices = [] + for device_id in range(16): + fake_devices.append(FakeDevice(device_id, device_id % 2)) + + allocated = mesh.allocate_named_mesh_device_slices( + [("actor", 8)], + devices=fake_devices, + ) + + self.assertEqual( + [device.id for device in allocated["actor"]], + [0, 2, 4, 6, 8, 10, 12, 14], + ) + + def test_group_devices_by_host_returns_none_without_host_metadata(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.group_devices_by_host([FakeDevice()])) + + def test_group_devices_by_host_returns_none_for_inconsistent_host_sizes(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + self.assertIsNone( + mesh.group_devices_by_host([ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + ]) + ) + + def test_host_mesh_shape_infers_consistent_per_host_shape(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((0, 1, 0), 0), + FakeDevice((1, 1, 0), 0), + FakeDevice((2, 0, 0), 1), + FakeDevice((3, 0, 0), 1), + FakeDevice((2, 1, 0), 1), + FakeDevice((3, 1, 0), 1), + ] + + self.assertEqual(mesh.host_mesh_shape(fake_devices), (2, 2, 1)) + + def test_host_mesh_shape_returns_none_for_sparse_host_box(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((1, 1, 0), 0), + ] + + self.assertIsNone(mesh.host_mesh_shape(fake_devices)) + + def test_host_mesh_shape_returns_none_for_inconsistent_host_shapes(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((0, 1, 0), 0), + FakeDevice((1, 1, 0), 0), + FakeDevice((2, 0, 0), 1), + FakeDevice((3, 0, 0), 1), + ] + + self.assertIsNone(mesh.host_mesh_shape(fake_devices)) + + def test_divisors_returns_sorted_unique_factors(self): + self.assertEqual(mesh._divisors(12), [1, 2, 3, 4, 6, 12]) + + def test_enumerate_box_shapes_returns_shapes_with_requested_volume(self): + self.assertEqual( + mesh._enumerate_box_shapes(4, (4, 2, 2)), + [(1, 2, 2), (2, 1, 2), (2, 2, 1), (4, 1, 1)], + ) + + def test_coord_box_score_prefers_host_aligned_boxes(self): + aligned_score = mesh._coord_box_score((0, 0, 0), (2, 2, 1), (2, 2, 1)) + unaligned_score = mesh._coord_box_score((1, 0, 0), (2, 2, 1), (2, 2, 1)) + + self.assertLess(aligned_score, unaligned_score) + + def test_select_best_candidate_coords_prefers_host_aligned_box(self): + candidate_boxes = [ + ((1, 0, 0), (2, 2, 1), ((1, 0, 0), (1, 1, 0), (2, 0, 0), (2, 1, 0))), + ((0, 0, 0), (2, 2, 1), ((0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0))), + ] + + self.assertEqual( + mesh.select_best_candidate_coords(candidate_boxes, (2, 2, 1)), + [(0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0)], + ) + + def test_select_best_candidate_coords_prefers_chip_host_aligned_box_with_core_dimension(self): + candidate_boxes = [ + ( + (0, 0, 0, 0), + (1, 2, 2, 2), + ( + (0, 0, 0, 0), + (0, 0, 0, 1), + (0, 1, 0, 0), + (0, 1, 0, 1), + (0, 0, 1, 0), + (0, 0, 1, 1), + (0, 1, 1, 0), + (0, 1, 1, 1), + ), + ), + ( + (0, 0, 0, 0), + (2, 2, 1, 2), + ( + (0, 0, 0, 0), + (0, 0, 0, 1), + (0, 1, 0, 0), + (0, 1, 0, 1), + (1, 0, 0, 0), + (1, 0, 0, 1), + (1, 1, 0, 0), + (1, 1, 0, 1), + ), + ), + ] + + self.assertEqual( + mesh.select_best_candidate_coords(candidate_boxes, (2, 2, 1, 2)), + [ + (0, 0, 0, 0), + (0, 0, 0, 1), + (0, 1, 0, 0), + (0, 1, 0, 1), + (1, 0, 0, 0), + (1, 0, 0, 1), + (1, 1, 0, 0), + (1, 1, 0, 1), + ], + ) + + def test_select_best_candidate_coords_prefers_more_compact_shape(self): + candidate_boxes = [ + ((0, 0, 0), (1, 4, 1), ((0, 0, 0), (0, 1, 0), (0, 2, 0), (0, 3, 0))), + ((0, 0, 0), (2, 2, 1), ((0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0))), + ] + + self.assertEqual( + mesh.select_best_candidate_coords(candidate_boxes, None), + [(0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0)], + ) + + def test_select_best_candidate_coords_uses_start_as_tiebreaker(self): + candidate_boxes = [ + ((2, 0, 0), (2, 1, 1), ((2, 0, 0), (3, 0, 0))), + ((0, 0, 0), (2, 1, 1), ((0, 0, 0), (1, 0, 0))), + ] + + self.assertEqual( + mesh.select_best_candidate_coords(candidate_boxes, None), + [(0, 0, 0), (1, 0, 0)], + ) + + def test_select_best_candidate_coords_returns_none_without_candidates(self): + self.assertIsNone(mesh.select_best_candidate_coords([], (2, 2, 1))) + + def test_device_mesh_coords_appends_core_on_chip_when_present(self): + class FakeDevice: + + def __init__(self): + self.coords = (1, 2, 0) + self.core_on_chip = 1 + + self.assertEqual( + mesh.device_mesh_coords(FakeDevice()), + (1, 2, 0, 1), + ) + + def test_device_mesh_coords_returns_none_without_coords(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.device_mesh_coords(FakeDevice())) + + def test_known_host_mesh_shape_returns_none_for_unknown_device_family(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0, 0) + self.device_kind = "unknown" + + self.assertIsNone(mesh.known_host_mesh_shape([FakeDevice()])) + + def test_known_host_mesh_shape_returns_none_when_coord_rank_mismatches_bounds(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0) + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice() for _ in range(128)] + + self.assertIsNone(mesh.known_host_mesh_shape(fake_devices)) + + def test_resolve_per_host_mesh_shape_returns_inferred_shape(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((0, 1, 0), 0), + FakeDevice((1, 1, 0), 0), + FakeDevice((2, 0, 0), 1), + FakeDevice((3, 0, 0), 1), + FakeDevice((2, 1, 0), 1), + FakeDevice((3, 1, 0), 1), + ] + + self.assertEqual(mesh.resolve_per_host_mesh_shape(fake_devices), (2, 2, 1)) + + def test_known_host_mesh_shape_uses_static_topology_metadata(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0, 0) + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice() for _ in range(128)] + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (2, 2, 1), + ) + + def test_known_host_mesh_shape_uses_single_host_bounds_for_tpu7x_2(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0))] + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (1, 1, 1), + ) + + def test_known_host_mesh_shape_appends_core_dimension_when_present(self): + class FakeDevice: + + def __init__(self, coords, core_on_chip): + self.coords = coords + self.core_on_chip = core_on_chip + self.device_kind = "TPU v7" + + fake_devices = [] + for x in range(4): + for y in range(4): + for z in range(4): + for core_on_chip in (0, 1): + fake_devices.append(FakeDevice((x, y, z), core_on_chip)) + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (2, 2, 1, 2), + ) + + def test_resolve_per_host_mesh_shape_raises_on_mismatch(self): + class FakeDevice: + + def __init__(self, device_id, coords, logical_task): + self.id = device_id + self.coords = coords + self.logical_task = logical_task + self.device_kind = "TPU v7" + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (2, 0, 0), 0), + FakeDevice(3, (3, 0, 0), 0), + FakeDevice(4, (0, 0, 1), 1), + FakeDevice(5, (1, 0, 1), 1), + FakeDevice(6, (2, 0, 1), 1), + FakeDevice(7, (3, 0, 1), 1), + ] + + with self.assertRaisesRegex(ValueError, "does not match known host bounds"): + mesh.resolve_per_host_mesh_shape(fake_devices) + + def test_allocate_named_mesh_device_slices_prefers_coord_boxes(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0, 0)), + FakeDevice(1, (0, 0, 0, 1)), + FakeDevice(2, (1, 0, 0, 0)), + FakeDevice(3, (1, 0, 0, 1)), + FakeDevice(4, (2, 0, 0, 0)), + FakeDevice(5, (2, 0, 0, 1)), + FakeDevice(6, (3, 0, 0, 0)), + FakeDevice(7, (3, 0, 0, 1)), + FakeDevice(8, (0, 1, 0, 0)), + FakeDevice(9, (0, 1, 0, 1)), + FakeDevice(10, (1, 1, 0, 0)), + FakeDevice(11, (1, 1, 0, 1)), + FakeDevice(12, (2, 1, 0, 0)), + FakeDevice(13, (2, 1, 0, 1)), + FakeDevice(14, (3, 1, 0, 0)), + FakeDevice(15, (3, 1, 0, 1)), + ] + + allocated = mesh.allocate_named_mesh_device_slices( + [("actor", 8)], + devices=fake_devices, + ) + + self.assertEqual( + [device.id for device in allocated["actor"]], + [0, 1, 2, 3, 8, 9, 10, 11], + ) + + def test_allocate_devices_by_coords_uses_core_on_chip_dimension(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(4): + for y in range(4): + for z in range(2): + for core_on_chip in (0, 1): + fake_devices.append(FakeDevice(device_id, (x, y, z), core_on_chip)) + device_id += 1 + + allocated = mesh.allocate_devices_by_coords(fake_devices, 8) + + self.assertEqual( + [device.id for device in allocated], + [0, 1, 4, 5, 16, 17, 20, 21], + ) + + def test_allocate_devices_by_coords_returns_none_without_coord_topology(self): + class FakeDevice: + + def __init__(self, process_index): + self.process_index = process_index + + self.assertIsNone( + mesh.allocate_devices_by_coords([FakeDevice(0), FakeDevice(0)], 2) + ) + + def test_allocate_devices_by_coords_returns_best_contiguous_box(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (0, 1, 0), 0), + FakeDevice(3, (1, 1, 0), 0), + FakeDevice(4, (2, 0, 0), 1), + FakeDevice(5, (3, 0, 0), 1), + FakeDevice(6, (2, 1, 0), 1), + FakeDevice(7, (3, 1, 0), 1), + ] + + allocated = mesh.allocate_devices_by_coords(fake_devices, 4) + + self.assertEqual([device.id for device in allocated], [0, 1, 2, 3]) + + def test_allocate_named_mesh_device_slices_uses_jax_devices_by_default(self): + class FakeDevice: + + def __init__(self, device_id): + self.id = device_id + + fake_devices = [FakeDevice(0), FakeDevice(1)] + + with mock.patch.object(mesh.jax, "devices", return_value=fake_devices): + allocated = mesh.allocate_named_mesh_device_slices([("trainer", 2)]) + + self.assertEqual([device.id for device in allocated["trainer"]], [0, 1]) + + def test_allocate_named_mesh_device_slices_uses_whole_hosts(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + FakeDevice(3, 1), + ] + + allocated = mesh.allocate_named_mesh_device_slices( + [("trainer", 2), ("rollout", 2)], + devices=fake_devices, + ) + + self.assertEqual([device.id for device in allocated["trainer"]], [0, 1]) + self.assertEqual([device.id for device in allocated["rollout"]], [2, 3]) + + def test_allocate_named_mesh_device_slices_allows_multiple_single_host_subslices(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (0, 1, 0), 0), + FakeDevice(3, (1, 1, 0), 0), + FakeDevice(4, (0, 2, 0), 0), + FakeDevice(5, (1, 2, 0), 0), + FakeDevice(6, (0, 3, 0), 0), + FakeDevice(7, (1, 3, 0), 0), + ] + + allocated = mesh.allocate_named_mesh_device_slices( + [("actor", 2), ("reference", 2), ("rollout", 2)], + devices=fake_devices, + ) + + self.assertEqual([device.id for device in allocated["actor"]], [0, 1]) + self.assertEqual([device.id for device in allocated["reference"]], [2, 3]) + self.assertEqual([device.id for device in allocated["rollout"]], [4, 5]) + + def test_allocate_named_mesh_device_slices_raises_on_host_misalignment(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + FakeDevice(3, 1), + ] + + with self.assertRaisesRegex(ValueError, "does not align with the detected host size"): + mesh.allocate_named_mesh_device_slices( + [("trainer", 3)], + devices=fake_devices, + ) + + def test_allocate_named_mesh_device_slices_raises_when_not_enough_hosts(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + FakeDevice(3, 1), + ] + + with self.assertRaisesRegex(ValueError, "but only 2 are available"): + mesh.allocate_named_mesh_device_slices( + [("trainer", 6)], + devices=fake_devices, + ) + + def test_allocate_named_mesh_device_slices_raises_when_not_enough_devices(self): + class FakeDevice: + + def __init__(self, device_id): + self.id = device_id + + fake_devices = [FakeDevice(0), FakeDevice(1)] + + with self.assertRaisesRegex(ValueError, "but only 2 remain available"): + mesh.allocate_named_mesh_device_slices( + [("trainer", 3)], + devices=fake_devices, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/utils/topology_test.py b/tests/utils/topology_test.py new file mode 100644 index 000000000..db8ffe25b --- /dev/null +++ b/tests/utils/topology_test.py @@ -0,0 +1,85 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from tunix.utils import topology + + +class TopologyTest(absltest.TestCase): + + def test_normalize_device_kind_recognizes_supported_families(self): + self.assertEqual(topology._normalize_device_kind("TPU v7"), "tpu7x") + self.assertEqual(topology._normalize_device_kind("TPU v6e"), "v6e") + self.assertEqual(topology._normalize_device_kind("TPU v5e"), "v5e") + self.assertEqual(topology._normalize_device_kind("TPU v5p"), "v5p") + self.assertEqual(topology._normalize_device_kind("TPU v4"), "v4") + self.assertIsNone(topology._normalize_device_kind("gpu")) + + def test_infer_chips_per_host_bounds_returns_none_for_empty_devices(self): + self.assertIsNone(topology.infer_chips_per_host_bounds([])) + + def test_infer_chips_per_host_bounds_returns_none_for_missing_device_kind(self): + class FakeDevice: + pass + + self.assertIsNone(topology.infer_chips_per_host_bounds([FakeDevice()])) + + def test_infer_chips_per_host_bounds_uses_single_host_shapes(self): + class FakeDevice: + + def __init__(self, device_kind): + self.device_kind = device_kind + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v5e")]), + (1, 1, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v6e")]), + (1, 1, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v7"), FakeDevice("TPU v7")]), + (1, 1, 1), + ) + + def test_infer_chips_per_host_bounds_uses_multi_host_shape_otherwise(self): + class FakeDevice: + + def __init__(self, device_kind): + self.device_kind = device_kind + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v7") for _ in range(4)]), + (2, 2, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v4") for _ in range(8)]), + (2, 2, 1), + ) + + def test_infer_chips_per_host_bounds_handles_callable_device_kind(self): + class FakeDevice: + + def device_kind(self): + return "TPU v7" + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice() for _ in range(128)]), + (2, 2, 1), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 3491b885b..dab55dcea 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -30,7 +30,6 @@ python -m tunix.cli.grpo_main examples/deepswe/configs/qwen3_32b.yaml """ -import collections import dataclasses import importlib import os @@ -53,6 +52,7 @@ from tunix.perf.experimental import export as perf_export_v2 from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout +from tunix.utils import mesh as mesh_lib _PATHWAYS_BNS = flags.DEFINE_string( @@ -190,6 +190,12 @@ def resolve_owner( role_to_owner[role] = resolve_owner(role, set()) return role_to_owner + def _device_mesh_coords(self, device: Any) -> tuple[int, ...] | None: + return mesh_lib.device_mesh_coords(device) + + def _known_host_mesh_shape(self, devices: list[Any]) -> tuple[int, ...] | None: + return mesh_lib.known_host_mesh_shape(devices) + def _create_role_to_mesh(self): devices = list(jax.devices()) role_to_owner = self._resolve_mesh_owners() @@ -201,40 +207,24 @@ def _create_role_to_mesh(self): if owner not in owner_order: owner_order.append(owner) - owner_to_mesh = {} - owner_to_device_slice = {} - device_offset = 0 + mesh_requirements = [] for owner in owner_order: model_key = self._ROLE_TO_MODEL_KEY[owner] axis_shapes, _ = self._parse_mesh_config(model_key) - required_devices = int(np.prod(axis_shapes)) - next_offset = device_offset + required_devices - if next_offset > len(devices): - raise ValueError( - f"Mesh allocation requires {next_offset} devices after allocating" - f" {model_key}, but only {len(devices)} are available." - ) - assigned_devices = devices[device_offset:next_offset] - owner_to_device_slice[owner] = assigned_devices + mesh_requirements.append((model_key, int(np.prod(axis_shapes)))) + + allocated_devices = mesh_lib.allocate_named_mesh_device_slices( + mesh_requirements, + devices=devices, + ) + + owner_to_mesh = {} + for owner in owner_order: + model_key = self._ROLE_TO_MODEL_KEY[owner] + assigned_devices = allocated_devices[model_key] owner_to_mesh[owner] = self.create_mesh( model_key, devices=assigned_devices ) - device_offset = next_offset - - if device_offset < len(devices): - logging.warning( - "Mesh allocation used %d of %d devices; %d devices remain unused.", - device_offset, - len(devices), - len(devices) - device_offset, - ) - logging.info( - "Mesh device allocation: %s", - { - self._ROLE_TO_MODEL_KEY[owner]: len(owner_to_device_slice[owner]) - for owner in owner_order - }, - ) return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} def create_role_to_mesh(self): diff --git a/tunix/utils/mesh.py b/tunix/utils/mesh.py new file mode 100644 index 000000000..5ad66e118 --- /dev/null +++ b/tunix/utils/mesh.py @@ -0,0 +1,782 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared mesh device allocation helpers. + +Typical usage: + + allocations = allocate_named_mesh_device_slices([ + ("actor", 8), + ("rollout", 4), + ]) + +The keys are arbitrary mesh names chosen by the caller. The integer is the +number of devices that mesh should receive. +""" + +import collections +import dataclasses +from typing import Any, Sequence + +from absl import logging +import jax +import numpy as np +from tunix.utils import topology + +MeshRequirement = tuple[str, int] + + +@dataclasses.dataclass(frozen=True) +class CoordTopology: + """Normalized coord metadata for a device pool. + + Attributes: + coord_to_device: Mapping from physical coords to device objects. + all_coords: Normalized coord tuples for all devices. + num_dims: Number of coord dimensions. + max_shape: Bounding-box shape of the device pool. + """ + + coord_to_device: dict[tuple[int, ...], Any] + all_coords: tuple[tuple[int, ...], ...] + num_dims: int + max_shape: tuple[int, ...] + chip_coord_to_coords: dict[tuple[int, ...], tuple[tuple[int, ...], ...]] + + +def device_attr(device: Any, attr_name: str) -> Any: + """Returns a raw device attribute, calling it first if JAX exposes it lazily. + + Args: + device: A JAX device or test double. + attr_name: Attribute name such as "coords" or "process_index". + + Returns: + The attribute value, or None if the attribute does not exist. + """ + value = getattr(device, attr_name, None) + return value() if callable(value) else value + + +def device_host_key(device: Any) -> tuple[Any, ...] | None: + """Returns a stable host grouping key for topology-aware allocation. + + Args: + device: A JAX device or test double. + + Returns: + A tuple of (slice_id, task_id) when that metadata is available, otherwise + None. + """ + task_id = None + for attr_name in ("logical_task", "task_id", "process_index"): + task_id = device_attr(device, attr_name) + if task_id is not None: + break + if task_id is None: + return None + + slice_id = None + for attr_name in ("slice_index", "slice"): + slice_id = device_attr(device, attr_name) + if slice_id is not None: + break + return (slice_id, task_id) + + +def device_mesh_coords(device: Any) -> tuple[int, ...] | None: + """Returns physical mesh coordinates for topology-aware allocation. + + Args: + device: A JAX device or test double. + + Returns: + A tuple like (x, y, z) or (x, y, z, core) when the runtime exposes device + coordinates, otherwise None. + """ + coords = device_attr(device, "coords") + if coords is None: + return None + + coords = tuple(coords) + if not coords: + return None + + normalized_coords = tuple(int(coord) for coord in coords) + core_on_chip = device_attr(device, "core_on_chip") + if core_on_chip is None: + return normalized_coords + return normalized_coords + (int(core_on_chip),) + + +def infer_core_on_chip_count(devices: Sequence[Any]) -> int | None: + """Returns the per-chip core count when the runtime exposes it consistently.""" + chip_to_cores = collections.defaultdict(set) + saw_any_core = False + + for device in devices: + coords = device_attr(device, "coords") + core_on_chip = device_attr(device, "core_on_chip") + if coords is None: + return None + if core_on_chip is None: + continue + saw_any_core = True + chip_to_cores[tuple(int(coord) for coord in coords)].add(int(core_on_chip)) + + if not saw_any_core: + return None + + core_counts = {len(core_ids) for core_ids in chip_to_cores.values()} + if len(core_counts) != 1: + return None + return next(iter(core_counts)) + + +def summarize_devices_for_logging(devices: Sequence[Any]) -> list[dict[str, Any]]: + """Builds compact log-friendly summaries for a device list. + + Args: + devices: Devices to summarize. + + Returns: + A list of dictionaries containing device id, coords, and inferred host key. + """ + summaries = [] + for device in devices: + summaries.append({ + "id": device_attr(device, "id"), + "coords": device_mesh_coords(device), + "host": device_host_key(device), + }) + return summaries + + +def summarize_devices_for_debug_logging( + devices: Sequence[Any], + limit: int = 16, +) -> list[dict[str, Any]]: + """Builds richer device summaries for topology debugging. + + Args: + devices: Devices to summarize. + limit: Maximum number of devices to include. + + Returns: + A list of dictionaries with raw device topology metadata. + """ + summaries = [] + for device in devices[:limit]: + summaries.append({ + "id": device_attr(device, "id"), + "coords": device_attr(device, "coords"), + "core_on_chip": device_attr(device, "core_on_chip"), + "process_index": device_attr(device, "process_index"), + "logical_task": device_attr(device, "logical_task"), + "task_id": device_attr(device, "task_id"), + "slice_index": device_attr(device, "slice_index"), + "slice": device_attr(device, "slice"), + "host": device_host_key(device), + }) + return summaries + + +def summarize_host_groups_for_logging(devices: Sequence[Any]) -> dict[tuple[Any, ...], int]: + """Summarizes device counts per derived host key for debug logging.""" + host_counts = collections.Counter() + for device in devices: + host_key = device_host_key(device) + host_counts[host_key] += 1 + return dict(sorted(host_counts.items(), key=lambda item: str(item[0]))) + + +def group_devices_by_host(devices: Sequence[Any]) -> list[list[Any]] | None: + """Groups devices by host/task when that metadata is available. + + Args: + devices: Candidate devices to partition. + + Returns: + A list of equal-sized per-host device lists, or None if host metadata is + missing or inconsistent. + """ + host_to_devices = {} + for device in devices: + host_key = device_host_key(device) + if host_key is None: + return None + host_to_devices.setdefault(host_key, []).append(device) + + host_sizes = {len(host_devices) for host_devices in host_to_devices.values()} + if len(host_sizes) != 1: + logging.warning( + "Falling back to flat device allocation because host sizes differ: %s", + sorted(host_sizes), + ) + return None + return list(host_to_devices.values()) + + +def host_mesh_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Returns the per-host physical box shape when coords are available. + + Args: + devices: Devices spanning one or more hosts. + + Returns: + The shape of one host in physical coords, such as (2, 2, 1), or None when + it cannot be inferred reliably. + """ + host_to_coords = collections.defaultdict(list) + for device in devices: + host_key = device_host_key(device) + coords = device_mesh_coords(device) + if host_key is None or coords is None: + return None + host_to_coords[host_key].append(coords) + + host_shapes = set() + for coords_list in host_to_coords.values(): + ndim = len(coords_list[0]) + mins = tuple(min(coords[i] for coords in coords_list) for i in range(ndim)) + maxs = tuple(max(coords[i] for coords in coords_list) for i in range(ndim)) + shape = tuple(max_coord - min_coord + 1 for min_coord, max_coord in zip(mins, maxs)) + if int(np.prod(shape)) != len(coords_list): + return None + host_shapes.add(shape) + + if len(host_shapes) != 1: + return None + return next(iter(host_shapes)) + + +def get_coord_topology(devices: Sequence[Any]) -> CoordTopology | None: + """Builds normalized coord metadata for a device pool. + + Args: + devices: Candidate devices to inspect. + + Returns: + A CoordTopology describing the device coords and overall bounding box, or + None when the devices do not expose a consistent coord layout. + """ + if not devices: + return None + + coord_to_device = {} + all_coords = [] + for device in devices: + coords = device_mesh_coords(device) + if coords is None: + logging.info( + "Coord topology unavailable because device lacks coords: %s", + summarize_devices_for_debug_logging([device]), + ) + return None + if all_coords and len(coords) != len(all_coords[0]): + logging.info( + "Coord topology unavailable because coord rank differs: existing_rank=%d device=%s", + len(all_coords[0]), + summarize_devices_for_debug_logging([device]), + ) + return None + if coords in coord_to_device: + logging.info( + "Coord topology unavailable because multiple devices share coords %s: %s", + coords, + summarize_devices_for_debug_logging([coord_to_device[coords], device]), + ) + return None + coord_to_device[coords] = device + all_coords.append(coords) + + num_dims = len(all_coords[0]) + chip_coord_to_coords = collections.defaultdict(list) + for coords in all_coords: + chip_coord_to_coords[coords[:-1]].append(coords) + max_shape = tuple( + max(coords[dim] for coords in all_coords) + - min(coords[dim] for coords in all_coords) + + 1 + for dim in range(num_dims) + ) + return CoordTopology( + coord_to_device=coord_to_device, + all_coords=tuple(all_coords), + num_dims=num_dims, + max_shape=max_shape, + chip_coord_to_coords={ + chip_coord: tuple(sorted(group_coords)) + for chip_coord, group_coords in chip_coord_to_coords.items() + }, + ) + + +def candidate_uses_whole_chips( + coord_topology: CoordTopology, + candidate_coords: Sequence[tuple[int, ...]], +) -> bool: + """Returns whether a candidate includes all logical devices for each chip. + + When multiple logical devices share the same physical chip coordinates, a + valid Pathways subslice must include either all of them or none of them. + This rejects candidates that split `core_on_chip` siblings across meshes. + """ + if coord_topology.num_dims <= 1: + return True + + selected_coords = set(candidate_coords) + selected_chip_coords = {coords[:-1] for coords in selected_coords} + for chip_coord in selected_chip_coords: + chip_group = coord_topology.chip_coord_to_coords.get(chip_coord, ()) + if any(coords not in selected_coords for coords in chip_group): + return False + return True + + +def known_host_mesh_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Returns known host bounds from static topology metadata when available. + + Args: + devices: Devices from a single TPU slice. + + Returns: + A known per-host physical bound such as (1, 1, 1) or (2, 2, 1), or None if + the accelerator family is unknown. + """ + bounds = topology.infer_chips_per_host_bounds(devices) + if bounds is None: + return None + + coords = device_mesh_coords(devices[0]) if devices else None + if coords is None: + return None + + if len(coords) == len(bounds): + return bounds + + if len(coords) == len(bounds) + 1: + core_count = infer_core_on_chip_count(devices) + if core_count is None: + return None + return bounds + (core_count,) + + return None + + +def resolve_per_host_mesh_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Resolves per-host shape and validates inferred vs known topology. + + Args: + devices: Devices spanning one or more hosts. + + Returns: + The inferred per-host shape when available, otherwise the known static host + bounds. + + Raises: + ValueError: If runtime-inferred host shape disagrees with known static host + bounds for the device family. + """ + inferred_shape = host_mesh_shape(devices) + static_shape = known_host_mesh_shape(devices) + if ( + inferred_shape is not None + and static_shape is not None + and inferred_shape != static_shape + ): + raise ValueError( + "Inferred per-host device shape " + f"{inferred_shape} does not match known host bounds {static_shape}." + ) + return inferred_shape or static_shape + + +def _divisors(value: int) -> list[int]: + divisors = set() + for candidate in range(1, int(np.sqrt(value)) + 1): + if value % candidate == 0: + divisors.add(candidate) + divisors.add(value // candidate) + return sorted(divisors) + + +def _enumerate_box_shapes( + required_devices: int, + max_shape: tuple[int, ...], +) -> list[tuple[int, ...]]: + """Enumerates box shapes whose volume matches the requested device count.""" + shapes = [] + num_dims = len(max_shape) + + def build(dim_index: int, remaining: int, prefix: tuple[int, ...]): + if dim_index == num_dims - 1: + if remaining <= max_shape[dim_index]: + shapes.append(prefix + (remaining,)) + return + + for size in _divisors(remaining): + if size > max_shape[dim_index]: + continue + build(dim_index + 1, remaining // size, prefix + (size,)) + + build(0, required_devices, ()) + return shapes + + +def _coord_box_score( + start: tuple[int, ...], + shape: tuple[int, ...], + host_shape: tuple[int, ...] | None, +) -> tuple[Any, ...]: + """Builds a lexicographic sort key for candidate coord boxes. + + The returned tuple is ordered so Python tuple comparison implements the + desired ranking policy directly: + + 1. Prefer host-aligned boxes when host_shape is known. + 2. Prefer boxes with a smaller maximum dimension. + 3. Prefer more compact overall shapes. + 4. Prefer earlier start coordinates as a stable tiebreaker. + + Args: + start: Candidate box origin. + shape: Candidate box shape. + host_shape: Per-host physical shape such as (2, 2, 1). + + Returns: + A tuple sort key suitable for lexicographic comparison. + """ + chip_host_alignment = 1 + full_host_alignment = 1 + if host_shape is not None: + chip_dims = min(3, len(shape), len(host_shape)) + chip_aligned = all( + start[dim] % host_shape[dim] == 0 + and shape[dim] % host_shape[dim] == 0 + for dim in range(chip_dims) + if host_shape[dim] > 1 + ) + fully_aligned = all( + start[dim] % host_shape[dim] == 0 + and shape[dim] % host_shape[dim] == 0 + for dim in range(len(shape)) + if host_shape[dim] > 1 + ) + chip_host_alignment = 0 if chip_aligned else 1 + full_host_alignment = 0 if fully_aligned else 1 + return ( + chip_host_alignment, + full_host_alignment, + max(shape), + tuple(sorted(shape, reverse=True)), + tuple(-dim for dim in shape), + start, + ) + + +def select_best_candidate_coords( + candidate_boxes: Sequence[ + tuple[tuple[int, ...], tuple[int, ...], Sequence[tuple[int, ...]]] + ], + host_shape: tuple[int, ...] | None, +) -> list[tuple[int, ...]] | None: + """Selects the best candidate coord box using the mesh heuristic. + + Args: + candidate_boxes: Sequence of (start, shape, candidate_coords) tuples. + `start` is the box origin, `shape` is the physical box shape, and + `candidate_coords` are the device coords inside that box. + host_shape: Per-host physical shape such as (2, 2, 1), used to prefer + host-aligned boxes when available. + + Returns: + The candidate coord list for the best-ranked box, or None when there are no + candidates. + + Notes: + Candidate boxes are ranked by `_coord_box_score()`, which uses a + lexicographic sort key instead of a single numeric score. This makes the + priority order explicit and avoids arbitrary weighting between ranking + factors. + """ + best_candidate_coords = None + best_score = None + for start, shape, candidate_coords in candidate_boxes: + score = _coord_box_score(start, shape, host_shape) + if best_score is None or score < best_score: + best_score = score + best_candidate_coords = list(candidate_coords) + return best_candidate_coords + + +def find_candidate_coord_boxes( + coord_topology: CoordTopology, + required_devices: int, +) -> list[tuple[tuple[int, ...], tuple[int, ...], tuple[tuple[int, ...], ...]]]: + """Finds contiguous candidate coord boxes for a requested device count. + + Args: + coord_topology: Normalized coord metadata for the candidate device pool. + required_devices: Number of devices needed for one mesh. + + Returns: + A list of (start, shape, candidate_coords) tuples representing contiguous + coord boxes whose volume matches required_devices. + + Notes: + This function only enumerates valid contiguous boxes that exist in the + current device pool. It does not choose among them; ranking is handled by + `select_best_candidate_coords()`. + """ + candidate_boxes = [] + for shape in _enumerate_box_shapes(required_devices, coord_topology.max_shape): + for start in coord_topology.coord_to_device: + candidate_coords = [] + for offset in np.ndindex(shape): + candidate_coord = tuple( + start[dim] + offset[dim] for dim in range(coord_topology.num_dims) + ) + if candidate_coord not in coord_topology.coord_to_device: + break + candidate_coords.append(candidate_coord) + else: + if candidate_uses_whole_chips(coord_topology, candidate_coords): + candidate_boxes.append((start, shape, tuple(candidate_coords))) + return candidate_boxes + + +def find_host_aligned_candidate_coord_boxes( + coord_topology: CoordTopology, + required_devices: int, + host_shape: tuple[int, ...], +) -> list[tuple[tuple[int, ...], tuple[int, ...], tuple[tuple[int, ...], ...]]]: + """Finds contiguous candidate boxes that exactly respect host bounds. + + Args: + coord_topology: Normalized coord metadata for the candidate device pool. + required_devices: Number of devices needed for one mesh. + host_shape: Known per-host physical shape such as (2, 2, 1) or + (2, 2, 1, 2). + + Returns: + A list of valid coord boxes whose shape is an exact multiple of host_shape. + """ + if len(host_shape) != coord_topology.num_dims: + return [] + + host_volume = int(np.prod(host_shape)) + if host_volume <= 0 or required_devices % host_volume != 0: + return [] + + host_grid_shape = tuple( + coord_topology.max_shape[dim] // host_shape[dim] + for dim in range(coord_topology.num_dims) + ) + required_host_boxes = required_devices // host_volume + + candidate_boxes = [] + for host_box_shape in _enumerate_box_shapes(required_host_boxes, host_grid_shape): + physical_shape = tuple( + host_box_shape[dim] * host_shape[dim] + for dim in range(coord_topology.num_dims) + ) + for start in coord_topology.coord_to_device: + if any( + start[dim] % host_shape[dim] != 0 + for dim in range(coord_topology.num_dims) + if host_shape[dim] > 1 + ): + continue + + candidate_coords = [] + for offset in np.ndindex(physical_shape): + candidate_coord = tuple( + start[dim] + offset[dim] for dim in range(coord_topology.num_dims) + ) + if candidate_coord not in coord_topology.coord_to_device: + break + candidate_coords.append(candidate_coord) + else: + if candidate_uses_whole_chips(coord_topology, candidate_coords): + candidate_boxes.append((start, physical_shape, tuple(candidate_coords))) + return candidate_boxes + + +def allocate_devices_by_coords( + devices: Sequence[Any], + required_devices: int, +) -> list[Any] | None: + """Allocates a contiguous physical box of devices when coords exist. + + Args: + devices: Candidate devices to allocate from. + required_devices: Number of devices needed for one mesh. + + Returns: + A list of devices forming the best contiguous physical box, or None if the + devices do not expose usable coordinates. + + Notes: + This helper runs in three stages: + + 1. Build normalized coord metadata with `get_coord_topology()`. + 2. Enumerate valid contiguous candidate boxes with + `find_candidate_coord_boxes()`. + 3. Rank those candidates with `select_best_candidate_coords()` and map the + winning coords back to device objects. + """ + coord_topology = get_coord_topology(devices) + if coord_topology is None: + return None + per_host_shape = resolve_per_host_mesh_shape(devices) + + candidate_boxes = [] + if per_host_shape is not None: + candidate_boxes = find_host_aligned_candidate_coord_boxes( + coord_topology, + required_devices, + per_host_shape, + ) + if not candidate_boxes: + candidate_boxes = find_candidate_coord_boxes(coord_topology, required_devices) + + best_candidate_coords = select_best_candidate_coords( + candidate_boxes, + per_host_shape, + ) + if best_candidate_coords is None: + return None + + selected_coords = set(best_candidate_coords) + return [ + device + for device in devices + if device_mesh_coords(device) in selected_coords + ] + + +def allocate_named_mesh_device_slices( + mesh_requirements: Sequence[MeshRequirement], + devices: Sequence[Any] | None = None, +) -> dict[str, list[Any]]: + """Allocates device subsets for named meshes. + + The allocator prefers coord-aligned physical boxes, then whole-host groups, + and finally falls back to flat prefixes when no topology metadata is usable. + + Args: + mesh_requirements: Sequence of (mesh_name, required_devices) pairs. + Example: [("actor", 8), ("rollout", 4)]. The mesh_name is only used for + logging and as the key in the returned dictionary. + devices: Optional explicit device list. When omitted, this uses + jax.devices(). + + Returns: + A dictionary mapping each mesh name to the list of devices assigned to it. + + Raises: + ValueError: If a requested mesh cannot be assigned enough devices or if a + host-based allocation would split hosts illegally. + """ + all_devices = list(jax.devices() if devices is None else devices) + logging.info( + "Mesh allocator raw device sample: %s", + summarize_devices_for_debug_logging(all_devices), + ) + logging.info( + "Mesh allocator derived host groups: %s", + summarize_host_groups_for_logging(all_devices), + ) + remaining_devices = list(all_devices) + remaining_host_groups = group_devices_by_host(all_devices) + allocations = {} + used_device_count = 0 + + for mesh_name, required_devices in mesh_requirements: + assigned_devices = allocate_devices_by_coords(remaining_devices, required_devices) + if assigned_devices is not None: + assigned_device_ids = {id(device) for device in assigned_devices} + remaining_devices = [ + device for device in remaining_devices if id(device) not in assigned_device_ids + ] + remaining_host_groups = None + used_device_count += len(assigned_devices) + + if remaining_host_groups: + devices_per_host = len(remaining_host_groups[0]) + if required_devices >= devices_per_host: + if required_devices % devices_per_host != 0: + raise ValueError( + f"Mesh allocation for {mesh_name} requires {required_devices} devices, " + f"which does not align with the detected host size of {devices_per_host}. " + "Choose a mesh shape that fits within one host or uses a whole " + "number of hosts." + ) + required_hosts = required_devices // devices_per_host + if required_hosts > len(remaining_host_groups): + raise ValueError( + f"Mesh allocation requires {required_hosts} hosts for {mesh_name}, " + f"but only {len(remaining_host_groups)} are available." + ) + assigned_devices = [ + device + for host_devices in remaining_host_groups[:required_hosts] + for device in host_devices + ] + remaining_host_groups = remaining_host_groups[required_hosts:] or None + + assigned_device_ids = {id(device) for device in assigned_devices} + remaining_devices = [ + device for device in remaining_devices if id(device) not in assigned_device_ids + ] + used_device_count += len(assigned_devices) + + if assigned_devices is None: + if required_devices > len(remaining_devices): + raise ValueError( + f"Mesh allocation requires {required_devices} devices for {mesh_name}, " + f"but only {len(remaining_devices)} remain available." + ) + assigned_devices = remaining_devices[:required_devices] + remaining_devices = remaining_devices[required_devices:] + used_device_count += len(assigned_devices) + + if remaining_host_groups: + assigned_device_ids = {id(device) for device in assigned_devices} + remaining_host_groups = [ + host_devices + for host_devices in remaining_host_groups + if all(id(device) not in assigned_device_ids for device in host_devices) + ] or None + + allocations[mesh_name] = assigned_devices + logging.info( + "Allocated devices for %s: %s", + mesh_name, + summarize_devices_for_logging(assigned_devices), + ) + + if used_device_count < len(all_devices): + logging.warning( + "Mesh allocation used %d of %d devices; %d devices remain unused.", + used_device_count, + len(all_devices), + len(all_devices) - used_device_count, + ) + logging.info( + "Mesh device allocation: %s", + {mesh_name: len(assigned_devices) for mesh_name, assigned_devices in allocations.items()}, + ) + return allocations diff --git a/tunix/utils/topology.py b/tunix/utils/topology.py new file mode 100644 index 000000000..3820e0416 --- /dev/null +++ b/tunix/utils/topology.py @@ -0,0 +1,63 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal accelerator topology helpers used by Tunix mesh allocation.""" + +from typing import Any, Sequence + +_SINGLE_HOST_BOUNDS = (1, 1, 1) +_MULTI_HOST_BOUNDS = (2, 2, 1) + + +def _device_attr(device: Any, attr_name: str) -> Any: + """Returns a raw device attribute, calling it first when exposed lazily.""" + value = getattr(device, attr_name, None) + return value() if callable(value) else value + + +def _normalize_device_kind(device_kind: str) -> str | None: + device_kind = device_kind.lower() + if "v7" in device_kind: + return "tpu7x" + if "v6e" in device_kind or "v6" in device_kind: + return "v6e" + if "v5e" in device_kind: + return "v5e" + if "v5" in device_kind: + return "v5p" + if "v4" in device_kind: + return "v4" + return None + + +def infer_chips_per_host_bounds( + devices: Sequence[Any], +) -> tuple[int, ...] | None: + if not devices: + return None + + device_kind = _device_attr(devices[0], "device_kind") + if not isinstance(device_kind, str): + return None + + family = _normalize_device_kind(device_kind) + if family is None: + return None + + device_count = len(devices) + if family in {"v5e", "v6e"} and device_count == 1: + return _SINGLE_HOST_BOUNDS + if family == "tpu7x" and device_count == 2: + return _SINGLE_HOST_BOUNDS + return _MULTI_HOST_BOUNDS