Preserve first-partitioner decisions in torch.compile via _SaveAllPartitioner#485
Open
fmassa wants to merge 10 commits into
Open
Preserve first-partitioner decisions in torch.compile via _SaveAllPartitioner#485fmassa wants to merge 10 commits into
fmassa wants to merge 10 commits into
Conversation
…titioner Reproduce the first partitioner's save/recompute decisions in the second compilation (torch.compile with autoparallel_backend) so FSDP allgather chains stay recomputed instead of being saved as activations. Background. AutoParallel partitions the joint graph twice: once inside apply_placement via aot_compile_joint_with_descriptors, and again when the user calls torch.compile(parallel_mod, backend=autoparallel_backend()). The two partitioners operate on structurally different graphs, so the second partitioner's independent min-cut can diverge from the first's decisions. The most visible symptom: with the AC joint pass active in the second compilation, PREFER_RECOMPUTE tags on compute ops cause min-cut to recompute matmuls in backward, which in turn pulls FSDP allgather outputs into backward as live dependencies — and force_save_collectives then pins them as MUST_SAVE. The result is ~1.2 GB of extra activation memory per 4 transformer layers from saving allgathered weights that should be recomputed via FSDP prefetch. This adds _SaveAllPartitioner, an inductor custom_partitioner_fn that reads `custom.ap_must_save` tags placed by the first compilation (via _boxed_nop_preserve_node_meta(tag_forward=True)) and saves exactly those nodes — sidestepping min-cut's independent decisions. The tags propagate to the second compilation through preserve_node_meta. FSDP MUST_RECOMPUTE tags also survive, so even users who don't pass autoparallel_backend still get correct FSDP recomputation from the default partitioner. Supporting machinery: - _patch_partitioner_dce makes the partitioner's is_not_collective callback DCE-eligible for wait_tensor so unused collectives can be eliminated (the partitioner has its own DCE that would otherwise override _suppress_wait_tensor_side_effect). - autoparallel_backend wires custom_partitioner_fn via torch._inductor.config.patch (forward-only) and keeps overlap scheduling configs in compile_fx's config_patches (persists to lazy backward compilation). Recommended review order: api.py (the tagging + fw_compiler wiring), compile.py (_SaveAllPartitioner and the backend), then tests/test_save_all_partitioner.py for the full picture of what's being verified. Authored with Claude.
- Removed unused `pre_pass` parameter from `_boxed_nop_preserve_node_meta` - Moved `from functools import partial` to top-of-file imports - **`autoparallel/compile.py`**: - Rewrote `_SaveAllPartitioner` docstring to describe what it actually does (uses `ap_must_save` tags) and why (sidesteps the `force_save_collectives` + AC interaction) - Added comment explaining why we keep the `force_save_collectives`/`force_save_effectful_ops`/`force_save_bw_mutation_src` calls even though they don't affect our save decision - **Removed dead test pollution**: the `AutoParallel._make_fuse_allgather_pass = lambda self: None` line was leftover from a different branch - **Removed redundant test**: `test_autoparallel_backend_includes_save_all_partitioner` was a tautology - **Added bad-case test**: `test_default_partitioner_diverges_from_save_all_partitioner` — proves the default min-cut partitioner produces a different save list than `_SaveAllPartitioner` when AC is active (the motivating divergence) - **Added regression-guard test**: `test_save_all_partitioner_reproduces_first_partitioner_saves` — confirms `_SaveAllPartitioner`'s saved set approximately matches what the first partitioner saved (by shape histogram, tolerant to view/reshape differences from retracing) - All 12 tests in `test_save_all_partitioner.py` pass (~4 minutes runtime) - `test_api.py` and `test_activation_checkpointing.py` still pass (one pre-existing unrelated failure on main) - Lint clean (F401, F841 checks)
Three behavioral changes plus tests and clarifying comments.
1. MUST_SAVE on a multi-output parent now overrides ap_must_save_getitem_indices.
Previously, if a multi-output op was both MUST_SAVE and ap_must_save with
a specific index list, only the indexed children were saved — silently
under-saving relative to what MUST_SAVE means ("save all tensor outputs
needed from this op"). The fix makes MUST_SAVE clear the index restriction,
keeping exact replay for ap_must_save while keeping PyTorch's MUST_SAVE
tags conservative.
2. Deduplicate saved_values, saved_sym_nodes, and saved_opaque_nodes before
handing them to _extract_fwd_bwd_modules. Matches upstream
default_partition's dict.fromkeys pattern. Defensive — duplicates don't
arise in the current control flow, but they could under future
refactors of the iteration order or tag combinations, and the cost is
one line per list.
3. Added clarifying comments on three corner cases:
- Opaque nodes (ProcessGroup, ScriptObject) are saved unconditionally
regardless of ap_must_save/MUST_SAVE; documents the intentional
deviation from pure replay semantics, matching the standard partitioner.
- The inference fallback to default_partition when there are no backward
nodes; documents that ap_must_save tags are bypassed there because
inference doesn't have the fwd/bwd-divergence problem the partitioner
exists to solve.
- CSE merges duplicate chains without combining metadata; documents that
the safety contract holds for FSDP allgather chains (first occurrence
keeps MUST_RECOMPUTE) and that general replay across CSE'd duplicates
is a known limitation.
Tests added (all verified to catch a real regression by reverting the
corresponding fix and checking the assertion fires):
- test_save_all_partitioner_must_save_overrides_getitem_indices
- test_save_all_partitioner_replays_only_indexed_getitems (locks in the
index-specific replay precision from the earlier round)
- test_save_all_partitioner_must_recompute_blocks_opaque_save and
test_save_all_partitioner_must_recompute_blocks_multi_output_save
(regression guards for the _must_recompute ordering)
- test_save_all_partitioner_multi_output_with_non_getitem_user (covers
the tuple-aware _is_multi_output check)
Authored with Claude.
Summary Two structural changes 1. **Capture-only backend for partitioner-checking tests.** Replaced `_capture_partitioner_call`'s implementation to wire `_SaveAllPartitioner` directly via `aot_module_simplified` with identity compilers. The partitioner runs and is captured, but no Triton kernel codegen happens. The function signature changed from accepting a pre-built `backend` to accepting `enable_ac=` directly. 2. **Module-scoped fixtures for `parallel_mod`.** Added `parallel_mod_2d` and `parallel_mod_1d` fixtures that cache the AutoParallel solve + apply_placement (~31s for 2D). Five integration tests now share the same parallel module instead of rebuilding it. Test cleanup - Removed `test_save_all_partitioner_runs_end_to_end` (redundant with `compile_with_ac_enabled`). - Removed `test_save_all_partitioner_uses_ap_must_save_tags` (its assertions folded into `does_not_save_fsdp_wait_tensors`). Result - **252s → 135s** (47% reduction), 21 → 19 tests - Capture-only tests now run in 3-6s each (was 25-50s with full Inductor) - The two real-Inductor smoke tests (`compile_with_ac_enabled`, `compile_1d_mesh`) remain to guard against full-pipeline regressions The breakdown: - 14 unit tests: ~1s total - 1 reproduces test (calls `_capture_first_partitioner_saves` which builds its own AutoParallel): ~24s - 3 capture-only 2D integration tests: ~9-15s combined - 1 Inductor 2D smoke test: ~14s - 1 Inductor 1D smoke test (separate mesh): ~45s + 8s setup - Module-scope 2D fixture setup: ~31s (shared across 5 tests)
expand_rule used copy.deepcopy(op_schema_) to snapshot the schema before mutating it. DeviceMesh has no __deepcopy__, so deepcopy went through __getstate__/__setstate__ and produced a fresh DeviceMesh with the same logical content but an empty _flatten_mapping cache. The DTensorSpecs returned from expand_rule carried these duplicates, which propagated into the sharding solution. apply_placement's pre-warming in _apply_placement_common only populates the user mesh's cache. When _optimize_same_nd_sharding_as_1d inside make_fx hit a duplicate mesh, _flatten() cache-missed and dispatched as_strided on the real rank_map — failing FakeTensorMode's non-fake-input check. Which solution the solver picked depended on the cost model, so the failure surfaced on g5/A10G CI but not on H100. Fix: _deepcopy_preserving_mesh pre-seeds copy.deepcopy's memo with DeviceMesh identity mappings so duplicates aren't produced. Adds a regression test that asserts every spec mesh's root has a warm _flatten cache after apply_placement. Authored with Claude.
This reverts commit ca5de35.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Reproduce the first partitioner's save/recompute decisions in the second compilation (torch.compile with autoparallel_backend) so FSDP allgather chains stay recomputed instead of being saved as activations.
Background. AutoParallel partitions the joint graph twice: once inside apply_placement via aot_compile_joint_with_descriptors, and again when the user calls torch.compile(parallel_mod, backend=autoparallel_backend()). The two partitioners operate on structurally different graphs, so the second partitioner's independent min-cut can diverge from the first's decisions. The most visible symptom: with the AC joint pass active in the second compilation, PREFER_RECOMPUTE tags on compute ops cause min-cut to recompute matmuls in backward, which in turn pulls FSDP allgather outputs into backward as live dependencies — and force_save_collectives then pins them as MUST_SAVE. The result is ~1.2 GB of extra activation memory per 4 transformer layers from saving allgathered weights that should be recomputed via FSDP prefetch.
This adds _SaveAllPartitioner, an inductor custom_partitioner_fn that reads
custom.ap_must_savetags placed by the first compilation (via _boxed_nop_preserve_node_meta(tag_forward=True)) and saves exactly those nodes — sidestepping min-cut's independent decisions. The tags propagate to the second compilation through preserve_node_meta. FSDP MUST_RECOMPUTE tags also survive, so even users who don't pass autoparallel_backend still get correct FSDP recomputation from the default partitioner.Supporting machinery:
Recommended review order: api.py (the tagging + fw_compiler wiring), compile.py (_SaveAllPartitioner and the backend), then tests/test_save_all_partitioner.py for the full picture of what's being verified.
Authored with Claude.