Skip to content

Model communication-computation overlap in the sharding ILP#353

Draft
fmassa wants to merge 6 commits into
mainfrom
fmassa/comms_compute_overlap_model
Draft

Model communication-computation overlap in the sharding ILP#353
fmassa wants to merge 6 commits into
mainfrom
fmassa/comms_compute_overlap_model

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented Mar 7, 2026

The ILP objective currently treats communication and computation as fully sequential (Σ(comm + compute) * x), which is an upper bound on actual runtime. In practice, parameter redistributions (e.g., all-gathers) can overlap with preceding compute (prefetch), and gradient reduce-scatters can overlap with subsequent compute (post-compute overlap).

This PR models overlap within the ILP using continuous "savings" variables. For each overlappable edge, a savings variable is created with:

  • savings <= comm_cost(selected) — can't save more than the communication
  • Σ savings_using_A <= compute_cost(A, selected) — can't save more than the available compute budget

The solver maximizes savings (since they're subtracted from the objective), computing savings = min(comm, compute_budget).

Two scan passes identify overlappable edges:

  • Forward scan: edges from parameter-derived inputs (propagated transitively through dtype_cast, views, etc.) overlap with preceding compute
  • Reverse scan: edges into terminal-derived nodes (all paths lead to output) overlap with subsequent compute

A shared compute budget constraint across both scans prevents double-counting.

The feature is off by default (enable_prefetch_overlap=False) and can be enabled via AutoParallel or auto_parallel().

Authored with Claude.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 7, 2026
fmassa added 5 commits March 8, 2026 12:18
Previously each savings variable was added directly to every compute partner's budget constraint, which meant savings <= min(compute_cost(A), compute_cost(B)) — bounding it
by the smallest node in the group. This is too conservative: a 12-unit comm overlapping with two compute nodes (5 and 10) should allow up to 12 units of savings, not 5.

Fix by splitting each savings into per-node contribution variables (savings = contrib_A + contrib_B), where each contribution is non-negative and participates in its node's
budget constraint. The solver can now allocate e.g. 5 from A and 7 from B to fully hide 12 units of comm.

Authored with Claude.
… savings logging

Three improvements to the prefetch overlap model:

The forward scan now creates savings variables for all param-derived input edges (not just boundary edges), since the ILP may place the all-gather anywhere in the param
chain (e.g. param → dtype_cast). The compute group is only reset at boundary edges (non-param-derived consumer), so intermediate param-derived edges share the same compute
window and don't fragment the budget. The reverse scan applies symmetric boundary logic.

The violated-constraints logger now uses a 1e-6 tolerance, fixing false positives from floating-point residuals in the continuous savings/contribution equality constraints.

The cost summary now reports overlap_savings and effective_cost when prefetch overlap is enabled.

Authored with Claude.
The group-based overlap model splits each savings variable into per-node
contributions bounded by per-node compute budgets. When a savings variable
is created at a boundary edge, the compute group resets, so leftover
compute budget can't carry forward across windows.

This replaces it with a cumulative budget chain: a continuous LP variable
that grows with compute and shrinks with savings as we scan the graph.

```
B_0 = 0
B_i = B_{i-1} + compute(i) - savings(i)
B_i >= 0
```

The non-negativity of B_i implicitly prevents savings from exceeding
accumulated compute, and leftover budget naturally carries forward. This
is strictly more expressive — a comm preceded by compute windows of 3, 4,
5 can now draw from the full cumulative budget of 12 rather than only its
immediate group.

Also fixes a regression in the backward scan where a boundary_args
filter was too restrictive: in graphs where all nodes are terminal-derived,
it produced zero backward savings.

Authored with Claude.
The two cumulative budget chains (forward for all-gather prefetch,
backward for reduce-scatter) were both counting every node's compute
independently, allowing the same compute to be "spent" twice. This meant
overlap savings could exceed total compute cost.

Fix by introducing a per-node continuous LP variable that splits each
node's compute between the two chains. The forward chain gets
compute_share, the backward gets compute_expr - compute_share, and
the solver decides the optimal allocation. This guarantees total savings
never exceed total compute while still allowing full flexibility in how
compute is distributed between the two overlap directions.

Authored with Claude.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant