diff --git a/autoparallel/graph_passes/auto_bucketing.py b/autoparallel/graph_passes/auto_bucketing.py index fa01ff70..42562a38 100644 --- a/autoparallel/graph_passes/auto_bucketing.py +++ b/autoparallel/graph_passes/auto_bucketing.py @@ -4,16 +4,212 @@ # LICENSE file in the root directory of this source tree. import logging +from collections import Counter, defaultdict from functools import partial import torch from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing +from torch.utils._ordered_set import OrderedSet from .autobucketing_inductor import bucket_func, bucket_plan, bucket_utils, reorder logger = logging.getLogger(__name__) +def _patch_fsdp_bucketing(): + """Patch PyTorch's FSDP bucketing for better multi-group handling. + + Two fixes: + 1. Primary-group-only: only include the group with the most FSDP + all-gathers in fsdp_groups, preventing minority groups (tp, combined) + from limiting dp bucketing aggressiveness. + 2. Topo-span-bounded bucketing: allow collectives to be bucketed even + when interleaved with non-FSDP collectives on other groups, but + close the bucket once its topo-span (rank of latest member minus + rank of first member) exceeds aten_autobucketing_config.max_topo_span. + + Without a span bound, merging collectives that are far apart in the + graph rewires the dependency graph so that stable_topological_sort + can pull compute from late layers forward, batching many MMs before + any RS fires (the MM*N problem). The span bound caps how far compute + can be displaced by any one bucket merge. + """ + import torch._inductor.fx_passes.bucketing as bucketing_mod + import torch._inductor.fx_passes.fsdp as fsdp_mod + from torch._inductor.fx_passes.bucketing import ( + collect_node_descendants, + is_wait_tensor, + ) + from torch._inductor.fx_passes.fsdp import ( + _find_all_gathers, + _get_group_name, + _get_group_size_from_node, + is_fsdp_all_gather, + ) + + def _patched_identify_fsdp_groups(gm): + fsdp_counts_by_group = Counter() + group_size = None + for n in _find_all_gathers(gm.graph): + if is_fsdp_all_gather(n): + gn = _get_group_name(n) + fsdp_counts_by_group[gn] += 1 + if group_size is None: + group_size = _get_group_size_from_node(n) + + if fsdp_counts_by_group: + primary_group = fsdp_counts_by_group.most_common(1)[0][0] + fsdp_groups = OrderedSet([primary_group]) + else: + fsdp_groups = OrderedSet() + + logger.debug( + "identify_fsdp_groups (patched): fsdp_groups=%s, all_counts=%s", + list(fsdp_groups), + dict(fsdp_counts_by_group), + ) + return fsdp_groups, group_size + + def _patched_greedy_bucket( + gm, + bucket_cap_mb_by_bucket_idx, + filter_node, + node_group_key, + filter_wait_node=None, + ): + g = gm.graph + # Snapshot ranks before any bucketing mutates the graph. Used to + # bound each bucket's topo-span, which bounds how far compute can + # be displaced by stable_topological_sort after merging. + ranks = {n.name: i for i, n in enumerate(g.nodes)} + + groups = defaultdict(list) + for node in g.nodes: + if is_wait_tensor(node) and filter_node(node.args[0]): + if (filter_wait_node is None) or filter_wait_node(node): + coll_node = node.args[0] + key = node_group_key(coll_node) + groups[key].append(coll_node) + + if not groups: + return [] + + node_descendents = collect_node_descendants(g) + max_topo_span = aten_autobucketing_config.max_topo_span + + buckets = [] + # Metrics aggregated across all groups. + n_close_bytes = 0 + n_close_span = 0 + max_observed_span = 0 + + for key, nodes in groups.items(): + cur_bucket = [] + cur_bucket_descendents = OrderedSet() + cur_bucket_size_bytes = 0 + cur_bucket_start_rank = None + cur_bucket_id = 0 + bucket_size_bytes = int( + bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024 + ) + for node in nodes: + if node in cur_bucket_descendents: + continue + n_val = node.meta["val"] + out_size_bytes = n_val.numel() * n_val.element_size() + n_input_val = node.all_input_nodes[0].meta["val"] + in_size_bytes = n_input_val.numel() * n_input_val.element_size() + size_bytes = max(out_size_bytes, in_size_bytes) + + node_rank = ranks.get(node.name, 0) + would_span = ( + node_rank - cur_bucket_start_rank + if cur_bucket_start_rank is not None + else 0 + ) + + close_for_bytes = ( + cur_bucket_size_bytes + size_bytes > bucket_size_bytes + and cur_bucket + ) + close_for_span = ( + max_topo_span is not None + and would_span > max_topo_span + and cur_bucket + ) + + if close_for_bytes or close_for_span: + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + if close_for_bytes: + n_close_bytes += 1 + if close_for_span: + n_close_span += 1 + observed_span = ( + ranks.get(cur_bucket[-1].name, 0) - cur_bucket_start_rank + ) + max_observed_span = max(max_observed_span, observed_span) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + cur_bucket_descendents = OrderedSet() + cur_bucket_start_rank = None + bucket_size_bytes = int( + bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024 + ) + cur_bucket_size_bytes += size_bytes + cur_bucket.append(node) + cur_bucket_descendents |= node_descendents[node] + if cur_bucket_start_rank is None: + cur_bucket_start_rank = node_rank + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + observed_span = ( + ranks.get(cur_bucket[-1].name, 0) - cur_bucket_start_rank + ) + max_observed_span = max(max_observed_span, observed_span) + + if buckets: + logger.info( + "greedy_bucket: %d buckets, max_span=%d, " + "closed (bytes=%d, span=%d, max_topo_span=%s)", + len(buckets), + max_observed_span, + n_close_bytes, + n_close_span, + max_topo_span, + ) + return buckets + + fsdp_mod.identify_fsdp_groups = _patched_identify_fsdp_groups + bucketing_mod.greedy_bucket_collective_by_mb = _patched_greedy_bucket + + +_patch_fsdp_bucketing() + + +def _max_consec_compute_between_rs(graph): + """Return the maximum consecutive compute nodes between RS ops. + + Useful as a regression metric: bucketing followed by topo sort can + pull compute from late layers forward, batching many MMs before any + RS fires. This metric grows linearly with the size of such batches. + """ + from torch._inductor.fx_passes.overlap_scheduling import is_compute_node + + max_run = cur = 0 + for n in graph.nodes: + if n.op != "call_function": + continue + target = str(n.target) + if is_compute_node(n): + cur += 1 + max_run = max(max_run, cur) + elif "reduce_scatter" in target and "wait" not in target: + cur = 0 + return max_run + + class simplefsdp_autobucketing_config: """ Config for simplefsdp's autobucketing pass, which by default would give good performance. @@ -101,13 +297,18 @@ class aten_autobucketing_config: - max_in_flight_gb: maximum GB of concurrent collective data - compute_overlap_multipler: scale factor for compute time used to hide collectives - max_coll_distance: maximum node distance for overlap consideration + - max_topo_span: maximum number of graph positions a single bucket may + span. Bounds how far compute can be displaced when bucketing rewires + the dep graph and stable_topological_sort runs afterwards. Set to + None to disable the span bound (only bytes cap applies). """ max_in_flight_gb = 2.0 compute_overlap_multipler = 1.0 max_coll_distance = 100 custom_runtime_estimation = None - max_compute_pre_fetch = 5 + max_compute_pre_fetch = 50 + max_topo_span: int | None = 1500 collective_bucketing = False save_trace = True _counter = 0 @@ -117,6 +318,7 @@ def aten_autobucketing_reordering_pass( gm: torch.fx.Graph, configs: "aten_autobucketing_config" ) -> torch.fx.GraphModule: assert gm.owning_module is not None + new_gm = schedule_overlap_bucketing( gm.owning_module, collective_bucketing=configs.collective_bucketing, @@ -126,6 +328,13 @@ def aten_autobucketing_reordering_pass( max_in_flight_gb=configs.max_in_flight_gb, max_coll_distance=configs.max_coll_distance, ) + + logger.info( + "aten_autobucketing_reordering_pass: post-pass " + "max_consec_compute_between_rs=%d", + _max_consec_compute_between_rs(new_gm.graph), + ) + new_gm.recompile() if configs.save_trace: diff --git a/tests/test_auto_bucketing_patches.py b/tests/test_auto_bucketing_patches.py new file mode 100644 index 00000000..f67c1ba6 --- /dev/null +++ b/tests/test_auto_bucketing_patches.py @@ -0,0 +1,280 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the monkey-patches installed by +autoparallel.graph_passes.auto_bucketing._patch_fsdp_bucketing. + +The patches are installed at module import time, so importing +auto_bucketing replaces the originals in PyTorch's bucketing/fsdp modules. +""" + +import torch +import torch._inductor.fx_passes.bucketing as bucketing_mod +import torch._inductor.fx_passes.fsdp as fsdp_mod + +# Importing auto_bucketing triggers _patch_fsdp_bucketing() at import time. +import autoparallel.graph_passes.auto_bucketing as ab # noqa: F401 + + +def _make_fake_tensor_meta(shape, dtype=torch.bfloat16): + """Return a FakeTensor-like object usable as node.meta['val']. + + The bucketing helpers only access .numel(), .element_size(), and .dtype. + A real torch.empty meta tensor satisfies that contract. + """ + return torch.empty(shape, dtype=dtype, device="meta") + + +def _make_ag_node(g, x, group_name, shape, dtype=torch.bfloat16, group_size=8): + """Build an all_gather_into_tensor call_function node with FSDP-shaped + input chain: exactly one placeholder ancestor (which is_fsdp_all_gather + requires).""" + ag = g.call_function( + torch.ops._c10d_functional.all_gather_into_tensor.default, + (x, group_size, group_name), + ) + ag.meta["val"] = _make_fake_tensor_meta(shape, dtype) + return ag + + +def _make_wait_node(g, ag): + w = g.call_function(torch.ops._c10d_functional.wait_tensor.default, (ag,)) + w.meta["val"] = ag.meta["val"] + return w + + +def _build_gm(graph): + return torch.fx.GraphModule(torch.nn.Module(), graph) + + +# ---------- identify_fsdp_groups tests ---------- + + +def test_identify_fsdp_groups_picks_primary_group(): + """With AGs split across two PGs, only the group with more AGs is returned.""" + g = torch.fx.Graph() + # Group "dp" gets 5 AGs, group "tp" gets 2. "dp" should win. + placeholders = [g.placeholder(f"p{i}") for i in range(7)] + for ph in placeholders: + ph.meta["val"] = _make_fake_tensor_meta((1024,), torch.bfloat16) + + for i in range(5): + _make_ag_node(g, placeholders[i], "dp", (8192,), group_size=8) + for i in range(5, 7): + _make_ag_node(g, placeholders[i], "tp", (8192,), group_size=8) + g.output(()) + gm = _build_gm(g) + + groups, group_size = fsdp_mod.identify_fsdp_groups(gm) + + assert list(groups) == ["dp"], f"expected only primary 'dp', got {list(groups)}" + assert group_size == 8 + + +def test_identify_fsdp_groups_handles_no_fsdp_ags(): + """Empty graph or graphs with no FSDP AGs return an empty OrderedSet.""" + g = torch.fx.Graph() + g.output(()) + gm = _build_gm(g) + + groups, group_size = fsdp_mod.identify_fsdp_groups(gm) + + assert len(groups) == 0 + assert group_size is None + + +def test_identify_fsdp_groups_ties_pick_one(): + """Equal AG counts: most_common is deterministic but the exact group + chosen depends on insertion order. We only require that exactly one is + returned (not both).""" + g = torch.fx.Graph() + placeholders = [g.placeholder(f"p{i}") for i in range(4)] + for ph in placeholders: + ph.meta["val"] = _make_fake_tensor_meta((1024,), torch.bfloat16) + for ph in placeholders[:2]: + _make_ag_node(g, ph, "dp", (8192,)) + for ph in placeholders[2:]: + _make_ag_node(g, ph, "tp", (8192,)) + g.output(()) + gm = _build_gm(g) + + groups, _ = fsdp_mod.identify_fsdp_groups(gm) + assert len(groups) == 1 + assert list(groups)[0] in ("dp", "tp") + + +# ---------- greedy_bucket tests ---------- + + +def _build_ag_chain_graph(n_ags, ag_shape_bytes_each=1024, group_name="dp"): + """Build a graph with n_ags AGs in the same group + their waits. + + Each AG has its own placeholder input (so descendant relations between + AGs are absent — every AG is independent). + + Returns (gm, ag_nodes, wait_nodes). + """ + g = torch.fx.Graph() + ag_nodes = [] + wait_nodes = [] + # Use shape such that numel * element_size == ag_shape_bytes_each. + # bfloat16 = 2 bytes, so numel = bytes / 2. + numel = max(1, ag_shape_bytes_each // 2) + for i in range(n_ags): + ph = g.placeholder(f"p{i}") + ph.meta["val"] = _make_fake_tensor_meta((numel,), torch.bfloat16) + ag = _make_ag_node(g, ph, group_name, (numel,), group_size=8) + w = _make_wait_node(g, ag) + ag_nodes.append(ag) + wait_nodes.append(w) + g.output(tuple(wait_nodes)) + return _build_gm(g), ag_nodes, wait_nodes + + +def _call_greedy_bucket(gm, bucket_cap_mb, *, filter_wait_node=None): + """Invoke the patched greedy_bucket via the public API surface.""" + return bucketing_mod.greedy_bucket_collective_by_mb( + gm, + lambda _idx: bucket_cap_mb, + # filter_node accepts the collective node (args[0] of the wait). + # We accept all AGs for these tests. + filter_node=bucketing_mod.is_all_gather_into_tensor, + node_group_key=bucketing_mod._ag_group_key, + filter_wait_node=filter_wait_node, + ) + + +def test_greedy_bucket_merges_within_caps(): + """4 small AGs, plenty of room: one bucket containing all of them.""" + gm, ag_nodes, _ = _build_ag_chain_graph(n_ags=4, ag_shape_bytes_each=1024) + # max_topo_span large enough to fit all (graph has ~12 nodes). + ab.aten_autobucketing_config.max_topo_span = 1000 + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = 1500 + + assert len(buckets) == 1, f"expected 1 bucket, got {len(buckets)}" + assert len(buckets[0]) == 4, f"expected 4 AGs in bucket, got {len(buckets[0])}" + + +def test_greedy_bucket_splits_on_bytes_cap(): + """Each AG = 4 MB, cap = 10 MB → buckets of at most 2 AGs each.""" + # 5 MB per AG → second AG exceeds 10 MB cap. + bytes_per_ag = 5 * 1024 * 1024 + gm, ag_nodes, _ = _build_ag_chain_graph(n_ags=6, ag_shape_bytes_each=bytes_per_ag) + ab.aten_autobucketing_config.max_topo_span = 1000 + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = 1500 + + assert len(buckets) >= 2, f"expected >=2 buckets, got {len(buckets)}: {buckets}" + # Each bucket must respect the 10 MB cap, so each bucket has ≤ 2 AGs. + for b in buckets: + assert len(b) <= 2, f"bucket has {len(b)} > 2 AGs (exceeds 10 MB cap)" + # Single-element buckets are dropped, so each surviving bucket has ≥ 2. + for b in buckets: + assert len(b) >= 2, f"surviving bucket has only {len(b)} member" + + +def test_greedy_bucket_splits_on_span_cap(): + """4 AGs spaced out in the graph, max_topo_span forces splits below + the byte cap.""" + # Each AG with its placeholder, wait, plus a string of no-op compute + # nodes between them so the topo distance between AGs is large. + g = torch.fx.Graph() + ag_nodes = [] + wait_nodes = [] + bytes_per_ag = 1024 # tiny — bytes cap won't fire + numel = bytes_per_ag // 2 + for i in range(4): + ph = g.placeholder(f"p{i}") + ph.meta["val"] = _make_fake_tensor_meta((numel,), torch.bfloat16) + ag = _make_ag_node(g, ph, "dp", (numel,), group_size=8) + w = _make_wait_node(g, ag) + ag_nodes.append(ag) + wait_nodes.append(w) + # Pad: add 50 unrelated ops between this wait and the next AG. + x = w + for j in range(50): + x = g.call_function(torch.ops.aten.relu.default, (x,)) + x.meta["val"] = w.meta["val"] + g.output(tuple(wait_nodes)) + gm = _build_gm(g) + + # With ~50+ nodes per AG, total span across 4 AGs ≈ 200. + # max_topo_span = 80 means at most 2 AGs per bucket. + ab.aten_autobucketing_config.max_topo_span = 80 + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = 1500 + + assert ( + len(buckets) >= 2 + ), f"expected >=2 buckets due to span cap, got {len(buckets)}: {buckets}" + for b in buckets: + # Each bucket spans ~50 nodes per pair → 2 AGs at most fit in 80. + assert len(b) <= 2, f"bucket of {len(b)} AGs violates max_topo_span=80" + + +def test_greedy_bucket_span_cap_none_disables_span(): + """max_topo_span=None should restore byte-only behavior: + same 4 spaced-out AGs that split at span=80 above now go into one bucket.""" + g = torch.fx.Graph() + wait_nodes = [] + for i in range(4): + ph = g.placeholder(f"p{i}") + ph.meta["val"] = _make_fake_tensor_meta((512,), torch.bfloat16) + ag = _make_ag_node(g, ph, "dp", (512,), group_size=8) + w = _make_wait_node(g, ag) + wait_nodes.append(w) + x = w + for j in range(50): + x = g.call_function(torch.ops.aten.relu.default, (x,)) + x.meta["val"] = w.meta["val"] + g.output(tuple(wait_nodes)) + gm = _build_gm(g) + + saved = ab.aten_autobucketing_config.max_topo_span + ab.aten_autobucketing_config.max_topo_span = None + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = saved + + assert ( + len(buckets) == 1 + ), f"expected 1 bucket with span cap disabled, got {len(buckets)}" + assert len(buckets[0]) == 4 + + +def test_greedy_bucket_skips_descendant_collectives(): + """Two AGs where AG2 depends on AG1's wait must not bucket together + (would create a cycle on merge).""" + g = torch.fx.Graph() + p1 = g.placeholder("p1") + p1.meta["val"] = _make_fake_tensor_meta((512,), torch.bfloat16) + ag1 = _make_ag_node(g, p1, "dp", (512,), group_size=8) + w1 = _make_wait_node(g, ag1) + # AG2 takes w1 as input → AG2 is a descendant of AG1. + ag2 = _make_ag_node(g, w1, "dp", (512,), group_size=8) + w2 = _make_wait_node(g, ag2) + g.output((w2,)) + gm = _build_gm(g) + + ab.aten_autobucketing_config.max_topo_span = 1000 + try: + buckets = _call_greedy_bucket(gm, bucket_cap_mb=10.0) + finally: + ab.aten_autobucketing_config.max_topo_span = 1500 + + # Either no buckets or a single-element bucket (which the impl drops). + # The two AGs must never end up together. + for b in buckets: + assert not ( + ag1 in b and ag2 in b + ), "AG1 and AG2 ended up in same bucket despite descendant relation"