Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,12 @@ def inference_fn(args):
self._traced_inputs, solved_input_placements, self.mesh
)

def forward(self, *args):
def forward(self, *args, **kwargs):
# Flatten pytree args (e.g. dicts, nested structures) to tensor
# leaves, matching how Dynamo flattened the inputs during tracing.
flat_args, _ = torch.utils._pytree.tree_flatten(args)
if len(flat_args) != len(expected_inputs):
flat_args, _ = torch.utils._pytree.tree_flatten((args, kwargs))
_check_forward_args(flat_args, expected_inputs)
# NB: don't close over the parameters/buffers, as the user may
# reassign the module!
Expand Down
153 changes: 4 additions & 149 deletions autoparallel/graph_passes/graph_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import math
import time
from collections import defaultdict
from typing import Optional, cast
from typing import Optional

import torch
from torch._dynamo.graph_region_tracker import (
Expand Down Expand Up @@ -81,143 +81,6 @@ def _hash_node(node, strategies, input_pickler):
return sha256_hash(input_pickler.dumps(key))


def _extend_with_sibling_getitems(
region_groups: list[list[Region]],
node_to_duplicates: dict[Node, IdenticalNodes],
strategies: dict[Node, OpStrategy],
topological_ranking: dict[Node, int],
) -> set[Node]:
"""Extend region groups with unclaimed getitem siblings of clustered nodes.

The backward-BFS expansion only reaches getitem users that happen to be on
the main data path. Sibling tuple projections (e.g. logsumexp, RNG state
from SDPA) are left orphaned even though their producer is already aligned
across regions. This post-pass recovers them in two ways:

1. If a getitem's producer is already in a region, find matching unclaimed
getitems across all other regions and append them in-place.
2. If a getitem's producer is NOT in any region but its duplicate getitems
(from node_to_duplicates) ARE clustered, create a small bridge group
that links the orphan to a clustered sibling. This handles the case
where the BFS created N-1 regions out of N identical layers.

Returns the set of bridge root nodes — already-clustered nodes that are
reused as the root region of a bridge group and therefore intentionally
appear in two groups.
"""
claimed: set[Node] = set()
for region_group in region_groups:
for region in region_group:
claimed.update(region)

# Case 1: extend existing regions with unclaimed getitem siblings.
for region_group in region_groups:
root_region = region_group[0]
num_regions = len(region_group)

for pos in range(len(root_region)):
root_producer = root_region[pos]
getitems_by_idx: dict[int, list[Node]] = defaultdict(list)
for user in root_producer.users:
if (
user.target is operator.getitem
and user not in claimed
and user in strategies
):
getitems_by_idx[cast(int, user.args[1])].append(user)

for k, root_matches in getitems_by_idx.items():
if len(root_matches) != 1:
continue
root_getitem = root_matches[0]
if root_getitem not in node_to_duplicates:
continue
root_dups = node_to_duplicates[root_getitem]
root_phase = root_getitem.meta.get("partitioner_tag")

candidates = [root_getitem]
valid = True
for other_region in region_group[1:]:
other_producer = other_region[pos]
matches = [
user
for user in other_producer.users
if (
user.target is operator.getitem
and user.args[1] == k
and user not in claimed
and user in strategies
and user in node_to_duplicates
and node_to_duplicates[user] is root_dups
and user.meta.get("partitioner_tag") == root_phase
)
]
if len(matches) != 1:
valid = False
break
candidates.append(matches[0])

if valid and len(candidates) == num_regions:
for region, getitem_node in zip(region_group, candidates):
region.append(getitem_node)
claimed.add(getitem_node)

for region in region_group:
region.sort(key=lambda n: topological_ranking[n])

# Case 2: create bridge groups for orphaned getitems whose duplicates
# are already clustered. Each bridge group pairs one clustered sibling
# (as the root region) with the orphan, so create_cluster_links maps
# the orphan's decision variables to the root's.
bridge_roots: set[Node] = set()
seen_dup_groups: set[int] = set()
for node in strategies:
if node.target is not operator.getitem:
continue
if node in claimed:
continue
if node not in node_to_duplicates:
continue
dups = node_to_duplicates[node]
group_id = id(dups)
if group_id in seen_dup_groups:
continue
seen_dup_groups.add(group_id)

if len(dups) < 2:
continue
if not all(d in strategies for d in dups):
continue

# Find one claimed duplicate to serve as the root.
root = None
for d in dups:
if d in claimed:
root = d
break
if root is None:
continue

# Create a bridge group: [[root], [orphan1], [orphan2], ...]
bridge = [[root]]
for d in dups:
if d not in claimed:
bridge.append([d])
claimed.add(d)
if len(bridge) < 2:
continue
bridge.sort(key=lambda r: topological_ranking[r[0]])
# Ensure the root is first (create_cluster_links uses region[0]
# as the root).
root_idx = next(i for i, r in enumerate(bridge) if r[0] is root)
if root_idx != 0:
bridge[0], bridge[root_idx] = bridge[root_idx], bridge[0]
region_groups.append(bridge)
bridge_roots.add(root)

return bridge_roots


def get_identical_regions(
graph: torch.fx.Graph, strategies: dict[Node, OpStrategy]
) -> list[list[Region]]:
Expand Down Expand Up @@ -302,7 +165,6 @@ def _is_identical(n0: Node, n1: Node) -> bool:
# overlap.
t = time.time()
seen_nodes: set[Node] = set()
expanded_groups: list[list[Region]] = []
for region_group in region_groups:
# NOTE: this seems like it's missing in the original implementation
# from PyTorch. Given that fully_expand_region_group doesn't check
Expand All @@ -322,30 +184,23 @@ def _is_identical(n0: Node, n1: Node) -> bool:
# sort topologically
for region in region_group:
region.sort(key=lambda n: topological_ranking[n])
expanded_groups.append(region_group)

region_groups = [
region_group for region_group in expanded_groups if len(region_group[0]) > 1
region_group for region_group in region_groups if len(region_group[0]) > 1
]

bridge_roots = _extend_with_sibling_getitems(
region_groups, node_to_duplicates, strategies, topological_ranking
)

# sort everything so that we have nodes in topological ranking
for region_group in region_groups:
region_group.sort(key=lambda rg: topological_ranking[rg[0]])
region_groups.sort(key=lambda rg: topological_ranking[rg[0][0]])
logger.debug(f"Expanded regions in {time.time() - t} s")

# sanity check that we don't have duplicate nodes.
# Bridge roots are already-clustered nodes reused as root regions in
# bridge groups (case 2 above); they intentionally appear in two groups.
# sanity check that we don't have duplicate nodes
seen_nodes.clear()
for region_group in region_groups:
for region in region_group:
for node in region:
if node in seen_nodes and node not in bridge_roots:
if node in seen_nodes:
raise RuntimeError(f"Duplicate node {node} in region group")
seen_nodes.add(node)
return region_groups
13 changes: 7 additions & 6 deletions examples/example_ds3_local_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def _seed_dtensor_rng(rng_seed: Optional[int]) -> None:


def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
seq_len = 1024
# Match TorchTitan's DeepSeek V3 debug model shape. This example is a
# regression guard for placement/clustering issues that only appear at the
# larger debug shape used by TorchTitan GraphTrainer.
seq_len = 2048
if fake_evaluate:
world_size = 256

Expand Down Expand Up @@ -66,11 +69,9 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
mesh_dim_names=("dp", "ep"),
)

config = make_dsv3_config(
num_experts=4, top_k=2, n_layers=4, n_dense_layers=0, max_seq_len=seq_len
)
config = make_dsv3_config(max_seq_len=seq_len)

local_batch_size = 2
local_batch_size = 8
global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1]

with torch.device("meta"):
Expand Down Expand Up @@ -140,7 +141,7 @@ def input_fn():
out.backward(torch.ones_like(out))
else:
for i, x in enumerate(microbatches):
assert x.shape[0] == 2
assert x.shape[0] == local_batch_size
out = parallel_mod(x)
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
out.backward(torch.ones_like(out))
Expand Down
Loading
Loading