Fix double recomputation in torch.compile round-trip#467
Draft
fmassa wants to merge 16 commits into
Draft
Conversation
When `torch.compile(parallel_mod, backend=autoparallel_backend())` compiles an
already-partitioned module, the second AOT autograd partitioner was making
different recomputation decisions from the first, causing redundant computation
and wasted memory.
Three fixes, best reviewed in order:
**1. Tag the full downstream chain in `force_recompute_fsdp_all_gather`**
(`activation_checkpointing.py`). The previous code only tagged allgather,
wait_tensor, slice, and dtype_cast with MUST_RECOMPUTE, missing the
`permute([1,0])` that follows. The first partitioner saved the permute output
(the transposed allgathered weight, 2.62 GiB total), defeating MUST_RECOMPUTE.
The fix walks the single-input chain downstream from wait_tensor and tags all
layout ops.
**2. Replace the second partitioner with `_SaveAllPartitioner`**
(`compile.py`). Instead of running a second `min_cut_rematerialization_partition`
that diverges, use a custom tag-driven partitioner that reproduces the first
min_cut's save/recompute decisions. It uses `classify_nodes` for correct
fwd/bwd boundary detection (handles interleaved backward ops that
`default_partition` can't), CSE to deduplicate baked-in backward allgather
chains, and saves only nodes tagged `ap_must_save` by the first compilation.
**3. Tag forward/backward nodes and saved tensors in the first compilation**
(`api.py`). The first compilation's `fw_compiler` tags forward nodes with
`custom: {ap_graph_part: "forward"}` and marks min_cut's saved-for-backward
tensors with `ap_must_save: True` (or `ap_save_getitems` for multi-output
ops like SDPA). The `bw_compiler` tags backward nodes with
`custom: {ap_graph_part: "backward"}`. Tags survive into the second
compilation via `preserve_node_meta` and `_COPY_META_FIELDS`.
Result (LLaMA-3 8B, 32 layers, 128 GPUs): second compilation's forward
allgather count exactly matches the first (357 = 357). Forward peak memory
6.40 GiB, backward 8.21 GiB.
Authored with Claude.
…stead
Simplifies how the first compilation marks saved-for-backward tensors for
the second compilation to reproduce. No functional change — allgather counts
and peak memory are identical before and after.
For multi-output ops like SDPA, getitem metadata doesn't survive
`preserve_node_meta` (Python builtin, not dispatched). The previous approach
tracked specific getitem indices via `custom: {ap_save_getitems: [0, 1, 6, 7]}`
on the parent. This replaces it with `custom: {ap_must_save: True}` on the
parent — the second partitioner saves all getitem children, and
`_extract_fwd_bwd_modules`'s DCE removes the ones backward doesn't need.
Also removes dead code: `tag_backward` parameter and `ap_graph_part` tagging
were only used by the `must_be_in_backward` approach that was removed.
Authored with Claude.
… single full-mesh allgather When AutoParallel places weights as `S(0)S(0)` on a multi-dim mesh, the `apply_sharding` pass decomposes the `S(0)S(0) → RR` redistribution into per-dim allgathers: first a dp-dim allgather, then a tp-dim allgather. In the backward graph, these appear as recomputed chains with cancelling permute pairs between them: ``` dp_ag → wait → permute([1,0]) → permute([1,0]) → tp_ag → wait ``` Each pair produces two separate NCCL kernel launches when a single full-mesh allgather would suffice. This PR adds `fuse_chained_allgathers`, a graph pass that detects these chains and replaces them with a single allgather on the flattened mesh process group. The pass validates: - Both allgathers are on known mesh subgroups (in descending dim order) - Their group sizes multiply to the full mesh size - The intermediate view ops compose to the identity (verified via FakeTensor shape/stride metadata, requiring at least one non-trivial dimension reorder like permute or transpose) - No intermediate value has other consumers The pass runs as a `pre_pass` on the partitioned forward and backward graphs during both the first compilation (inside `AutoParallel`) and the inference path. On LLaMA-3 8B (dp=16, tp=8, 32 layers), this fuses 49 allgather pairs in the backward, eliminating 49 standalone NCCL kernel launches per iteration. Authored with Claude.
When the ILP picks a placement for an `aten.alias.default` node that differs from its producer's placement, any consumer whose `input_spec` matches the producer's placement would redistribute the alias back to the producer's placement at runtime. The original tensor must then stay alive longer than necessary just to feed the redistribution. The LLaMA-3 backward graph hits this on every transformer block: the gradient at each residual add (`grad_h2`, S(0)S(1)) feeds an alias that the optimizer assigns to S(0)R for two einsum consumers, but a third consumer (the skip-add) wants S(0)S(1) — forcing a redistribution from R back to S(1). `eliminate_alias_round_trips` runs after `get_solution()` and rewires each such consumer directly to the alias's producer. The alias keeps serving any consumer that genuinely needs its placement; if no users remain, the alias is erased from both the graph and the solution dict. Unit tests in `tests/test_graph_utils.py` cover rewiring, alias erasure, the no-op case (alias placement matches producer), intermediate redistributions (consumer wants a third placement), and repeated inputs (`x + x`-style consumers). On the LLaMA-3 8B example (32 layers, 128 GPUs), the pass eliminates 32 round-trips. Authored with Claude. Authored with Claude
…odel Port TorchTitan's apply_ac activation checkpointing implementation into AutoParallel's LLaMA3 model definition, replacing the graph-level AC pass with module-level checkpoint wrapping applied before tracing. apply_ac(model, mode, selective_ac_option) supports three strategies matching TorchTitan: full AC (every layer), layer-selective (every nth layer), and per-op selective (saves compute-heavy ops like matmuls and SDPA but recomputes every 2nd matmul). The example uses op-level selective AC and disables the autoparallel backend's built-in AC pass to avoid double-checkpointing. Authored with Claude.
Three fixes for overlap scheduling and memory: 1. **Fix config.patch expiration for backward compilation** (`compile.py`): Overlap scheduling configs (`enable_overlap_scheduling`, `collective_bucketing`, etc.) were set via `config.patch` context manager inside `autoparallel_backend`. Since backward compilation is lazy (triggered on first `.backward()` after `compile_fx` returns), the context manager has already exited and the configs revert to defaults. Fix: set these configs globally instead. The `custom_partitioner_fn` stays in the context manager since it only affects forward partitioning. 2. **Add compute batch size capping** (`auto_bucketing.py`): After the overlap scheduler runs, AG bucketing can rewire dependencies such that `stable_topological_sort` batches hundreds of compute ops together (e.g., MM×40). This causes massive memory spikes. `_cap_compute_batch_size` adds ordering dependencies between compute nodes and their nearest reduce-scatter ops, limiting consecutive compute to `max_consecutive=8` ops. Also patches FSDP bucketing to use only the primary group and allow non-adjacent collectives, and increases `max_compute_pre_fetch` from 5 to 50. 3. **Save first residual add in SAC policy** (`llama3.py`): The selective AC policy now saves the first `aten.add.Tensor` output per transformer block (the attention residual `h = x + attn(...)`). Without this, the backward recomputes more ops between layers, giving the scheduler more latitude to prefetch and inflating peak memory. This reduced the post-scheduling memory gap vs reference from +3.30 GB to +0.39 GB. The `example_llama3.py` changes are incidental (world_size, collective_bucketing flag, debug JSON dump).
…rder The `fuse_chained_allgathers` pass never fired for the unconstrained LLaMA-3 8B model despite 64 fusible dp→tp allgather chains in the backward graph (2 per layer × 32 layers, from `S(0)S(0) → RS(0) → RR` weight unsharding). Three independent issues prevented fusion: 1. The row-major flat mesh (`mesh._flatten()`) has the wrong rank ordering for dp→tp chains produced by reverse shard order. The fix creates a column-major process group via `dist.new_group(col_major_ranks, sort_ranks=False)` and uses it for ascending (dp→tp) chains, while the row-major mesh continues to serve descending (tp→dp) chains. 2. AOT autograd eliminates the canceling permutes between the two allgathers, leaving a direct chain (`ag1 → wait → ag2`) that `_is_identity_view_chain` rejected. Direct chains are now accepted when `subgroup_order` validates the direction and the matching process group is available. 3. The subgroup_order direction check (commit bf4c912 changed `<=` to `>=`) accepted ascending chains on the row-major flat mesh — the wrong combination. The check is replaced by explicit direction-to-group routing. Authored with Claude.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When
torch.compile(parallel_mod, backend=autoparallel_backend())compiles an already-partitioned module, the second AOT autograd partitioner was making different recomputation decisions from the first, causing redundant computation and wasted memory.The first commit fixes the issue and the second simplifies the implementation with no functional changes.
The first compilation's
fw_compilernow tags forward output tensors withcustom: {ap_must_save: True}, marking which tensors the first min_cut decided to save for backward. For multi-output ops (like SDPA), the parent op is tagged instead (getitem metadata doesn't survivepreserve_node_meta), and DCE removes unused getitem children. Tags survive into the second compilation via_COPY_META_FIELDS.The second compilation replaces
min_cut_rematerialization_partitionwith_SaveAllPartitioner— a tag-drivenCustomPartitionerFnthat:classify_nodesfor topology-based fwd/bwd boundary detection (handles interleaved backward ops thatdefault_partitioncan't)ap_must_savenodes, reproducing the first min_cut's decisionsac_joint_pass'sPREFER_RECOMPUTEoverrides for memory optimizationAlso fixes
force_recompute_fsdp_all_gatherto tag the full downstream single-input chain fromwait_tensor(permute, view, etc.), preventing the first partitioner from saving 2.62 GiB of transposed allgathered weights.Result on LLaMA-3 8B (32 layers, 128 GPUs): second compilation's forward allgather count exactly matches the first (357 = 357). Peak memory 6.40 GiB forward, 8.21 GiB backward.
Authored with Claude.