Skip to content

Fuse chained allgathers on different subgroups into single full-mesh allgather#472

Draft
fmassa wants to merge 3 commits into
mainfrom
fmassa/fuse-allgather
Draft

Fuse chained allgathers on different subgroups into single full-mesh allgather#472
fmassa wants to merge 3 commits into
mainfrom
fmassa/fuse-allgather

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented May 21, 2026

When the forward and backward use different shardings for a weight, the backward's recomputed allgather chain can decompose into two sequential allgathers on different mesh dimensions — e.g. S(0)S(0) → RS(0) (dp allgather) then RS(0) → RR (tp allgather) — with cancelling permute pairs between them. While the forward path already fuses S(0)S(0) → RR into a single collective via _optimize_same_nd_sharding_as_1d, these recomputed backward chains bypass that optimization and produce two separate NCCL kernel launches.

This adds fuse_chained_allgathers, a graph pass that detects these chains and replaces them with a single allgather on the flattened mesh process group. The pass validates that both allgathers are on known mesh subgroups in descending dim order, their group sizes multiply to the full mesh size, and the intermediate view ops compose to the identity (verified via FakeTensor shape/stride metadata). The pass runs on the partitioned forward and backward graphs during the first compilation and on the inference path, gated on mesh.ndim > 1.

Authored with Claude.

…allgather

When weights are placed as `S(0)S(0)` on a multi-dim mesh, `apply_sharding` decomposes the `S(0)S(0) → RR` redistribution into per-dim allgathers: a dp-dim allgather followed by a tp-dim allgather, with cancelling permute pairs between them. Each pair produces two separate NCCL kernel launches when a single full-mesh allgather would suffice.

This adds `fuse_chained_allgathers`, a graph pass that detects these chains and replaces them with a single allgather on the flattened mesh process group. The pass validates that both allgathers are on known mesh subgroups in descending dim order, their group sizes multiply to the full mesh size, and the intermediate view ops compose to the identity (verified via FakeTensor shape/stride metadata). The pass runs on the partitioned forward and backward graphs during the first compilation and on the inference path, gated on `mesh.ndim > 1`.

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 21, 2026
@fmassa fmassa marked this pull request as draft May 29, 2026 06:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant