Improve FSDP bucketing and cap compute batches between ReduceScatters#481
Open
fmassa wants to merge 1 commit into
Open
Improve FSDP bucketing and cap compute batches between ReduceScatters#481fmassa wants to merge 1 commit into
fmassa wants to merge 1 commit into
Conversation
Ports three related changes to `auto_bucketing.py` from `fmassa/double_recomp` (cherry of 428e9d2). Together they address an interaction between FSDP all-gather bucketing and `stable_topological_sort` that batches recomputation MMs upfront and inflates backward peak memory. **1. `_patch_fsdp_bucketing()` — monkey-patches PyTorch's FSDP bucketing** - *Primary-group-only*: `identify_fsdp_groups` keeps only the group with the most FSDP all-gathers, so minority groups (1 tp AG, ~65 norm AGs on the combined group) no longer pollute the bucketing pool. - *Non-adjacent bucketing*: `greedy_bucket_collective_by_mb` collects all eligible collectives per group key instead of requiring graph adjacency, allowing dp AGs interleaved with tp activation collectives to bucket together. **2. `max_compute_pre_fetch`: 5 → 50** The previous value allowed only ~0.3 layers of prefetch (≈17 compute nodes/layer), insufficient to hide 5–7ms full-mesh AGs. 50 gives ≈3 layers of headroom. **3. `_cap_compute_batch_size(max_consecutive=8)`** After bucketing rewires dependencies, `stable_topological_sort` reorders 525/540 compute ops into an `MM*40` block before the first backward RS, blowing up peak memory. Snapshots the original compute/RS interleaving before scheduling, then for any post-schedule segment with >8 compute nodes between RS ops, chains chunks and pulls forward an RS that originally sat between them. Falls back gracefully if a cycle is detected. **Why this matters (from prior investigation, LLaMA-3 8B, 128 H100s, dp=16/tp=8):** - Bucketing patches + prefetch bump closed the unconstrained AP-to-reference gap from 12.1% → 4.6% (385ms vs 368ms reference). Prefetch alone: −20ms (−4.6%) unconstrained. - `_cap_compute_batch_size` brings constrained from a runaway `MM*40` recomp batch down to **358ms / 16.78 GB** (vs reference 339ms / 5.97 GB), trading +5.9ms latency for −4.3 GB peak memory vs the uncapped variant. A more aggressive `_restore_compute_order` was rejected — it killed 62.8ms of overlap for only 1.3 GB extra savings. Authored with Claude.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Ports three related changes to
auto_bucketing.pyfromfmassa/double_recomp(cherry of 428e9d2). Together they address an interaction between FSDP all-gather bucketing andstable_topological_sortthat batches recomputation MMs upfront and inflates backward peak memory.1.
_patch_fsdp_bucketing()— monkey-patches PyTorch's FSDP bucketingidentify_fsdp_groupskeeps only the group with the most FSDP all-gathers, so minority groups (1 tp AG, ~65 norm AGs on the combined group) no longer pollute the bucketing pool.greedy_bucket_collective_by_mbcollects all eligible collectives per group key instead of requiring graph adjacency, allowing dp AGs interleaved with tp activation collectives to bucket together.2.
max_compute_pre_fetch: 5 → 50The previous value allowed only ~0.3 layers of prefetch (≈17 compute nodes/layer), insufficient to hide 5–7ms full-mesh AGs. 50 gives ≈3 layers of headroom.
3.
_cap_compute_batch_size(max_consecutive=8)After bucketing rewires dependencies,stable_topological_sortreorders 525/540 compute ops into anMM*40block before the first backward RS, blowing up peak memory. Snapshots the original compute/RS interleaving before scheduling, then for any post-schedule segment with >8 compute nodes between RS ops, chains chunks and pulls forward an RS that originally sat between them. Falls back gracefully if a cycle is detected.Why this matters (from prior investigation, LLaMA-3 8B, 128 H100s, dp=16/tp=8):
_cap_compute_batch_sizebrings constrained from a runawayMM*40recomp batch down to 358ms / 16.78 GB (vs reference 339ms / 5.97 GB), trading +5.9ms latency for −4.3 GB peak memory vs the uncapped variant. A more aggressive_restore_compute_orderwas rejected — it killed 62.8ms of overlap for only 1.3 GB extra savings.Authored with Claude.