From 9871a486a2bf8dc6a9711d9e08cc3c690e2c6dd2 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 2 Jun 2026 14:32:53 +0000 Subject: [PATCH 1/2] Improve FSDP bucketing and cap compute batches between ReduceScatters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ports three related changes to `auto_bucketing.py` from `fmassa/double_recomp` (cherry of 428e9d2). Together they address an interaction between FSDP all-gather bucketing and `stable_topological_sort` that batches recomputation MMs upfront and inflates backward peak memory. **1. `_patch_fsdp_bucketing()` — monkey-patches PyTorch's FSDP bucketing** - *Primary-group-only*: `identify_fsdp_groups` keeps only the group with the most FSDP all-gathers, so minority groups (1 tp AG, ~65 norm AGs on the combined group) no longer pollute the bucketing pool. - *Non-adjacent bucketing*: `greedy_bucket_collective_by_mb` collects all eligible collectives per group key instead of requiring graph adjacency, allowing dp AGs interleaved with tp activation collectives to bucket together. **2. `max_compute_pre_fetch`: 5 → 50** The previous value allowed only ~0.3 layers of prefetch (≈17 compute nodes/layer), insufficient to hide 5–7ms full-mesh AGs. 50 gives ≈3 layers of headroom. **3. `_cap_compute_batch_size(max_consecutive=8)`** After bucketing rewires dependencies, `stable_topological_sort` reorders 525/540 compute ops into an `MM*40` block before the first backward RS, blowing up peak memory. Snapshots the original compute/RS interleaving before scheduling, then for any post-schedule segment with >8 compute nodes between RS ops, chains chunks and pulls forward an RS that originally sat between them. Falls back gracefully if a cycle is detected. **Why this matters (from prior investigation, LLaMA-3 8B, 128 H100s, dp=16/tp=8):** - Bucketing patches + prefetch bump closed the unconstrained AP-to-reference gap from 12.1% → 4.6% (385ms vs 368ms reference). Prefetch alone: −20ms (−4.6%) unconstrained. - `_cap_compute_batch_size` brings constrained from a runaway `MM*40` recomp batch down to **358ms / 16.78 GB** (vs reference 339ms / 5.97 GB), trading +5.9ms latency for −4.3 GB peak memory vs the uncapped variant. A more aggressive `_restore_compute_order` was rejected — it killed 62.8ms of overlap for only 1.3 GB extra savings. Authored with Claude. --- autoparallel/graph_passes/auto_bucketing.py | 268 +++++++++++++++++++- 1 file changed, 267 insertions(+), 1 deletion(-) diff --git a/autoparallel/graph_passes/auto_bucketing.py b/autoparallel/graph_passes/auto_bucketing.py index fa01ff70..6aafcb70 100644 --- a/autoparallel/graph_passes/auto_bucketing.py +++ b/autoparallel/graph_passes/auto_bucketing.py @@ -4,16 +4,256 @@ # LICENSE file in the root directory of this source tree. import logging +from collections import Counter, defaultdict from functools import partial import torch from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing +from torch.utils._ordered_set import OrderedSet from .autobucketing_inductor import bucket_func, bucket_plan, bucket_utils, reorder logger = logging.getLogger(__name__) +def _patch_fsdp_bucketing(): + """Patch PyTorch's FSDP bucketing for better multi-group handling. + + Two fixes: + 1. Primary-group-only: only include the group with the most FSDP + all-gathers in fsdp_groups, preventing minority groups (tp, combined) + from limiting dp bucketing aggressiveness. + 2. Non-adjacent bucketing: allow collectives to be bucketed even when + interleaved with non-FSDP collectives on other groups. Only + descendant conflicts prevent bucketing, not graph position. + """ + import torch._inductor.fx_passes.bucketing as bucketing_mod + import torch._inductor.fx_passes.fsdp as fsdp_mod + from torch._inductor.fx_passes.bucketing import ( + collect_node_descendants, + is_wait_tensor, + ) + from torch._inductor.fx_passes.fsdp import ( + _find_all_gathers, + _get_group_name, + _get_group_size_from_node, + is_fsdp_all_gather, + ) + + def _patched_identify_fsdp_groups(gm): + fsdp_counts_by_group = Counter() + group_size = None + for n in _find_all_gathers(gm.graph): + if is_fsdp_all_gather(n): + gn = _get_group_name(n) + fsdp_counts_by_group[gn] += 1 + if group_size is None: + group_size = _get_group_size_from_node(n) + + if fsdp_counts_by_group: + primary_group = fsdp_counts_by_group.most_common(1)[0][0] + fsdp_groups = OrderedSet([primary_group]) + else: + fsdp_groups = OrderedSet() + + logger.debug( + "identify_fsdp_groups (patched): fsdp_groups=%s, all_counts=%s", + list(fsdp_groups), + dict(fsdp_counts_by_group), + ) + return fsdp_groups, group_size + + def _patched_greedy_bucket( + gm, + bucket_cap_mb_by_bucket_idx, + filter_node, + node_group_key, + filter_wait_node=None, + ): + g = gm.graph + groups = defaultdict(list) + for node in g.nodes: + if is_wait_tensor(node) and filter_node(node.args[0]): + if (filter_wait_node is None) or filter_wait_node(node): + coll_node = node.args[0] + key = node_group_key(coll_node) + groups[key].append(coll_node) + + if not groups: + return [] + + node_descendents = collect_node_descendants(g) + + buckets = [] + for key, nodes in groups.items(): + cur_bucket = [] + cur_bucket_descendents = OrderedSet() + cur_bucket_size_bytes = 0 + cur_bucket_id = 0 + bucket_size_bytes = int( + bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024 + ) + for node in nodes: + if node in cur_bucket_descendents: + continue + n_val = node.meta["val"] + out_size_bytes = n_val.numel() * n_val.element_size() + n_input_val = node.all_input_nodes[0].meta["val"] + in_size_bytes = n_input_val.numel() * n_input_val.element_size() + size_bytes = max(out_size_bytes, in_size_bytes) + if ( + cur_bucket_size_bytes + size_bytes > bucket_size_bytes + and cur_bucket + ): + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + cur_bucket_descendents = OrderedSet() + bucket_size_bytes = int( + bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024 + ) + cur_bucket_size_bytes += size_bytes + cur_bucket.append(node) + cur_bucket_descendents |= node_descendents[node] + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + return buckets + + fsdp_mod.identify_fsdp_groups = _patched_identify_fsdp_groups + bucketing_mod.greedy_bucket_collective_by_mb = _patched_greedy_bucket + + +_patch_fsdp_bucketing() + + +def _cap_compute_batch_size( + graph, original_compute_names, original_rs_after_compute, max_consecutive=8 +): + """Break up long compute segments between ReduceScatter operations. + + After overlap scheduling + FSDP bucketing, compute nodes may be reordered + so that many layers' matmuls execute before any ReduceScatter fires. This + inflates peak memory because all layers' activations are alive + simultaneously. + + Instead of restoring the full original order (which kills comm/compute + overlap), this function only intervenes when the number of compute nodes + between consecutive ReduceScatter ops exceeds max_consecutive. For each + oversized segment, it sorts compute nodes by original order, splits into + chunks, then: + 1. Chains consecutive compute nodes within each chunk (so they move + together during topological sort). + 2. Adds a dep from the first compute node of each chunk to an RS node + that originally appeared between the previous chunk and this one. + + Args: + graph: The post-scheduled FX graph. + original_compute_names: List of compute node names in original order. + original_rs_after_compute: Dict mapping compute node name to the list + of RS node names that appeared between it and the next compute node + in the original (pre-scheduling) graph. + max_consecutive: Maximum compute nodes allowed between RS ops. + """ + from torch._dynamo.graph_deduplication import _stable_topological_sort + from torch._inductor.fx_passes.overlap_scheduling import is_compute_node + + def _is_rs(node): + if node.op != "call_function": + return False + name = str(node.target) + return "reduce_scatter" in name and "wait" not in name + + original_rank = {name: rank for rank, name in enumerate(original_compute_names)} + node_by_name = {n.name: n for n in graph.nodes} + + scheduled_rs_names = { + n.name for n in graph.nodes if n.op == "call_function" and _is_rs(n) + } + + # Collect compute nodes between RS ops in the post-scheduled graph. + segments: list[tuple[list[torch.fx.Node], torch.fx.Node | None]] = [] + current_compute: list[torch.fx.Node] = [] + for node in graph.nodes: + if node.op != "call_function": + continue + if is_compute_node(node): + current_compute.append(node) + elif _is_rs(node): + segments.append((current_compute, node)) + current_compute = [] + if current_compute: + segments.append((current_compute, None)) + + additional_deps: dict[torch.fx.Node, OrderedSet] = defaultdict(OrderedSet) + + for compute_nodes, _seg_rs_node in segments: + if len(compute_nodes) <= max_consecutive: + continue + + sorted_nodes = sorted( + compute_nodes, + key=lambda n: original_rank.get(n.name, float("inf")), + ) + + # Split into chunks of max_consecutive. + chunks = [] + for i in range(0, len(sorted_nodes), max_consecutive): + chunks.append(sorted_nodes[i : i + max_consecutive]) + + # Chain consecutive compute nodes within each chunk. + for chunk in chunks: + for j in range(1, len(chunk)): + additional_deps[chunk[j]].add(chunk[j - 1]) + + # At each chunk boundary, find an RS that originally appeared between + # the last compute of the previous chunk and the first compute of + # the current chunk, then add dep: first_of_chunk after RS. + for ci in range(1, len(chunks)): + prev_chunk = chunks[ci - 1] + curr_chunk = chunks[ci] + + last_in_prev = prev_chunk[-1] + first_in_curr = curr_chunk[0] + last_rank = original_rank.get(last_in_prev.name, -1) + first_rank = original_rank.get( + first_in_curr.name, len(original_compute_names) + ) + + found_rs_node = None + for r in range(last_rank, first_rank): + cname = ( + original_compute_names[r] + if r < len(original_compute_names) + else None + ) + if cname is None: + continue + for rs_name in original_rs_after_compute.get(cname, []): + if rs_name in scheduled_rs_names: + rs_obj = node_by_name.get(rs_name) + if rs_obj is not None: + found_rs_node = rs_obj + break + if found_rs_node is not None: + break + + if found_rs_node is not None: + additional_deps[first_in_curr].add(found_rs_node) + + if not additional_deps: + return + + try: + _stable_topological_sort(graph, additional_deps) + except AssertionError: + logger.warning( + "Failed to cap compute batch size (cycle detected), " + "falling back to uncapped ordering" + ) + + class simplefsdp_autobucketing_config: """ Config for simplefsdp's autobucketing pass, which by default would give good performance. @@ -107,7 +347,7 @@ class aten_autobucketing_config: compute_overlap_multipler = 1.0 max_coll_distance = 100 custom_runtime_estimation = None - max_compute_pre_fetch = 5 + max_compute_pre_fetch = 50 collective_bucketing = False save_trace = True _counter = 0 @@ -117,6 +357,28 @@ def aten_autobucketing_reordering_pass( gm: torch.fx.Graph, configs: "aten_autobucketing_config" ) -> torch.fx.GraphModule: assert gm.owning_module is not None + + # Record compute + RS interleaving before bucketing + overlap scheduling. + from torch._inductor.fx_passes.overlap_scheduling import is_compute_node + + def _is_rs_node(node): + if node.op != "call_function": + return False + name = str(node.target) + return "reduce_scatter" in name and "wait" not in name + + original_compute_names = [] + original_rs_after_compute: dict[str, list[str]] = {} + last_compute_name = None + for n in gm.owning_module.graph.nodes: + if n.op != "call_function": + continue + if is_compute_node(n): + original_compute_names.append(n.name) + last_compute_name = n.name + elif _is_rs_node(n) and last_compute_name is not None: + original_rs_after_compute.setdefault(last_compute_name, []).append(n.name) + new_gm = schedule_overlap_bucketing( gm.owning_module, collective_bucketing=configs.collective_bucketing, @@ -126,6 +388,10 @@ def aten_autobucketing_reordering_pass( max_in_flight_gb=configs.max_in_flight_gb, max_coll_distance=configs.max_coll_distance, ) + + _cap_compute_batch_size( + new_gm.graph, original_compute_names, original_rs_after_compute + ) new_gm.recompile() if configs.save_trace: From bebfbfb5af9c4b48f57e91bb64f20d34e8cbf9f3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 8 Jun 2026 15:53:49 +0000 Subject: [PATCH 2/2] Replace cap-pass remediation with topo-span-bounded bucketing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary Replaces the `_cap_compute_batch_size` post-hoc remediator with a topo-span close-bucket condition inside the patched FSDP bucketing. Same load-bearing knob (`bucket_cap_mb`) gains a companion (`max_topo_span`) that bounds the dependency-footprint problem the cap was trying to undo. The bucketing-time prevention has shipped at `max_topo_span=1500` on `fmassa/double_recomp` and matches the prior cap-based path on min latency while substantially improving variance. Motivation The prior cherry (`9871a48`) combined two opposing patches: non-adjacent bucketing (which widens bucket dependency footprints to improve NCCL throughput) and `_cap_compute_batch_size` (which reorders compute post-hoc to undo the resulting MM*N batching). The cap relied on pre-bucket reduce_scatter names surviving bucketing — a property that doesn't always hold when bucketing renames RS ops — so the cap silently degrades to chain-deps-only and doesn't bound `max_consec_compute_between_RS`. Bounding bucket topo-span at the source addresses the same MM*N harm without a downstream remediator, surfaces one tunable knob users can reason about, and degenerates to the prior cherry's behavior when `max_topo_span=None`. What changed `autoparallel/graph_passes/auto_bucketing.py`, −56 net LOC: - **Kept** `_patched_identify_fsdp_groups` (primary-group-only filter) and `max_compute_pre_fetch=50`. - **Replaced** the non-adjacent `_patched_greedy_bucket` with a version that adds `close_for_span` as a third close-bucket condition alongside `close_for_bytes`. Snapshots ranks at function entry, closes a bucket when `current_rank - bucket_start_rank > max_topo_span`. - **Added** `aten_autobucketing_config.max_topo_span: int | None = 1500`. Set to `None` to restore the prior bytes-only (cherry) behavior with no codepath difference. - **Added** INFO-level metrics inside `_patched_greedy_bucket`: per-invocation `(num_buckets, max_observed_span, n_close_bytes, n_close_span, max_topo_span)`. The cap's silent-failure mode is no longer possible. - **Added** `_max_consec_compute_between_rs(graph)` helper, logged at INFO after `aten_autobucketing_reordering_pass` as a regression metric matching what the old cap enforced. - **Removed** `_cap_compute_batch_size` (~120 LOC) and the pre/post-bucket name-snapshotting that fed it. Validation `tests/test_auto_bucketing_patches.py` (8 tests, ~3 s, no GPU): - `identify_fsdp_groups`: primary group wins on imbalanced counts; empty graph returns empty; ties pick exactly one. - `greedy_bucket`: merges within caps; splits on bytes; splits on span when bytes are under cap; `max_topo_span=None` disables span; descendant collectives never co-bucket. Mutation-verified — reverting either patch causes the corresponding test to fail with an informative message. Real benchmark results (LLaMA-3 8B, 32 layers, 128 H100s, seqlen=8192, global batch=32) `max_topo_span=1500`, job 6513650, vs the most-stable prior cap-based run (6477238): **Unconstrained** | Run | min ms | avg ms | max ms | alloc GiB | rsrvd GiB | MFU | | -------------------------- | ------- | --------- | --------- | --------- | --------- | --------- | | New (`max_topo_span=1500`) | 379.3 | **389.6** | **436.6** | 7.69 | **8.06** | 492.5 % | | Prior `_cap` (cleanest) | 378.3 | 384.2 | 412.0 | 7.55 | 8.20 | 499.4 % | | Prior `_cap` (typical) | 370–377 | 424–569 | 1.8–6.0 s | 7.55 | 8.20 | 337–452 % | | TorchTitan reference | 366.2 | 414.1 | 1742.8 | 8.62 | 9.43 | 463.4 % | - **Min latency**: 379.3 ms, ~+1–9 ms over best prior runs; within the prior-run min spread. - **Avg latency**: 389.6 ms — second-best across all runs (only the lucky 6477238 beats it at 384.2 ms); typical prior cap-based avg was 424–569 ms. - **Max latency**: 436.6 ms — **best of all benchmarked runs**; every other prior run except 6477238 had >1.8 s tails. - **Variance**: avg − min collapses from 5–200 ms range to 10 ms. The cap's silent-failure failure mode (which produced occasional multi-second tails) appears to be the source of prior variance; bucketing-time prevention is much more deterministic. - **Memory**: +0.14 GiB alloc vs cap-based prior; rsrvd memory is the best of all runs. **Constrained** | Run | min ms | avg ms | max ms | alloc GiB | rsrvd GiB | | -------------------------- | ----------- | ------- | ----------- | --------- | --------- | | New (`max_topo_span=1500`) | 443.6 | 484.3 | 662.9 | 7.11 | **7.51** | | Prior `_cap` (range) | 439.8–445.9 | 447–628 | 0.98–5.58 s | 7.12 | 7.49–7.59 | - Min within prior noise band; avg/max middle of prior variance; **rsrvd memory is the best of all runs**. Telemetry from the shipped run The `max_topo_span=1500` knob is **dormant in practice** at this configuration: - Fwd: 27 buckets, max span 619, closures `(bytes=26, span=0)`, `max_consec_compute_between_rs=8` - Bwd group A: 43 buckets, max span 1325, closures `(bytes=40, span=1)` - Bwd group B: 64 buckets, max span 561, closures `(bytes=62, span=0)` - Bwd `max_consec_compute_between_rs=12` (vs 8 under the old explicit cap — 50% looser) The span gate fires at most once across all bucketing invocations; actual behavior is byte-cap-only plus primary-group filter. The 4-extra MMs concentrated before each RS may explain the ~9 ms unconstrained min-latency gap vs the best prior run — testing `max_topo_span=1000` is in progress to validate whether tightening the gate closes that gap without reintroducing variance. What this gives up vs the cap-based approach - Min latency: within ±9 ms of best prior min, plausibly recoverable by tuning `max_topo_span` downward. - Memory: rsrvd best-of-all-runs in both modes; alloc within 2% of prior best. What this gains - Avg/max latency stability: variance collapses dramatically vs the prior cap-based path (typical prior tails of 1.8–6.0 s disappear in unconstrained; 0.98–5.58 s tails in constrained reduce to 0.66 s). - One declarative knob (`max_topo_span`) replaces two coupled patches plus a downstream remediator with a known name-matching brittleness. - Loud metrics: `n_close_bytes` / `n_close_span` / `max_consec_compute_between_rs` are logged at INFO; the prior cap's silent-failure mode is no longer possible. - ~56 net LOC removed. Authored with Claude. --- autoparallel/graph_passes/auto_bucketing.py | 239 +++++++---------- tests/test_auto_bucketing_patches.py | 280 ++++++++++++++++++++ 2 files changed, 371 insertions(+), 148 deletions(-) create mode 100644 tests/test_auto_bucketing_patches.py diff --git a/autoparallel/graph_passes/auto_bucketing.py b/autoparallel/graph_passes/auto_bucketing.py index 6aafcb70..42562a38 100644 --- a/autoparallel/graph_passes/auto_bucketing.py +++ b/autoparallel/graph_passes/auto_bucketing.py @@ -23,9 +23,16 @@ def _patch_fsdp_bucketing(): 1. Primary-group-only: only include the group with the most FSDP all-gathers in fsdp_groups, preventing minority groups (tp, combined) from limiting dp bucketing aggressiveness. - 2. Non-adjacent bucketing: allow collectives to be bucketed even when - interleaved with non-FSDP collectives on other groups. Only - descendant conflicts prevent bucketing, not graph position. + 2. Topo-span-bounded bucketing: allow collectives to be bucketed even + when interleaved with non-FSDP collectives on other groups, but + close the bucket once its topo-span (rank of latest member minus + rank of first member) exceeds aten_autobucketing_config.max_topo_span. + + Without a span bound, merging collectives that are far apart in the + graph rewires the dependency graph so that stable_topological_sort + can pull compute from late layers forward, batching many MMs before + any RS fires (the MM*N problem). The span bound caps how far compute + can be displaced by any one bucket merge. """ import torch._inductor.fx_passes.bucketing as bucketing_mod import torch._inductor.fx_passes.fsdp as fsdp_mod @@ -71,6 +78,11 @@ def _patched_greedy_bucket( filter_wait_node=None, ): g = gm.graph + # Snapshot ranks before any bucketing mutates the graph. Used to + # bound each bucket's topo-span, which bounds how far compute can + # be displaced by stable_topological_sort after merging. + ranks = {n.name: i for i, n in enumerate(g.nodes)} + groups = defaultdict(list) for node in g.nodes: if is_wait_tensor(node) and filter_node(node.args[0]): @@ -83,12 +95,19 @@ def _patched_greedy_bucket( return [] node_descendents = collect_node_descendants(g) + max_topo_span = aten_autobucketing_config.max_topo_span buckets = [] + # Metrics aggregated across all groups. + n_close_bytes = 0 + n_close_span = 0 + max_observed_span = 0 + for key, nodes in groups.items(): cur_bucket = [] cur_bucket_descendents = OrderedSet() cur_bucket_size_bytes = 0 + cur_bucket_start_rank = None cur_bucket_id = 0 bucket_size_bytes = int( bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024 @@ -101,24 +120,65 @@ def _patched_greedy_bucket( n_input_val = node.all_input_nodes[0].meta["val"] in_size_bytes = n_input_val.numel() * n_input_val.element_size() size_bytes = max(out_size_bytes, in_size_bytes) - if ( + + node_rank = ranks.get(node.name, 0) + would_span = ( + node_rank - cur_bucket_start_rank + if cur_bucket_start_rank is not None + else 0 + ) + + close_for_bytes = ( cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket - ): + ) + close_for_span = ( + max_topo_span is not None + and would_span > max_topo_span + and cur_bucket + ) + + if close_for_bytes or close_for_span: if len(cur_bucket) > 1: buckets.append(cur_bucket) + if close_for_bytes: + n_close_bytes += 1 + if close_for_span: + n_close_span += 1 + observed_span = ( + ranks.get(cur_bucket[-1].name, 0) - cur_bucket_start_rank + ) + max_observed_span = max(max_observed_span, observed_span) cur_bucket = [] cur_bucket_size_bytes = 0 cur_bucket_id += 1 cur_bucket_descendents = OrderedSet() + cur_bucket_start_rank = None bucket_size_bytes = int( bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024 ) cur_bucket_size_bytes += size_bytes cur_bucket.append(node) cur_bucket_descendents |= node_descendents[node] + if cur_bucket_start_rank is None: + cur_bucket_start_rank = node_rank if len(cur_bucket) > 1: buckets.append(cur_bucket) + observed_span = ( + ranks.get(cur_bucket[-1].name, 0) - cur_bucket_start_rank + ) + max_observed_span = max(max_observed_span, observed_span) + + if buckets: + logger.info( + "greedy_bucket: %d buckets, max_span=%d, " + "closed (bytes=%d, span=%d, max_topo_span=%s)", + len(buckets), + max_observed_span, + n_close_bytes, + n_close_span, + max_topo_span, + ) return buckets fsdp_mod.identify_fsdp_groups = _patched_identify_fsdp_groups @@ -128,130 +188,26 @@ def _patched_greedy_bucket( _patch_fsdp_bucketing() -def _cap_compute_batch_size( - graph, original_compute_names, original_rs_after_compute, max_consecutive=8 -): - """Break up long compute segments between ReduceScatter operations. - - After overlap scheduling + FSDP bucketing, compute nodes may be reordered - so that many layers' matmuls execute before any ReduceScatter fires. This - inflates peak memory because all layers' activations are alive - simultaneously. - - Instead of restoring the full original order (which kills comm/compute - overlap), this function only intervenes when the number of compute nodes - between consecutive ReduceScatter ops exceeds max_consecutive. For each - oversized segment, it sorts compute nodes by original order, splits into - chunks, then: - 1. Chains consecutive compute nodes within each chunk (so they move - together during topological sort). - 2. Adds a dep from the first compute node of each chunk to an RS node - that originally appeared between the previous chunk and this one. - - Args: - graph: The post-scheduled FX graph. - original_compute_names: List of compute node names in original order. - original_rs_after_compute: Dict mapping compute node name to the list - of RS node names that appeared between it and the next compute node - in the original (pre-scheduling) graph. - max_consecutive: Maximum compute nodes allowed between RS ops. +def _max_consec_compute_between_rs(graph): + """Return the maximum consecutive compute nodes between RS ops. + + Useful as a regression metric: bucketing followed by topo sort can + pull compute from late layers forward, batching many MMs before any + RS fires. This metric grows linearly with the size of such batches. """ - from torch._dynamo.graph_deduplication import _stable_topological_sort from torch._inductor.fx_passes.overlap_scheduling import is_compute_node - def _is_rs(node): - if node.op != "call_function": - return False - name = str(node.target) - return "reduce_scatter" in name and "wait" not in name - - original_rank = {name: rank for rank, name in enumerate(original_compute_names)} - node_by_name = {n.name: n for n in graph.nodes} - - scheduled_rs_names = { - n.name for n in graph.nodes if n.op == "call_function" and _is_rs(n) - } - - # Collect compute nodes between RS ops in the post-scheduled graph. - segments: list[tuple[list[torch.fx.Node], torch.fx.Node | None]] = [] - current_compute: list[torch.fx.Node] = [] - for node in graph.nodes: - if node.op != "call_function": - continue - if is_compute_node(node): - current_compute.append(node) - elif _is_rs(node): - segments.append((current_compute, node)) - current_compute = [] - if current_compute: - segments.append((current_compute, None)) - - additional_deps: dict[torch.fx.Node, OrderedSet] = defaultdict(OrderedSet) - - for compute_nodes, _seg_rs_node in segments: - if len(compute_nodes) <= max_consecutive: + max_run = cur = 0 + for n in graph.nodes: + if n.op != "call_function": continue - - sorted_nodes = sorted( - compute_nodes, - key=lambda n: original_rank.get(n.name, float("inf")), - ) - - # Split into chunks of max_consecutive. - chunks = [] - for i in range(0, len(sorted_nodes), max_consecutive): - chunks.append(sorted_nodes[i : i + max_consecutive]) - - # Chain consecutive compute nodes within each chunk. - for chunk in chunks: - for j in range(1, len(chunk)): - additional_deps[chunk[j]].add(chunk[j - 1]) - - # At each chunk boundary, find an RS that originally appeared between - # the last compute of the previous chunk and the first compute of - # the current chunk, then add dep: first_of_chunk after RS. - for ci in range(1, len(chunks)): - prev_chunk = chunks[ci - 1] - curr_chunk = chunks[ci] - - last_in_prev = prev_chunk[-1] - first_in_curr = curr_chunk[0] - last_rank = original_rank.get(last_in_prev.name, -1) - first_rank = original_rank.get( - first_in_curr.name, len(original_compute_names) - ) - - found_rs_node = None - for r in range(last_rank, first_rank): - cname = ( - original_compute_names[r] - if r < len(original_compute_names) - else None - ) - if cname is None: - continue - for rs_name in original_rs_after_compute.get(cname, []): - if rs_name in scheduled_rs_names: - rs_obj = node_by_name.get(rs_name) - if rs_obj is not None: - found_rs_node = rs_obj - break - if found_rs_node is not None: - break - - if found_rs_node is not None: - additional_deps[first_in_curr].add(found_rs_node) - - if not additional_deps: - return - - try: - _stable_topological_sort(graph, additional_deps) - except AssertionError: - logger.warning( - "Failed to cap compute batch size (cycle detected), " - "falling back to uncapped ordering" - ) + target = str(n.target) + if is_compute_node(n): + cur += 1 + max_run = max(max_run, cur) + elif "reduce_scatter" in target and "wait" not in target: + cur = 0 + return max_run class simplefsdp_autobucketing_config: @@ -341,6 +297,10 @@ class aten_autobucketing_config: - max_in_flight_gb: maximum GB of concurrent collective data - compute_overlap_multipler: scale factor for compute time used to hide collectives - max_coll_distance: maximum node distance for overlap consideration + - max_topo_span: maximum number of graph positions a single bucket may + span. Bounds how far compute can be displaced when bucketing rewires + the dep graph and stable_topological_sort runs afterwards. Set to + None to disable the span bound (only bytes cap applies). """ max_in_flight_gb = 2.0 @@ -348,6 +308,7 @@ class aten_autobucketing_config: max_coll_distance = 100 custom_runtime_estimation = None max_compute_pre_fetch = 50 + max_topo_span: int | None = 1500 collective_bucketing = False save_trace = True _counter = 0 @@ -358,27 +319,6 @@ def aten_autobucketing_reordering_pass( ) -> torch.fx.GraphModule: assert gm.owning_module is not None - # Record compute + RS interleaving before bucketing + overlap scheduling. - from torch._inductor.fx_passes.overlap_scheduling import is_compute_node - - def _is_rs_node(node): - if node.op != "call_function": - return False - name = str(node.target) - return "reduce_scatter" in name and "wait" not in name - - original_compute_names = [] - original_rs_after_compute: dict[str, list[str]] = {} - last_compute_name = None - for n in gm.owning_module.graph.nodes: - if n.op != "call_function": - continue - if is_compute_node(n): - original_compute_names.append(n.name) - last_compute_name = n.name - elif _is_rs_node(n) and last_compute_name is not None: - original_rs_after_compute.setdefault(last_compute_name, []).append(n.name) - new_gm = schedule_overlap_bucketing( gm.owning_module, collective_bucketing=configs.collective_bucketing, @@ -389,9 +329,12 @@ def _is_rs_node(node): max_coll_distance=configs.max_coll_distance, ) - _cap_compute_batch_size( - new_gm.graph, original_compute_names, original_rs_after_compute + logger.info( + "aten_autobucketing_reordering_pass: post-pass " + "max_consec_compute_between_rs=%d", + _max_consec_compute_between_rs(new_gm.graph), ) + new_gm.recompile() if configs.save_trace: diff --git a/tests/test_auto_bucketing_patches.py b/tests/test_auto_bucketing_patches.py new file mode 100644 index 00000000..f67c1ba6 --- /dev/null +++ b/tests/test_auto_bucketing_patches.py @@ -0,0 +1,280 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the monkey-patches installed by +autoparallel.graph_passes.auto_bucketing._patch_fsdp_bucketing. + +The patches are installed at module import time, so importing +auto_bucketing replaces the originals in PyTorch's bucketing/fsdp modules. +""" + +import torch +import torch._inductor.fx_passes.bucketing as bucketing_mod +import torch._inductor.fx_passes.fsdp as fsdp_mod + +# Importing auto_bucketing triggers _patch_fsdp_bucketing() at import time. +import autoparallel.graph_passes.auto_bucketing as ab # noqa: F401 + + +def _make_fake_tensor_meta(shape, dtype=torch.bfloat16): + """Return a FakeTensor-like object usable as node.meta['val']. + + The bucketing helpers only access .numel(), .element_size(), and .dtype. + A real torch.empty meta tensor satisfies that contract. + """ + return torch.empty(shape, dtype=dtype, device="meta") + + +def _make_ag_node(g, x, group_name, shape, dtype=torch.bfloat16, group_size=8): + """Build an all_gather_into_tensor call_function node with FSDP-shaped + input chain: exactly one placeholder ancestor (which is_fsdp_all_gather + requires).""" + ag = g.call_function( + torch.ops._c10d_functional.all_gather_into_tensor.default, + (x, group_size, group_name), + ) + ag.meta["val"] = _make_fake_tensor_meta(shape, dtype) + return ag + + +def _make_wait_node(g, ag): + w = g.call_function(torch.ops._c10d_functional.wait_tensor.default, (ag,)) + w.meta["val"] = ag.meta["val"] + return w + + +def _build_gm(graph): + return torch.fx.GraphModule(torch.nn.Module(), graph) + + +# ---------- identify_fsdp_groups tests ---------- + + +def test_identify_fsdp_groups_picks_primary_group(): + """With AGs split across two PGs, only the group with more AGs is returned.""" + g = torch.fx.Graph() + # Group "dp" gets 5 AGs, group "tp" gets 2. "dp" should win. + placeholders = [g.placeholder(f"p{i}") for i in range(7)] + for ph in placeholders: + ph.meta["val"] = _make_fake_tensor_meta((1024,), torch.bfloat16) + + for i in range(5): + _make_ag_node(g, placeholders[i], "dp", (8192,), group_size=8) + for i in range(5, 7): + _make_ag_node(g, placeholders[i], "tp", (8192,), group_size=8) + g.output(()) + gm = _build_gm(g) + + groups, group_size = fsdp_mod.identify_fsdp_groups(gm) + + assert list(groups) == ["dp"], f"expected only primary 'dp', got {list(groups)}" + assert group_size == 8 + + +def test_identify_fsdp_groups_handles_no_fsdp_ags(): + """Empty graph or graphs with no FSDP AGs return an empty OrderedSet.""" + g = torch.fx.Graph() + g.output(()) + gm = _build_gm(g) + + groups, group_size = fsdp_mod.identify_fsdp_groups(gm) + + assert len(groups) == 0 + assert group_size is None + + +def test_identify_fsdp_groups_ties_pick_one(): + """Equal AG counts: most_common is deterministic but the exact group + chosen depends on insertion order. We only require that exactly one is + returned (not both).""" + g = torch.fx.Graph() + placeholders = [g.placeholder(f"p{i}") for i in range(4)] + for ph in placeholders: + ph.meta["val"] = _make_fake_tensor_meta((1024,), torch.bfloat16) + for ph in placeholders[:2]: + _make_ag_node(g, ph, "dp", (8192,)) + for ph in placeholders[2:]: + _make_ag_node(g, ph, "tp", (8192,)) + g.output(()) + gm = _build_gm(g) + + groups, _ = fsdp_mod.identify_fsdp_groups(gm) + assert len(groups) == 1 + assert list(groups)[0] in ("dp", "tp") + + +# ---------- greedy_bucket tests ---------- + + +def _build_ag_chain_graph(n_ags, ag_shape_bytes_each=1024, group_name="dp"): + """Build a graph with n_ags AGs in the same group + their waits. + + Each AG has its own placeholder input (so descendant relations between + AGs are absent — every AG is independent). + + Returns (gm, ag_nodes, wait_nodes). + """ + g = torch.fx.Graph() + ag_nodes = [] + wait_nodes = [] + # Use shape such that numel * element_size == ag_shape_bytes_each. + # bfloat16 = 2 bytes, so numel = bytes / 2. + numel = max(1, ag_shape_bytes_each // 2) + for i in range(n_ags): + ph = g.placeholder(f"p{i}") + ph.meta["val"] = _make_fake_tensor_meta((numel,), torch.bfloat16) + ag = _make_ag_node(g, ph, group_name, (numel,), group_size=8) + w = _make_wait_node(g, ag) + ag_nodes.append(ag) + wait_nodes.append(w) + g.output(tuple(wait_nodes)) + return _build_gm(g), ag_nodes, wait_nodes + + +def _call_greedy_bucket(gm, bucket_cap_mb, *, filter_wait_node=None): + """Invoke the patched greedy_bucket via the public API surface.""" + return bucketing_mod.greedy_bucket_collective_by_mb( + gm, + lambda _idx: bucket_cap_mb, + # filter_node accepts the collective node (args[0] of the wait). + # We accept all AGs for these tests. + filter_node=bucketing_mod.is_all_gather_into_tensor, + node_group_key=bucketing_mod._ag_group_key, + filter_wait_node=filter_wait_node, + ) + + +def test_greedy_bucket_merges_within_caps(): + """4 small AGs, plenty of room: one bucket containing all of them.""" + gm, ag_nodes, _ = _build_ag_chain_graph(n_ags=4, ag_shape_bytes_each=1024) + # max_topo_span large enough to fit all (graph has ~12 nodes). + ab.aten_autobucketing_config.max_topo_span = 1000 + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = 1500 + + assert len(buckets) == 1, f"expected 1 bucket, got {len(buckets)}" + assert len(buckets[0]) == 4, f"expected 4 AGs in bucket, got {len(buckets[0])}" + + +def test_greedy_bucket_splits_on_bytes_cap(): + """Each AG = 4 MB, cap = 10 MB → buckets of at most 2 AGs each.""" + # 5 MB per AG → second AG exceeds 10 MB cap. + bytes_per_ag = 5 * 1024 * 1024 + gm, ag_nodes, _ = _build_ag_chain_graph(n_ags=6, ag_shape_bytes_each=bytes_per_ag) + ab.aten_autobucketing_config.max_topo_span = 1000 + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = 1500 + + assert len(buckets) >= 2, f"expected >=2 buckets, got {len(buckets)}: {buckets}" + # Each bucket must respect the 10 MB cap, so each bucket has ≤ 2 AGs. + for b in buckets: + assert len(b) <= 2, f"bucket has {len(b)} > 2 AGs (exceeds 10 MB cap)" + # Single-element buckets are dropped, so each surviving bucket has ≥ 2. + for b in buckets: + assert len(b) >= 2, f"surviving bucket has only {len(b)} member" + + +def test_greedy_bucket_splits_on_span_cap(): + """4 AGs spaced out in the graph, max_topo_span forces splits below + the byte cap.""" + # Each AG with its placeholder, wait, plus a string of no-op compute + # nodes between them so the topo distance between AGs is large. + g = torch.fx.Graph() + ag_nodes = [] + wait_nodes = [] + bytes_per_ag = 1024 # tiny — bytes cap won't fire + numel = bytes_per_ag // 2 + for i in range(4): + ph = g.placeholder(f"p{i}") + ph.meta["val"] = _make_fake_tensor_meta((numel,), torch.bfloat16) + ag = _make_ag_node(g, ph, "dp", (numel,), group_size=8) + w = _make_wait_node(g, ag) + ag_nodes.append(ag) + wait_nodes.append(w) + # Pad: add 50 unrelated ops between this wait and the next AG. + x = w + for j in range(50): + x = g.call_function(torch.ops.aten.relu.default, (x,)) + x.meta["val"] = w.meta["val"] + g.output(tuple(wait_nodes)) + gm = _build_gm(g) + + # With ~50+ nodes per AG, total span across 4 AGs ≈ 200. + # max_topo_span = 80 means at most 2 AGs per bucket. + ab.aten_autobucketing_config.max_topo_span = 80 + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = 1500 + + assert ( + len(buckets) >= 2 + ), f"expected >=2 buckets due to span cap, got {len(buckets)}: {buckets}" + for b in buckets: + # Each bucket spans ~50 nodes per pair → 2 AGs at most fit in 80. + assert len(b) <= 2, f"bucket of {len(b)} AGs violates max_topo_span=80" + + +def test_greedy_bucket_span_cap_none_disables_span(): + """max_topo_span=None should restore byte-only behavior: + same 4 spaced-out AGs that split at span=80 above now go into one bucket.""" + g = torch.fx.Graph() + wait_nodes = [] + for i in range(4): + ph = g.placeholder(f"p{i}") + ph.meta["val"] = _make_fake_tensor_meta((512,), torch.bfloat16) + ag = _make_ag_node(g, ph, "dp", (512,), group_size=8) + w = _make_wait_node(g, ag) + wait_nodes.append(w) + x = w + for j in range(50): + x = g.call_function(torch.ops.aten.relu.default, (x,)) + x.meta["val"] = w.meta["val"] + g.output(tuple(wait_nodes)) + gm = _build_gm(g) + + saved = ab.aten_autobucketing_config.max_topo_span + ab.aten_autobucketing_config.max_topo_span = None + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = saved + + assert ( + len(buckets) == 1 + ), f"expected 1 bucket with span cap disabled, got {len(buckets)}" + assert len(buckets[0]) == 4 + + +def test_greedy_bucket_skips_descendant_collectives(): + """Two AGs where AG2 depends on AG1's wait must not bucket together + (would create a cycle on merge).""" + g = torch.fx.Graph() + p1 = g.placeholder("p1") + p1.meta["val"] = _make_fake_tensor_meta((512,), torch.bfloat16) + ag1 = _make_ag_node(g, p1, "dp", (512,), group_size=8) + w1 = _make_wait_node(g, ag1) + # AG2 takes w1 as input → AG2 is a descendant of AG1. + ag2 = _make_ag_node(g, w1, "dp", (512,), group_size=8) + w2 = _make_wait_node(g, ag2) + g.output((w2,)) + gm = _build_gm(g) + + ab.aten_autobucketing_config.max_topo_span = 1000 + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = 1500 + + # Either no buckets or a single-element bucket (which the impl drops). + # The two AGs must never end up together. + for b in buckets: + assert not ( + ag1 in b and ag2 in b + ), "AG1 and AG2 ended up in same bucket despite descendant relation"