Skip to content

Improve FSDP bucketing and cap compute batches between ReduceScatters#481

Open
fmassa wants to merge 1 commit into
mainfrom
fmassa/better_bucketing
Open

Improve FSDP bucketing and cap compute batches between ReduceScatters#481
fmassa wants to merge 1 commit into
mainfrom
fmassa/better_bucketing

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented Jun 2, 2026

Ports three related changes to auto_bucketing.py from fmassa/double_recomp (cherry of 428e9d2). Together they address an interaction between FSDP all-gather bucketing and stable_topological_sort that batches recomputation MMs upfront and inflates backward peak memory.

1. _patch_fsdp_bucketing() — monkey-patches PyTorch's FSDP bucketing

  • Primary-group-only: identify_fsdp_groups keeps only the group with the most FSDP all-gathers, so minority groups (1 tp AG, ~65 norm AGs on the combined group) no longer pollute the bucketing pool.
  • Non-adjacent bucketing: greedy_bucket_collective_by_mb collects all eligible collectives per group key instead of requiring graph adjacency, allowing dp AGs interleaved with tp activation collectives to bucket together.

2. max_compute_pre_fetch: 5 → 50
The previous value allowed only ~0.3 layers of prefetch (≈17 compute nodes/layer), insufficient to hide 5–7ms full-mesh AGs. 50 gives ≈3 layers of headroom.

3. _cap_compute_batch_size(max_consecutive=8) After bucketing rewires dependencies, stable_topological_sort reorders 525/540 compute ops into an MM*40 block before the first backward RS, blowing up peak memory. Snapshots the original compute/RS interleaving before scheduling, then for any post-schedule segment with >8 compute nodes between RS ops, chains chunks and pulls forward an RS that originally sat between them. Falls back gracefully if a cycle is detected.

Why this matters (from prior investigation, LLaMA-3 8B, 128 H100s, dp=16/tp=8):

  • Bucketing patches + prefetch bump closed the unconstrained AP-to-reference gap from 12.1% → 4.6% (385ms vs 368ms reference). Prefetch alone: −20ms (−4.6%) unconstrained.
  • _cap_compute_batch_size brings constrained from a runaway MM*40 recomp batch down to 358ms / 16.78 GB (vs reference 339ms / 5.97 GB), trading +5.9ms latency for −4.3 GB peak memory vs the uncapped variant. A more aggressive _restore_compute_order was rejected — it killed 62.8ms of overlap for only 1.3 GB extra savings.

Authored with Claude.

Ports three related changes to `auto_bucketing.py` from `fmassa/double_recomp` (cherry of 428e9d2). Together they address an interaction between FSDP all-gather bucketing and `stable_topological_sort` that batches recomputation MMs upfront and inflates backward peak memory.

**1. `_patch_fsdp_bucketing()` — monkey-patches PyTorch's FSDP bucketing**
- *Primary-group-only*: `identify_fsdp_groups` keeps only the group with the most FSDP all-gathers, so minority groups (1 tp AG, ~65 norm AGs on the combined group) no longer pollute the bucketing pool.
- *Non-adjacent bucketing*: `greedy_bucket_collective_by_mb` collects all eligible collectives per group key instead of requiring graph adjacency, allowing dp AGs interleaved with tp activation collectives to bucket together.

**2. `max_compute_pre_fetch`: 5 → 50**
The previous value allowed only ~0.3 layers of prefetch (≈17 compute nodes/layer), insufficient to hide 5–7ms full-mesh AGs. 50 gives ≈3 layers of headroom.

**3. `_cap_compute_batch_size(max_consecutive=8)`**
After bucketing rewires dependencies, `stable_topological_sort` reorders 525/540 compute ops into an `MM*40` block before the first backward RS, blowing up peak memory. Snapshots the original compute/RS interleaving before scheduling, then for any post-schedule segment with >8 compute nodes between RS ops, chains chunks and pulls forward an RS that originally sat between them. Falls back gracefully if a cycle is detected.

**Why this matters (from prior investigation, LLaMA-3 8B, 128 H100s, dp=16/tp=8):**
- Bucketing patches + prefetch bump closed the unconstrained AP-to-reference gap from 12.1% → 4.6% (385ms vs 368ms reference). Prefetch alone: −20ms (−4.6%) unconstrained.
- `_cap_compute_batch_size` brings constrained from a runaway `MM*40` recomp batch down to **358ms / 16.78 GB** (vs reference 339ms / 5.97 GB), trading +5.9ms latency for −4.3 GB peak memory vs the uncapped variant. A more aggressive `_restore_compute_order` was rejected — it killed 62.8ms of overlap for only 1.3 GB extra savings.

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 2, 2026
@fmassa fmassa requested a review from IvanKobzarev June 2, 2026 19:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant