Skip to content

Bypass aliases for consumers that would redistribute to the producer#479

Open
fmassa wants to merge 2 commits into
mainfrom
fmassa/bypass_alias
Open

Bypass aliases for consumers that would redistribute to the producer#479
fmassa wants to merge 2 commits into
mainfrom
fmassa/bypass_alias

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented May 27, 2026

When the ILP picks a placement for an aten.alias.default node that differs from its producer's placement, any consumer whose input_spec matches the producer's placement would redistribute the alias back to the producer's placement at runtime. The original tensor must then stay alive longer than necessary just to feed the redistribution.

The LLaMA-3 backward graph hits this on every transformer block: the gradient at each residual add (grad_h2, S(0)S(1)) feeds an alias that the optimizer assigns to S(0)R for two einsum consumers, but a third consumer (the skip-add) wants S(0)S(1) — forcing a redistribution from R back to S(1).

eliminate_alias_round_trips runs after get_solution() and rewires each such consumer directly to the alias's producer. The alias keeps serving any consumer that genuinely needs its placement; if no users remain, the alias is erased from both the graph and the solution dict.

Unit tests in tests/test_graph_utils.py cover rewiring, alias erasure, the no-op case (alias placement matches producer), intermediate redistributions (consumer wants a third placement), and repeated inputs (x + x-style consumers). On the LLaMA-3 8B example (32 layers, 128 GPUs), the pass eliminates 32 round-trips. Authored with Claude.

Authored with Claude

When the ILP picks a placement for an `aten.alias.default` node that differs
from its producer's placement, any consumer whose `input_spec` matches the
producer's placement would redistribute the alias back to the producer's
placement at runtime. The original tensor must then stay alive longer than
necessary just to feed the redistribution.

The LLaMA-3 backward graph hits this on every transformer block: the
gradient at each residual add (`grad_h2`, S(0)S(1)) feeds an alias that the
optimizer assigns to S(0)R for two einsum consumers, but a third consumer
(the skip-add) wants S(0)S(1) — forcing a redistribution from R back to
S(1).

`eliminate_alias_round_trips` runs after `get_solution()` and rewires each
such consumer directly to the alias's producer. The alias keeps serving any
consumer that genuinely needs its placement; if no users remain, the alias
is erased from both the graph and the solution dict.

Unit tests in `tests/test_graph_utils.py` cover rewiring, alias erasure,
the no-op case (alias placement matches producer), intermediate
redistributions (consumer wants a third placement), and repeated inputs
(`x + x`-style consumers). On the LLaMA-3 8B example (32 layers, 128 GPUs),
the pass eliminates 32 round-trips. Authored with Claude.

Authored with Claude
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant