From 68e65134b66679e5b47ce9f817b298bfd73af486 Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Thu, 7 May 2026 18:59:39 -0700 Subject: [PATCH] Revert getitem sibling clustering and support kwargs inputs Revert the graph clustering behavior introduced by 51f2f67 because the bridge-group recovery can reuse an already cluster-linked node as the root of another cluster group. The sharding optimizer treats cluster-linked nodes as not owning PuLP variables, so later flow constraints can resolve a linked key through cluster_links to a root key that was never materialized in pulp_variables. DeepSeek V3 placement then fails before solving with KeyError on the resolved cluster key. Keep the later DeepSeek V3 clustering coverage, but remove the getitem-specific expectations that depended on the reverted bridge-group behavior. Make both the DS3 local_map example and DS3 clustering coverage use the TorchTitan debug shape, so a future reintroduction of getitem sibling recovery has to handle the graph that exposed the bug. Also allow AutoParallel generated forward wrappers to accept kwargs by flattening (args, kwargs) when positional flattening does not match the traced input arity. TorchTitan GraphTrainer can pass model inputs by keyword, so the generated wrapper needs to preserve that call shape while keeping the positional-only path unchanged. stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/451, branch: sanketpurandare/stack/7 --- autoparallel/api.py | 4 +- autoparallel/graph_passes/graph_clustering.py | 153 +----------------- examples/example_ds3_local_map.py | 13 +- tests/test_graph_clustering.py | 105 ++++-------- 4 files changed, 44 insertions(+), 231 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 8f2eea55..932dc5f2 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -556,10 +556,12 @@ def inference_fn(args): self._traced_inputs, solved_input_placements, self.mesh ) - def forward(self, *args): + def forward(self, *args, **kwargs): # Flatten pytree args (e.g. dicts, nested structures) to tensor # leaves, matching how Dynamo flattened the inputs during tracing. flat_args, _ = torch.utils._pytree.tree_flatten(args) + if len(flat_args) != len(expected_inputs): + flat_args, _ = torch.utils._pytree.tree_flatten((args, kwargs)) _check_forward_args(flat_args, expected_inputs) # NB: don't close over the parameters/buffers, as the user may # reassign the module! diff --git a/autoparallel/graph_passes/graph_clustering.py b/autoparallel/graph_passes/graph_clustering.py index 4503e611..c01a09a3 100644 --- a/autoparallel/graph_passes/graph_clustering.py +++ b/autoparallel/graph_passes/graph_clustering.py @@ -11,7 +11,7 @@ import math import time from collections import defaultdict -from typing import Optional, cast +from typing import Optional import torch from torch._dynamo.graph_region_tracker import ( @@ -81,143 +81,6 @@ def _hash_node(node, strategies, input_pickler): return sha256_hash(input_pickler.dumps(key)) -def _extend_with_sibling_getitems( - region_groups: list[list[Region]], - node_to_duplicates: dict[Node, IdenticalNodes], - strategies: dict[Node, OpStrategy], - topological_ranking: dict[Node, int], -) -> set[Node]: - """Extend region groups with unclaimed getitem siblings of clustered nodes. - - The backward-BFS expansion only reaches getitem users that happen to be on - the main data path. Sibling tuple projections (e.g. logsumexp, RNG state - from SDPA) are left orphaned even though their producer is already aligned - across regions. This post-pass recovers them in two ways: - - 1. If a getitem's producer is already in a region, find matching unclaimed - getitems across all other regions and append them in-place. - 2. If a getitem's producer is NOT in any region but its duplicate getitems - (from node_to_duplicates) ARE clustered, create a small bridge group - that links the orphan to a clustered sibling. This handles the case - where the BFS created N-1 regions out of N identical layers. - - Returns the set of bridge root nodes — already-clustered nodes that are - reused as the root region of a bridge group and therefore intentionally - appear in two groups. - """ - claimed: set[Node] = set() - for region_group in region_groups: - for region in region_group: - claimed.update(region) - - # Case 1: extend existing regions with unclaimed getitem siblings. - for region_group in region_groups: - root_region = region_group[0] - num_regions = len(region_group) - - for pos in range(len(root_region)): - root_producer = root_region[pos] - getitems_by_idx: dict[int, list[Node]] = defaultdict(list) - for user in root_producer.users: - if ( - user.target is operator.getitem - and user not in claimed - and user in strategies - ): - getitems_by_idx[cast(int, user.args[1])].append(user) - - for k, root_matches in getitems_by_idx.items(): - if len(root_matches) != 1: - continue - root_getitem = root_matches[0] - if root_getitem not in node_to_duplicates: - continue - root_dups = node_to_duplicates[root_getitem] - root_phase = root_getitem.meta.get("partitioner_tag") - - candidates = [root_getitem] - valid = True - for other_region in region_group[1:]: - other_producer = other_region[pos] - matches = [ - user - for user in other_producer.users - if ( - user.target is operator.getitem - and user.args[1] == k - and user not in claimed - and user in strategies - and user in node_to_duplicates - and node_to_duplicates[user] is root_dups - and user.meta.get("partitioner_tag") == root_phase - ) - ] - if len(matches) != 1: - valid = False - break - candidates.append(matches[0]) - - if valid and len(candidates) == num_regions: - for region, getitem_node in zip(region_group, candidates): - region.append(getitem_node) - claimed.add(getitem_node) - - for region in region_group: - region.sort(key=lambda n: topological_ranking[n]) - - # Case 2: create bridge groups for orphaned getitems whose duplicates - # are already clustered. Each bridge group pairs one clustered sibling - # (as the root region) with the orphan, so create_cluster_links maps - # the orphan's decision variables to the root's. - bridge_roots: set[Node] = set() - seen_dup_groups: set[int] = set() - for node in strategies: - if node.target is not operator.getitem: - continue - if node in claimed: - continue - if node not in node_to_duplicates: - continue - dups = node_to_duplicates[node] - group_id = id(dups) - if group_id in seen_dup_groups: - continue - seen_dup_groups.add(group_id) - - if len(dups) < 2: - continue - if not all(d in strategies for d in dups): - continue - - # Find one claimed duplicate to serve as the root. - root = None - for d in dups: - if d in claimed: - root = d - break - if root is None: - continue - - # Create a bridge group: [[root], [orphan1], [orphan2], ...] - bridge = [[root]] - for d in dups: - if d not in claimed: - bridge.append([d]) - claimed.add(d) - if len(bridge) < 2: - continue - bridge.sort(key=lambda r: topological_ranking[r[0]]) - # Ensure the root is first (create_cluster_links uses region[0] - # as the root). - root_idx = next(i for i, r in enumerate(bridge) if r[0] is root) - if root_idx != 0: - bridge[0], bridge[root_idx] = bridge[root_idx], bridge[0] - region_groups.append(bridge) - bridge_roots.add(root) - - return bridge_roots - - def get_identical_regions( graph: torch.fx.Graph, strategies: dict[Node, OpStrategy] ) -> list[list[Region]]: @@ -302,7 +165,6 @@ def _is_identical(n0: Node, n1: Node) -> bool: # overlap. t = time.time() seen_nodes: set[Node] = set() - expanded_groups: list[list[Region]] = [] for region_group in region_groups: # NOTE: this seems like it's missing in the original implementation # from PyTorch. Given that fully_expand_region_group doesn't check @@ -322,30 +184,23 @@ def _is_identical(n0: Node, n1: Node) -> bool: # sort topologically for region in region_group: region.sort(key=lambda n: topological_ranking[n]) - expanded_groups.append(region_group) region_groups = [ - region_group for region_group in expanded_groups if len(region_group[0]) > 1 + region_group for region_group in region_groups if len(region_group[0]) > 1 ] - bridge_roots = _extend_with_sibling_getitems( - region_groups, node_to_duplicates, strategies, topological_ranking - ) - # sort everything so that we have nodes in topological ranking for region_group in region_groups: region_group.sort(key=lambda rg: topological_ranking[rg[0]]) region_groups.sort(key=lambda rg: topological_ranking[rg[0][0]]) logger.debug(f"Expanded regions in {time.time() - t} s") - # sanity check that we don't have duplicate nodes. - # Bridge roots are already-clustered nodes reused as root regions in - # bridge groups (case 2 above); they intentionally appear in two groups. + # sanity check that we don't have duplicate nodes seen_nodes.clear() for region_group in region_groups: for region in region_group: for node in region: - if node in seen_nodes and node not in bridge_roots: + if node in seen_nodes: raise RuntimeError(f"Duplicate node {node} in region group") seen_nodes.add(node) return region_groups diff --git a/examples/example_ds3_local_map.py b/examples/example_ds3_local_map.py index 106ce80b..1695de3f 100644 --- a/examples/example_ds3_local_map.py +++ b/examples/example_ds3_local_map.py @@ -25,7 +25,10 @@ def _seed_dtensor_rng(rng_seed: Optional[int]) -> None: def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): - seq_len = 1024 + # Match TorchTitan's DeepSeek V3 debug model shape. This example is a + # regression guard for placement/clustering issues that only appear at the + # larger debug shape used by TorchTitan GraphTrainer. + seq_len = 2048 if fake_evaluate: world_size = 256 @@ -66,11 +69,9 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): mesh_dim_names=("dp", "ep"), ) - config = make_dsv3_config( - num_experts=4, top_k=2, n_layers=4, n_dense_layers=0, max_seq_len=seq_len - ) + config = make_dsv3_config(max_seq_len=seq_len) - local_batch_size = 2 + local_batch_size = 8 global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1] with torch.device("meta"): @@ -140,7 +141,7 @@ def input_fn(): out.backward(torch.ones_like(out)) else: for i, x in enumerate(microbatches): - assert x.shape[0] == 2 + assert x.shape[0] == local_batch_size out = parallel_mod(x) assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" out.backward(torch.ones_like(out)) diff --git a/tests/test_graph_clustering.py b/tests/test_graph_clustering.py index 3a4c9d28..775a03e2 100644 --- a/tests/test_graph_clustering.py +++ b/tests/test_graph_clustering.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -import operator import re from collections import Counter @@ -67,9 +66,10 @@ def _clustering_stats(graph, strats, clusters, n_layers): } -def _assert_layer_coverage(stats, n_layers, min_coverage, label): +def _assert_layer_coverage(stats, n_layers, min_coverage, label, layers=None): """Assert each repeated layer has enough nodes in clustered regions.""" - layer_totals = [stats["per_layer_total"].get(i, 0) for i in range(n_layers)] + layers = range(n_layers) if layers is None else layers + layer_totals = [stats["per_layer_total"].get(i, 0) for i in layers] if len(set(layer_totals)) != 1: raise AssertionError( f"{label}: layers have different node counts: {layer_totals}" @@ -78,7 +78,7 @@ def _assert_layer_coverage(stats, n_layers, min_coverage, label): total = layer_totals[0] if total == 0: raise AssertionError(f"{label}: no layer nodes found in clustering stats") - for layer_idx in range(n_layers): + for layer_idx in layers: clustered = stats["per_layer_clustered"].get(layer_idx, 0) coverage = clustered / total if coverage < min_coverage: @@ -152,6 +152,7 @@ def _assert_model_clustering( min_coverage, forward_layers, backward_layers, + coverage_layers=None, min_region_size=100, ): """Assert coverage, phase separation, and large fwd/bwd layer clusters.""" @@ -160,6 +161,7 @@ def _assert_model_clustering( n_layers, min_coverage=min_coverage, label=label, + layers=coverage_layers, ) _assert_cross_layer_cluster( clusters, @@ -209,9 +211,13 @@ def input_fn(): return autop, model_args -def _setup_ds3_local_map_autop(device_mesh_2d, n_layers=2): - global_batch_size = 2 * device_mesh_2d.shape[0] * device_mesh_2d.shape[1] - config = make_dsv3_config(n_layers=n_layers, n_dense_layers=0) +def _setup_ds3_local_map_autop(device_mesh_2d): + local_batch_size = 8 + seq_len = 2048 + global_batch_size = ( + local_batch_size * device_mesh_2d.shape[0] * device_mesh_2d.shape[1] + ) + config = make_dsv3_config(max_seq_len=seq_len) with torch.device("meta"): model = DeepSeekV3Model( config, @@ -230,7 +236,7 @@ def input_fn(): device="cuda", ) - return AutoParallel(model, input_fn, device_mesh_2d, dynamic=True) + return AutoParallel(model, input_fn, device_mesh_2d, dynamic=True), config def test_clustering_high_coverage(device_mesh_2d): @@ -259,8 +265,8 @@ def test_clustering_high_coverage(device_mesh_2d): backward_layers=range(1, n_layers), ) - n_layers = 2 - autop = _setup_ds3_local_map_autop(device_mesh_2d, n_layers=n_layers) + autop, config = _setup_ds3_local_map_autop(device_mesh_2d) + n_layers = len(config.layers) stats, clusters = _run_clustering( autop, n_layers, @@ -272,52 +278,16 @@ def test_clustering_high_coverage(device_mesh_2d): label="DS3", n_layers=n_layers, min_coverage=0.75, - forward_layers=range(n_layers), - backward_layers=range(n_layers), - ) - - -def test_getitem_siblings_are_clustered(device_mesh_2d): - """Getitem nodes that project sibling outputs from tuple-returning ops - (e.g. SDPA returning (output, logsumexp, rng_state)) should be clustered - together with their producer when the producer is already clustered.""" - n_layers = 4 - autop, _ = _setup_llama_autop(device_mesh_2d, n_layers=n_layers) - with autop: - x_sharding = (Shard(0), Replicate()) - out_sharding = (Shard(0), Shard(2)) - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([out_sharding]) - - graph = autop.sharding_optimizer.graph - strats = autop.sharding_optimizer.strats - clusters = get_identical_regions(graph, strats) - - clustered_nodes = set() - for group in clusters: - for region in group: - clustered_nodes.update(region) - - # Find all getitem nodes that have strategies and belong to layers - unclustered_getitems = [] - for node in graph.nodes: - if node.target is not operator.getitem: - continue - if node not in strats: - continue - layer_idx = _get_layer_index(node) - if layer_idx is not None and node not in clustered_nodes: - unclustered_getitems.append(node) - - assert len(unclustered_getitems) == 0, ( - f"{len(unclustered_getitems)} getitem nodes in layers are not clustered: " - f"{[n.name for n in unclustered_getitems[:5]]}" + coverage_layers=range(1, n_layers), + forward_layers=range(1, n_layers), + backward_layers=range(1, n_layers), ) -def test_getitem_siblings_cluster_consistency(device_mesh_2d): - """Getitem siblings added by _extend_with_sibling_getitems should appear - in the same cluster group as their producer, with one per region.""" +def test_clustering_no_forward_backward_mixing(device_mesh_2d): + """Each cluster group's regions should contain only forward or only + backward nodes, never a mix. Expansion must not cross the phase boundary + by following saved-tensor edges from backward into forward.""" n_layers = 4 autop, _ = _setup_llama_autop(device_mesh_2d, n_layers=n_layers) with autop: @@ -326,27 +296,12 @@ def test_getitem_siblings_cluster_consistency(device_mesh_2d): autop.add_input_constraints([x_sharding]) autop.add_output_constraints([out_sharding]) - graph = autop.sharding_optimizer.graph - strats = autop.sharding_optimizer.strats - clusters = get_identical_regions(graph, strats) + clusters = get_identical_regions( + autop.sharding_optimizer.graph, autop.sharding_optimizer.strats + ) for i, group in enumerate(clusters): - num_regions = len(group) - # Collect all getitem nodes across all regions in this group - getitems_in_group = [] - for region in group: - for node in region: - if node.target is operator.getitem: - getitems_in_group.append(node) - - if not getitems_in_group: - continue - - # Each getitem index should appear exactly once per region - # Group getitems by their tuple index - by_index = Counter(n.args[1] for n in getitems_in_group) - for idx, count in by_index.items(): - assert count == num_regions, ( - f"Cluster group {i}: getitem[{idx}] appears {count} times " - f"but there are {num_regions} regions" - ) + for j, region in enumerate(group): + tags = set(n.meta.get("partitioner_tag") for n in region) + tags.discard(None) + assert len(tags) <= 1, f"Cluster group {i}, region {j} mixes phases: {tags}"