Skip to content

Pre-overlap collective bucketing pass for FSDP/DDP#412

Open
fmassa wants to merge 1 commit into
mainfrom
fmassa/bucketing_pass
Open

Pre-overlap collective bucketing pass for FSDP/DDP#412
fmassa wants to merge 1 commit into
mainfrom
fmassa/bucketing_pass

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented Apr 3, 2026

PyTorch's overlap scheduler adds sequential timeline dependencies between all consecutive events on each process group, which prevents its downstream bucketer from merging collectives that were originally independent. This PR adds a pre-pass that merges per-parameter FSDP/DDP collectives before the overlap scheduler runs, so it sees fewer, larger collectives.

The pass targets three collective types: forward all-gathers (param-derived), backward reduce-scatters (terminal-derived), and backward all-reduces (terminal-derived, for DDP).

The implementation is split into two phases:

  • Tagging runs on the joint graph (where placeholder metadata is available) and marks eligible collectives via node.meta. Tags survive the fw/bw partition via node_copy's shallow copy.
  • Bucketing runs on the split fw/bw graphs inside the compiler, reads the tags, and merges collectives using PyTorch's existing merge functions.

This PR also fixes a pre-existing bug in _copy_descriptors_and_rename_placeholders where make_fx could nest the output tuple while desc stayed flat, causing get_all_input_and_grad_nodes's zip to silently mismatch output nodes with their descriptors.

Results on LLaMA: AG 290→98 (fwd), 225→97 (bwd recomputed), RS 290→194 (bwd).

Test plan

  • example_autoparallel.py passes (compile=False)
  • example_llama3.py passes
  • pytest tests/ passes

Authored with Claude.

PyTorch's overlap scheduler adds sequential timeline dependencies between all consecutive events on each process group, which prevents its downstream bucketer from merging collectives that were originally independent. This PR adds a pre-pass
that merges per-parameter FSDP/DDP collectives before the overlap scheduler runs, so it sees fewer, larger collectives.

The pass targets three collective types: forward all-gathers (param-derived), backward reduce-scatters (terminal-derived), and backward all-reduces (terminal-derived, for DDP).

The implementation is split into two phases:
- Tagging runs on the joint graph (where placeholder metadata is available) and marks eligible collectives via node.meta. Tags survive the fw/bw partition via node_copy's shallow copy.
- Bucketing runs on the split fw/bw graphs inside the compiler, reads the tags, and merges collectives using PyTorch's existing merge functions.

This PR also fixes a pre-existing bug in _copy_descriptors_and_rename_placeholders where make_fx could nest the output tuple while desc stayed flat, causing get_all_input_and_grad_nodes's zip to silently mismatch output nodes with their
descriptors.

Results on LLaMA: AG 290→98 (fwd), 225→97 (bwd recomputed), RS 290→194 (bwd).

Test plan

- example_autoparallel.py passes (compile=False)
- example_llama3.py passes
- pytest tests/ passes

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants