diff --git a/autoparallel/api.py b/autoparallel/api.py index 8f2eea55..932dc5f2 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -556,10 +556,12 @@ def inference_fn(args): self._traced_inputs, solved_input_placements, self.mesh ) - def forward(self, *args): + def forward(self, *args, **kwargs): # Flatten pytree args (e.g. dicts, nested structures) to tensor # leaves, matching how Dynamo flattened the inputs during tracing. flat_args, _ = torch.utils._pytree.tree_flatten(args) + if len(flat_args) != len(expected_inputs): + flat_args, _ = torch.utils._pytree.tree_flatten((args, kwargs)) _check_forward_args(flat_args, expected_inputs) # NB: don't close over the parameters/buffers, as the user may # reassign the module! diff --git a/autoparallel/graph_passes/graph_clustering.py b/autoparallel/graph_passes/graph_clustering.py index 4503e611..c01a09a3 100644 --- a/autoparallel/graph_passes/graph_clustering.py +++ b/autoparallel/graph_passes/graph_clustering.py @@ -11,7 +11,7 @@ import math import time from collections import defaultdict -from typing import Optional, cast +from typing import Optional import torch from torch._dynamo.graph_region_tracker import ( @@ -81,143 +81,6 @@ def _hash_node(node, strategies, input_pickler): return sha256_hash(input_pickler.dumps(key)) -def _extend_with_sibling_getitems( - region_groups: list[list[Region]], - node_to_duplicates: dict[Node, IdenticalNodes], - strategies: dict[Node, OpStrategy], - topological_ranking: dict[Node, int], -) -> set[Node]: - """Extend region groups with unclaimed getitem siblings of clustered nodes. - - The backward-BFS expansion only reaches getitem users that happen to be on - the main data path. Sibling tuple projections (e.g. logsumexp, RNG state - from SDPA) are left orphaned even though their producer is already aligned - across regions. This post-pass recovers them in two ways: - - 1. If a getitem's producer is already in a region, find matching unclaimed - getitems across all other regions and append them in-place. - 2. If a getitem's producer is NOT in any region but its duplicate getitems - (from node_to_duplicates) ARE clustered, create a small bridge group - that links the orphan to a clustered sibling. This handles the case - where the BFS created N-1 regions out of N identical layers. - - Returns the set of bridge root nodes — already-clustered nodes that are - reused as the root region of a bridge group and therefore intentionally - appear in two groups. - """ - claimed: set[Node] = set() - for region_group in region_groups: - for region in region_group: - claimed.update(region) - - # Case 1: extend existing regions with unclaimed getitem siblings. - for region_group in region_groups: - root_region = region_group[0] - num_regions = len(region_group) - - for pos in range(len(root_region)): - root_producer = root_region[pos] - getitems_by_idx: dict[int, list[Node]] = defaultdict(list) - for user in root_producer.users: - if ( - user.target is operator.getitem - and user not in claimed - and user in strategies - ): - getitems_by_idx[cast(int, user.args[1])].append(user) - - for k, root_matches in getitems_by_idx.items(): - if len(root_matches) != 1: - continue - root_getitem = root_matches[0] - if root_getitem not in node_to_duplicates: - continue - root_dups = node_to_duplicates[root_getitem] - root_phase = root_getitem.meta.get("partitioner_tag") - - candidates = [root_getitem] - valid = True - for other_region in region_group[1:]: - other_producer = other_region[pos] - matches = [ - user - for user in other_producer.users - if ( - user.target is operator.getitem - and user.args[1] == k - and user not in claimed - and user in strategies - and user in node_to_duplicates - and node_to_duplicates[user] is root_dups - and user.meta.get("partitioner_tag") == root_phase - ) - ] - if len(matches) != 1: - valid = False - break - candidates.append(matches[0]) - - if valid and len(candidates) == num_regions: - for region, getitem_node in zip(region_group, candidates): - region.append(getitem_node) - claimed.add(getitem_node) - - for region in region_group: - region.sort(key=lambda n: topological_ranking[n]) - - # Case 2: create bridge groups for orphaned getitems whose duplicates - # are already clustered. Each bridge group pairs one clustered sibling - # (as the root region) with the orphan, so create_cluster_links maps - # the orphan's decision variables to the root's. - bridge_roots: set[Node] = set() - seen_dup_groups: set[int] = set() - for node in strategies: - if node.target is not operator.getitem: - continue - if node in claimed: - continue - if node not in node_to_duplicates: - continue - dups = node_to_duplicates[node] - group_id = id(dups) - if group_id in seen_dup_groups: - continue - seen_dup_groups.add(group_id) - - if len(dups) < 2: - continue - if not all(d in strategies for d in dups): - continue - - # Find one claimed duplicate to serve as the root. - root = None - for d in dups: - if d in claimed: - root = d - break - if root is None: - continue - - # Create a bridge group: [[root], [orphan1], [orphan2], ...] - bridge = [[root]] - for d in dups: - if d not in claimed: - bridge.append([d]) - claimed.add(d) - if len(bridge) < 2: - continue - bridge.sort(key=lambda r: topological_ranking[r[0]]) - # Ensure the root is first (create_cluster_links uses region[0] - # as the root). - root_idx = next(i for i, r in enumerate(bridge) if r[0] is root) - if root_idx != 0: - bridge[0], bridge[root_idx] = bridge[root_idx], bridge[0] - region_groups.append(bridge) - bridge_roots.add(root) - - return bridge_roots - - def get_identical_regions( graph: torch.fx.Graph, strategies: dict[Node, OpStrategy] ) -> list[list[Region]]: @@ -302,7 +165,6 @@ def _is_identical(n0: Node, n1: Node) -> bool: # overlap. t = time.time() seen_nodes: set[Node] = set() - expanded_groups: list[list[Region]] = [] for region_group in region_groups: # NOTE: this seems like it's missing in the original implementation # from PyTorch. Given that fully_expand_region_group doesn't check @@ -322,30 +184,23 @@ def _is_identical(n0: Node, n1: Node) -> bool: # sort topologically for region in region_group: region.sort(key=lambda n: topological_ranking[n]) - expanded_groups.append(region_group) region_groups = [ - region_group for region_group in expanded_groups if len(region_group[0]) > 1 + region_group for region_group in region_groups if len(region_group[0]) > 1 ] - bridge_roots = _extend_with_sibling_getitems( - region_groups, node_to_duplicates, strategies, topological_ranking - ) - # sort everything so that we have nodes in topological ranking for region_group in region_groups: region_group.sort(key=lambda rg: topological_ranking[rg[0]]) region_groups.sort(key=lambda rg: topological_ranking[rg[0][0]]) logger.debug(f"Expanded regions in {time.time() - t} s") - # sanity check that we don't have duplicate nodes. - # Bridge roots are already-clustered nodes reused as root regions in - # bridge groups (case 2 above); they intentionally appear in two groups. + # sanity check that we don't have duplicate nodes seen_nodes.clear() for region_group in region_groups: for region in region_group: for node in region: - if node in seen_nodes and node not in bridge_roots: + if node in seen_nodes: raise RuntimeError(f"Duplicate node {node} in region group") seen_nodes.add(node) return region_groups diff --git a/examples/example_ds3_local_map.py b/examples/example_ds3_local_map.py index 106ce80b..1695de3f 100644 --- a/examples/example_ds3_local_map.py +++ b/examples/example_ds3_local_map.py @@ -25,7 +25,10 @@ def _seed_dtensor_rng(rng_seed: Optional[int]) -> None: def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): - seq_len = 1024 + # Match TorchTitan's DeepSeek V3 debug model shape. This example is a + # regression guard for placement/clustering issues that only appear at the + # larger debug shape used by TorchTitan GraphTrainer. + seq_len = 2048 if fake_evaluate: world_size = 256 @@ -66,11 +69,9 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): mesh_dim_names=("dp", "ep"), ) - config = make_dsv3_config( - num_experts=4, top_k=2, n_layers=4, n_dense_layers=0, max_seq_len=seq_len - ) + config = make_dsv3_config(max_seq_len=seq_len) - local_batch_size = 2 + local_batch_size = 8 global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1] with torch.device("meta"): @@ -140,7 +141,7 @@ def input_fn(): out.backward(torch.ones_like(out)) else: for i, x in enumerate(microbatches): - assert x.shape[0] == 2 + assert x.shape[0] == local_batch_size out = parallel_mod(x) assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" out.backward(torch.ones_like(out)) diff --git a/tests/test_graph_clustering.py b/tests/test_graph_clustering.py index 3a4c9d28..775a03e2 100644 --- a/tests/test_graph_clustering.py +++ b/tests/test_graph_clustering.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -import operator import re from collections import Counter @@ -67,9 +66,10 @@ def _clustering_stats(graph, strats, clusters, n_layers): } -def _assert_layer_coverage(stats, n_layers, min_coverage, label): +def _assert_layer_coverage(stats, n_layers, min_coverage, label, layers=None): """Assert each repeated layer has enough nodes in clustered regions.""" - layer_totals = [stats["per_layer_total"].get(i, 0) for i in range(n_layers)] + layers = range(n_layers) if layers is None else layers + layer_totals = [stats["per_layer_total"].get(i, 0) for i in layers] if len(set(layer_totals)) != 1: raise AssertionError( f"{label}: layers have different node counts: {layer_totals}" @@ -78,7 +78,7 @@ def _assert_layer_coverage(stats, n_layers, min_coverage, label): total = layer_totals[0] if total == 0: raise AssertionError(f"{label}: no layer nodes found in clustering stats") - for layer_idx in range(n_layers): + for layer_idx in layers: clustered = stats["per_layer_clustered"].get(layer_idx, 0) coverage = clustered / total if coverage < min_coverage: @@ -152,6 +152,7 @@ def _assert_model_clustering( min_coverage, forward_layers, backward_layers, + coverage_layers=None, min_region_size=100, ): """Assert coverage, phase separation, and large fwd/bwd layer clusters.""" @@ -160,6 +161,7 @@ def _assert_model_clustering( n_layers, min_coverage=min_coverage, label=label, + layers=coverage_layers, ) _assert_cross_layer_cluster( clusters, @@ -209,9 +211,13 @@ def input_fn(): return autop, model_args -def _setup_ds3_local_map_autop(device_mesh_2d, n_layers=2): - global_batch_size = 2 * device_mesh_2d.shape[0] * device_mesh_2d.shape[1] - config = make_dsv3_config(n_layers=n_layers, n_dense_layers=0) +def _setup_ds3_local_map_autop(device_mesh_2d): + local_batch_size = 8 + seq_len = 2048 + global_batch_size = ( + local_batch_size * device_mesh_2d.shape[0] * device_mesh_2d.shape[1] + ) + config = make_dsv3_config(max_seq_len=seq_len) with torch.device("meta"): model = DeepSeekV3Model( config, @@ -230,7 +236,7 @@ def input_fn(): device="cuda", ) - return AutoParallel(model, input_fn, device_mesh_2d, dynamic=True) + return AutoParallel(model, input_fn, device_mesh_2d, dynamic=True), config def test_clustering_high_coverage(device_mesh_2d): @@ -259,8 +265,8 @@ def test_clustering_high_coverage(device_mesh_2d): backward_layers=range(1, n_layers), ) - n_layers = 2 - autop = _setup_ds3_local_map_autop(device_mesh_2d, n_layers=n_layers) + autop, config = _setup_ds3_local_map_autop(device_mesh_2d) + n_layers = len(config.layers) stats, clusters = _run_clustering( autop, n_layers, @@ -272,52 +278,16 @@ def test_clustering_high_coverage(device_mesh_2d): label="DS3", n_layers=n_layers, min_coverage=0.75, - forward_layers=range(n_layers), - backward_layers=range(n_layers), - ) - - -def test_getitem_siblings_are_clustered(device_mesh_2d): - """Getitem nodes that project sibling outputs from tuple-returning ops - (e.g. SDPA returning (output, logsumexp, rng_state)) should be clustered - together with their producer when the producer is already clustered.""" - n_layers = 4 - autop, _ = _setup_llama_autop(device_mesh_2d, n_layers=n_layers) - with autop: - x_sharding = (Shard(0), Replicate()) - out_sharding = (Shard(0), Shard(2)) - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([out_sharding]) - - graph = autop.sharding_optimizer.graph - strats = autop.sharding_optimizer.strats - clusters = get_identical_regions(graph, strats) - - clustered_nodes = set() - for group in clusters: - for region in group: - clustered_nodes.update(region) - - # Find all getitem nodes that have strategies and belong to layers - unclustered_getitems = [] - for node in graph.nodes: - if node.target is not operator.getitem: - continue - if node not in strats: - continue - layer_idx = _get_layer_index(node) - if layer_idx is not None and node not in clustered_nodes: - unclustered_getitems.append(node) - - assert len(unclustered_getitems) == 0, ( - f"{len(unclustered_getitems)} getitem nodes in layers are not clustered: " - f"{[n.name for n in unclustered_getitems[:5]]}" + coverage_layers=range(1, n_layers), + forward_layers=range(1, n_layers), + backward_layers=range(1, n_layers), ) -def test_getitem_siblings_cluster_consistency(device_mesh_2d): - """Getitem siblings added by _extend_with_sibling_getitems should appear - in the same cluster group as their producer, with one per region.""" +def test_clustering_no_forward_backward_mixing(device_mesh_2d): + """Each cluster group's regions should contain only forward or only + backward nodes, never a mix. Expansion must not cross the phase boundary + by following saved-tensor edges from backward into forward.""" n_layers = 4 autop, _ = _setup_llama_autop(device_mesh_2d, n_layers=n_layers) with autop: @@ -326,27 +296,12 @@ def test_getitem_siblings_cluster_consistency(device_mesh_2d): autop.add_input_constraints([x_sharding]) autop.add_output_constraints([out_sharding]) - graph = autop.sharding_optimizer.graph - strats = autop.sharding_optimizer.strats - clusters = get_identical_regions(graph, strats) + clusters = get_identical_regions( + autop.sharding_optimizer.graph, autop.sharding_optimizer.strats + ) for i, group in enumerate(clusters): - num_regions = len(group) - # Collect all getitem nodes across all regions in this group - getitems_in_group = [] - for region in group: - for node in region: - if node.target is operator.getitem: - getitems_in_group.append(node) - - if not getitems_in_group: - continue - - # Each getitem index should appear exactly once per region - # Group getitems by their tuple index - by_index = Counter(n.args[1] for n in getitems_in_group) - for idx, count in by_index.items(): - assert count == num_regions, ( - f"Cluster group {i}: getitem[{idx}] appears {count} times " - f"but there are {num_regions} regions" - ) + for j, region in enumerate(group): + tags = set(n.meta.get("partitioner_tag") for n in region) + tags.discard(None) + assert len(tags) <= 1, f"Cluster group {i}, region {j} mixes phases: {tags}"