From 837b5ba09a3ed310ef807b5f5a5bb33e099bfb55 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 9 Jun 2026 09:11:12 +0000 Subject: [PATCH 01/10] Preserve first-partitioner decisions in torch.compile via _SaveAllPartitioner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reproduce the first partitioner's save/recompute decisions in the second compilation (torch.compile with autoparallel_backend) so FSDP allgather chains stay recomputed instead of being saved as activations. Background. AutoParallel partitions the joint graph twice: once inside apply_placement via aot_compile_joint_with_descriptors, and again when the user calls torch.compile(parallel_mod, backend=autoparallel_backend()). The two partitioners operate on structurally different graphs, so the second partitioner's independent min-cut can diverge from the first's decisions. The most visible symptom: with the AC joint pass active in the second compilation, PREFER_RECOMPUTE tags on compute ops cause min-cut to recompute matmuls in backward, which in turn pulls FSDP allgather outputs into backward as live dependencies — and force_save_collectives then pins them as MUST_SAVE. The result is ~1.2 GB of extra activation memory per 4 transformer layers from saving allgathered weights that should be recomputed via FSDP prefetch. This adds _SaveAllPartitioner, an inductor custom_partitioner_fn that reads `custom.ap_must_save` tags placed by the first compilation (via _boxed_nop_preserve_node_meta(tag_forward=True)) and saves exactly those nodes — sidestepping min-cut's independent decisions. The tags propagate to the second compilation through preserve_node_meta. FSDP MUST_RECOMPUTE tags also survive, so even users who don't pass autoparallel_backend still get correct FSDP recomputation from the default partitioner. Supporting machinery: - _patch_partitioner_dce makes the partitioner's is_not_collective callback DCE-eligible for wait_tensor so unused collectives can be eliminated (the partitioner has its own DCE that would otherwise override _suppress_wait_tensor_side_effect). - autoparallel_backend wires custom_partitioner_fn via torch._inductor.config.patch (forward-only) and keeps overlap scheduling configs in compile_fx's config_patches (persists to lazy backward compilation). Recommended review order: api.py (the tagging + fw_compiler wiring), compile.py (_SaveAllPartitioner and the backend), then tests/test_save_all_partitioner.py for the full picture of what's being verified. Authored with Claude. --- autoparallel/api.py | 38 ++- autoparallel/compile.py | 214 +++++++++++++- tests/test_save_all_partitioner.py | 430 +++++++++++++++++++++++++++++ 3 files changed, 675 insertions(+), 7 deletions(-) create mode 100644 tests/test_save_all_partitioner.py diff --git a/autoparallel/api.py b/autoparallel/api.py index 1670d509..d31db69c 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -5,6 +5,7 @@ import copy import logging +import operator import time from contextlib import ExitStack, contextmanager from dataclasses import dataclass @@ -57,7 +58,36 @@ logger = logging.getLogger(__name__) -def _boxed_nop_preserve_node_meta(fx_g, example_inputs): +def _boxed_nop_preserve_node_meta( + fx_g, example_inputs, pre_pass=None, tag_forward=False +): + if pre_pass is not None: + pre_pass(fx_g.graph) + + if tag_forward: + # Tag the forward graph's OUTPUT values as "must save". These are + # the tensors the first min_cut decided to save for backward — + # only these should be saved in the second compilation. + # Uses the "custom" meta field (not "recompute") to avoid + # interfering with ac_joint_pass which uses "recompute" for + # activation checkpointing decisions. + output_node = next(n for n in fx_g.graph.nodes if n.op == "output") + for out in output_node.args[0]: + if not isinstance(out, torch.fx.Node) or out.op != "call_function": + continue + if out.target == operator.getitem: + # getitem metadata doesn't survive preserve_node_meta + # (Python builtin, not dispatched). Tag the parent + # multi-output op instead — DCE in _extract_fwd_bwd_modules + # removes any unused getitem children. + parent = out.args[0] + if isinstance(parent, torch.fx.Node): + parent.meta.setdefault("custom", {}) + parent.meta["custom"]["ap_must_save"] = True + else: + out.meta.setdefault("custom", {}) + out.meta["custom"]["ap_must_save"] = True + def run(args): with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(fx_g).boxed_run(args) @@ -482,9 +512,13 @@ def apply_placement(self, sharding_placement): self.parallel_gm.graph, self.reshard_after_forward ) + from functools import partial + + fw_compiler_fn = partial(self.compiler_fn, tag_forward=True) + self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( self.joint_with_descriptors, - fw_compiler=self.compiler_fn, + fw_compiler=fw_compiler_fn, bw_compiler=self.compiler_fn, ) diff --git a/autoparallel/compile.py b/autoparallel/compile.py index e9d8b5c1..eaa0e28e 100644 --- a/autoparallel/compile.py +++ b/autoparallel/compile.py @@ -3,13 +3,17 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Union +import operator +from contextlib import contextmanager +from typing import Any, Optional, Sequence, Union import torch import torch._functorch.config import torch._inductor.config from torch._inductor.compile_fx import compile_fx +from torch._inductor.custom_graph_pass import CustomPartitionerFn +from .api import _suppress_wait_tensor_side_effect from .graph_passes.activation_checkpointing import ac_joint_pass _INDUCTOR_OVERLAP_PATCHES = { @@ -20,6 +24,186 @@ } +class _SaveAllPartitioner(CustomPartitionerFn): + """Tag-driven partitioner: save all forward tensors except MUST_RECOMPUTE. + + The first compilation (inside AutoParallel) already makes recomputation + decisions via MUST_RECOMPUTE tags on FSDP allgather chains. Those tags + survive into the second compilation via preserve_node_meta. + + Unlike default_partition (which uses positional fwd/bwd boundary detection + and fails on interleaved backward ops like gradient reduce_scatters), this + partitioner uses classify_nodes for topology-based boundary detection. And + unlike min_cut_rematerialization_partition, it doesn't run a second min-cut + that would diverge from the first compilation's decisions. + """ + + def __call__( + self, + gm: torch.fx.GraphModule, + joint_inputs: Sequence[object], + **kwargs: Any, + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + num_fwd_outputs: int = kwargs.pop("num_fwd_outputs") # type: ignore[assignment] + static_lifetime_input_indices: list[int] | None = kwargs.pop( # type: ignore[assignment] + "static_lifetime_input_indices", None + ) + from torch._functorch.partitioners import ( + _extract_fwd_bwd_modules, + _is_assert_only_symbool, + classify_nodes, + cleanup_recompute_tags, + default_partition, + force_save_bw_mutation_src, + force_save_collectives, + force_save_effectful_ops, + functionalize_rng_ops, + has_recomputable_ops, + has_recomputable_rng_ops, + is_opaque_node, + is_sym_node, + must_recompute, + raise_getitems, + reordering_to_mimic_autograd_engine, + thread_graphsafe_rng_from_hops, + ) + + gm.graph.eliminate_dead_code() + gm.recompile() + + # CSE merges duplicate allgather chains: the forward's + # MUST_RECOMPUTE allgathers and the baked-in backward copies + # compute the same values from the same primals. Without CSE, + # both appear in the backward (duplication). + if torch._functorch.config.cse: + from torch._functorch.partitioners import fx_graph_cse + + cse_graph = fx_graph_cse(gm.graph) + gm.graph = cse_graph + + graph_has_recomputable_ops = has_recomputable_ops(gm) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(gm) + if graph_has_recomputable_ops: + gm = cleanup_recompute_tags(gm, is_default_partition=False) + + if not torch._functorch.config.unsafe_allow_optimization_of_collectives: + force_save_collectives(gm) + force_save_effectful_ops(gm) + force_save_bw_mutation_src(gm) + + if static_lifetime_input_indices is None: + static_lifetime_input_indices = [] + node_info = classify_nodes(gm, static_lifetime_input_indices, num_fwd_outputs) + + if len(node_info.required_bw_nodes) == 0: + return default_partition( + gm, + joint_inputs, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + static_lifetime_input_nodes=node_info.static_lifetime_input_nodes, + ) + + saved_values = [] + saved_sym_nodes = [] + saved_opaque_nodes = [] + + def _is_multi_output(node: torch.fx.Node) -> bool: + return ( + all(user.target == operator.getitem for user in node.users) + and len(node.users) > 0 + ) + + def _maybe_save(node: torch.fx.Node) -> None: + if is_sym_node(node): + if not _is_assert_only_symbool(node): + saved_sym_nodes.append(node) + return + if _is_multi_output(node): + # Multi-output ops tagged ap_must_save: save all their + # getitem children (DCE removes unused ones later). + if node.meta.get("custom", {}).get("ap_must_save"): + for user in node.users: + if user.target == operator.getitem: + saved_values.append(user) + return + if is_opaque_node(node): + saved_opaque_nodes.append(node) + return + if must_recompute(node): + return + # Save nodes tagged ap_must_save by the first compilation. + # These are the forward graph's output tensors from the first + # partitioner — reproducing its save/recompute decisions. + if node.op == "placeholder": + saved_values.append(node) + elif node.meta.get("custom", {}).get("ap_must_save"): + saved_values.append(node) + + for node in node_info.required_fw_nodes: + _maybe_save(node) + + # Unclaimed nodes (neither strictly forward nor backward) may be + # needed by backward outputs — e.g. mutable ops like index_put that + # are _must_be_in_forward. Save them so they're available as backward + # inputs in _extract_fwd_bwd_modules. + for node in node_info.unclaimed_nodes: + _maybe_save(node) + + fw_module, bw_module = _extract_fwd_bwd_modules( + gm, + saved_values, + saved_sym_nodes=saved_sym_nodes, + saved_opaque_nodes=saved_opaque_nodes, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_nodes=node_info.static_lifetime_input_nodes, + ) + + if graph_has_recomputable_ops and graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + gm, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + fw_module = raise_getitems(fw_module) + bw_module = raise_getitems(bw_module) + + fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False) + bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True) + + return fw_module, bw_module + + def uuid(self) -> Any: + return None + + +@contextmanager +def _patch_partitioner_dce(): + """Patch the partitioner's DCE to allow wait_tensor to be eliminated. + + The partitioner uses its own is_not_collective callback that treats all + _c10d_functional ops as impure, overriding _suppress_wait_tensor_side_effect. + We patch it to let wait_tensor through so unused collectives get DCE'd. + """ + import torch._functorch.partitioners as partitioners + + original = partitioners.is_not_collective + + def patched_is_not_collective(node): + if node.target in ( + torch.ops._c10d_functional.wait_tensor, + torch.ops._c10d_functional.wait_tensor.default, + ): + return False + return original(node) + + partitioners.is_not_collective = patched_is_not_collective + try: + yield + finally: + partitioners.is_not_collective = original + + def _make_ac_joint_pass( ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto", ): @@ -50,16 +234,36 @@ def autoparallel_backend( sqrt(total_recomputable_memory). overlap_scheduling: Enable comm/compute overlap scheduling. """ - functorch_patches = {} - inductor_patches = _INDUCTOR_OVERLAP_PATCHES if overlap_scheduling else None + functorch_patches: dict[str, Any] = {} if enable_ac: functorch_patches["joint_custom_pass"] = _make_ac_joint_pass( ac_stage_size_in_GiB ) + # Inductor configs split by lifetime: + # - overlap scheduling configs must persist to lazy backward compilation + # (which runs on the first .backward() call, after compile_fx returns). + # compile_fx's config_patches argument re-enters the patch when backward + # is later compiled out of scope. + # - custom_partitioner_fn only runs during the synchronous joint→fwd/bwd + # partitioning inside compile_fx, so a context manager suffices. + inductor_persistent_patches = ( + _INDUCTOR_OVERLAP_PATCHES if overlap_scheduling else None + ) + inductor_fwd_patches: dict[str, Any] = { + "custom_partitioner_fn": _SaveAllPartitioner(), + } + def backend(gm, example_inputs): - with torch._functorch.config.patch(functorch_patches): - return compile_fx(gm, example_inputs, config_patches=inductor_patches) + with ( + _suppress_wait_tensor_side_effect(), + _patch_partitioner_dce(), + torch._functorch.config.patch(functorch_patches), + torch._inductor.config.patch(inductor_fwd_patches), + ): + return compile_fx( + gm, example_inputs, config_patches=inductor_persistent_patches + ) return backend diff --git a/tests/test_save_all_partitioner.py b/tests/test_save_all_partitioner.py new file mode 100644 index 00000000..f0c41f72 --- /dev/null +++ b/tests/test_save_all_partitioner.py @@ -0,0 +1,430 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the _SaveAllPartitioner mechanism in autoparallel/compile.py. + +The mechanism reproduces the first partitioner's save/recompute decisions in +the second compilation (torch.compile with autoparallel_backend) by: + +1. Tagging forward outputs with `ap_must_save` in `_boxed_nop_preserve_node_meta` +2. `preserve_node_meta` propagates the tags to the second compilation's joint graph +3. `_SaveAllPartitioner` reads the tags and saves only those nodes + +Without this machinery, the default min-cut partitioner makes independent +decisions that diverge from the first partitioner (most importantly, it +saves FSDP allgather outputs that should be recomputed via prefetch). +""" + +import torch +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.utils.checkpoint import CheckpointPolicy + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel, _boxed_nop_preserve_node_meta +from autoparallel.compile import ( + _patch_partitioner_dce, + _SaveAllPartitioner, + autoparallel_backend, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_small_llama(n_layers=2): + """Tiny LLaMA-3 sized for fast tests.""" + return Transformer( + TransformerModelArgs( + dim=256, + n_layers=n_layers, + n_heads=8, + n_kv_heads=2, + ffn_dim_multiplier=1.3, + multiple_of=64, + rope_theta=500000, + vocab_size=1024, + max_seq_len=512, + ) + ) + + +def _run_autoparallel(mesh, n_layers=2, batch_size=None, seqlen=128): + """Run AutoParallel up to apply_placement and return the parallel module + plus AC pass for the second compilation.""" + from autoparallel.compile import _make_ac_joint_pass + + vocab_size = 1024 + if batch_size is None: + batch_size = 2 * mesh.shape[0] + + with torch.device("meta"): + model = _make_small_llama(n_layers=n_layers) + + AutoParallel._make_fuse_allgather_pass = lambda self: None + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + mesh, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + if mesh.ndim == 2: + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + else: + autop.add_input_constraints([(Shard(0),)]) + autop.add_output_constraints([(Shard(0),)]) + sharding_placement = autop.optimize_placement(verbose=False) + + ac_pass = _make_ac_joint_pass() + with torch._functorch.config.patch({"joint_custom_pass": ac_pass}): + parallel_mod = autop.apply_placement(sharding_placement) + + parallel_mod.to_empty(device="cuda") + parallel_mod.init_weights() + return parallel_mod, batch_size, seqlen, vocab_size + + +def _capture_saved_activations( + parallel_mod, batch_size, seqlen, vocab_size, mesh, backend +): + """Compile parallel_mod, capture what _SaveAllPartitioner sees and returns, + and return (saved_activation_names, captured_info).""" + captured = {} + + orig_call = _SaveAllPartitioner.__call__ + + def capturing_call(self, gm, joint_inputs, **kwargs): + # Snapshot tag state before _SaveAllPartitioner mutates anything + captured["wait_tensor_recompute_tags"] = sum( + 1 + for n in gm.graph.nodes + if "wait_tensor" in n.name + and n.meta.get("recompute") == CheckpointPolicy.MUST_RECOMPUTE + ) + captured["allgather_recompute_tags"] = sum( + 1 + for n in gm.graph.nodes + if "all_gather_into_tensor" in n.name + and n.meta.get("recompute") == CheckpointPolicy.MUST_RECOMPUTE + ) + captured["ap_must_save_count"] = sum( + 1 for n in gm.graph.nodes if n.meta.get("custom", {}).get("ap_must_save") + ) + fw, bw = orig_call(self, gm, joint_inputs, **kwargs) + captured["fw_module"] = fw + captured["bw_module"] = bw + return fw, bw + + _SaveAllPartitioner.__call__ = capturing_call + try: + compiled = torch.compile(parallel_mod, backend=backend) + x = torch.randint( + 0, vocab_size, (batch_size // mesh.shape[0], seqlen), device="cuda" + ) + out = compiled(x) + out.backward(torch.randn_like(out)) + finally: + _SaveAllPartitioner.__call__ = orig_call + + # The backward graph's placeholders (minus tangents and primals) are the + # saved activations. + bw = captured["bw_module"] + saved_names = [] + for node in bw.graph.nodes: + if node.op != "placeholder": + continue + if isinstance(node.target, str) and ( + "tangent" in node.target or "primals" in node.target + ): + continue + saved_names.append(node.name) + return saved_names, captured + + +# --------------------------------------------------------------------------- +# Unit tests for the standalone mechanisms +# --------------------------------------------------------------------------- + + +def test_boxed_nop_tag_forward_marks_outputs(): + """_boxed_nop_preserve_node_meta(tag_forward=True) tags the forward + output node's tensor args with ap_must_save.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, device="meta") + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + add.meta["val"] = torch.randn(4, device="meta") + mul = graph.call_function(torch.ops.aten.mul.Tensor, args=(add, add)) + mul.meta["val"] = torch.randn(4, device="meta") + graph.output((add, mul)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + _boxed_nop_preserve_node_meta(gm, None, tag_forward=True) + + add_tagged = add.meta.get("custom", {}).get("ap_must_save") + mul_tagged = mul.meta.get("custom", {}).get("ap_must_save") + assert add_tagged is True + assert mul_tagged is True + + +def test_boxed_nop_tag_forward_skips_getitem(): + """For getitem outputs (multi-output ops), the parent is tagged instead + since getitem metadata doesn't survive preserve_node_meta.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, device="meta") + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 2 + import operator + + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + graph.output((g0,)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + _boxed_nop_preserve_node_meta(gm, None, tag_forward=True) + + # Parent (split) gets the tag, not the getitem + assert split.meta.get("custom", {}).get("ap_must_save") is True + assert g0.meta.get("custom", {}).get("ap_must_save") is None + + +def test_boxed_nop_no_tag_forward_default(): + """tag_forward defaults to False — no tagging happens.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, device="meta") + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + add.meta["val"] = torch.randn(4, device="meta") + graph.output((add,)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + _boxed_nop_preserve_node_meta(gm, None) + + assert add.meta.get("custom", {}).get("ap_must_save") is None + + +def test_preserve_node_meta_propagates_recompute_through_collectives(): + """preserve_node_meta correctly propagates MUST_RECOMPUTE through + collective ops (allgather, wait_tensor) — confirming the safety net + mechanism works. + """ + from torch.fx.experimental.proxy_tensor import make_fx + from torch.testing._internal.distributed.fake_pg import FakeStore + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + "fake", store=FakeStore(), rank=0, world_size=8 + ) + + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, device="meta") + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + add.meta["val"] = torch.randn(4, device="meta") + add.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + add.meta["ac_graph_id"] = 100000 + + ag = graph.call_function( + torch.ops._c10d_functional.all_gather_into_tensor.default, + args=(add, 2, "0"), + ) + ag.meta["val"] = torch.randn(8, device="meta") + ag.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + ag.meta["ac_graph_id"] = 100000 + + wt = graph.call_function(torch.ops._c10d_functional.wait_tensor.default, args=(ag,)) + wt.meta["val"] = torch.randn(8, device="meta") + wt.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + wt.meta["ac_graph_id"] = 100000 + + graph.output((wt,)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + compiled = _boxed_nop_preserve_node_meta(gm, None) + fake_input = torch.randn(4, device="cpu") + new_gm = make_fx(compiled, tracing_mode="fake")([fake_input]) + + for node in new_gm.graph.nodes: + if node.op != "call_function": + continue + name = ( + node.target.__name__ + if hasattr(node.target, "__name__") + else str(node.target) + ) + if name in ( + "add.Tensor", + "all_gather_into_tensor.default", + "wait_tensor.default", + ): + assert ( + node.meta.get("recompute") == CheckpointPolicy.MUST_RECOMPUTE + ), f"{name} lost MUST_RECOMPUTE tag through preserve_node_meta" + + +# --------------------------------------------------------------------------- +# Integration test for the partitioner — verifies that wait_tensor outputs +# from FSDP allgathers are NOT saved by _SaveAllPartitioner. +# --------------------------------------------------------------------------- + + +def test_save_all_partitioner_does_not_save_fsdp_wait_tensors(device_mesh_2d): + """The whole point of the machinery: with FSDP allgathers tagged + MUST_RECOMPUTE, _SaveAllPartitioner should NOT save their wait_tensor + outputs in the backward. + """ + parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( + device_mesh_2d, n_layers=2 + ) + backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) + + saved_names, captured = _capture_saved_activations( + parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend + ) + + # FSDP wait_tensor outputs should NOT appear as saved activations + saved_wait_tensors = [n for n in saved_names if "wait_tensor" in n] + assert len(saved_wait_tensors) == 0, ( + f"FSDP wait_tensor outputs were saved (should be recomputed via " + f"FSDP prefetch): {saved_wait_tensors}" + ) + + # Sanity: MUST_RECOMPUTE tags survived into the joint graph + assert captured["allgather_recompute_tags"] > 0, ( + "Expected MUST_RECOMPUTE tags on all_gather nodes to survive " + "preserve_node_meta into the second compilation's joint graph" + ) + + +def test_save_all_partitioner_uses_ap_must_save_tags(device_mesh_2d): + """_SaveAllPartitioner saves nodes tagged ap_must_save by the first + compilation.""" + parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( + device_mesh_2d, n_layers=2 + ) + backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) + + saved_names, captured = _capture_saved_activations( + parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend + ) + + # ap_must_save tags should be present in the joint graph (set by + # _boxed_nop_preserve_node_meta during first compilation, propagated + # via preserve_node_meta) + assert captured["ap_must_save_count"] > 0, ( + "Expected ap_must_save tags from first compilation to survive into " + "second compilation's joint graph" + ) + + # And we should have saved at least some non-trivial activations + activation_saves = [n for n in saved_names if not n.startswith("primals")] + assert len(activation_saves) > 0 + + +def test_save_all_partitioner_runs_end_to_end(device_mesh_2d): + """Full end-to-end: AutoParallel + torch.compile(autoparallel_backend) + forward + backward without errors.""" + parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( + device_mesh_2d, n_layers=2 + ) + backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) + + compiled = torch.compile(parallel_mod, backend=backend) + x = torch.randint( + 0, + vocab_size, + (batch_size // device_mesh_2d.shape[0], seqlen), + device="cuda", + ) + out = compiled(x) + out.backward(torch.randn_like(out)) + # Successful completion is the test + + +# --------------------------------------------------------------------------- +# Supporting machinery tests +# --------------------------------------------------------------------------- + + +def test_patch_partitioner_dce_allows_wait_tensor_elimination(): + """_patch_partitioner_dce makes is_not_collective return False for + wait_tensor — required so unused wait_tensors can be DCE'd.""" + import torch._functorch.partitioners as partitioners + + graph = torch.fx.Graph() + x = graph.placeholder("x") + wt = graph.call_function(torch.ops._c10d_functional.wait_tensor.default, args=(x,)) + graph.output((wt,)) + + original_result = partitioners.is_not_collective(wt) + + with _patch_partitioner_dce(): + patched_result = partitioners.is_not_collective(wt) + + # After exit, original behavior is restored + restored_result = partitioners.is_not_collective(wt) + + # Inside the patch, wait_tensor is treated as DCE-eligible (returns False + # is consistent — see the patch's intent of "letting wait_tensor through") + assert patched_result is False + # And the patch is properly restored + assert restored_result == original_result + + +def test_autoparallel_backend_includes_save_all_partitioner(): + """autoparallel_backend() configures custom_partitioner_fn to be a + _SaveAllPartitioner instance.""" + autoparallel_backend(enable_ac=False, overlap_scheduling=False) + # The backend wires _SaveAllPartitioner as the custom_partitioner_fn. + # Just verify the type exists and is properly defined. + assert _SaveAllPartitioner is not None + p = _SaveAllPartitioner() + assert callable(p) + assert hasattr(p, "uuid") + assert p.uuid() is None + + +def test_save_all_partitioner_compile_with_ac_enabled(device_mesh_2d): + """autoparallel_backend(enable_ac=True) runs the AC joint pass before + the partitioner. The combination should compile end-to-end.""" + parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( + device_mesh_2d, n_layers=2 + ) + # AC enabled in the backend → joint_custom_pass adds PREFER_RECOMPUTE + backend = autoparallel_backend(enable_ac=True, overlap_scheduling=False) + + compiled = torch.compile(parallel_mod, backend=backend) + x = torch.randint( + 0, + vocab_size, + (batch_size // device_mesh_2d.shape[0], seqlen), + device="cuda", + ) + out = compiled(x) + out.backward(torch.randn_like(out)) + + +def test_save_all_partitioner_compile_1d_mesh(device_mesh_1d): + """The partitioner works with a 1D (FSDP-only) mesh.""" + parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( + device_mesh_1d, n_layers=2 + ) + backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) + + compiled = torch.compile(parallel_mod, backend=backend) + x = torch.randint( + 0, + vocab_size, + (batch_size // device_mesh_1d.shape[0], seqlen), + device="cuda", + ) + out = compiled(x) + out.backward(torch.randn_like(out)) From fb0de359a19853ef52ed51fe3ca9a6601beda051 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 9 Jun 2026 09:53:38 +0000 Subject: [PATCH 02/10] - **`autoparallel/api.py`**: - Removed unused `pre_pass` parameter from `_boxed_nop_preserve_node_meta` - Moved `from functools import partial` to top-of-file imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **`autoparallel/compile.py`**: - Rewrote `_SaveAllPartitioner` docstring to describe what it actually does (uses `ap_must_save` tags) and why (sidesteps the `force_save_collectives` + AC interaction) - Added comment explaining why we keep the `force_save_collectives`/`force_save_effectful_ops`/`force_save_bw_mutation_src` calls even though they don't affect our save decision - **Removed dead test pollution**: the `AutoParallel._make_fuse_allgather_pass = lambda self: None` line was leftover from a different branch - **Removed redundant test**: `test_autoparallel_backend_includes_save_all_partitioner` was a tautology - **Added bad-case test**: `test_default_partitioner_diverges_from_save_all_partitioner` — proves the default min-cut partitioner produces a different save list than `_SaveAllPartitioner` when AC is active (the motivating divergence) - **Added regression-guard test**: `test_save_all_partitioner_reproduces_first_partitioner_saves` — confirms `_SaveAllPartitioner`'s saved set approximately matches what the first partitioner saved (by shape histogram, tolerant to view/reshape differences from retracing) - All 12 tests in `test_save_all_partitioner.py` pass (~4 minutes runtime) - `test_api.py` and `test_activation_checkpointing.py` still pass (one pre-existing unrelated failure on main) - Lint clean (F401, F841 checks) --- autoparallel/api.py | 10 +- autoparallel/compile.py | 40 +++- tests/test_save_all_partitioner.py | 365 +++++++++++++++++++++++------ 3 files changed, 330 insertions(+), 85 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index d31db69c..7db6162f 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -9,6 +9,7 @@ import time from contextlib import ExitStack, contextmanager from dataclasses import dataclass +from functools import partial from typing import Any, Callable, Optional, Union import torch @@ -58,12 +59,7 @@ logger = logging.getLogger(__name__) -def _boxed_nop_preserve_node_meta( - fx_g, example_inputs, pre_pass=None, tag_forward=False -): - if pre_pass is not None: - pre_pass(fx_g.graph) - +def _boxed_nop_preserve_node_meta(fx_g, example_inputs, tag_forward=False): if tag_forward: # Tag the forward graph's OUTPUT values as "must save". These are # the tensors the first min_cut decided to save for backward — @@ -512,8 +508,6 @@ def apply_placement(self, sharding_placement): self.parallel_gm.graph, self.reshard_after_forward ) - from functools import partial - fw_compiler_fn = partial(self.compiler_fn, tag_forward=True) self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( diff --git a/autoparallel/compile.py b/autoparallel/compile.py index eaa0e28e..5046233b 100644 --- a/autoparallel/compile.py +++ b/autoparallel/compile.py @@ -25,17 +25,27 @@ class _SaveAllPartitioner(CustomPartitionerFn): - """Tag-driven partitioner: save all forward tensors except MUST_RECOMPUTE. - - The first compilation (inside AutoParallel) already makes recomputation - decisions via MUST_RECOMPUTE tags on FSDP allgather chains. Those tags - survive into the second compilation via preserve_node_meta. - - Unlike default_partition (which uses positional fwd/bwd boundary detection - and fails on interleaved backward ops like gradient reduce_scatters), this - partitioner uses classify_nodes for topology-based boundary detection. And - unlike min_cut_rematerialization_partition, it doesn't run a second min-cut - that would diverge from the first compilation's decisions. + """Reproduce the first partitioner's save/recompute decisions via tags. + + AutoParallel partitions the joint graph twice: once inside apply_placement + via aot_compile_joint_with_descriptors (the "first" partitioner), and + again when the user calls torch.compile(parallel_mod, backend=...). + + The first compilation tags each forward output that it decided to save + with custom.ap_must_save in _boxed_nop_preserve_node_meta(tag_forward=True). + Those tags propagate to the second compilation's joint graph through + preserve_node_meta. This partitioner reads the tags and saves exactly + those nodes — sidestepping min-cut, which would make independent + cost-based decisions on the second joint graph and may save FSDP + allgather outputs that should be recomputed via FSDP prefetch. + + Specifically, when ac_joint_pass runs in the second compilation it + adds PREFER_RECOMPUTE tags to compute ops, which causes min-cut to + recompute matmuls in backward. The backward then needs unsharded + weights to redo those computations, force_save_collectives marks the + allgather outputs as MUST_SAVE, and min-cut saves them. This + partitioner avoids the chain by ignoring those tags entirely and + saving only what the first partitioner already chose. """ def __call__( @@ -86,6 +96,14 @@ def __call__( if graph_has_recomputable_ops: gm = cleanup_recompute_tags(gm, is_default_partition=False) + # Apply PyTorch's standard save-forcing passes. None of these affect + # our own save decision (which only consults `ap_must_save`), but + # they normalize the graph by setting MUST_SAVE on collectives, + # effectful ops, and backward-mutated values. We keep them as a + # defense against future PyTorch internals that may consult these + # tags during extraction. force_save_collectives correctly skips + # nodes already tagged MUST_RECOMPUTE (e.g. FSDP allgathers), so + # the FSDP recomputation contract is preserved. if not torch._functorch.config.unsafe_allow_optimization_of_collectives: force_save_collectives(gm) force_save_effectful_ops(gm) diff --git a/tests/test_save_all_partitioner.py b/tests/test_save_all_partitioner.py index f0c41f72..1cc433cf 100644 --- a/tests/test_save_all_partitioner.py +++ b/tests/test_save_all_partitioner.py @@ -17,13 +17,15 @@ saves FSDP allgather outputs that should be recomputed via prefetch). """ +import operator + import torch from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard from torch.utils.checkpoint import CheckpointPolicy from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs -from autoparallel.api import AutoParallel, _boxed_nop_preserve_node_meta +from autoparallel.api import _boxed_nop_preserve_node_meta from autoparallel.compile import ( _patch_partitioner_dce, _SaveAllPartitioner, @@ -31,7 +33,7 @@ ) # --------------------------------------------------------------------------- -# Fixtures +# Helpers # --------------------------------------------------------------------------- @@ -55,6 +57,7 @@ def _make_small_llama(n_layers=2): def _run_autoparallel(mesh, n_layers=2, batch_size=None, seqlen=128): """Run AutoParallel up to apply_placement and return the parallel module plus AC pass for the second compilation.""" + from autoparallel.api import AutoParallel from autoparallel.compile import _make_ac_joint_pass vocab_size = 1024 @@ -64,8 +67,6 @@ def _run_autoparallel(mesh, n_layers=2, batch_size=None, seqlen=128): with torch.device("meta"): model = _make_small_llama(n_layers=n_layers) - AutoParallel._make_fuse_allgather_pass = lambda self: None - with AutoParallel( model, lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), @@ -91,17 +92,23 @@ def _run_autoparallel(mesh, n_layers=2, batch_size=None, seqlen=128): return parallel_mod, batch_size, seqlen, vocab_size -def _capture_saved_activations( +def _capture_partitioner_call( parallel_mod, batch_size, seqlen, vocab_size, mesh, backend ): - """Compile parallel_mod, capture what _SaveAllPartitioner sees and returns, - and return (saved_activation_names, captured_info).""" + """Compile parallel_mod with `backend`, capture what _SaveAllPartitioner + sees and returns, and return the captured info dict. + + The dict contains: + - wait_tensor_recompute_tags: count of MUST_RECOMPUTE on wait_tensors + - allgather_recompute_tags: count of MUST_RECOMPUTE on all_gathers + - ap_must_save_count: count of nodes tagged ap_must_save + - fw_module, bw_module: the partitioned graph modules + - saved_activation_names: backward inputs that aren't primals or tangents + """ captured = {} - orig_call = _SaveAllPartitioner.__call__ def capturing_call(self, gm, joint_inputs, **kwargs): - # Snapshot tag state before _SaveAllPartitioner mutates anything captured["wait_tensor_recompute_tags"] = sum( 1 for n in gm.graph.nodes @@ -145,7 +152,141 @@ def capturing_call(self, gm, joint_inputs, **kwargs): ): continue saved_names.append(node.name) - return saved_names, captured + captured["saved_activation_names"] = saved_names + return captured + + +def _capture_first_partitioner_saves(mesh, n_layers=2, seqlen=128): + """Run AutoParallel and capture the first partitioner's saved values + (the forward outputs beyond num_fwd_outputs).""" + from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors + + from autoparallel.api import AutoParallel + from autoparallel.compile import _make_ac_joint_pass + + vocab_size = 1024 + batch_size = 2 * mesh.shape[0] + + with torch.device("meta"): + model = _make_small_llama(n_layers=n_layers) + + captured = {} + + def capturing_fw_compiler(fx_g, example_inputs, **kwargs): + # The compiled forward returns [*model_outputs, *saved_values] + output_node = next(n for n in fx_g.graph.nodes if n.op == "output") + captured["fw_outputs"] = list(output_node.args[0]) + from autoparallel.api import _boxed_nop_preserve_node_meta + + return _boxed_nop_preserve_node_meta(fx_g, example_inputs, **kwargs) + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + mesh, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + if mesh.ndim == 2: + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + else: + autop.add_input_constraints([(Shard(0),)]) + autop.add_output_constraints([(Shard(0),)]) + sharding_placement = autop.optimize_placement(verbose=False) + + ac_pass = _make_ac_joint_pass() + from functools import partial + + with torch._functorch.config.patch({"joint_custom_pass": ac_pass}): + # Replicate apply_placement's compile call so we can intercept + # the fw_compiler. + autop._apply_placement_common(sharding_placement) + from autoparallel.graph_passes.activation_checkpointing import ( + mark_fsdp_all_gather_recomputation, + ) + + mark_fsdp_all_gather_recomputation( + autop.parallel_gm.graph, autop.reshard_after_forward + ) + aot_compile_joint_with_descriptors( + autop.joint_with_descriptors, + fw_compiler=partial(capturing_fw_compiler, tag_forward=True), + bw_compiler=autop.compiler_fn, + ) + + fw_metadata = autop.joint_with_descriptors._aot_state.fw_metadata + num_fwd_outputs = fw_metadata.num_forward_returns + # Saved values = forward outputs beyond num_fwd_outputs (the model outputs). + # Filter out primal/placeholder pass-throughs since the second + # partitioner's saved set doesn't include those either (they're inputs). + fw_outputs = captured["fw_outputs"] + saved = [ + n + for n in fw_outputs[num_fwd_outputs:] + if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + return saved + + +def _saved_names_from_default_compile( + parallel_mod, batch_size, seqlen, vocab_size, mesh, enable_ac=True +): + """Compile with the default min-cut partitioner (NOT _SaveAllPartitioner) + and capture which backward inputs are saved. This is the "without the fix" + baseline that motivates _SaveAllPartitioner. + + By default this enables AC (joint_custom_pass = ac_joint_pass) — the + motivating bad case is min-cut + AC tags driving force_save_collectives + to save FSDP allgather outputs. + """ + from autoparallel.compile import _make_ac_joint_pass + + captured = {} + + def simple_backend(gm, example_inputs): + from torch._functorch.aot_autograd import aot_module_simplified + from torch._functorch.partitioners import min_cut_rematerialization_partition + + def fw(g, i): + return g + + def bw(g, i): + captured["bw_module"] = g + return g + + return aot_module_simplified( + gm, + example_inputs, + fw_compiler=fw, + bw_compiler=bw, + partition_fn=min_cut_rematerialization_partition, + ) + + functorch_patches = {} + if enable_ac: + functorch_patches["joint_custom_pass"] = _make_ac_joint_pass() + + with torch._functorch.config.patch(functorch_patches): + compiled = torch.compile(parallel_mod, backend=simple_backend) + x = torch.randint( + 0, vocab_size, (batch_size // mesh.shape[0], seqlen), device="cuda" + ) + out = compiled(x) + out.backward(torch.randn_like(out)) + + bw = captured["bw_module"] + saved_names = [] + for node in bw.graph.nodes: + if node.op != "placeholder": + continue + if isinstance(node.target, str) and ( + "tangent" in node.target or "primals" in node.target + ): + continue + saved_names.append(node.name) + return saved_names # --------------------------------------------------------------------------- @@ -168,10 +309,8 @@ def test_boxed_nop_tag_forward_marks_outputs(): _boxed_nop_preserve_node_meta(gm, None, tag_forward=True) - add_tagged = add.meta.get("custom", {}).get("ap_must_save") - mul_tagged = mul.meta.get("custom", {}).get("ap_must_save") - assert add_tagged is True - assert mul_tagged is True + assert add.meta.get("custom", {}).get("ap_must_save") is True + assert mul.meta.get("custom", {}).get("ap_must_save") is True def test_boxed_nop_tag_forward_skips_getitem(): @@ -182,8 +321,6 @@ def test_boxed_nop_tag_forward_skips_getitem(): x.meta["val"] = torch.randn(4, device="meta") split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) split.meta["val"] = [torch.randn(2, device="meta")] * 2 - import operator - g0 = graph.call_function(operator.getitem, args=(split, 0)) g0.meta["val"] = torch.randn(2, device="meta") graph.output((g0,)) @@ -270,9 +407,32 @@ def test_preserve_node_meta_propagates_recompute_through_collectives(): ), f"{name} lost MUST_RECOMPUTE tag through preserve_node_meta" +def test_patch_partitioner_dce_allows_wait_tensor_elimination(): + """_patch_partitioner_dce overrides is_not_collective for wait_tensor, + and the override is properly reverted on exit.""" + import torch._functorch.partitioners as partitioners + + graph = torch.fx.Graph() + x = graph.placeholder("x") + wt = graph.call_function(torch.ops._c10d_functional.wait_tensor.default, args=(x,)) + graph.output((wt,)) + + original_result = partitioners.is_not_collective(wt) + + with _patch_partitioner_dce(): + patched_result = partitioners.is_not_collective(wt) + + # After exit, the original function is restored + restored_result = partitioners.is_not_collective(wt) + + # Inside the patch, wait_tensor goes through our shortcut + assert patched_result is False + # And the patch is properly restored + assert restored_result == original_result + + # --------------------------------------------------------------------------- -# Integration test for the partitioner — verifies that wait_tensor outputs -# from FSDP allgathers are NOT saved by _SaveAllPartitioner. +# Integration tests for the partitioner behavior on a real model # --------------------------------------------------------------------------- @@ -286,12 +446,14 @@ def test_save_all_partitioner_does_not_save_fsdp_wait_tensors(device_mesh_2d): ) backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) - saved_names, captured = _capture_saved_activations( + captured = _capture_partitioner_call( parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend ) # FSDP wait_tensor outputs should NOT appear as saved activations - saved_wait_tensors = [n for n in saved_names if "wait_tensor" in n] + saved_wait_tensors = [ + n for n in captured["saved_activation_names"] if "wait_tensor" in n + ] assert len(saved_wait_tensors) == 0, ( f"FSDP wait_tensor outputs were saved (should be recomputed via " f"FSDP prefetch): {saved_wait_tensors}" @@ -312,7 +474,7 @@ def test_save_all_partitioner_uses_ap_must_save_tags(device_mesh_2d): ) backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) - saved_names, captured = _capture_saved_activations( + captured = _capture_partitioner_call( parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend ) @@ -325,10 +487,125 @@ def test_save_all_partitioner_uses_ap_must_save_tags(device_mesh_2d): ) # And we should have saved at least some non-trivial activations - activation_saves = [n for n in saved_names if not n.startswith("primals")] + activation_saves = [ + n for n in captured["saved_activation_names"] if not n.startswith("primals") + ] assert len(activation_saves) > 0 +def test_default_partitioner_diverges_from_save_all_partitioner(device_mesh_2d): + """The default min-cut partitioner produces a different save list than + _SaveAllPartitioner when AC is active. This is the motivating reason + _SaveAllPartitioner exists: min-cut + AC tags can save FSDP allgather + outputs (or other tensors that the first partitioner chose to recompute). + + We assert "differ" rather than checking for a specific wait_tensor name + because the exact divergence depends on model size, mesh shape, and the + AC stage budget. For the production LLaMA-3 8B config (32 layers, dim=4096, + 128 GPUs), the divergence manifests as ~1.2 GB of FSDP wait_tensor outputs + being saved per 4 layers — see [memory_gap_investigation memory note]. + """ + parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( + device_mesh_2d, n_layers=2 + ) + + # Default partitioner with AC enabled in second compilation (bad case) + default_saves = _saved_names_from_default_compile( + parallel_mod, + batch_size, + seqlen, + vocab_size, + device_mesh_2d, + enable_ac=True, + ) + + # _SaveAllPartitioner with AC enabled in second compilation (our fix) + backend = autoparallel_backend(enable_ac=True, overlap_scheduling=False) + save_all_captured = _capture_partitioner_call( + parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend + ) + save_all_saves = save_all_captured["saved_activation_names"] + + # The two partitioners pick different things. Equal save count would + # only happen by coincidence — even when counts match the choice of + # tensors differs. + default_set = set(default_saves) + save_all_set = set(save_all_saves) + only_in_default = default_set - save_all_set + only_in_save_all = save_all_set - default_set + assert only_in_default or only_in_save_all, ( + f"Default partitioner and _SaveAllPartitioner produced identical " + f"saves ({sorted(default_set)}). If this is reproducible, the " + f"motivating divergence may have been fixed upstream and " + f"_SaveAllPartitioner could be reevaluated." + ) + + +def test_save_all_partitioner_reproduces_first_partitioner_saves(device_mesh_2d): + """_SaveAllPartitioner's saved set should approximately match what the + first partitioner chose. The two operate on different graphs (the first + inside apply_placement, the second inside torch.compile), so we compare + by tensor shape/dtype histograms rather than node names — names will + differ across compilations but the underlying tensors should match. + """ + # First partitioner: capture the forward outputs beyond num_fwd_outputs. + first_saves = _capture_first_partitioner_saves(device_mesh_2d, n_layers=2) + + # Second partitioner: run torch.compile with our backend and capture + # the backward inputs (saved values). + parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( + device_mesh_2d, n_layers=2 + ) + backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) + captured = _capture_partitioner_call( + parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend + ) + bw = captured["bw_module"] + second_saves = [] + for node in bw.graph.nodes: + if node.op != "placeholder": + continue + if isinstance(node.target, str) and ( + "tangent" in node.target or "primals" in node.target + ): + continue + second_saves.append(node) + + def _shape_sig(node): + val = node.meta.get("val") + if val is None or not hasattr(val, "shape"): + return None + return (tuple(val.shape), str(val.dtype)) + + # Counts should be close. Retracing through Dynamo can add or eliminate + # a small number of view/reshape nodes. + diff = abs(len(first_saves) - len(second_saves)) + assert diff <= 2, ( + f"First partitioner saved {len(first_saves)} values, but " + f"_SaveAllPartitioner saved {len(second_saves)} (diff {diff} > 2). " + f"They should match closely." + ) + + # Shape/dtype histograms should match closely too. + from collections import Counter + + first_shapes = Counter(_shape_sig(n) for n in first_saves) + second_shapes = Counter(_shape_sig(n) for n in second_saves) + # Drop entries with no shape (sym ints, opaque, etc.) + first_shapes.pop(None, None) + second_shapes.pop(None, None) + + # Symmetric difference should be small + diff_count = sum((first_shapes - second_shapes).values()) + sum( + (second_shapes - first_shapes).values() + ) + assert diff_count <= 4, ( + f"Shape histograms diverged too much (diff_count={diff_count}).\n" + f" Only in first: {first_shapes - second_shapes}\n" + f" Only in second: {second_shapes - first_shapes}" + ) + + def test_save_all_partitioner_runs_end_to_end(device_mesh_2d): """Full end-to-end: AutoParallel + torch.compile(autoparallel_backend) forward + backward without errors.""" @@ -346,50 +623,6 @@ def test_save_all_partitioner_runs_end_to_end(device_mesh_2d): ) out = compiled(x) out.backward(torch.randn_like(out)) - # Successful completion is the test - - -# --------------------------------------------------------------------------- -# Supporting machinery tests -# --------------------------------------------------------------------------- - - -def test_patch_partitioner_dce_allows_wait_tensor_elimination(): - """_patch_partitioner_dce makes is_not_collective return False for - wait_tensor — required so unused wait_tensors can be DCE'd.""" - import torch._functorch.partitioners as partitioners - - graph = torch.fx.Graph() - x = graph.placeholder("x") - wt = graph.call_function(torch.ops._c10d_functional.wait_tensor.default, args=(x,)) - graph.output((wt,)) - - original_result = partitioners.is_not_collective(wt) - - with _patch_partitioner_dce(): - patched_result = partitioners.is_not_collective(wt) - - # After exit, original behavior is restored - restored_result = partitioners.is_not_collective(wt) - - # Inside the patch, wait_tensor is treated as DCE-eligible (returns False - # is consistent — see the patch's intent of "letting wait_tensor through") - assert patched_result is False - # And the patch is properly restored - assert restored_result == original_result - - -def test_autoparallel_backend_includes_save_all_partitioner(): - """autoparallel_backend() configures custom_partitioner_fn to be a - _SaveAllPartitioner instance.""" - autoparallel_backend(enable_ac=False, overlap_scheduling=False) - # The backend wires _SaveAllPartitioner as the custom_partitioner_fn. - # Just verify the type exists and is properly defined. - assert _SaveAllPartitioner is not None - p = _SaveAllPartitioner() - assert callable(p) - assert hasattr(p, "uuid") - assert p.uuid() is None def test_save_all_partitioner_compile_with_ac_enabled(device_mesh_2d): From 696329baa68fedd38ef84a50446b2542e51ef6cd Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 9 Jun 2026 10:07:16 +0000 Subject: [PATCH 03/10] Improvements --- autoparallel/api.py | 12 +++-- autoparallel/compile.py | 39 ++++++++++------ tests/test_save_all_partitioner.py | 73 ++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 19 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 7db6162f..54f0ebe3 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -74,12 +74,16 @@ def _boxed_nop_preserve_node_meta(fx_g, example_inputs, tag_forward=False): if out.target == operator.getitem: # getitem metadata doesn't survive preserve_node_meta # (Python builtin, not dispatched). Tag the parent - # multi-output op instead — DCE in _extract_fwd_bwd_modules - # removes any unused getitem children. + # multi-output op instead, keeping the getitem index so the + # second partitioner can replay the exact saved output. parent = out.args[0] if isinstance(parent, torch.fx.Node): - parent.meta.setdefault("custom", {}) - parent.meta["custom"]["ap_must_save"] = True + custom = parent.meta.setdefault("custom", {}) + custom["ap_must_save"] = True + indices = custom.setdefault("ap_must_save_getitem_indices", []) + idx = out.args[1] + if idx not in indices: + indices.append(idx) else: out.meta.setdefault("custom", {}) out.meta["custom"]["ap_must_save"] = True diff --git a/autoparallel/compile.py b/autoparallel/compile.py index 5046233b..cbc119f5 100644 --- a/autoparallel/compile.py +++ b/autoparallel/compile.py @@ -12,6 +12,7 @@ import torch._inductor.config from torch._inductor.compile_fx import compile_fx from torch._inductor.custom_graph_pass import CustomPartitionerFn +from torch.utils.checkpoint import CheckpointPolicy from .api import _suppress_wait_tensor_side_effect from .graph_passes.activation_checkpointing import ac_joint_pass @@ -72,7 +73,6 @@ def __call__( has_recomputable_rng_ops, is_opaque_node, is_sym_node, - must_recompute, raise_getitems, reordering_to_mimic_autograd_engine, thread_graphsafe_rng_from_hops, @@ -96,14 +96,10 @@ def __call__( if graph_has_recomputable_ops: gm = cleanup_recompute_tags(gm, is_default_partition=False) - # Apply PyTorch's standard save-forcing passes. None of these affect - # our own save decision (which only consults `ap_must_save`), but - # they normalize the graph by setting MUST_SAVE on collectives, - # effectful ops, and backward-mutated values. We keep them as a - # defense against future PyTorch internals that may consult these - # tags during extraction. force_save_collectives correctly skips - # nodes already tagged MUST_RECOMPUTE (e.g. FSDP allgathers), so - # the FSDP recomputation contract is preserved. + # Apply PyTorch's standard save-forcing passes, then honor their + # MUST_SAVE tags below. force_save_collectives skips nodes already + # tagged MUST_RECOMPUTE (e.g. FSDP allgathers), so the FSDP + # recomputation contract is preserved. if not torch._functorch.config.unsafe_allow_optimization_of_collectives: force_save_collectives(gm) force_save_effectful_ops(gm) @@ -132,30 +128,43 @@ def _is_multi_output(node: torch.fx.Node) -> bool: and len(node.users) > 0 ) + def _must_recompute(node: torch.fx.Node) -> bool: + return node.meta.get("recompute") is CheckpointPolicy.MUST_RECOMPUTE + + def _must_save(node: torch.fx.Node) -> bool: + return node.meta.get("recompute") is CheckpointPolicy.MUST_SAVE + def _maybe_save(node: torch.fx.Node) -> None: if is_sym_node(node): if not _is_assert_only_symbool(node): saved_sym_nodes.append(node) return if _is_multi_output(node): - # Multi-output ops tagged ap_must_save: save all their - # getitem children (DCE removes unused ones later). - if node.meta.get("custom", {}).get("ap_must_save"): + if _must_recompute(node): + return + custom = node.meta.get("custom", {}) + if _must_save(node) or custom.get("ap_must_save"): + # getitem metadata does not survive preserve_node_meta, so + # the first partitioner tags the parent multi-output op. If + # it recorded specific getitem indices, replay only those. + indices = custom.get("ap_must_save_getitem_indices") for user in node.users: - if user.target == operator.getitem: + if user.target != operator.getitem: + continue + if indices is None or user.args[1] in indices: saved_values.append(user) return if is_opaque_node(node): saved_opaque_nodes.append(node) return - if must_recompute(node): + if _must_recompute(node): return # Save nodes tagged ap_must_save by the first compilation. # These are the forward graph's output tensors from the first # partitioner — reproducing its save/recompute decisions. if node.op == "placeholder": saved_values.append(node) - elif node.meta.get("custom", {}).get("ap_must_save"): + elif node.meta.get("custom", {}).get("ap_must_save") or _must_save(node): saved_values.append(node) for node in node_info.required_fw_nodes: diff --git a/tests/test_save_all_partitioner.py b/tests/test_save_all_partitioner.py index 1cc433cf..5a7bc4a4 100644 --- a/tests/test_save_all_partitioner.py +++ b/tests/test_save_all_partitioner.py @@ -289,6 +289,25 @@ def bw(g, i): return saved_names +def _simple_partition(saved_node_meta): + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(4, device="meta") + tangent = graph.placeholder("tangents_1") + tangent.meta["val"] = torch.randn(4, device="meta") + saved = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + saved.meta["val"] = torch.randn(4, device="meta") + saved.meta.update(saved_node_meta) + bwd = graph.call_function(torch.ops.aten.mul.Tensor, args=(saved, tangent)) + bwd.meta["val"] = torch.randn(4, device="meta") + output = graph.output((saved, bwd)) + output.meta["desc"] = [None, None] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + return _SaveAllPartitioner()( + gm, [torch.randn(4), torch.randn(4)], num_fwd_outputs=1 + ) + + # --------------------------------------------------------------------------- # Unit tests for the standalone mechanisms # --------------------------------------------------------------------------- @@ -330,9 +349,28 @@ def test_boxed_nop_tag_forward_skips_getitem(): # Parent (split) gets the tag, not the getitem assert split.meta.get("custom", {}).get("ap_must_save") is True + assert split.meta.get("custom", {}).get("ap_must_save_getitem_indices") == [0] assert g0.meta.get("custom", {}).get("ap_must_save") is None +def test_boxed_nop_tag_forward_records_getitem_indices(): + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(6, device="meta") + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 3 + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + g2 = graph.call_function(operator.getitem, args=(split, 2)) + g2.meta["val"] = torch.randn(2, device="meta") + graph.output((g0, g2)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + _boxed_nop_preserve_node_meta(gm, None, tag_forward=True) + + assert split.meta["custom"]["ap_must_save_getitem_indices"] == [0, 2] + + def test_boxed_nop_no_tag_forward_default(): """tag_forward defaults to False — no tagging happens.""" graph = torch.fx.Graph() @@ -348,6 +386,41 @@ def test_boxed_nop_no_tag_forward_default(): assert add.meta.get("custom", {}).get("ap_must_save") is None +def test_save_all_partitioner_saves_ap_must_save_despite_prefer_recompute(): + fw, bw = _simple_partition( + { + "custom": {"ap_must_save": True}, + "recompute": CheckpointPolicy.PREFER_RECOMPUTE, + } + ) + + fw_outputs = next(n for n in fw.graph.nodes if n.op == "output").args[0] + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert "add_tensor" in [n.name for n in fw_outputs if isinstance(n, torch.fx.Node)] + assert "add_tensor" in bw_placeholders + + +def test_save_all_partitioner_honors_must_save(): + fw, bw = _simple_partition({"recompute": CheckpointPolicy.MUST_SAVE}) + + fw_outputs = next(n for n in fw.graph.nodes if n.op == "output").args[0] + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert "add_tensor" in [n.name for n in fw_outputs if isinstance(n, torch.fx.Node)] + assert "add_tensor" in bw_placeholders + + +def test_save_all_partitioner_does_not_save_must_recompute(): + fw, bw = _simple_partition( + { + "custom": {"ap_must_save": True}, + "recompute": CheckpointPolicy.MUST_RECOMPUTE, + } + ) + + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert "add_tensor" not in bw_placeholders + + def test_preserve_node_meta_propagates_recompute_through_collectives(): """preserve_node_meta correctly propagates MUST_RECOMPUTE through collective ops (allgather, wait_tensor) — confirming the safety net From 080efd06805808508a0e0916414fb1d66ca902d7 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 9 Jun 2026 12:19:38 +0000 Subject: [PATCH 04/10] Tighten _SaveAllPartitioner policy and add corner-case tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three behavioral changes plus tests and clarifying comments. 1. MUST_SAVE on a multi-output parent now overrides ap_must_save_getitem_indices. Previously, if a multi-output op was both MUST_SAVE and ap_must_save with a specific index list, only the indexed children were saved — silently under-saving relative to what MUST_SAVE means ("save all tensor outputs needed from this op"). The fix makes MUST_SAVE clear the index restriction, keeping exact replay for ap_must_save while keeping PyTorch's MUST_SAVE tags conservative. 2. Deduplicate saved_values, saved_sym_nodes, and saved_opaque_nodes before handing them to _extract_fwd_bwd_modules. Matches upstream default_partition's dict.fromkeys pattern. Defensive — duplicates don't arise in the current control flow, but they could under future refactors of the iteration order or tag combinations, and the cost is one line per list. 3. Added clarifying comments on three corner cases: - Opaque nodes (ProcessGroup, ScriptObject) are saved unconditionally regardless of ap_must_save/MUST_SAVE; documents the intentional deviation from pure replay semantics, matching the standard partitioner. - The inference fallback to default_partition when there are no backward nodes; documents that ap_must_save tags are bypassed there because inference doesn't have the fwd/bwd-divergence problem the partitioner exists to solve. - CSE merges duplicate chains without combining metadata; documents that the safety contract holds for FSDP allgather chains (first occurrence keeps MUST_RECOMPUTE) and that general replay across CSE'd duplicates is a known limitation. Tests added (all verified to catch a real regression by reverting the corresponding fix and checking the assertion fires): - test_save_all_partitioner_must_save_overrides_getitem_indices - test_save_all_partitioner_replays_only_indexed_getitems (locks in the index-specific replay precision from the earlier round) - test_save_all_partitioner_must_recompute_blocks_opaque_save and test_save_all_partitioner_must_recompute_blocks_multi_output_save (regression guards for the _must_recompute ordering) - test_save_all_partitioner_multi_output_with_non_getitem_user (covers the tuple-aware _is_multi_output check) Authored with Claude. --- autoparallel/compile.py | 61 ++++++- tests/test_save_all_partitioner.py | 260 +++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+), 9 deletions(-) diff --git a/autoparallel/compile.py b/autoparallel/compile.py index cbc119f5..7c882a30 100644 --- a/autoparallel/compile.py +++ b/autoparallel/compile.py @@ -85,6 +85,14 @@ def __call__( # MUST_RECOMPUTE allgathers and the baked-in backward copies # compute the same values from the same primals. Without CSE, # both appear in the backward (duplication). + # + # Caveat: fx_graph_cse keeps the first occurrence and drops later + # duplicates without merging metadata. If a duplicate has + # ap_must_save and the kept node doesn't, the tag is lost. In + # practice the duplicates we care about are FSDP allgather chains + # whose first occurrence (forward) carries MUST_RECOMPUTE — kept + # correctly — so the safety contract holds. General replay + # semantics across CSE'd duplicates is a known limitation. if torch._functorch.config.cse: from torch._functorch.partitioners import fx_graph_cse @@ -110,6 +118,10 @@ def __call__( node_info = classify_nodes(gm, static_lifetime_input_indices, num_fwd_outputs) if len(node_info.required_bw_nodes) == 0: + # Inference path (no backward nodes from autograd): fall back to + # the standard partitioner. Our ap_must_save tags are not used + # here, but inference doesn't have the fwd/bwd-divergence problem + # _SaveAllPartitioner exists to solve. return default_partition( gm, joint_inputs, @@ -122,10 +134,18 @@ def __call__( saved_sym_nodes = [] saved_opaque_nodes = [] + def _has_tuple_val(node: torch.fx.Node) -> bool: + return isinstance(node.meta.get("val"), (list, tuple)) + def _is_multi_output(node: torch.fx.Node) -> bool: - return ( - all(user.target == operator.getitem for user in node.users) - and len(node.users) > 0 + # Definitive test: the node returns a tuple/list. Checking users + # alone (only getitem) is fragile because a single non-getitem + # user (e.g. a debug op or a sym-shape extractor) flips the + # answer without changing the node's actual return type. + if _has_tuple_val(node): + return True + return len(node.users) > 0 and all( + user.target == operator.getitem for user in node.users ) def _must_recompute(node: torch.fx.Node) -> bool: @@ -135,19 +155,32 @@ def _must_save(node: torch.fx.Node) -> bool: return node.meta.get("recompute") is CheckpointPolicy.MUST_SAVE def _maybe_save(node: torch.fx.Node) -> None: + # MUST_RECOMPUTE wins over everything else. Check before any + # save branch (including opaque) so the first compilation's + # explicit recompute intent is always honored. + if _must_recompute(node): + return + # Sym nodes are saved unconditionally per the standard + # partitioner convention; ap_must_save/MUST_SAVE tags on them + # are ignored. Assert-only symbools are pure runtime checks + # and don't need to cross the fwd/bwd boundary. if is_sym_node(node): if not _is_assert_only_symbool(node): saved_sym_nodes.append(node) return if _is_multi_output(node): - if _must_recompute(node): - return custom = node.meta.get("custom", {}) if _must_save(node) or custom.get("ap_must_save"): # getitem metadata does not survive preserve_node_meta, so # the first partitioner tags the parent multi-output op. If - # it recorded specific getitem indices, replay only those. - indices = custom.get("ap_must_save_getitem_indices") + # only ap_must_save is set and it recorded specific getitem + # indices, replay only those. MUST_SAVE is a stronger + # directive ("save all tensor outputs needed") and overrides + # the index restriction — save every getitem child. + if _must_save(node): + indices = None + else: + indices = custom.get("ap_must_save_getitem_indices") for user in node.users: if user.target != operator.getitem: continue @@ -155,10 +188,12 @@ def _maybe_save(node: torch.fx.Node) -> None: saved_values.append(user) return if is_opaque_node(node): + # Opaque nodes (ProcessGroup, ScriptObject) can't be + # recomputed — they have no functional meaning. Saved + # unconditionally regardless of ap_must_save / MUST_SAVE + # tags; this matches the standard partitioner's behavior. saved_opaque_nodes.append(node) return - if _must_recompute(node): - return # Save nodes tagged ap_must_save by the first compilation. # These are the forward graph's output tensors from the first # partitioner — reproducing its save/recompute decisions. @@ -177,6 +212,14 @@ def _maybe_save(node: torch.fx.Node) -> None: for node in node_info.unclaimed_nodes: _maybe_save(node) + # Deduplicate. Multi-output getitem handling, overlapping iteration + # over required_fw_nodes + unclaimed_nodes, or overlapping save tags + # can land the same node in a list twice. Matches upstream + # default_partition's dict.fromkeys deduping. + saved_values = list(dict.fromkeys(saved_values)) + saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes)) + saved_opaque_nodes = list(dict.fromkeys(saved_opaque_nodes)) + fw_module, bw_module = _extract_fwd_bwd_modules( gm, saved_values, diff --git a/tests/test_save_all_partitioner.py b/tests/test_save_all_partitioner.py index 5a7bc4a4..95fba5f4 100644 --- a/tests/test_save_all_partitioner.py +++ b/tests/test_save_all_partitioner.py @@ -308,6 +308,65 @@ def _simple_partition(saved_node_meta): ) +def _multi_output_partition( + parent_meta, + saved_indices=None, + extra_parent_consumer=False, +): + """Build a tiny joint graph with a multi-output forward op (split) and + run _SaveAllPartitioner on it. + + parent_meta is merged into the multi-output op's meta. Optionally tag + saved_indices on the parent (mirroring what tag_forward does). If + extra_parent_consumer is True, the parent is also referenced directly + in the output (a non-getitem user that survives DCE), so the partitioner + observes split.users containing both getitems and a non-getitem node. + """ + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(6, device="meta") + tangent = graph.placeholder("tangents_1") + tangent.meta["val"] = torch.randn(6, device="meta") + + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 3 + custom = split.meta.setdefault("custom", {}) + custom.update(parent_meta.get("custom", {})) + if "recompute" in parent_meta: + split.meta["recompute"] = parent_meta["recompute"] + if saved_indices is not None: + custom["ap_must_save_getitem_indices"] = saved_indices + + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + g1 = graph.call_function(operator.getitem, args=(split, 1)) + g1.meta["val"] = torch.randn(2, device="meta") + g2 = graph.call_function(operator.getitem, args=(split, 2)) + g2.meta["val"] = torch.randn(2, device="meta") + + # Concatenate g0+g1+g2 + tangent so all getitems feed into backward + cat = graph.call_function(torch.ops.aten.cat.default, args=([g0, g1, g2],)) + cat.meta["val"] = torch.randn(6, device="meta") + bwd = graph.call_function(torch.ops.aten.mul.Tensor, args=(cat, tangent)) + bwd.meta["val"] = torch.randn(6, device="meta") + + # Optionally make split a direct forward output (non-getitem user). + if extra_parent_consumer: + output_args = (cat, split, bwd) + num_fwd_outputs = 2 + else: + output_args = (cat, bwd) + num_fwd_outputs = 1 + output = graph.output(output_args) + output.meta["desc"] = [None] * len(output_args) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + return _SaveAllPartitioner()( + gm, + [torch.randn(6), torch.randn(6)], + num_fwd_outputs=num_fwd_outputs, + ) + + # --------------------------------------------------------------------------- # Unit tests for the standalone mechanisms # --------------------------------------------------------------------------- @@ -421,6 +480,207 @@ def test_save_all_partitioner_does_not_save_must_recompute(): assert "add_tensor" not in bw_placeholders +def test_save_all_partitioner_must_recompute_blocks_multi_output_save(): + """MUST_RECOMPUTE on a multi-output op blocks saving its getitem + children, even when ap_must_save is also set. Documents the invariant + that the first partitioner's recompute decision wins over its own + save tags (which can happen if both are set during graph passes).""" + fw, bw = _multi_output_partition( + parent_meta={ + "custom": {"ap_must_save": True}, + "recompute": CheckpointPolicy.MUST_RECOMPUTE, + }, + saved_indices=[0, 1, 2], + ) + + # No getitem children should appear in backward inputs + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert not any( + name.startswith("getitem") for name in bw_placeholders + ), f"MUST_RECOMPUTE should block multi-output save, got: {bw_placeholders}" + + +def test_save_all_partitioner_must_recompute_blocks_opaque_save(): + """MUST_RECOMPUTE blocks saving even on nodes that would otherwise be + saved as opaque. The _must_recompute check must run before is_opaque_node + so the first partitioner's recompute intent is honored uniformly.""" + # Build a fake graph with a node we'll force is_opaque_node to recognize. + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(4, device="meta") + tangent = graph.placeholder("tangents_1") + tangent.meta["val"] = torch.randn(4, device="meta") + opaque_like = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + opaque_like.meta["val"] = torch.randn(4, device="meta") + opaque_like.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + bwd = graph.call_function(torch.ops.aten.mul.Tensor, args=(opaque_like, tangent)) + bwd.meta["val"] = torch.randn(4, device="meta") + output = graph.output((opaque_like, bwd)) + output.meta["desc"] = [None, None] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + # Force is_opaque_node to return True for our op. _SaveAllPartitioner + # imports it from torch._functorch.partitioners, so patch there. + import torch._functorch.partitioners as partitioners + + original = partitioners.is_opaque_node + partitioners.is_opaque_node = lambda n: n.name == "add_tensor" + try: + fw, bw = _SaveAllPartitioner()( + gm, [torch.randn(4), torch.randn(4)], num_fwd_outputs=1 + ) + finally: + partitioners.is_opaque_node = original + + # add_tensor was tagged MUST_RECOMPUTE; even though is_opaque_node + # returned True, it should NOT be saved + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert "add_tensor" not in bw_placeholders, ( + f"MUST_RECOMPUTE on an opaque-classified node was ignored; " + f"backward placeholders: {bw_placeholders}" + ) + + +def test_save_all_partitioner_multi_output_with_non_getitem_user(): + """A multi-output op tagged ap_must_save must save its getitem children + (not the parent tuple) even when it has a non-getitem user. The save + logic must look at the node's value type, not just its users.""" + fw, bw = _multi_output_partition( + parent_meta={"custom": {"ap_must_save": True}}, + saved_indices=[0, 1, 2], + extra_parent_consumer=True, + ) + + # Backward should receive the getitem children, NOT the split tuple + bw_placeholders = [n for n in bw.graph.nodes if n.op == "placeholder"] + # No placeholder should have a tuple/list value (which would mean we + # saved the multi-output op directly) + for n in bw_placeholders: + val = n.meta.get("val") + assert not isinstance(val, (list, tuple)), ( + f"Backward got placeholder {n.name} with tuple/list val — multi-output op " + f"was saved directly instead of its getitem children" + ) + # And at least one getitem should be present (the save took effect) + getitem_names = [n.name for n in bw_placeholders if n.name.startswith("getitem")] + assert len(getitem_names) > 0, ( + f"Expected getitem children to be saved, got placeholders: " + f"{[n.name for n in bw_placeholders]}" + ) + + +def test_save_all_partitioner_replays_only_indexed_getitems(): + """When tag_forward records specific getitem indices on a multi-output + parent, only those indices should appear as saved backward inputs; + other indices should not be saved (and would be recomputed if needed). + + Build a graph where backward consumes BOTH getitem 0 and getitem 1, but + we only tag index 0 as ap_must_save. The partitioner should save + getitem 0 only. + """ + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(4, device="meta") + t0 = graph.placeholder("tangents_1") + t0.meta["val"] = torch.randn(2, device="meta") + t1 = graph.placeholder("tangents_2") + t1.meta["val"] = torch.randn(2, device="meta") + + # Multi-output op: split into 2 chunks. Mark only index 0 as ap_must_save. + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 2 + custom = split.meta.setdefault("custom", {}) + custom["ap_must_save"] = True + custom["ap_must_save_getitem_indices"] = [0] + + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + g1 = graph.call_function(operator.getitem, args=(split, 1)) + g1.meta["val"] = torch.randn(2, device="meta") + + # Forward outputs: a value that depends on both g0 and g1 (so both + # getitems are required in forward). + add_fw = graph.call_function(torch.ops.aten.add.Tensor, args=(g0, g1)) + add_fw.meta["val"] = torch.randn(2, device="meta") + + # Backward ops: independent muls using g0 and g1 respectively, so each + # getitem is a real backward dependency. + bwd0 = graph.call_function(torch.ops.aten.mul.Tensor, args=(g0, t0)) + bwd0.meta["val"] = torch.randn(2, device="meta") + bwd1 = graph.call_function(torch.ops.aten.mul.Tensor, args=(g1, t1)) + bwd1.meta["val"] = torch.randn(2, device="meta") + + output = graph.output((add_fw, bwd0, bwd1)) + output.meta["desc"] = [None, None, None] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + fw, bw = _SaveAllPartitioner()( + gm, + [torch.randn(4), torch.randn(2), torch.randn(2)], + num_fwd_outputs=1, + ) + + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + # getitem 0 should be a saved backward input (we tagged index 0) + assert ( + "getitem" in bw_placeholders + ), f"Expected getitem (index 0) to be saved; got {bw_placeholders}" + # getitem 1 should NOT be a saved input — it's not in the indices list + assert "getitem_1" not in bw_placeholders, ( + f"Expected getitem_1 (index 1) NOT to be saved (index restriction); " + f"got {bw_placeholders}" + ) + + +def test_save_all_partitioner_must_save_overrides_getitem_indices(): + """MUST_SAVE on a multi-output parent should save ALL getitem children + even if ap_must_save_getitem_indices restricts to a subset. MUST_SAVE + is a stronger directive than ap_must_save's index-specific replay.""" + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(4, device="meta") + t0 = graph.placeholder("tangents_1") + t0.meta["val"] = torch.randn(2, device="meta") + t1 = graph.placeholder("tangents_2") + t1.meta["val"] = torch.randn(2, device="meta") + + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 2 + # Both MUST_SAVE AND restricted indices. MUST_SAVE should win. + split.meta["recompute"] = CheckpointPolicy.MUST_SAVE + custom = split.meta.setdefault("custom", {}) + custom["ap_must_save"] = True + custom["ap_must_save_getitem_indices"] = [0] + + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + g1 = graph.call_function(operator.getitem, args=(split, 1)) + g1.meta["val"] = torch.randn(2, device="meta") + + add_fw = graph.call_function(torch.ops.aten.add.Tensor, args=(g0, g1)) + add_fw.meta["val"] = torch.randn(2, device="meta") + bwd0 = graph.call_function(torch.ops.aten.mul.Tensor, args=(g0, t0)) + bwd0.meta["val"] = torch.randn(2, device="meta") + bwd1 = graph.call_function(torch.ops.aten.mul.Tensor, args=(g1, t1)) + bwd1.meta["val"] = torch.randn(2, device="meta") + + output = graph.output((add_fw, bwd0, bwd1)) + output.meta["desc"] = [None, None, None] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + fw, bw = _SaveAllPartitioner()( + gm, + [torch.randn(4), torch.randn(2), torch.randn(2)], + num_fwd_outputs=1, + ) + + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + # Both getitems should be saved (MUST_SAVE overrides the index restriction) + assert "getitem" in bw_placeholders, f"Expected getitem; got {bw_placeholders}" + assert "getitem_1" in bw_placeholders, ( + f"Expected getitem_1 even though indices=[0] (MUST_SAVE should override); " + f"got {bw_placeholders}" + ) + + def test_preserve_node_meta_propagates_recompute_through_collectives(): """preserve_node_meta correctly propagates MUST_RECOMPUTE through collective ops (allgather, wait_tensor) — confirming the safety net From ede4a91242ba56c2bcea88be9e35f56360218ebb Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 08:51:28 +0000 Subject: [PATCH 05/10] Fix tests --- tests/test_save_all_partitioner.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_save_all_partitioner.py b/tests/test_save_all_partitioner.py index 95fba5f4..35b1aab7 100644 --- a/tests/test_save_all_partitioner.py +++ b/tests/test_save_all_partitioner.py @@ -20,6 +20,7 @@ import operator import torch +from conftest import apply_cuda_patches from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard from torch.utils.checkpoint import CheckpointPolicy @@ -769,6 +770,7 @@ def test_patch_partitioner_dce_allows_wait_tensor_elimination(): # --------------------------------------------------------------------------- +@apply_cuda_patches def test_save_all_partitioner_does_not_save_fsdp_wait_tensors(device_mesh_2d): """The whole point of the machinery: with FSDP allgathers tagged MUST_RECOMPUTE, _SaveAllPartitioner should NOT save their wait_tensor @@ -799,6 +801,7 @@ def test_save_all_partitioner_does_not_save_fsdp_wait_tensors(device_mesh_2d): ) +@apply_cuda_patches def test_save_all_partitioner_uses_ap_must_save_tags(device_mesh_2d): """_SaveAllPartitioner saves nodes tagged ap_must_save by the first compilation.""" @@ -826,6 +829,7 @@ def test_save_all_partitioner_uses_ap_must_save_tags(device_mesh_2d): assert len(activation_saves) > 0 +@apply_cuda_patches def test_default_partitioner_diverges_from_save_all_partitioner(device_mesh_2d): """The default min-cut partitioner produces a different save list than _SaveAllPartitioner when AC is active. This is the motivating reason @@ -874,6 +878,7 @@ def test_default_partitioner_diverges_from_save_all_partitioner(device_mesh_2d): ) +@apply_cuda_patches def test_save_all_partitioner_reproduces_first_partitioner_saves(device_mesh_2d): """_SaveAllPartitioner's saved set should approximately match what the first partitioner chose. The two operate on different graphs (the first @@ -939,6 +944,7 @@ def _shape_sig(node): ) +@apply_cuda_patches def test_save_all_partitioner_runs_end_to_end(device_mesh_2d): """Full end-to-end: AutoParallel + torch.compile(autoparallel_backend) forward + backward without errors.""" @@ -958,6 +964,7 @@ def test_save_all_partitioner_runs_end_to_end(device_mesh_2d): out.backward(torch.randn_like(out)) +@apply_cuda_patches def test_save_all_partitioner_compile_with_ac_enabled(device_mesh_2d): """autoparallel_backend(enable_ac=True) runs the AC joint pass before the partitioner. The combination should compile end-to-end.""" @@ -978,6 +985,7 @@ def test_save_all_partitioner_compile_with_ac_enabled(device_mesh_2d): out.backward(torch.randn_like(out)) +@apply_cuda_patches def test_save_all_partitioner_compile_1d_mesh(device_mesh_1d): """The partitioner works with a 1D (FSDP-only) mesh.""" parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( From f63eda0b58111afc4b0430af2011b86b1719eb4b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 09:58:31 +0000 Subject: [PATCH 06/10] Make tests faster MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary Two structural changes 1. **Capture-only backend for partitioner-checking tests.** Replaced `_capture_partitioner_call`'s implementation to wire `_SaveAllPartitioner` directly via `aot_module_simplified` with identity compilers. The partitioner runs and is captured, but no Triton kernel codegen happens. The function signature changed from accepting a pre-built `backend` to accepting `enable_ac=` directly. 2. **Module-scoped fixtures for `parallel_mod`.** Added `parallel_mod_2d` and `parallel_mod_1d` fixtures that cache the AutoParallel solve + apply_placement (~31s for 2D). Five integration tests now share the same parallel module instead of rebuilding it. Test cleanup - Removed `test_save_all_partitioner_runs_end_to_end` (redundant with `compile_with_ac_enabled`). - Removed `test_save_all_partitioner_uses_ap_must_save_tags` (its assertions folded into `does_not_save_fsdp_wait_tensors`). Result - **252s → 135s** (47% reduction), 21 → 19 tests - Capture-only tests now run in 3-6s each (was 25-50s with full Inductor) - The two real-Inductor smoke tests (`compile_with_ac_enabled`, `compile_1d_mesh`) remain to guard against full-pipeline regressions The breakdown: - 14 unit tests: ~1s total - 1 reproduces test (calls `_capture_first_partitioner_saves` which builds its own AutoParallel): ~24s - 3 capture-only 2D integration tests: ~9-15s combined - 1 Inductor 2D smoke test: ~14s - 1 Inductor 1D smoke test (separate mesh): ~45s + 8s setup - Module-scope 2D fixture setup: ~31s (shared across 5 tests) --- tests/test_save_all_partitioner.py | 193 ++++++++++++++++------------- 1 file changed, 110 insertions(+), 83 deletions(-) diff --git a/tests/test_save_all_partitioner.py b/tests/test_save_all_partitioner.py index 35b1aab7..7fb6f228 100644 --- a/tests/test_save_all_partitioner.py +++ b/tests/test_save_all_partitioner.py @@ -19,6 +19,7 @@ import operator +import pytest import torch from conftest import apply_cuda_patches from torch.distributed.fsdp import MixedPrecisionPolicy @@ -93,11 +94,36 @@ def _run_autoparallel(mesh, n_layers=2, batch_size=None, seqlen=128): return parallel_mod, batch_size, seqlen, vocab_size +# Module-scoped caches for parallel_mod. Each integration test below builds +# the same one (deterministic for a given mesh shape) — caching saves ~30s +# per test. The cache is keyed by mesh.ndim since the helper only varies +# along that dimension; both 1D and 2D meshes get their own entry. +_parallel_mod_cache: dict = {} + + +@pytest.fixture(scope="module") +@apply_cuda_patches +def parallel_mod_2d(device_mesh_2d): + if "2d" not in _parallel_mod_cache: + _parallel_mod_cache["2d"] = _run_autoparallel(device_mesh_2d, n_layers=2) + return _parallel_mod_cache["2d"] + + +@pytest.fixture(scope="module") +@apply_cuda_patches +def parallel_mod_1d(device_mesh_1d): + if "1d" not in _parallel_mod_cache: + _parallel_mod_cache["1d"] = _run_autoparallel(device_mesh_1d, n_layers=2) + return _parallel_mod_cache["1d"] + + def _capture_partitioner_call( - parallel_mod, batch_size, seqlen, vocab_size, mesh, backend + parallel_mod, batch_size, seqlen, vocab_size, mesh, enable_ac=False ): - """Compile parallel_mod with `backend`, capture what _SaveAllPartitioner - sees and returns, and return the captured info dict. + """Run the second compilation with _SaveAllPartitioner as the + partition_fn but use identity compilers (no Inductor codegen). This + keeps the partitioner under test in the loop while skipping the + expensive Triton kernel compilation. The dict contains: - wait_tensor_recompute_tags: count of MUST_RECOMPUTE on wait_tensors @@ -106,40 +132,68 @@ def _capture_partitioner_call( - fw_module, bw_module: the partitioned graph modules - saved_activation_names: backward inputs that aren't primals or tangents """ + from autoparallel.api import _suppress_wait_tensor_side_effect + from autoparallel.compile import _make_ac_joint_pass, _patch_partitioner_dce + captured = {} - orig_call = _SaveAllPartitioner.__call__ + partitioner = _SaveAllPartitioner() - def capturing_call(self, gm, joint_inputs, **kwargs): + def capturing_partition_fn( + joint_module, joint_inputs, *, num_fwd_outputs, **kwargs + ): captured["wait_tensor_recompute_tags"] = sum( 1 - for n in gm.graph.nodes + for n in joint_module.graph.nodes if "wait_tensor" in n.name and n.meta.get("recompute") == CheckpointPolicy.MUST_RECOMPUTE ) captured["allgather_recompute_tags"] = sum( 1 - for n in gm.graph.nodes + for n in joint_module.graph.nodes if "all_gather_into_tensor" in n.name and n.meta.get("recompute") == CheckpointPolicy.MUST_RECOMPUTE ) captured["ap_must_save_count"] = sum( - 1 for n in gm.graph.nodes if n.meta.get("custom", {}).get("ap_must_save") + 1 + for n in joint_module.graph.nodes + if n.meta.get("custom", {}).get("ap_must_save") + ) + fw, bw = partitioner( + joint_module, + joint_inputs, + num_fwd_outputs=num_fwd_outputs, + **kwargs, ) - fw, bw = orig_call(self, gm, joint_inputs, **kwargs) captured["fw_module"] = fw captured["bw_module"] = bw return fw, bw - _SaveAllPartitioner.__call__ = capturing_call - try: - compiled = torch.compile(parallel_mod, backend=backend) + def capture_only_backend(gm, example_inputs): + from torch._functorch.aot_autograd import aot_module_simplified + + return aot_module_simplified( + gm, + example_inputs, + fw_compiler=lambda g, i: g, + bw_compiler=lambda g, i: g, + partition_fn=capturing_partition_fn, + ) + + functorch_patches = {} + if enable_ac: + functorch_patches["joint_custom_pass"] = _make_ac_joint_pass() + + with ( + _suppress_wait_tensor_side_effect(), + _patch_partitioner_dce(), + torch._functorch.config.patch(functorch_patches), + ): + compiled = torch.compile(parallel_mod, backend=capture_only_backend) x = torch.randint( 0, vocab_size, (batch_size // mesh.shape[0], seqlen), device="cuda" ) out = compiled(x) out.backward(torch.randn_like(out)) - finally: - _SaveAllPartitioner.__call__ = orig_call # The backward graph's placeholders (minus tangents and primals) are the # saved activations. @@ -771,18 +825,22 @@ def test_patch_partitioner_dce_allows_wait_tensor_elimination(): @apply_cuda_patches -def test_save_all_partitioner_does_not_save_fsdp_wait_tensors(device_mesh_2d): +def test_save_all_partitioner_does_not_save_fsdp_wait_tensors( + parallel_mod_2d, device_mesh_2d +): """The whole point of the machinery: with FSDP allgathers tagged MUST_RECOMPUTE, _SaveAllPartitioner should NOT save their wait_tensor outputs in the backward. """ - parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( - device_mesh_2d, n_layers=2 - ) - backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_2d captured = _capture_partitioner_call( - parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend + parallel_mod, + batch_size, + seqlen, + vocab_size, + device_mesh_2d, + enable_ac=False, ) # FSDP wait_tensor outputs should NOT appear as saved activations @@ -800,29 +858,12 @@ def test_save_all_partitioner_does_not_save_fsdp_wait_tensors(device_mesh_2d): "preserve_node_meta into the second compilation's joint graph" ) - -@apply_cuda_patches -def test_save_all_partitioner_uses_ap_must_save_tags(device_mesh_2d): - """_SaveAllPartitioner saves nodes tagged ap_must_save by the first - compilation.""" - parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( - device_mesh_2d, n_layers=2 - ) - backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) - - captured = _capture_partitioner_call( - parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend - ) - - # ap_must_save tags should be present in the joint graph (set by - # _boxed_nop_preserve_node_meta during first compilation, propagated - # via preserve_node_meta) + # And ap_must_save tags should be present (sanity check on tag + # propagation — used to be a separate test). assert captured["ap_must_save_count"] > 0, ( "Expected ap_must_save tags from first compilation to survive into " "second compilation's joint graph" ) - - # And we should have saved at least some non-trivial activations activation_saves = [ n for n in captured["saved_activation_names"] if not n.startswith("primals") ] @@ -830,7 +871,9 @@ def test_save_all_partitioner_uses_ap_must_save_tags(device_mesh_2d): @apply_cuda_patches -def test_default_partitioner_diverges_from_save_all_partitioner(device_mesh_2d): +def test_default_partitioner_diverges_from_save_all_partitioner( + parallel_mod_2d, device_mesh_2d +): """The default min-cut partitioner produces a different save list than _SaveAllPartitioner when AC is active. This is the motivating reason _SaveAllPartitioner exists: min-cut + AC tags can save FSDP allgather @@ -842,9 +885,7 @@ def test_default_partitioner_diverges_from_save_all_partitioner(device_mesh_2d): 128 GPUs), the divergence manifests as ~1.2 GB of FSDP wait_tensor outputs being saved per 4 layers — see [memory_gap_investigation memory note]. """ - parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( - device_mesh_2d, n_layers=2 - ) + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_2d # Default partitioner with AC enabled in second compilation (bad case) default_saves = _saved_names_from_default_compile( @@ -857,9 +898,13 @@ def test_default_partitioner_diverges_from_save_all_partitioner(device_mesh_2d): ) # _SaveAllPartitioner with AC enabled in second compilation (our fix) - backend = autoparallel_backend(enable_ac=True, overlap_scheduling=False) save_all_captured = _capture_partitioner_call( - parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend + parallel_mod, + batch_size, + seqlen, + vocab_size, + device_mesh_2d, + enable_ac=True, ) save_all_saves = save_all_captured["saved_activation_names"] @@ -879,7 +924,9 @@ def test_default_partitioner_diverges_from_save_all_partitioner(device_mesh_2d): @apply_cuda_patches -def test_save_all_partitioner_reproduces_first_partitioner_saves(device_mesh_2d): +def test_save_all_partitioner_reproduces_first_partitioner_saves( + parallel_mod_2d, device_mesh_2d +): """_SaveAllPartitioner's saved set should approximately match what the first partitioner chose. The two operate on different graphs (the first inside apply_placement, the second inside torch.compile), so we compare @@ -889,14 +936,16 @@ def test_save_all_partitioner_reproduces_first_partitioner_saves(device_mesh_2d) # First partitioner: capture the forward outputs beyond num_fwd_outputs. first_saves = _capture_first_partitioner_saves(device_mesh_2d, n_layers=2) - # Second partitioner: run torch.compile with our backend and capture - # the backward inputs (saved values). - parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( - device_mesh_2d, n_layers=2 - ) - backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) + # Second partitioner: reuse the cached parallel_mod and capture the + # backward inputs (saved values). + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_2d captured = _capture_partitioner_call( - parallel_mod, batch_size, seqlen, vocab_size, device_mesh_2d, backend + parallel_mod, + batch_size, + seqlen, + vocab_size, + device_mesh_2d, + enable_ac=False, ) bw = captured["bw_module"] second_saves = [] @@ -945,32 +994,12 @@ def _shape_sig(node): @apply_cuda_patches -def test_save_all_partitioner_runs_end_to_end(device_mesh_2d): - """Full end-to-end: AutoParallel + torch.compile(autoparallel_backend) - forward + backward without errors.""" - parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( - device_mesh_2d, n_layers=2 - ) - backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) - - compiled = torch.compile(parallel_mod, backend=backend) - x = torch.randint( - 0, - vocab_size, - (batch_size // device_mesh_2d.shape[0], seqlen), - device="cuda", - ) - out = compiled(x) - out.backward(torch.randn_like(out)) - - -@apply_cuda_patches -def test_save_all_partitioner_compile_with_ac_enabled(device_mesh_2d): - """autoparallel_backend(enable_ac=True) runs the AC joint pass before - the partitioner. The combination should compile end-to-end.""" - parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( - device_mesh_2d, n_layers=2 - ) +def test_save_all_partitioner_compile_with_ac_enabled(parallel_mod_2d, device_mesh_2d): + """End-to-end smoke test: AutoParallel + torch.compile(autoparallel_backend) + with AC enabled. Exercises the full Inductor pipeline including AC joint + pass and codegen — kept around so a regression in the backend wiring + surfaces here even if the unit tests still pass.""" + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_2d # AC enabled in the backend → joint_custom_pass adds PREFER_RECOMPUTE backend = autoparallel_backend(enable_ac=True, overlap_scheduling=False) @@ -986,11 +1015,9 @@ def test_save_all_partitioner_compile_with_ac_enabled(device_mesh_2d): @apply_cuda_patches -def test_save_all_partitioner_compile_1d_mesh(device_mesh_1d): +def test_save_all_partitioner_compile_1d_mesh(parallel_mod_1d, device_mesh_1d): """The partitioner works with a 1D (FSDP-only) mesh.""" - parallel_mod, batch_size, seqlen, vocab_size = _run_autoparallel( - device_mesh_1d, n_layers=2 - ) + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_1d backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) compiled = torch.compile(parallel_mod, backend=backend) From ca5de35eaa6163d9736e2c2e4b20ed5bba8e550d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 12:10:39 +0000 Subject: [PATCH 07/10] Diagnose failures --- .github/workflows/test_cuda.yml | 11 ++ tests/diagnose_flatten_ci.py | 258 ++++++++++++++++++++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 tests/diagnose_flatten_ci.py diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 9ee4d5c9..e34ed9c4 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -31,7 +31,18 @@ jobs: pip uninstall -y torch pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 pip install --quiet . + # Run tests, but always emit the _flatten diagnostic afterward so + # CI job logs capture which mesh triggers the as_strided dispatch + # failure on A10G (see tests/diagnose_flatten_ci.py). + set +e pytest tests + pytest_status=$? + set -e + echo "============================================================" + echo " _flatten diagnostic (always runs)" + echo "============================================================" + python tests/diagnose_flatten_ci.py || true + exit $pytest_status examples-cuda-single-gpu: name: Examples CUDA Single GPU (cuda12.6-py3.12) diff --git a/tests/diagnose_flatten_ci.py b/tests/diagnose_flatten_ci.py new file mode 100644 index 00000000..63c8a045 --- /dev/null +++ b/tests/diagnose_flatten_ci.py @@ -0,0 +1,258 @@ +""" +Diagnose the CI _flatten cache miss. + +Wraps DeviceMesh._flatten and DeviceMesh._create_flatten_mesh to log: + - At every _flatten call: mesh identity, dim_names, root_mesh id, + cache state, call site + - At every _create_flatten_mesh entry: same info, plus whether the + cache check at line 904 hit or missed + +The goal is to figure out, on CI, which mesh is triggering the +`as_strided` dispatch failure: is it the user's top-level mesh (cache +should hit), or a sub-mesh / independently-constructed mesh? + +Run with: + PYTHONPATH=. python tests/diagnose_flatten_ci.py + +This script reproduces the same path as the failing CI test +(test_save_all_partitioner_compile_with_ac_enabled) but with logging. +""" + +import os +import traceback +from typing import Any + +# Force device emulation matching CI before importing torch +from unittest.mock import patch + +_PATCHES: list[Any] = [ + patch("torch.cuda.device_count", lambda: 8), + patch("torch.cuda.get_device_name", lambda *a, **k: "NVIDIA A10G"), + patch("torch.cuda.get_device_capability", lambda *a, **k: (8, 6)), + patch( + "torch.cuda.get_device_properties", + lambda *a, **k: type( + "Props", + (), + { + "major": 8, + "minor": 6, + "name": "NVIDIA A10G", + "total_memory": 24 * 1024**3, + "multi_processor_count": 80, + }, + )(), + ), +] +for p in _PATCHES: + p.start() + +import torch # noqa: E402 +from torch.distributed.device_mesh import DeviceMesh # noqa: E402 +from torch.testing._internal.distributed.fake_pg import FakeStore # noqa: E402 + +# --- Instrumentation --- + +_log_lines: list[str] = [] + + +def _log(msg: str) -> None: + _log_lines.append(msg) + print(msg, flush=True) + + +def _mesh_id(m) -> str: + return f"id={id(m):#x}" + + +def _summarize(m) -> str: + if m is None: + return "None" + dim_names = getattr(m, "_mesh_dim_names", None) + ndim = getattr(m, "ndim", "?") + root = m._get_root_mesh() if hasattr(m, "_get_root_mesh") else None + is_root = root is m + cache_keys = ( + list(root._flatten_mapping.keys()) + if root is not None and hasattr(root, "_flatten_mapping") + else "" + ) + return ( + f"{_mesh_id(m)} ndim={ndim} dim_names={dim_names} " + f"is_root={is_root} root_id={_mesh_id(root)} " + f"root_cache_keys={cache_keys}" + ) + + +def _short_traceback() -> str: + # Get the call site that invoked _flatten — skip the wrapper frames. + stack = traceback.extract_stack() + interesting = [ + f" {f.filename}:{f.lineno} in {f.name}" + for f in stack + if "device_mesh" not in f.filename and "diagnose_flatten_ci" not in f.filename + ] + return "\n".join(interesting[-6:]) + + +_orig_flatten = DeviceMesh._flatten +_orig_create = DeviceMesh._create_flatten_mesh + +_call_count = {"flatten": 0, "create": 0, "create_miss": 0} + + +def _wrapped_flatten(self, mesh_dim_name=None, backend_override=None): + _call_count["flatten"] += 1 + n = _call_count["flatten"] + requested_name = mesh_dim_name or ( + "_".join(self._mesh_dim_names) if self._mesh_dim_names else "" + ) + _log( + f"\n[_flatten #{n}] CALL on mesh {_summarize(self)} " + f"requested_name={requested_name!r}" + ) + _log(f"[_flatten #{n}] call site:\n{_short_traceback()}") + try: + result = _orig_flatten(self, mesh_dim_name, backend_override) + _log(f"[_flatten #{n}] OK → result {_summarize(result)}") + return result + except Exception as e: + _log(f"[_flatten #{n}] RAISED: {type(e).__name__}: {e}") + raise + + +def _wrapped_create(self, mesh_dim_name, backend_override=(None, None)): + _call_count["create"] += 1 + n = _call_count["create"] + root = self._get_root_mesh() + cache_hit = mesh_dim_name in root._flatten_mapping + if not cache_hit: + _call_count["create_miss"] += 1 + _log( + f" [_create_flatten_mesh #{n}] *** CACHE MISS *** " + f"requested {mesh_dim_name!r} on root {_summarize(root)}" + ) + else: + _log( + f" [_create_flatten_mesh #{n}] cache hit " + f"{mesh_dim_name!r} on root id={id(root):#x}" + ) + return _orig_create(self, mesh_dim_name, backend_override) + + +DeviceMesh._flatten = _wrapped_flatten # type: ignore[method-assign] +DeviceMesh._create_flatten_mesh = _wrapped_create # type: ignore[method-assign] + + +# --- Reproduction of the failing test path --- + + +def main() -> None: + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + "fake", store=FakeStore(), rank=0, world_size=256 + ) + + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", (32, 8), mesh_dim_names=("dp", "tp") + ) + _log(f"\n=== USER MESH: {_summarize(mesh)} ===\n") + + from torch.distributed.fsdp import MixedPrecisionPolicy + from torch.distributed.tensor.placement_types import Replicate, Shard + + from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs + from autoparallel.api import AutoParallel + from autoparallel.compile import _make_ac_joint_pass + + with torch.device("meta"): + model = Transformer( + TransformerModelArgs( + dim=256, + n_layers=2, + n_heads=8, + n_kv_heads=2, + ffn_dim_multiplier=1.3, + multiple_of=64, + rope_theta=500000, + vocab_size=1024, + max_seq_len=512, + ) + ) + + vocab_size = 1024 + seqlen = 128 + batch_size = 2 * mesh.shape[0] + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + mesh, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=False) + + # Log identity of every mesh referenced by the sharding solution + seen: dict[int, Any] = {} + _log("\n=== MESHES REFERENCED BY SHARDING SOLUTION ===") + for node, strategy in sharding_placement.items(): + specs: list[Any] = [] + if hasattr(strategy, "output_specs"): + output_specs = strategy.output_specs + if isinstance(output_specs, (list, tuple)): + specs.extend(output_specs) + else: + specs.append(output_specs) + if hasattr(strategy, "input_specs"): + input_specs = strategy.input_specs or [] + specs.extend(input_specs) + for s in specs: + if s is None: + continue + m = getattr(s, "mesh", None) + if m is None: + continue + key = id(m) + if key not in seen: + seen[key] = m + _log(f" spec mesh: {_summarize(m)}") + _log(f"=== TOTAL UNIQUE SPEC MESHES: {len(seen)} ===\n") + + ac_pass = _make_ac_joint_pass() + with torch._functorch.config.patch({"joint_custom_pass": ac_pass}): + _log("\n=== ENTERING apply_placement ===\n") + try: + autop.apply_placement(sharding_placement) + _log("\n=== apply_placement SUCCEEDED ===\n") + except Exception as e: + _log(f"\n=== apply_placement FAILED: {type(e).__name__}: {e} ===\n") + raise + + +if __name__ == "__main__": + try: + main() + finally: + _log(f"\n=== TOTAL _flatten calls: {_call_count['flatten']} ===") + _log(f"=== TOTAL _create_flatten_mesh calls: {_call_count['create']} ===") + _log( + f"=== TOTAL _create_flatten_mesh CACHE MISSES: " + f"{_call_count['create_miss']} ===" + ) + if _call_count["create_miss"] > 1: + _log( + "\n*** SMOKING GUN: cache missed more than once. The first miss " + "is the pre-warming. Any subsequent miss is a real mesh that " + "wasn't covered by the pre-warming and will trigger as_strided " + "dispatch inside make_fx. Search the log above for " + "'*** CACHE MISS ***' to see which mesh." + ) + # Also write to a file for CI artifact upload + log_path = os.environ.get("FLATTEN_DIAG_LOG", "flatten_diagnosis.log") + with open(log_path, "w") as f: + f.write("\n".join(_log_lines)) + print(f"\nFull diagnostic log written to {log_path}") From d33130ad6f5053e4a714d19a79711d14aae849cb Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 14:22:35 +0000 Subject: [PATCH 08/10] Preserve DeviceMesh identity in expand_rule's op_schema deepcopy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit expand_rule used copy.deepcopy(op_schema_) to snapshot the schema before mutating it. DeviceMesh has no __deepcopy__, so deepcopy went through __getstate__/__setstate__ and produced a fresh DeviceMesh with the same logical content but an empty _flatten_mapping cache. The DTensorSpecs returned from expand_rule carried these duplicates, which propagated into the sharding solution. apply_placement's pre-warming in _apply_placement_common only populates the user mesh's cache. When _optimize_same_nd_sharding_as_1d inside make_fx hit a duplicate mesh, _flatten() cache-missed and dispatched as_strided on the real rank_map — failing FakeTensorMode's non-fake-input check. Which solution the solver picked depended on the cost model, so the failure surfaced on g5/A10G CI but not on H100. Fix: _deepcopy_preserving_mesh pre-seeds copy.deepcopy's memo with DeviceMesh identity mappings so duplicates aren't produced. Adds a regression test that asserts every spec mesh's root has a warm _flatten cache after apply_placement. Authored with Claude. --- autoparallel/shardings/propagation_rules.py | 45 +++++++- tests/test_mesh_identity.py | 113 ++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 tests/test_mesh_identity.py diff --git a/autoparallel/shardings/propagation_rules.py b/autoparallel/shardings/propagation_rules.py index 27a20049..d054f345 100644 --- a/autoparallel/shardings/propagation_rules.py +++ b/autoparallel/shardings/propagation_rules.py @@ -51,6 +51,49 @@ logger = logging.getLogger(__name__) + +def _deepcopy_preserving_mesh(obj): + """Like copy.deepcopy, but reuses DeviceMesh instances instead of + duplicating them. + + DeviceMesh carries process-group state (rank maps, _flatten_mapping + cache, backend overrides) that is logically shared across all callers. + Deep-copying it produces a fresh object with an empty _flatten_mapping + that misses the cache populated on the original, which later forces + DeviceMesh._flatten to dispatch as_strided on the rank_map inside + make_fx — failing FakeTensorMode's non-fake-input assertion. + + We pre-populate copy.deepcopy's memo dict with identity mappings for + every DeviceMesh reachable from obj. deepcopy returns existing entries + from memo without recursing, so the same DeviceMesh instances appear + in the copy. + """ + from torch.distributed.device_mesh import DeviceMesh + + memo: dict = {} + stack = [obj] + seen: set[int] = set() + while stack: + x = stack.pop() + if id(x) in seen: + continue + seen.add(id(x)) + if isinstance(x, DeviceMesh): + memo[id(x)] = x + continue + if isinstance(x, (list, tuple, set, frozenset)): + stack.extend(x) + elif isinstance(x, dict): + stack.extend(x.values()) + elif hasattr(x, "__dict__"): + stack.extend(x.__dict__.values()) + elif hasattr(x, "__slots__"): + for slot in x.__slots__: + if hasattr(x, slot): + stack.append(getattr(x, slot)) + return copy.deepcopy(obj, memo) + + # TODO: move this to PyTorch dim_maps[torch.t] = lambda input: dim_transpose(input.ndim, -2, -1) @@ -829,7 +872,7 @@ def expand_rule(mesh, op_schema_): from torch._subclasses.fake_tensor import unset_fake_temporarily with unset_fake_temporarily(): - op_schema = copy.deepcopy(op_schema_) + op_schema = _deepcopy_preserving_mesh(op_schema_) input_strat = op_schema.args_schema[0] orig_shape = input_strat.strategies[0].output_specs.tensor_meta.shape dest_shape = op_schema.args_schema[1] diff --git a/tests/test_mesh_identity.py b/tests/test_mesh_identity.py new file mode 100644 index 00000000..cbd7b705 --- /dev/null +++ b/tests/test_mesh_identity.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Regression test: DeviceMesh duplicates introduced by +``copy.deepcopy(op_schema_)`` in propagation rules used to trigger an +``as_strided``-inside-FakeTensorMode failure during ``apply_placement``. + +Background: ``copy.deepcopy(op_schema)`` inside ``expand_rule`` produces a +fresh DeviceMesh object with an empty ``_flatten_mapping``. When the solver +picks a redistribution that calls ``mesh._flatten()`` on the duplicate, +``_create_flatten_mesh`` runs uncached, dispatching ``as_strided`` on the +rank_map — and FakeTensorMode rejects the non-fake tensor input. + +Fix lives in ``autoparallel/shardings/propagation_rules.py`` as +``_deepcopy_preserving_mesh``: pre-seeds copy.deepcopy's memo with +DeviceMesh identity mappings so the deepcopy reuses the original meshes. + +This test asserts the property we actually care about: every DeviceMesh +referenced by the sharding solution has a populated ``_flatten_mapping`` +on its root, so a subsequent ``_flatten()`` call inside ``make_fx`` hits +the cache instead of dispatching. + +We use the Transformer model because it triggers ``expand_rule`` (a +simpler model wouldn't exercise that propagation rule). +""" + +import torch +from conftest import apply_cuda_patches +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel + + +@apply_cuda_patches +def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): + """After ``apply_placement``'s pre-warming, every spec mesh's root + must have the default flattened mesh cached. Otherwise a subsequent + ``_flatten()`` call inside ``make_fx`` triggers ``as_strided`` on the + rank_map and FakeTensorMode rejects it (the original CI failure). + """ + vocab_size = 1024 + seqlen = 128 + batch_size = 2 * device_mesh_2d.shape[0] + + with torch.device("meta"): + model = Transformer( + TransformerModelArgs( + dim=256, + n_layers=2, + n_heads=8, + n_kv_heads=2, + ffn_dim_multiplier=1.3, + multiple_of=64, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + ) + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + device_mesh_2d, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=False) + # apply_placement pre-warms the user mesh's _flatten cache so + # subsequent _flatten() calls inside make_fx hit the cache. + autop.apply_placement(sharding_placement) + + # Collect every distinct spec mesh from the solution + spec_meshes: dict[int, object] = {} + for strategy in sharding_placement.values(): + specs = [] + if hasattr(strategy, "output_specs"): + o = strategy.output_specs + specs.extend(o if isinstance(o, (list, tuple)) else [o]) + if hasattr(strategy, "input_specs"): + specs.extend(strategy.input_specs or []) + for s in specs: + if s is None: + continue + m = getattr(s, "mesh", None) + if m is None: + continue + spec_meshes[id(m)] = m + + cold = [] + for mid, m in spec_meshes.items(): + if m.ndim == 1: + # 1D meshes: _flatten() short-circuits to self without dispatch + continue + root = m._get_root_mesh() + default_name = "_".join(m._mesh_dim_names) + if default_name not in root._flatten_mapping: + cold.append((mid, m._mesh_dim_names, list(root._flatten_mapping))) + + assert not cold, ( + f"After apply_placement, {len(cold)} spec mesh(es) still have a " + f"cold _flatten_mapping for their default name. A subsequent " + f"_flatten() call inside make_fx will dispatch as_strided and " + f"fail FakeTensorMode's non-fake-input check. Details " + f"(id, dim_names, root cache keys): {cold}. See " + f"_deepcopy_preserving_mesh in propagation_rules.py." + ) From 872f3fb39f7df2321fecbb1338c61f68cebdf374 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 14:27:37 +0000 Subject: [PATCH 09/10] Revert "Diagnose failures" This reverts commit ca5de35eaa6163d9736e2c2e4b20ed5bba8e550d. --- .github/workflows/test_cuda.yml | 11 -- tests/diagnose_flatten_ci.py | 258 -------------------------------- 2 files changed, 269 deletions(-) delete mode 100644 tests/diagnose_flatten_ci.py diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index e34ed9c4..9ee4d5c9 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -31,18 +31,7 @@ jobs: pip uninstall -y torch pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 pip install --quiet . - # Run tests, but always emit the _flatten diagnostic afterward so - # CI job logs capture which mesh triggers the as_strided dispatch - # failure on A10G (see tests/diagnose_flatten_ci.py). - set +e pytest tests - pytest_status=$? - set -e - echo "============================================================" - echo " _flatten diagnostic (always runs)" - echo "============================================================" - python tests/diagnose_flatten_ci.py || true - exit $pytest_status examples-cuda-single-gpu: name: Examples CUDA Single GPU (cuda12.6-py3.12) diff --git a/tests/diagnose_flatten_ci.py b/tests/diagnose_flatten_ci.py deleted file mode 100644 index 63c8a045..00000000 --- a/tests/diagnose_flatten_ci.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -Diagnose the CI _flatten cache miss. - -Wraps DeviceMesh._flatten and DeviceMesh._create_flatten_mesh to log: - - At every _flatten call: mesh identity, dim_names, root_mesh id, - cache state, call site - - At every _create_flatten_mesh entry: same info, plus whether the - cache check at line 904 hit or missed - -The goal is to figure out, on CI, which mesh is triggering the -`as_strided` dispatch failure: is it the user's top-level mesh (cache -should hit), or a sub-mesh / independently-constructed mesh? - -Run with: - PYTHONPATH=. python tests/diagnose_flatten_ci.py - -This script reproduces the same path as the failing CI test -(test_save_all_partitioner_compile_with_ac_enabled) but with logging. -""" - -import os -import traceback -from typing import Any - -# Force device emulation matching CI before importing torch -from unittest.mock import patch - -_PATCHES: list[Any] = [ - patch("torch.cuda.device_count", lambda: 8), - patch("torch.cuda.get_device_name", lambda *a, **k: "NVIDIA A10G"), - patch("torch.cuda.get_device_capability", lambda *a, **k: (8, 6)), - patch( - "torch.cuda.get_device_properties", - lambda *a, **k: type( - "Props", - (), - { - "major": 8, - "minor": 6, - "name": "NVIDIA A10G", - "total_memory": 24 * 1024**3, - "multi_processor_count": 80, - }, - )(), - ), -] -for p in _PATCHES: - p.start() - -import torch # noqa: E402 -from torch.distributed.device_mesh import DeviceMesh # noqa: E402 -from torch.testing._internal.distributed.fake_pg import FakeStore # noqa: E402 - -# --- Instrumentation --- - -_log_lines: list[str] = [] - - -def _log(msg: str) -> None: - _log_lines.append(msg) - print(msg, flush=True) - - -def _mesh_id(m) -> str: - return f"id={id(m):#x}" - - -def _summarize(m) -> str: - if m is None: - return "None" - dim_names = getattr(m, "_mesh_dim_names", None) - ndim = getattr(m, "ndim", "?") - root = m._get_root_mesh() if hasattr(m, "_get_root_mesh") else None - is_root = root is m - cache_keys = ( - list(root._flatten_mapping.keys()) - if root is not None and hasattr(root, "_flatten_mapping") - else "" - ) - return ( - f"{_mesh_id(m)} ndim={ndim} dim_names={dim_names} " - f"is_root={is_root} root_id={_mesh_id(root)} " - f"root_cache_keys={cache_keys}" - ) - - -def _short_traceback() -> str: - # Get the call site that invoked _flatten — skip the wrapper frames. - stack = traceback.extract_stack() - interesting = [ - f" {f.filename}:{f.lineno} in {f.name}" - for f in stack - if "device_mesh" not in f.filename and "diagnose_flatten_ci" not in f.filename - ] - return "\n".join(interesting[-6:]) - - -_orig_flatten = DeviceMesh._flatten -_orig_create = DeviceMesh._create_flatten_mesh - -_call_count = {"flatten": 0, "create": 0, "create_miss": 0} - - -def _wrapped_flatten(self, mesh_dim_name=None, backend_override=None): - _call_count["flatten"] += 1 - n = _call_count["flatten"] - requested_name = mesh_dim_name or ( - "_".join(self._mesh_dim_names) if self._mesh_dim_names else "" - ) - _log( - f"\n[_flatten #{n}] CALL on mesh {_summarize(self)} " - f"requested_name={requested_name!r}" - ) - _log(f"[_flatten #{n}] call site:\n{_short_traceback()}") - try: - result = _orig_flatten(self, mesh_dim_name, backend_override) - _log(f"[_flatten #{n}] OK → result {_summarize(result)}") - return result - except Exception as e: - _log(f"[_flatten #{n}] RAISED: {type(e).__name__}: {e}") - raise - - -def _wrapped_create(self, mesh_dim_name, backend_override=(None, None)): - _call_count["create"] += 1 - n = _call_count["create"] - root = self._get_root_mesh() - cache_hit = mesh_dim_name in root._flatten_mapping - if not cache_hit: - _call_count["create_miss"] += 1 - _log( - f" [_create_flatten_mesh #{n}] *** CACHE MISS *** " - f"requested {mesh_dim_name!r} on root {_summarize(root)}" - ) - else: - _log( - f" [_create_flatten_mesh #{n}] cache hit " - f"{mesh_dim_name!r} on root id={id(root):#x}" - ) - return _orig_create(self, mesh_dim_name, backend_override) - - -DeviceMesh._flatten = _wrapped_flatten # type: ignore[method-assign] -DeviceMesh._create_flatten_mesh = _wrapped_create # type: ignore[method-assign] - - -# --- Reproduction of the failing test path --- - - -def main() -> None: - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group( - "fake", store=FakeStore(), rank=0, world_size=256 - ) - - mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", (32, 8), mesh_dim_names=("dp", "tp") - ) - _log(f"\n=== USER MESH: {_summarize(mesh)} ===\n") - - from torch.distributed.fsdp import MixedPrecisionPolicy - from torch.distributed.tensor.placement_types import Replicate, Shard - - from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs - from autoparallel.api import AutoParallel - from autoparallel.compile import _make_ac_joint_pass - - with torch.device("meta"): - model = Transformer( - TransformerModelArgs( - dim=256, - n_layers=2, - n_heads=8, - n_kv_heads=2, - ffn_dim_multiplier=1.3, - multiple_of=64, - rope_theta=500000, - vocab_size=1024, - max_seq_len=512, - ) - ) - - vocab_size = 1024 - seqlen = 128 - batch_size = 2 * mesh.shape[0] - - with AutoParallel( - model, - lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), - mesh, - MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), - repeated_subgraphs=True, - ) as autop: - autop.add_parameter_memory_constraint(low=None, high=None) - autop.add_input_constraints([(Shard(0), Replicate())]) - autop.add_output_constraints([(Shard(0), Shard(2))]) - sharding_placement = autop.optimize_placement(verbose=False) - - # Log identity of every mesh referenced by the sharding solution - seen: dict[int, Any] = {} - _log("\n=== MESHES REFERENCED BY SHARDING SOLUTION ===") - for node, strategy in sharding_placement.items(): - specs: list[Any] = [] - if hasattr(strategy, "output_specs"): - output_specs = strategy.output_specs - if isinstance(output_specs, (list, tuple)): - specs.extend(output_specs) - else: - specs.append(output_specs) - if hasattr(strategy, "input_specs"): - input_specs = strategy.input_specs or [] - specs.extend(input_specs) - for s in specs: - if s is None: - continue - m = getattr(s, "mesh", None) - if m is None: - continue - key = id(m) - if key not in seen: - seen[key] = m - _log(f" spec mesh: {_summarize(m)}") - _log(f"=== TOTAL UNIQUE SPEC MESHES: {len(seen)} ===\n") - - ac_pass = _make_ac_joint_pass() - with torch._functorch.config.patch({"joint_custom_pass": ac_pass}): - _log("\n=== ENTERING apply_placement ===\n") - try: - autop.apply_placement(sharding_placement) - _log("\n=== apply_placement SUCCEEDED ===\n") - except Exception as e: - _log(f"\n=== apply_placement FAILED: {type(e).__name__}: {e} ===\n") - raise - - -if __name__ == "__main__": - try: - main() - finally: - _log(f"\n=== TOTAL _flatten calls: {_call_count['flatten']} ===") - _log(f"=== TOTAL _create_flatten_mesh calls: {_call_count['create']} ===") - _log( - f"=== TOTAL _create_flatten_mesh CACHE MISSES: " - f"{_call_count['create_miss']} ===" - ) - if _call_count["create_miss"] > 1: - _log( - "\n*** SMOKING GUN: cache missed more than once. The first miss " - "is the pre-warming. Any subsequent miss is a real mesh that " - "wasn't covered by the pre-warming and will trigger as_strided " - "dispatch inside make_fx. Search the log above for " - "'*** CACHE MISS ***' to see which mesh." - ) - # Also write to a file for CI artifact upload - log_path = os.environ.get("FLATTEN_DIAG_LOG", "flatten_diagnosis.log") - with open(log_path, "w") as f: - f.write("\n".join(_log_lines)) - print(f"\nFull diagnostic log written to {log_path}") From 7a5830068d9f4ab6c224898c33a0373311e477a3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 15:10:23 +0000 Subject: [PATCH 10/10] Update to use CUDA 13.0 --- .github/workflows/test_cuda.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 9ee4d5c9..d42d01f1 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -13,13 +13,13 @@ concurrency: jobs: test-cuda-single-gpu: - name: Test CUDA Single GPU (cuda12.6-py3.12) + name: Test CUDA Single GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -29,18 +29,18 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . pytest tests examples-cuda-single-gpu: - name: Examples CUDA Single GPU (cuda12.6-py3.12) + name: Examples CUDA Single GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -50,7 +50,7 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . run_timed() { local start=$SECONDS; "$@"; echo "$* : $((SECONDS - start))s" >> /tmp/timings.txt; } run_timed python examples/example_autoparallel.py @@ -60,13 +60,13 @@ jobs: cat /tmp/timings.txt test-cuda-multi-gpu: - name: Test CUDA Multi GPU (cuda12.6-py3.12) + name: Test CUDA Multi GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.12xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -76,7 +76,7 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . python examples/example_dcp.py torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py