Fuse chained allgathers on different subgroups into single full-mesh allgather#472
Draft
fmassa wants to merge 3 commits into
Draft
Fuse chained allgathers on different subgroups into single full-mesh allgather#472fmassa wants to merge 3 commits into
fmassa wants to merge 3 commits into
Conversation
…allgather When weights are placed as `S(0)S(0)` on a multi-dim mesh, `apply_sharding` decomposes the `S(0)S(0) → RR` redistribution into per-dim allgathers: a dp-dim allgather followed by a tp-dim allgather, with cancelling permute pairs between them. Each pair produces two separate NCCL kernel launches when a single full-mesh allgather would suffice. This adds `fuse_chained_allgathers`, a graph pass that detects these chains and replaces them with a single allgather on the flattened mesh process group. The pass validates that both allgathers are on known mesh subgroups in descending dim order, their group sizes multiply to the full mesh size, and the intermediate view ops compose to the identity (verified via FakeTensor shape/stride metadata). The pass runs on the partitioned forward and backward graphs during the first compilation and on the inference path, gated on `mesh.ndim > 1`. Authored with Claude.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When the forward and backward use different shardings for a weight, the backward's recomputed allgather chain can decompose into two sequential allgathers on different mesh dimensions — e.g.
S(0)S(0) → RS(0)(dp allgather) thenRS(0) → RR(tp allgather) — with cancelling permute pairs between them. While the forward path already fusesS(0)S(0) → RRinto a single collective via_optimize_same_nd_sharding_as_1d, these recomputed backward chains bypass that optimization and produce two separate NCCL kernel launches.This adds
fuse_chained_allgathers, a graph pass that detects these chains and replaces them with a single allgather on the flattened mesh process group. The pass validates that both allgathers are on known mesh subgroups in descending dim order, their group sizes multiply to the full mesh size, and the intermediate view ops compose to the identity (verified via FakeTensor shape/stride metadata). The pass runs on the partitioned forward and backward graphs during the first compilation and on the inference path, gated onmesh.ndim > 1.Authored with Claude.