Skip to content

Preserve first-partitioner decisions in torch.compile via _SaveAllPartitioner#485

Open
fmassa wants to merge 10 commits into
mainfrom
fmassa/new_partitioner
Open

Preserve first-partitioner decisions in torch.compile via _SaveAllPartitioner#485
fmassa wants to merge 10 commits into
mainfrom
fmassa/new_partitioner

Conversation

@fmassa

@fmassa fmassa commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Reproduce the first partitioner's save/recompute decisions in the second compilation (torch.compile with autoparallel_backend) so FSDP allgather chains stay recomputed instead of being saved as activations.

Background. AutoParallel partitions the joint graph twice: once inside apply_placement via aot_compile_joint_with_descriptors, and again when the user calls torch.compile(parallel_mod, backend=autoparallel_backend()). The two partitioners operate on structurally different graphs, so the second partitioner's independent min-cut can diverge from the first's decisions. The most visible symptom: with the AC joint pass active in the second compilation, PREFER_RECOMPUTE tags on compute ops cause min-cut to recompute matmuls in backward, which in turn pulls FSDP allgather outputs into backward as live dependencies — and force_save_collectives then pins them as MUST_SAVE. The result is ~1.2 GB of extra activation memory per 4 transformer layers from saving allgathered weights that should be recomputed via FSDP prefetch.

This adds _SaveAllPartitioner, an inductor custom_partitioner_fn that reads custom.ap_must_save tags placed by the first compilation (via _boxed_nop_preserve_node_meta(tag_forward=True)) and saves exactly those nodes — sidestepping min-cut's independent decisions. The tags propagate to the second compilation through preserve_node_meta. FSDP MUST_RECOMPUTE tags also survive, so even users who don't pass autoparallel_backend still get correct FSDP recomputation from the default partitioner.

Supporting machinery:

  • _patch_partitioner_dce makes the partitioner's is_not_collective callback DCE-eligible for wait_tensor so unused collectives can be eliminated (the partitioner has its own DCE that would otherwise override _suppress_wait_tensor_side_effect).
  • autoparallel_backend wires custom_partitioner_fn via torch._inductor.config.patch (forward-only) and keeps overlap scheduling configs in compile_fx's config_patches (persists to lazy backward compilation).

Recommended review order: api.py (the tagging + fw_compiler wiring), compile.py (_SaveAllPartitioner and the backend), then tests/test_save_all_partitioner.py for the full picture of what's being verified.

Authored with Claude.

…titioner

Reproduce the first partitioner's save/recompute decisions in the second
compilation (torch.compile with autoparallel_backend) so FSDP allgather
chains stay recomputed instead of being saved as activations.

Background. AutoParallel partitions the joint graph twice: once inside
apply_placement via aot_compile_joint_with_descriptors, and again when
the user calls torch.compile(parallel_mod, backend=autoparallel_backend()).
The two partitioners operate on structurally different graphs, so the
second partitioner's independent min-cut can diverge from the first's
decisions. The most visible symptom: with the AC joint pass active in
the second compilation, PREFER_RECOMPUTE tags on compute ops cause min-cut
to recompute matmuls in backward, which in turn pulls FSDP allgather
outputs into backward as live dependencies — and force_save_collectives
then pins them as MUST_SAVE. The result is ~1.2 GB of extra activation
memory per 4 transformer layers from saving allgathered weights that
should be recomputed via FSDP prefetch.

This adds _SaveAllPartitioner, an inductor custom_partitioner_fn that
reads `custom.ap_must_save` tags placed by the first compilation
(via _boxed_nop_preserve_node_meta(tag_forward=True)) and saves exactly
those nodes — sidestepping min-cut's independent decisions. The tags
propagate to the second compilation through preserve_node_meta. FSDP
MUST_RECOMPUTE tags also survive, so even users who don't pass
autoparallel_backend still get correct FSDP recomputation from the
default partitioner.

Supporting machinery:
- _patch_partitioner_dce makes the partitioner's is_not_collective
  callback DCE-eligible for wait_tensor so unused collectives can be
  eliminated (the partitioner has its own DCE that would otherwise
  override _suppress_wait_tensor_side_effect).
- autoparallel_backend wires custom_partitioner_fn via
  torch._inductor.config.patch (forward-only) and keeps overlap
  scheduling configs in compile_fx's config_patches (persists to
  lazy backward compilation).

Recommended review order: api.py (the tagging + fw_compiler wiring),
compile.py (_SaveAllPartitioner and the backend), then
tests/test_save_all_partitioner.py for the full picture of what's
being verified.

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 9, 2026
fmassa added 9 commits June 9, 2026 09:53
  - Removed unused `pre_pass` parameter from `_boxed_nop_preserve_node_meta`
  - Moved `from functools import partial` to top-of-file imports

- **`autoparallel/compile.py`**:
  - Rewrote `_SaveAllPartitioner` docstring to describe what it actually does (uses `ap_must_save` tags) and why (sidesteps the `force_save_collectives` + AC interaction)
  - Added comment explaining why we keep the `force_save_collectives`/`force_save_effectful_ops`/`force_save_bw_mutation_src` calls even though they don't affect our save decision

- **Removed dead test pollution**: the `AutoParallel._make_fuse_allgather_pass = lambda self: None` line was leftover from a different branch
- **Removed redundant test**: `test_autoparallel_backend_includes_save_all_partitioner` was a tautology
- **Added bad-case test**: `test_default_partitioner_diverges_from_save_all_partitioner` — proves the default min-cut partitioner produces a different save list than `_SaveAllPartitioner` when AC is active (the motivating divergence)
- **Added regression-guard test**: `test_save_all_partitioner_reproduces_first_partitioner_saves` — confirms `_SaveAllPartitioner`'s saved set approximately matches what the first partitioner saved (by shape histogram, tolerant to view/reshape differences from retracing)

- All 12 tests in `test_save_all_partitioner.py` pass (~4 minutes runtime)
- `test_api.py` and `test_activation_checkpointing.py` still pass (one pre-existing unrelated failure on main)
- Lint clean (F401, F841 checks)
Three behavioral changes plus tests and clarifying comments.

1. MUST_SAVE on a multi-output parent now overrides ap_must_save_getitem_indices.
   Previously, if a multi-output op was both MUST_SAVE and ap_must_save with
   a specific index list, only the indexed children were saved — silently
   under-saving relative to what MUST_SAVE means ("save all tensor outputs
   needed from this op"). The fix makes MUST_SAVE clear the index restriction,
   keeping exact replay for ap_must_save while keeping PyTorch's MUST_SAVE
   tags conservative.

2. Deduplicate saved_values, saved_sym_nodes, and saved_opaque_nodes before
   handing them to _extract_fwd_bwd_modules. Matches upstream
   default_partition's dict.fromkeys pattern. Defensive — duplicates don't
   arise in the current control flow, but they could under future
   refactors of the iteration order or tag combinations, and the cost is
   one line per list.

3. Added clarifying comments on three corner cases:
   - Opaque nodes (ProcessGroup, ScriptObject) are saved unconditionally
     regardless of ap_must_save/MUST_SAVE; documents the intentional
     deviation from pure replay semantics, matching the standard partitioner.
   - The inference fallback to default_partition when there are no backward
     nodes; documents that ap_must_save tags are bypassed there because
     inference doesn't have the fwd/bwd-divergence problem the partitioner
     exists to solve.
   - CSE merges duplicate chains without combining metadata; documents that
     the safety contract holds for FSDP allgather chains (first occurrence
     keeps MUST_RECOMPUTE) and that general replay across CSE'd duplicates
     is a known limitation.

Tests added (all verified to catch a real regression by reverting the
corresponding fix and checking the assertion fires):
   - test_save_all_partitioner_must_save_overrides_getitem_indices
   - test_save_all_partitioner_replays_only_indexed_getitems (locks in the
     index-specific replay precision from the earlier round)
   - test_save_all_partitioner_must_recompute_blocks_opaque_save and
     test_save_all_partitioner_must_recompute_blocks_multi_output_save
     (regression guards for the _must_recompute ordering)
   - test_save_all_partitioner_multi_output_with_non_getitem_user (covers
     the tuple-aware _is_multi_output check)

Authored with Claude.
Summary

Two structural changes

1. **Capture-only backend for partitioner-checking tests.** Replaced `_capture_partitioner_call`'s implementation to wire `_SaveAllPartitioner` directly via `aot_module_simplified` with identity compilers. The partitioner runs and is captured, but no Triton kernel codegen happens. The function signature changed from accepting a pre-built `backend` to accepting `enable_ac=` directly.

2. **Module-scoped fixtures for `parallel_mod`.** Added `parallel_mod_2d` and `parallel_mod_1d` fixtures that cache the AutoParallel solve + apply_placement (~31s for 2D). Five integration tests now share the same parallel module instead of rebuilding it.

Test cleanup

- Removed `test_save_all_partitioner_runs_end_to_end` (redundant with `compile_with_ac_enabled`).
- Removed `test_save_all_partitioner_uses_ap_must_save_tags` (its assertions folded into `does_not_save_fsdp_wait_tensors`).

Result

- **252s → 135s** (47% reduction), 21 → 19 tests
- Capture-only tests now run in 3-6s each (was 25-50s with full Inductor)
- The two real-Inductor smoke tests (`compile_with_ac_enabled`, `compile_1d_mesh`) remain to guard against full-pipeline regressions

The breakdown:
- 14 unit tests: ~1s total
- 1 reproduces test (calls `_capture_first_partitioner_saves` which builds its own AutoParallel): ~24s
- 3 capture-only 2D integration tests: ~9-15s combined
- 1 Inductor 2D smoke test: ~14s
- 1 Inductor 1D smoke test (separate mesh): ~45s + 8s setup
- Module-scope 2D fixture setup: ~31s (shared across 5 tests)
expand_rule used copy.deepcopy(op_schema_) to snapshot the schema before
mutating it. DeviceMesh has no __deepcopy__, so deepcopy went through
__getstate__/__setstate__ and produced a fresh DeviceMesh with the same
logical content but an empty _flatten_mapping cache. The DTensorSpecs
returned from expand_rule carried these duplicates, which propagated
into the sharding solution.

apply_placement's pre-warming in _apply_placement_common only populates
the user mesh's cache. When _optimize_same_nd_sharding_as_1d inside
make_fx hit a duplicate mesh, _flatten() cache-missed and dispatched
as_strided on the real rank_map — failing FakeTensorMode's non-fake-input
check. Which solution the solver picked depended on the cost model, so
the failure surfaced on g5/A10G CI but not on H100.

Fix: _deepcopy_preserving_mesh pre-seeds copy.deepcopy's memo with
DeviceMesh identity mappings so duplicates aren't produced. Adds a
regression test that asserts every spec mesh's root has a warm
_flatten cache after apply_placement.

Authored with Claude.
This reverts commit ca5de35.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant