Skip to content

Fix double recomputation in torch.compile round-trip#467

Draft
fmassa wants to merge 16 commits into
mainfrom
fmassa/double_recomp
Draft

Fix double recomputation in torch.compile round-trip#467
fmassa wants to merge 16 commits into
mainfrom
fmassa/double_recomp

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented May 20, 2026

When torch.compile(parallel_mod, backend=autoparallel_backend()) compiles an already-partitioned module, the second AOT autograd partitioner was making different recomputation decisions from the first, causing redundant computation and wasted memory.

The first commit fixes the issue and the second simplifies the implementation with no functional changes.

The first compilation's fw_compiler now tags forward output tensors with custom: {ap_must_save: True}, marking which tensors the first min_cut decided to save for backward. For multi-output ops (like SDPA), the parent op is tagged instead (getitem metadata doesn't survive preserve_node_meta), and DCE removes unused getitem children. Tags survive into the second compilation via _COPY_META_FIELDS.

The second compilation replaces min_cut_rematerialization_partition with
_SaveAllPartitioner — a tag-driven CustomPartitionerFn that:

  • Uses classify_nodes for topology-based fwd/bwd boundary detection (handles interleaved backward ops that default_partition can't)
  • Runs CSE to deduplicate allgather chains baked in from the first backward
  • Saves only ap_must_save nodes, reproducing the first min_cut's decisions
  • Respects ac_joint_pass's PREFER_RECOMPUTE overrides for memory optimization

Also fixes force_recompute_fsdp_all_gather to tag the full downstream single-input chain from wait_tensor (permute, view, etc.), preventing the first partitioner from saving 2.62 GiB of transposed allgathered weights.

Result on LLaMA-3 8B (32 layers, 128 GPUs): second compilation's forward allgather count exactly matches the first (357 = 357). Peak memory 6.40 GiB forward, 8.21 GiB backward.

Authored with Claude.

fmassa added 2 commits May 20, 2026 07:38
When `torch.compile(parallel_mod, backend=autoparallel_backend())` compiles an
already-partitioned module, the second AOT autograd partitioner was making
different recomputation decisions from the first, causing redundant computation
and wasted memory.

Three fixes, best reviewed in order:

**1. Tag the full downstream chain in `force_recompute_fsdp_all_gather`**
(`activation_checkpointing.py`). The previous code only tagged allgather,
wait_tensor, slice, and dtype_cast with MUST_RECOMPUTE, missing the
`permute([1,0])` that follows. The first partitioner saved the permute output
(the transposed allgathered weight, 2.62 GiB total), defeating MUST_RECOMPUTE.
The fix walks the single-input chain downstream from wait_tensor and tags all
layout ops.

**2. Replace the second partitioner with `_SaveAllPartitioner`**
(`compile.py`). Instead of running a second `min_cut_rematerialization_partition`
that diverges, use a custom tag-driven partitioner that reproduces the first
min_cut's save/recompute decisions. It uses `classify_nodes` for correct
fwd/bwd boundary detection (handles interleaved backward ops that
`default_partition` can't), CSE to deduplicate baked-in backward allgather
chains, and saves only nodes tagged `ap_must_save` by the first compilation.

**3. Tag forward/backward nodes and saved tensors in the first compilation**
(`api.py`). The first compilation's `fw_compiler` tags forward nodes with
`custom: {ap_graph_part: "forward"}` and marks min_cut's saved-for-backward
tensors with `ap_must_save: True` (or `ap_save_getitems` for multi-output
ops like SDPA). The `bw_compiler` tags backward nodes with
`custom: {ap_graph_part: "backward"}`. Tags survive into the second
compilation via `preserve_node_meta` and `_COPY_META_FIELDS`.

Result (LLaMA-3 8B, 32 layers, 128 GPUs): second compilation's forward
allgather count exactly matches the first (357 = 357). Forward peak memory
6.40 GiB, backward 8.21 GiB.

Authored with Claude.
…stead

Simplifies how the first compilation marks saved-for-backward tensors for
the second compilation to reproduce. No functional change — allgather counts
and peak memory are identical before and after.

For multi-output ops like SDPA, getitem metadata doesn't survive
`preserve_node_meta` (Python builtin, not dispatched). The previous approach
tracked specific getitem indices via `custom: {ap_save_getitems: [0, 1, 6, 7]}`
on the parent. This replaces it with `custom: {ap_must_save: True}` on the
parent — the second partitioner saves all getitem children, and
`_extract_fwd_bwd_modules`'s DCE removes the ones backward doesn't need.

Also removes dead code: `tag_backward` parameter and `ap_graph_part` tagging
were only used by the `must_be_in_backward` approach that was removed.

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 20, 2026
fmassa added 7 commits May 20, 2026 13:52
… single full-mesh allgather

When AutoParallel places weights as `S(0)S(0)` on a multi-dim mesh,
the `apply_sharding` pass decomposes the `S(0)S(0) → RR` redistribution
into per-dim allgathers: first a dp-dim allgather, then a tp-dim allgather.
In the backward graph, these appear as recomputed chains with cancelling
permute pairs between them:

```
dp_ag → wait → permute([1,0]) → permute([1,0]) → tp_ag → wait
```

Each pair produces two separate NCCL kernel launches when a single
full-mesh allgather would suffice.

This PR adds `fuse_chained_allgathers`, a graph pass that detects these
chains and replaces them with a single allgather on the flattened mesh
process group. The pass validates:
- Both allgathers are on known mesh subgroups (in descending dim order)
- Their group sizes multiply to the full mesh size
- The intermediate view ops compose to the identity (verified via
  FakeTensor shape/stride metadata, requiring at least one non-trivial
  dimension reorder like permute or transpose)
- No intermediate value has other consumers

The pass runs as a `pre_pass` on the partitioned forward and backward
graphs during both the first compilation (inside `AutoParallel`) and
the inference path. On LLaMA-3 8B (dp=16, tp=8, 32 layers), this fuses
49 allgather pairs in the backward, eliminating 49 standalone NCCL
kernel launches per iteration.

Authored with Claude.
When the ILP picks a placement for an `aten.alias.default` node that differs
from its producer's placement, any consumer whose `input_spec` matches the
producer's placement would redistribute the alias back to the producer's
placement at runtime. The original tensor must then stay alive longer than
necessary just to feed the redistribution.

The LLaMA-3 backward graph hits this on every transformer block: the
gradient at each residual add (`grad_h2`, S(0)S(1)) feeds an alias that the
optimizer assigns to S(0)R for two einsum consumers, but a third consumer
(the skip-add) wants S(0)S(1) — forcing a redistribution from R back to
S(1).

`eliminate_alias_round_trips` runs after `get_solution()` and rewires each
such consumer directly to the alias's producer. The alias keeps serving any
consumer that genuinely needs its placement; if no users remain, the alias
is erased from both the graph and the solution dict.

Unit tests in `tests/test_graph_utils.py` cover rewiring, alias erasure,
the no-op case (alias placement matches producer), intermediate
redistributions (consumer wants a third placement), and repeated inputs
(`x + x`-style consumers). On the LLaMA-3 8B example (32 layers, 128 GPUs),
the pass eliminates 32 round-trips. Authored with Claude.

Authored with Claude
…odel

Port TorchTitan's apply_ac activation checkpointing implementation into AutoParallel's LLaMA3 model definition, replacing the graph-level AC pass with module-level checkpoint wrapping applied before tracing.

apply_ac(model, mode, selective_ac_option) supports three strategies matching TorchTitan: full AC (every layer), layer-selective (every nth layer), and per-op selective (saves compute-heavy ops like matmuls and SDPA but recomputes every 2nd matmul). The example uses op-level selective AC and disables the autoparallel backend's built-in AC pass to avoid double-checkpointing.

Authored with Claude.
@fmassa fmassa marked this pull request as draft May 28, 2026 11:13
fmassa added 7 commits May 28, 2026 11:27
Three fixes for overlap scheduling and memory:

1. **Fix config.patch expiration for backward compilation** (`compile.py`): Overlap scheduling configs (`enable_overlap_scheduling`, `collective_bucketing`, etc.) were set via `config.patch` context manager inside `autoparallel_backend`. Since backward compilation is lazy (triggered on first `.backward()` after `compile_fx` returns), the context manager has already exited and the configs revert to defaults. Fix: set these configs globally instead. The `custom_partitioner_fn` stays in the context manager since it only affects forward partitioning.

2. **Add compute batch size capping** (`auto_bucketing.py`): After the overlap scheduler runs, AG bucketing can rewire dependencies such that `stable_topological_sort` batches hundreds of compute ops together (e.g., MM×40). This causes massive memory spikes. `_cap_compute_batch_size` adds ordering dependencies between compute nodes and their nearest reduce-scatter ops, limiting consecutive compute to `max_consecutive=8` ops. Also patches FSDP bucketing to use only the primary group and allow non-adjacent collectives, and increases `max_compute_pre_fetch` from 5 to 50.

3. **Save first residual add in SAC policy** (`llama3.py`): The selective AC policy now saves the first `aten.add.Tensor` output per transformer block (the attention residual `h = x + attn(...)`). Without this, the backward recomputes more ops between layers, giving the scheduler more latitude to prefetch and inflating peak memory. This reduced the post-scheduling memory gap vs reference from +3.30 GB to +0.39 GB.

The `example_llama3.py` changes are incidental (world_size, collective_bucketing flag, debug JSON dump).
…rder

The `fuse_chained_allgathers` pass never fired for the unconstrained LLaMA-3
8B model despite 64 fusible dp→tp allgather chains in the backward graph
(2 per layer × 32 layers, from `S(0)S(0) → RS(0) → RR` weight unsharding).

Three independent issues prevented fusion:

1. The row-major flat mesh (`mesh._flatten()`) has the wrong rank ordering for
   dp→tp chains produced by reverse shard order. The fix creates a column-major
   process group via `dist.new_group(col_major_ranks, sort_ranks=False)` and
   uses it for ascending (dp→tp) chains, while the row-major mesh continues to
   serve descending (tp→dp) chains.

2. AOT autograd eliminates the canceling permutes between the two allgathers,
   leaving a direct chain (`ag1 → wait → ag2`) that `_is_identity_view_chain`
   rejected. Direct chains are now accepted when `subgroup_order` validates the
   direction and the matching process group is available.

3. The subgroup_order direction check (commit bf4c912 changed `<=` to `>=`)
   accepted ascending chains on the row-major flat mesh — the wrong
   combination. The check is replaced by explicit direction-to-group routing.

Authored with Claude.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant