Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 210 additions & 1 deletion autoparallel/graph_passes/auto_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,212 @@
# LICENSE file in the root directory of this source tree.

import logging
from collections import Counter, defaultdict
from functools import partial

import torch
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
from torch.utils._ordered_set import OrderedSet

from .autobucketing_inductor import bucket_func, bucket_plan, bucket_utils, reorder

logger = logging.getLogger(__name__)


def _patch_fsdp_bucketing():
"""Patch PyTorch's FSDP bucketing for better multi-group handling.

Two fixes:
1. Primary-group-only: only include the group with the most FSDP
all-gathers in fsdp_groups, preventing minority groups (tp, combined)
from limiting dp bucketing aggressiveness.
2. Topo-span-bounded bucketing: allow collectives to be bucketed even
when interleaved with non-FSDP collectives on other groups, but
close the bucket once its topo-span (rank of latest member minus
rank of first member) exceeds aten_autobucketing_config.max_topo_span.

Without a span bound, merging collectives that are far apart in the
graph rewires the dependency graph so that stable_topological_sort
can pull compute from late layers forward, batching many MMs before
any RS fires (the MM*N problem). The span bound caps how far compute
can be displaced by any one bucket merge.
"""
import torch._inductor.fx_passes.bucketing as bucketing_mod
import torch._inductor.fx_passes.fsdp as fsdp_mod
from torch._inductor.fx_passes.bucketing import (
collect_node_descendants,
is_wait_tensor,
)
from torch._inductor.fx_passes.fsdp import (
_find_all_gathers,
_get_group_name,
_get_group_size_from_node,
is_fsdp_all_gather,
)

def _patched_identify_fsdp_groups(gm):
fsdp_counts_by_group = Counter()
group_size = None
for n in _find_all_gathers(gm.graph):
if is_fsdp_all_gather(n):
gn = _get_group_name(n)
fsdp_counts_by_group[gn] += 1
if group_size is None:
group_size = _get_group_size_from_node(n)

if fsdp_counts_by_group:
primary_group = fsdp_counts_by_group.most_common(1)[0][0]
fsdp_groups = OrderedSet([primary_group])
else:
fsdp_groups = OrderedSet()

logger.debug(
"identify_fsdp_groups (patched): fsdp_groups=%s, all_counts=%s",
list(fsdp_groups),
dict(fsdp_counts_by_group),
)
return fsdp_groups, group_size

def _patched_greedy_bucket(
gm,
bucket_cap_mb_by_bucket_idx,
filter_node,
node_group_key,
filter_wait_node=None,
):
g = gm.graph
# Snapshot ranks before any bucketing mutates the graph. Used to
# bound each bucket's topo-span, which bounds how far compute can
# be displaced by stable_topological_sort after merging.
ranks = {n.name: i for i, n in enumerate(g.nodes)}

groups = defaultdict(list)
for node in g.nodes:
if is_wait_tensor(node) and filter_node(node.args[0]):
if (filter_wait_node is None) or filter_wait_node(node):
coll_node = node.args[0]
key = node_group_key(coll_node)
groups[key].append(coll_node)

if not groups:
return []

node_descendents = collect_node_descendants(g)
max_topo_span = aten_autobucketing_config.max_topo_span

buckets = []
# Metrics aggregated across all groups.
n_close_bytes = 0
n_close_span = 0
max_observed_span = 0

for key, nodes in groups.items():
cur_bucket = []
cur_bucket_descendents = OrderedSet()
cur_bucket_size_bytes = 0
cur_bucket_start_rank = None
cur_bucket_id = 0
bucket_size_bytes = int(
bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
)
for node in nodes:
if node in cur_bucket_descendents:
continue
n_val = node.meta["val"]
out_size_bytes = n_val.numel() * n_val.element_size()
n_input_val = node.all_input_nodes[0].meta["val"]
in_size_bytes = n_input_val.numel() * n_input_val.element_size()
size_bytes = max(out_size_bytes, in_size_bytes)

node_rank = ranks.get(node.name, 0)
would_span = (
node_rank - cur_bucket_start_rank
if cur_bucket_start_rank is not None
else 0
)

close_for_bytes = (
cur_bucket_size_bytes + size_bytes > bucket_size_bytes
and cur_bucket
)
close_for_span = (
max_topo_span is not None
and would_span > max_topo_span
and cur_bucket
)

if close_for_bytes or close_for_span:
if len(cur_bucket) > 1:
buckets.append(cur_bucket)
if close_for_bytes:
n_close_bytes += 1
if close_for_span:
n_close_span += 1
observed_span = (
ranks.get(cur_bucket[-1].name, 0) - cur_bucket_start_rank
)
max_observed_span = max(max_observed_span, observed_span)
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
cur_bucket_descendents = OrderedSet()
cur_bucket_start_rank = None
bucket_size_bytes = int(
bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
)
cur_bucket_size_bytes += size_bytes
cur_bucket.append(node)
cur_bucket_descendents |= node_descendents[node]
if cur_bucket_start_rank is None:
cur_bucket_start_rank = node_rank
if len(cur_bucket) > 1:
buckets.append(cur_bucket)
observed_span = (
ranks.get(cur_bucket[-1].name, 0) - cur_bucket_start_rank
)
max_observed_span = max(max_observed_span, observed_span)

if buckets:
logger.info(
"greedy_bucket: %d buckets, max_span=%d, "
"closed (bytes=%d, span=%d, max_topo_span=%s)",
len(buckets),
max_observed_span,
n_close_bytes,
n_close_span,
max_topo_span,
)
return buckets

fsdp_mod.identify_fsdp_groups = _patched_identify_fsdp_groups
bucketing_mod.greedy_bucket_collective_by_mb = _patched_greedy_bucket


_patch_fsdp_bucketing()


def _max_consec_compute_between_rs(graph):
"""Return the maximum consecutive compute nodes between RS ops.

Useful as a regression metric: bucketing followed by topo sort can
pull compute from late layers forward, batching many MMs before any
RS fires. This metric grows linearly with the size of such batches.
"""
from torch._inductor.fx_passes.overlap_scheduling import is_compute_node

max_run = cur = 0
for n in graph.nodes:
if n.op != "call_function":
continue
target = str(n.target)
if is_compute_node(n):
cur += 1
max_run = max(max_run, cur)
elif "reduce_scatter" in target and "wait" not in target:
cur = 0
return max_run


class simplefsdp_autobucketing_config:
"""
Config for simplefsdp's autobucketing pass, which by default would give good performance.
Expand Down Expand Up @@ -101,13 +297,18 @@ class aten_autobucketing_config:
- max_in_flight_gb: maximum GB of concurrent collective data
- compute_overlap_multipler: scale factor for compute time used to hide collectives
- max_coll_distance: maximum node distance for overlap consideration
- max_topo_span: maximum number of graph positions a single bucket may
span. Bounds how far compute can be displaced when bucketing rewires
the dep graph and stable_topological_sort runs afterwards. Set to
None to disable the span bound (only bytes cap applies).
"""

max_in_flight_gb = 2.0
compute_overlap_multipler = 1.0
max_coll_distance = 100
custom_runtime_estimation = None
max_compute_pre_fetch = 5
max_compute_pre_fetch = 50
max_topo_span: int | None = 1500
collective_bucketing = False
save_trace = True
_counter = 0
Expand All @@ -117,6 +318,7 @@ def aten_autobucketing_reordering_pass(
gm: torch.fx.Graph, configs: "aten_autobucketing_config"
) -> torch.fx.GraphModule:
assert gm.owning_module is not None

new_gm = schedule_overlap_bucketing(
gm.owning_module,
collective_bucketing=configs.collective_bucketing,
Expand All @@ -126,6 +328,13 @@ def aten_autobucketing_reordering_pass(
max_in_flight_gb=configs.max_in_flight_gb,
max_coll_distance=configs.max_coll_distance,
)

logger.info(
"aten_autobucketing_reordering_pass: post-pass "
"max_consec_compute_between_rs=%d",
_max_consec_compute_between_rs(new_gm.graph),
)

new_gm.recompile()

if configs.save_trace:
Expand Down
Loading
Loading