diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 9ee4d5c9..d42d01f1 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -13,13 +13,13 @@ concurrency: jobs: test-cuda-single-gpu: - name: Test CUDA Single GPU (cuda12.6-py3.12) + name: Test CUDA Single GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -29,18 +29,18 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . pytest tests examples-cuda-single-gpu: - name: Examples CUDA Single GPU (cuda12.6-py3.12) + name: Examples CUDA Single GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -50,7 +50,7 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . run_timed() { local start=$SECONDS; "$@"; echo "$* : $((SECONDS - start))s" >> /tmp/timings.txt; } run_timed python examples/example_autoparallel.py @@ -60,13 +60,13 @@ jobs: cat /tmp/timings.txt test-cuda-multi-gpu: - name: Test CUDA Multi GPU (cuda12.6-py3.12) + name: Test CUDA Multi GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.12xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -76,7 +76,7 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . python examples/example_dcp.py torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py diff --git a/autoparallel/api.py b/autoparallel/api.py index 1670d509..54f0ebe3 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -5,9 +5,11 @@ import copy import logging +import operator import time from contextlib import ExitStack, contextmanager from dataclasses import dataclass +from functools import partial from typing import Any, Callable, Optional, Union import torch @@ -57,7 +59,35 @@ logger = logging.getLogger(__name__) -def _boxed_nop_preserve_node_meta(fx_g, example_inputs): +def _boxed_nop_preserve_node_meta(fx_g, example_inputs, tag_forward=False): + if tag_forward: + # Tag the forward graph's OUTPUT values as "must save". These are + # the tensors the first min_cut decided to save for backward — + # only these should be saved in the second compilation. + # Uses the "custom" meta field (not "recompute") to avoid + # interfering with ac_joint_pass which uses "recompute" for + # activation checkpointing decisions. + output_node = next(n for n in fx_g.graph.nodes if n.op == "output") + for out in output_node.args[0]: + if not isinstance(out, torch.fx.Node) or out.op != "call_function": + continue + if out.target == operator.getitem: + # getitem metadata doesn't survive preserve_node_meta + # (Python builtin, not dispatched). Tag the parent + # multi-output op instead, keeping the getitem index so the + # second partitioner can replay the exact saved output. + parent = out.args[0] + if isinstance(parent, torch.fx.Node): + custom = parent.meta.setdefault("custom", {}) + custom["ap_must_save"] = True + indices = custom.setdefault("ap_must_save_getitem_indices", []) + idx = out.args[1] + if idx not in indices: + indices.append(idx) + else: + out.meta.setdefault("custom", {}) + out.meta["custom"]["ap_must_save"] = True + def run(args): with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(fx_g).boxed_run(args) @@ -482,9 +512,11 @@ def apply_placement(self, sharding_placement): self.parallel_gm.graph, self.reshard_after_forward ) + fw_compiler_fn = partial(self.compiler_fn, tag_forward=True) + self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( self.joint_with_descriptors, - fw_compiler=self.compiler_fn, + fw_compiler=fw_compiler_fn, bw_compiler=self.compiler_fn, ) diff --git a/autoparallel/compile.py b/autoparallel/compile.py index e9d8b5c1..7c882a30 100644 --- a/autoparallel/compile.py +++ b/autoparallel/compile.py @@ -3,13 +3,18 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Union +import operator +from contextlib import contextmanager +from typing import Any, Optional, Sequence, Union import torch import torch._functorch.config import torch._inductor.config from torch._inductor.compile_fx import compile_fx +from torch._inductor.custom_graph_pass import CustomPartitionerFn +from torch.utils.checkpoint import CheckpointPolicy +from .api import _suppress_wait_tensor_side_effect from .graph_passes.activation_checkpointing import ac_joint_pass _INDUCTOR_OVERLAP_PATCHES = { @@ -20,6 +25,255 @@ } +class _SaveAllPartitioner(CustomPartitionerFn): + """Reproduce the first partitioner's save/recompute decisions via tags. + + AutoParallel partitions the joint graph twice: once inside apply_placement + via aot_compile_joint_with_descriptors (the "first" partitioner), and + again when the user calls torch.compile(parallel_mod, backend=...). + + The first compilation tags each forward output that it decided to save + with custom.ap_must_save in _boxed_nop_preserve_node_meta(tag_forward=True). + Those tags propagate to the second compilation's joint graph through + preserve_node_meta. This partitioner reads the tags and saves exactly + those nodes — sidestepping min-cut, which would make independent + cost-based decisions on the second joint graph and may save FSDP + allgather outputs that should be recomputed via FSDP prefetch. + + Specifically, when ac_joint_pass runs in the second compilation it + adds PREFER_RECOMPUTE tags to compute ops, which causes min-cut to + recompute matmuls in backward. The backward then needs unsharded + weights to redo those computations, force_save_collectives marks the + allgather outputs as MUST_SAVE, and min-cut saves them. This + partitioner avoids the chain by ignoring those tags entirely and + saving only what the first partitioner already chose. + """ + + def __call__( + self, + gm: torch.fx.GraphModule, + joint_inputs: Sequence[object], + **kwargs: Any, + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + num_fwd_outputs: int = kwargs.pop("num_fwd_outputs") # type: ignore[assignment] + static_lifetime_input_indices: list[int] | None = kwargs.pop( # type: ignore[assignment] + "static_lifetime_input_indices", None + ) + from torch._functorch.partitioners import ( + _extract_fwd_bwd_modules, + _is_assert_only_symbool, + classify_nodes, + cleanup_recompute_tags, + default_partition, + force_save_bw_mutation_src, + force_save_collectives, + force_save_effectful_ops, + functionalize_rng_ops, + has_recomputable_ops, + has_recomputable_rng_ops, + is_opaque_node, + is_sym_node, + raise_getitems, + reordering_to_mimic_autograd_engine, + thread_graphsafe_rng_from_hops, + ) + + gm.graph.eliminate_dead_code() + gm.recompile() + + # CSE merges duplicate allgather chains: the forward's + # MUST_RECOMPUTE allgathers and the baked-in backward copies + # compute the same values from the same primals. Without CSE, + # both appear in the backward (duplication). + # + # Caveat: fx_graph_cse keeps the first occurrence and drops later + # duplicates without merging metadata. If a duplicate has + # ap_must_save and the kept node doesn't, the tag is lost. In + # practice the duplicates we care about are FSDP allgather chains + # whose first occurrence (forward) carries MUST_RECOMPUTE — kept + # correctly — so the safety contract holds. General replay + # semantics across CSE'd duplicates is a known limitation. + if torch._functorch.config.cse: + from torch._functorch.partitioners import fx_graph_cse + + cse_graph = fx_graph_cse(gm.graph) + gm.graph = cse_graph + + graph_has_recomputable_ops = has_recomputable_ops(gm) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(gm) + if graph_has_recomputable_ops: + gm = cleanup_recompute_tags(gm, is_default_partition=False) + + # Apply PyTorch's standard save-forcing passes, then honor their + # MUST_SAVE tags below. force_save_collectives skips nodes already + # tagged MUST_RECOMPUTE (e.g. FSDP allgathers), so the FSDP + # recomputation contract is preserved. + if not torch._functorch.config.unsafe_allow_optimization_of_collectives: + force_save_collectives(gm) + force_save_effectful_ops(gm) + force_save_bw_mutation_src(gm) + + if static_lifetime_input_indices is None: + static_lifetime_input_indices = [] + node_info = classify_nodes(gm, static_lifetime_input_indices, num_fwd_outputs) + + if len(node_info.required_bw_nodes) == 0: + # Inference path (no backward nodes from autograd): fall back to + # the standard partitioner. Our ap_must_save tags are not used + # here, but inference doesn't have the fwd/bwd-divergence problem + # _SaveAllPartitioner exists to solve. + return default_partition( + gm, + joint_inputs, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + static_lifetime_input_nodes=node_info.static_lifetime_input_nodes, + ) + + saved_values = [] + saved_sym_nodes = [] + saved_opaque_nodes = [] + + def _has_tuple_val(node: torch.fx.Node) -> bool: + return isinstance(node.meta.get("val"), (list, tuple)) + + def _is_multi_output(node: torch.fx.Node) -> bool: + # Definitive test: the node returns a tuple/list. Checking users + # alone (only getitem) is fragile because a single non-getitem + # user (e.g. a debug op or a sym-shape extractor) flips the + # answer without changing the node's actual return type. + if _has_tuple_val(node): + return True + return len(node.users) > 0 and all( + user.target == operator.getitem for user in node.users + ) + + def _must_recompute(node: torch.fx.Node) -> bool: + return node.meta.get("recompute") is CheckpointPolicy.MUST_RECOMPUTE + + def _must_save(node: torch.fx.Node) -> bool: + return node.meta.get("recompute") is CheckpointPolicy.MUST_SAVE + + def _maybe_save(node: torch.fx.Node) -> None: + # MUST_RECOMPUTE wins over everything else. Check before any + # save branch (including opaque) so the first compilation's + # explicit recompute intent is always honored. + if _must_recompute(node): + return + # Sym nodes are saved unconditionally per the standard + # partitioner convention; ap_must_save/MUST_SAVE tags on them + # are ignored. Assert-only symbools are pure runtime checks + # and don't need to cross the fwd/bwd boundary. + if is_sym_node(node): + if not _is_assert_only_symbool(node): + saved_sym_nodes.append(node) + return + if _is_multi_output(node): + custom = node.meta.get("custom", {}) + if _must_save(node) or custom.get("ap_must_save"): + # getitem metadata does not survive preserve_node_meta, so + # the first partitioner tags the parent multi-output op. If + # only ap_must_save is set and it recorded specific getitem + # indices, replay only those. MUST_SAVE is a stronger + # directive ("save all tensor outputs needed") and overrides + # the index restriction — save every getitem child. + if _must_save(node): + indices = None + else: + indices = custom.get("ap_must_save_getitem_indices") + for user in node.users: + if user.target != operator.getitem: + continue + if indices is None or user.args[1] in indices: + saved_values.append(user) + return + if is_opaque_node(node): + # Opaque nodes (ProcessGroup, ScriptObject) can't be + # recomputed — they have no functional meaning. Saved + # unconditionally regardless of ap_must_save / MUST_SAVE + # tags; this matches the standard partitioner's behavior. + saved_opaque_nodes.append(node) + return + # Save nodes tagged ap_must_save by the first compilation. + # These are the forward graph's output tensors from the first + # partitioner — reproducing its save/recompute decisions. + if node.op == "placeholder": + saved_values.append(node) + elif node.meta.get("custom", {}).get("ap_must_save") or _must_save(node): + saved_values.append(node) + + for node in node_info.required_fw_nodes: + _maybe_save(node) + + # Unclaimed nodes (neither strictly forward nor backward) may be + # needed by backward outputs — e.g. mutable ops like index_put that + # are _must_be_in_forward. Save them so they're available as backward + # inputs in _extract_fwd_bwd_modules. + for node in node_info.unclaimed_nodes: + _maybe_save(node) + + # Deduplicate. Multi-output getitem handling, overlapping iteration + # over required_fw_nodes + unclaimed_nodes, or overlapping save tags + # can land the same node in a list twice. Matches upstream + # default_partition's dict.fromkeys deduping. + saved_values = list(dict.fromkeys(saved_values)) + saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes)) + saved_opaque_nodes = list(dict.fromkeys(saved_opaque_nodes)) + + fw_module, bw_module = _extract_fwd_bwd_modules( + gm, + saved_values, + saved_sym_nodes=saved_sym_nodes, + saved_opaque_nodes=saved_opaque_nodes, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_nodes=node_info.static_lifetime_input_nodes, + ) + + if graph_has_recomputable_ops and graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + gm, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + fw_module = raise_getitems(fw_module) + bw_module = raise_getitems(bw_module) + + fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False) + bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True) + + return fw_module, bw_module + + def uuid(self) -> Any: + return None + + +@contextmanager +def _patch_partitioner_dce(): + """Patch the partitioner's DCE to allow wait_tensor to be eliminated. + + The partitioner uses its own is_not_collective callback that treats all + _c10d_functional ops as impure, overriding _suppress_wait_tensor_side_effect. + We patch it to let wait_tensor through so unused collectives get DCE'd. + """ + import torch._functorch.partitioners as partitioners + + original = partitioners.is_not_collective + + def patched_is_not_collective(node): + if node.target in ( + torch.ops._c10d_functional.wait_tensor, + torch.ops._c10d_functional.wait_tensor.default, + ): + return False + return original(node) + + partitioners.is_not_collective = patched_is_not_collective + try: + yield + finally: + partitioners.is_not_collective = original + + def _make_ac_joint_pass( ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto", ): @@ -50,16 +304,36 @@ def autoparallel_backend( sqrt(total_recomputable_memory). overlap_scheduling: Enable comm/compute overlap scheduling. """ - functorch_patches = {} - inductor_patches = _INDUCTOR_OVERLAP_PATCHES if overlap_scheduling else None + functorch_patches: dict[str, Any] = {} if enable_ac: functorch_patches["joint_custom_pass"] = _make_ac_joint_pass( ac_stage_size_in_GiB ) + # Inductor configs split by lifetime: + # - overlap scheduling configs must persist to lazy backward compilation + # (which runs on the first .backward() call, after compile_fx returns). + # compile_fx's config_patches argument re-enters the patch when backward + # is later compiled out of scope. + # - custom_partitioner_fn only runs during the synchronous joint→fwd/bwd + # partitioning inside compile_fx, so a context manager suffices. + inductor_persistent_patches = ( + _INDUCTOR_OVERLAP_PATCHES if overlap_scheduling else None + ) + inductor_fwd_patches: dict[str, Any] = { + "custom_partitioner_fn": _SaveAllPartitioner(), + } + def backend(gm, example_inputs): - with torch._functorch.config.patch(functorch_patches): - return compile_fx(gm, example_inputs, config_patches=inductor_patches) + with ( + _suppress_wait_tensor_side_effect(), + _patch_partitioner_dce(), + torch._functorch.config.patch(functorch_patches), + torch._inductor.config.patch(inductor_fwd_patches), + ): + return compile_fx( + gm, example_inputs, config_patches=inductor_persistent_patches + ) return backend diff --git a/autoparallel/shardings/propagation_rules.py b/autoparallel/shardings/propagation_rules.py index 27a20049..d054f345 100644 --- a/autoparallel/shardings/propagation_rules.py +++ b/autoparallel/shardings/propagation_rules.py @@ -51,6 +51,49 @@ logger = logging.getLogger(__name__) + +def _deepcopy_preserving_mesh(obj): + """Like copy.deepcopy, but reuses DeviceMesh instances instead of + duplicating them. + + DeviceMesh carries process-group state (rank maps, _flatten_mapping + cache, backend overrides) that is logically shared across all callers. + Deep-copying it produces a fresh object with an empty _flatten_mapping + that misses the cache populated on the original, which later forces + DeviceMesh._flatten to dispatch as_strided on the rank_map inside + make_fx — failing FakeTensorMode's non-fake-input assertion. + + We pre-populate copy.deepcopy's memo dict with identity mappings for + every DeviceMesh reachable from obj. deepcopy returns existing entries + from memo without recursing, so the same DeviceMesh instances appear + in the copy. + """ + from torch.distributed.device_mesh import DeviceMesh + + memo: dict = {} + stack = [obj] + seen: set[int] = set() + while stack: + x = stack.pop() + if id(x) in seen: + continue + seen.add(id(x)) + if isinstance(x, DeviceMesh): + memo[id(x)] = x + continue + if isinstance(x, (list, tuple, set, frozenset)): + stack.extend(x) + elif isinstance(x, dict): + stack.extend(x.values()) + elif hasattr(x, "__dict__"): + stack.extend(x.__dict__.values()) + elif hasattr(x, "__slots__"): + for slot in x.__slots__: + if hasattr(x, slot): + stack.append(getattr(x, slot)) + return copy.deepcopy(obj, memo) + + # TODO: move this to PyTorch dim_maps[torch.t] = lambda input: dim_transpose(input.ndim, -2, -1) @@ -829,7 +872,7 @@ def expand_rule(mesh, op_schema_): from torch._subclasses.fake_tensor import unset_fake_temporarily with unset_fake_temporarily(): - op_schema = copy.deepcopy(op_schema_) + op_schema = _deepcopy_preserving_mesh(op_schema_) input_strat = op_schema.args_schema[0] orig_shape = input_strat.strategies[0].output_specs.tensor_meta.shape dest_shape = op_schema.args_schema[1] diff --git a/tests/test_mesh_identity.py b/tests/test_mesh_identity.py new file mode 100644 index 00000000..cbd7b705 --- /dev/null +++ b/tests/test_mesh_identity.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Regression test: DeviceMesh duplicates introduced by +``copy.deepcopy(op_schema_)`` in propagation rules used to trigger an +``as_strided``-inside-FakeTensorMode failure during ``apply_placement``. + +Background: ``copy.deepcopy(op_schema)`` inside ``expand_rule`` produces a +fresh DeviceMesh object with an empty ``_flatten_mapping``. When the solver +picks a redistribution that calls ``mesh._flatten()`` on the duplicate, +``_create_flatten_mesh`` runs uncached, dispatching ``as_strided`` on the +rank_map — and FakeTensorMode rejects the non-fake tensor input. + +Fix lives in ``autoparallel/shardings/propagation_rules.py`` as +``_deepcopy_preserving_mesh``: pre-seeds copy.deepcopy's memo with +DeviceMesh identity mappings so the deepcopy reuses the original meshes. + +This test asserts the property we actually care about: every DeviceMesh +referenced by the sharding solution has a populated ``_flatten_mapping`` +on its root, so a subsequent ``_flatten()`` call inside ``make_fx`` hits +the cache instead of dispatching. + +We use the Transformer model because it triggers ``expand_rule`` (a +simpler model wouldn't exercise that propagation rule). +""" + +import torch +from conftest import apply_cuda_patches +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel + + +@apply_cuda_patches +def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): + """After ``apply_placement``'s pre-warming, every spec mesh's root + must have the default flattened mesh cached. Otherwise a subsequent + ``_flatten()`` call inside ``make_fx`` triggers ``as_strided`` on the + rank_map and FakeTensorMode rejects it (the original CI failure). + """ + vocab_size = 1024 + seqlen = 128 + batch_size = 2 * device_mesh_2d.shape[0] + + with torch.device("meta"): + model = Transformer( + TransformerModelArgs( + dim=256, + n_layers=2, + n_heads=8, + n_kv_heads=2, + ffn_dim_multiplier=1.3, + multiple_of=64, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + ) + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + device_mesh_2d, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=False) + # apply_placement pre-warms the user mesh's _flatten cache so + # subsequent _flatten() calls inside make_fx hit the cache. + autop.apply_placement(sharding_placement) + + # Collect every distinct spec mesh from the solution + spec_meshes: dict[int, object] = {} + for strategy in sharding_placement.values(): + specs = [] + if hasattr(strategy, "output_specs"): + o = strategy.output_specs + specs.extend(o if isinstance(o, (list, tuple)) else [o]) + if hasattr(strategy, "input_specs"): + specs.extend(strategy.input_specs or []) + for s in specs: + if s is None: + continue + m = getattr(s, "mesh", None) + if m is None: + continue + spec_meshes[id(m)] = m + + cold = [] + for mid, m in spec_meshes.items(): + if m.ndim == 1: + # 1D meshes: _flatten() short-circuits to self without dispatch + continue + root = m._get_root_mesh() + default_name = "_".join(m._mesh_dim_names) + if default_name not in root._flatten_mapping: + cold.append((mid, m._mesh_dim_names, list(root._flatten_mapping))) + + assert not cold, ( + f"After apply_placement, {len(cold)} spec mesh(es) still have a " + f"cold _flatten_mapping for their default name. A subsequent " + f"_flatten() call inside make_fx will dispatch as_strided and " + f"fail FakeTensorMode's non-fake-input check. Details " + f"(id, dim_names, root cache keys): {cold}. See " + f"_deepcopy_preserving_mesh in propagation_rules.py." + ) diff --git a/tests/test_save_all_partitioner.py b/tests/test_save_all_partitioner.py new file mode 100644 index 00000000..7fb6f228 --- /dev/null +++ b/tests/test_save_all_partitioner.py @@ -0,0 +1,1031 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the _SaveAllPartitioner mechanism in autoparallel/compile.py. + +The mechanism reproduces the first partitioner's save/recompute decisions in +the second compilation (torch.compile with autoparallel_backend) by: + +1. Tagging forward outputs with `ap_must_save` in `_boxed_nop_preserve_node_meta` +2. `preserve_node_meta` propagates the tags to the second compilation's joint graph +3. `_SaveAllPartitioner` reads the tags and saves only those nodes + +Without this machinery, the default min-cut partitioner makes independent +decisions that diverge from the first partitioner (most importantly, it +saves FSDP allgather outputs that should be recomputed via prefetch). +""" + +import operator + +import pytest +import torch +from conftest import apply_cuda_patches +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.utils.checkpoint import CheckpointPolicy + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import _boxed_nop_preserve_node_meta +from autoparallel.compile import ( + _patch_partitioner_dce, + _SaveAllPartitioner, + autoparallel_backend, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_small_llama(n_layers=2): + """Tiny LLaMA-3 sized for fast tests.""" + return Transformer( + TransformerModelArgs( + dim=256, + n_layers=n_layers, + n_heads=8, + n_kv_heads=2, + ffn_dim_multiplier=1.3, + multiple_of=64, + rope_theta=500000, + vocab_size=1024, + max_seq_len=512, + ) + ) + + +def _run_autoparallel(mesh, n_layers=2, batch_size=None, seqlen=128): + """Run AutoParallel up to apply_placement and return the parallel module + plus AC pass for the second compilation.""" + from autoparallel.api import AutoParallel + from autoparallel.compile import _make_ac_joint_pass + + vocab_size = 1024 + if batch_size is None: + batch_size = 2 * mesh.shape[0] + + with torch.device("meta"): + model = _make_small_llama(n_layers=n_layers) + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + mesh, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + if mesh.ndim == 2: + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + else: + autop.add_input_constraints([(Shard(0),)]) + autop.add_output_constraints([(Shard(0),)]) + sharding_placement = autop.optimize_placement(verbose=False) + + ac_pass = _make_ac_joint_pass() + with torch._functorch.config.patch({"joint_custom_pass": ac_pass}): + parallel_mod = autop.apply_placement(sharding_placement) + + parallel_mod.to_empty(device="cuda") + parallel_mod.init_weights() + return parallel_mod, batch_size, seqlen, vocab_size + + +# Module-scoped caches for parallel_mod. Each integration test below builds +# the same one (deterministic for a given mesh shape) — caching saves ~30s +# per test. The cache is keyed by mesh.ndim since the helper only varies +# along that dimension; both 1D and 2D meshes get their own entry. +_parallel_mod_cache: dict = {} + + +@pytest.fixture(scope="module") +@apply_cuda_patches +def parallel_mod_2d(device_mesh_2d): + if "2d" not in _parallel_mod_cache: + _parallel_mod_cache["2d"] = _run_autoparallel(device_mesh_2d, n_layers=2) + return _parallel_mod_cache["2d"] + + +@pytest.fixture(scope="module") +@apply_cuda_patches +def parallel_mod_1d(device_mesh_1d): + if "1d" not in _parallel_mod_cache: + _parallel_mod_cache["1d"] = _run_autoparallel(device_mesh_1d, n_layers=2) + return _parallel_mod_cache["1d"] + + +def _capture_partitioner_call( + parallel_mod, batch_size, seqlen, vocab_size, mesh, enable_ac=False +): + """Run the second compilation with _SaveAllPartitioner as the + partition_fn but use identity compilers (no Inductor codegen). This + keeps the partitioner under test in the loop while skipping the + expensive Triton kernel compilation. + + The dict contains: + - wait_tensor_recompute_tags: count of MUST_RECOMPUTE on wait_tensors + - allgather_recompute_tags: count of MUST_RECOMPUTE on all_gathers + - ap_must_save_count: count of nodes tagged ap_must_save + - fw_module, bw_module: the partitioned graph modules + - saved_activation_names: backward inputs that aren't primals or tangents + """ + from autoparallel.api import _suppress_wait_tensor_side_effect + from autoparallel.compile import _make_ac_joint_pass, _patch_partitioner_dce + + captured = {} + partitioner = _SaveAllPartitioner() + + def capturing_partition_fn( + joint_module, joint_inputs, *, num_fwd_outputs, **kwargs + ): + captured["wait_tensor_recompute_tags"] = sum( + 1 + for n in joint_module.graph.nodes + if "wait_tensor" in n.name + and n.meta.get("recompute") == CheckpointPolicy.MUST_RECOMPUTE + ) + captured["allgather_recompute_tags"] = sum( + 1 + for n in joint_module.graph.nodes + if "all_gather_into_tensor" in n.name + and n.meta.get("recompute") == CheckpointPolicy.MUST_RECOMPUTE + ) + captured["ap_must_save_count"] = sum( + 1 + for n in joint_module.graph.nodes + if n.meta.get("custom", {}).get("ap_must_save") + ) + fw, bw = partitioner( + joint_module, + joint_inputs, + num_fwd_outputs=num_fwd_outputs, + **kwargs, + ) + captured["fw_module"] = fw + captured["bw_module"] = bw + return fw, bw + + def capture_only_backend(gm, example_inputs): + from torch._functorch.aot_autograd import aot_module_simplified + + return aot_module_simplified( + gm, + example_inputs, + fw_compiler=lambda g, i: g, + bw_compiler=lambda g, i: g, + partition_fn=capturing_partition_fn, + ) + + functorch_patches = {} + if enable_ac: + functorch_patches["joint_custom_pass"] = _make_ac_joint_pass() + + with ( + _suppress_wait_tensor_side_effect(), + _patch_partitioner_dce(), + torch._functorch.config.patch(functorch_patches), + ): + compiled = torch.compile(parallel_mod, backend=capture_only_backend) + x = torch.randint( + 0, vocab_size, (batch_size // mesh.shape[0], seqlen), device="cuda" + ) + out = compiled(x) + out.backward(torch.randn_like(out)) + + # The backward graph's placeholders (minus tangents and primals) are the + # saved activations. + bw = captured["bw_module"] + saved_names = [] + for node in bw.graph.nodes: + if node.op != "placeholder": + continue + if isinstance(node.target, str) and ( + "tangent" in node.target or "primals" in node.target + ): + continue + saved_names.append(node.name) + captured["saved_activation_names"] = saved_names + return captured + + +def _capture_first_partitioner_saves(mesh, n_layers=2, seqlen=128): + """Run AutoParallel and capture the first partitioner's saved values + (the forward outputs beyond num_fwd_outputs).""" + from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors + + from autoparallel.api import AutoParallel + from autoparallel.compile import _make_ac_joint_pass + + vocab_size = 1024 + batch_size = 2 * mesh.shape[0] + + with torch.device("meta"): + model = _make_small_llama(n_layers=n_layers) + + captured = {} + + def capturing_fw_compiler(fx_g, example_inputs, **kwargs): + # The compiled forward returns [*model_outputs, *saved_values] + output_node = next(n for n in fx_g.graph.nodes if n.op == "output") + captured["fw_outputs"] = list(output_node.args[0]) + from autoparallel.api import _boxed_nop_preserve_node_meta + + return _boxed_nop_preserve_node_meta(fx_g, example_inputs, **kwargs) + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + mesh, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + if mesh.ndim == 2: + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + else: + autop.add_input_constraints([(Shard(0),)]) + autop.add_output_constraints([(Shard(0),)]) + sharding_placement = autop.optimize_placement(verbose=False) + + ac_pass = _make_ac_joint_pass() + from functools import partial + + with torch._functorch.config.patch({"joint_custom_pass": ac_pass}): + # Replicate apply_placement's compile call so we can intercept + # the fw_compiler. + autop._apply_placement_common(sharding_placement) + from autoparallel.graph_passes.activation_checkpointing import ( + mark_fsdp_all_gather_recomputation, + ) + + mark_fsdp_all_gather_recomputation( + autop.parallel_gm.graph, autop.reshard_after_forward + ) + aot_compile_joint_with_descriptors( + autop.joint_with_descriptors, + fw_compiler=partial(capturing_fw_compiler, tag_forward=True), + bw_compiler=autop.compiler_fn, + ) + + fw_metadata = autop.joint_with_descriptors._aot_state.fw_metadata + num_fwd_outputs = fw_metadata.num_forward_returns + # Saved values = forward outputs beyond num_fwd_outputs (the model outputs). + # Filter out primal/placeholder pass-throughs since the second + # partitioner's saved set doesn't include those either (they're inputs). + fw_outputs = captured["fw_outputs"] + saved = [ + n + for n in fw_outputs[num_fwd_outputs:] + if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + return saved + + +def _saved_names_from_default_compile( + parallel_mod, batch_size, seqlen, vocab_size, mesh, enable_ac=True +): + """Compile with the default min-cut partitioner (NOT _SaveAllPartitioner) + and capture which backward inputs are saved. This is the "without the fix" + baseline that motivates _SaveAllPartitioner. + + By default this enables AC (joint_custom_pass = ac_joint_pass) — the + motivating bad case is min-cut + AC tags driving force_save_collectives + to save FSDP allgather outputs. + """ + from autoparallel.compile import _make_ac_joint_pass + + captured = {} + + def simple_backend(gm, example_inputs): + from torch._functorch.aot_autograd import aot_module_simplified + from torch._functorch.partitioners import min_cut_rematerialization_partition + + def fw(g, i): + return g + + def bw(g, i): + captured["bw_module"] = g + return g + + return aot_module_simplified( + gm, + example_inputs, + fw_compiler=fw, + bw_compiler=bw, + partition_fn=min_cut_rematerialization_partition, + ) + + functorch_patches = {} + if enable_ac: + functorch_patches["joint_custom_pass"] = _make_ac_joint_pass() + + with torch._functorch.config.patch(functorch_patches): + compiled = torch.compile(parallel_mod, backend=simple_backend) + x = torch.randint( + 0, vocab_size, (batch_size // mesh.shape[0], seqlen), device="cuda" + ) + out = compiled(x) + out.backward(torch.randn_like(out)) + + bw = captured["bw_module"] + saved_names = [] + for node in bw.graph.nodes: + if node.op != "placeholder": + continue + if isinstance(node.target, str) and ( + "tangent" in node.target or "primals" in node.target + ): + continue + saved_names.append(node.name) + return saved_names + + +def _simple_partition(saved_node_meta): + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(4, device="meta") + tangent = graph.placeholder("tangents_1") + tangent.meta["val"] = torch.randn(4, device="meta") + saved = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + saved.meta["val"] = torch.randn(4, device="meta") + saved.meta.update(saved_node_meta) + bwd = graph.call_function(torch.ops.aten.mul.Tensor, args=(saved, tangent)) + bwd.meta["val"] = torch.randn(4, device="meta") + output = graph.output((saved, bwd)) + output.meta["desc"] = [None, None] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + return _SaveAllPartitioner()( + gm, [torch.randn(4), torch.randn(4)], num_fwd_outputs=1 + ) + + +def _multi_output_partition( + parent_meta, + saved_indices=None, + extra_parent_consumer=False, +): + """Build a tiny joint graph with a multi-output forward op (split) and + run _SaveAllPartitioner on it. + + parent_meta is merged into the multi-output op's meta. Optionally tag + saved_indices on the parent (mirroring what tag_forward does). If + extra_parent_consumer is True, the parent is also referenced directly + in the output (a non-getitem user that survives DCE), so the partitioner + observes split.users containing both getitems and a non-getitem node. + """ + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(6, device="meta") + tangent = graph.placeholder("tangents_1") + tangent.meta["val"] = torch.randn(6, device="meta") + + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 3 + custom = split.meta.setdefault("custom", {}) + custom.update(parent_meta.get("custom", {})) + if "recompute" in parent_meta: + split.meta["recompute"] = parent_meta["recompute"] + if saved_indices is not None: + custom["ap_must_save_getitem_indices"] = saved_indices + + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + g1 = graph.call_function(operator.getitem, args=(split, 1)) + g1.meta["val"] = torch.randn(2, device="meta") + g2 = graph.call_function(operator.getitem, args=(split, 2)) + g2.meta["val"] = torch.randn(2, device="meta") + + # Concatenate g0+g1+g2 + tangent so all getitems feed into backward + cat = graph.call_function(torch.ops.aten.cat.default, args=([g0, g1, g2],)) + cat.meta["val"] = torch.randn(6, device="meta") + bwd = graph.call_function(torch.ops.aten.mul.Tensor, args=(cat, tangent)) + bwd.meta["val"] = torch.randn(6, device="meta") + + # Optionally make split a direct forward output (non-getitem user). + if extra_parent_consumer: + output_args = (cat, split, bwd) + num_fwd_outputs = 2 + else: + output_args = (cat, bwd) + num_fwd_outputs = 1 + output = graph.output(output_args) + output.meta["desc"] = [None] * len(output_args) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + return _SaveAllPartitioner()( + gm, + [torch.randn(6), torch.randn(6)], + num_fwd_outputs=num_fwd_outputs, + ) + + +# --------------------------------------------------------------------------- +# Unit tests for the standalone mechanisms +# --------------------------------------------------------------------------- + + +def test_boxed_nop_tag_forward_marks_outputs(): + """_boxed_nop_preserve_node_meta(tag_forward=True) tags the forward + output node's tensor args with ap_must_save.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, device="meta") + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + add.meta["val"] = torch.randn(4, device="meta") + mul = graph.call_function(torch.ops.aten.mul.Tensor, args=(add, add)) + mul.meta["val"] = torch.randn(4, device="meta") + graph.output((add, mul)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + _boxed_nop_preserve_node_meta(gm, None, tag_forward=True) + + assert add.meta.get("custom", {}).get("ap_must_save") is True + assert mul.meta.get("custom", {}).get("ap_must_save") is True + + +def test_boxed_nop_tag_forward_skips_getitem(): + """For getitem outputs (multi-output ops), the parent is tagged instead + since getitem metadata doesn't survive preserve_node_meta.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, device="meta") + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 2 + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + graph.output((g0,)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + _boxed_nop_preserve_node_meta(gm, None, tag_forward=True) + + # Parent (split) gets the tag, not the getitem + assert split.meta.get("custom", {}).get("ap_must_save") is True + assert split.meta.get("custom", {}).get("ap_must_save_getitem_indices") == [0] + assert g0.meta.get("custom", {}).get("ap_must_save") is None + + +def test_boxed_nop_tag_forward_records_getitem_indices(): + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(6, device="meta") + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 3 + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + g2 = graph.call_function(operator.getitem, args=(split, 2)) + g2.meta["val"] = torch.randn(2, device="meta") + graph.output((g0, g2)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + _boxed_nop_preserve_node_meta(gm, None, tag_forward=True) + + assert split.meta["custom"]["ap_must_save_getitem_indices"] == [0, 2] + + +def test_boxed_nop_no_tag_forward_default(): + """tag_forward defaults to False — no tagging happens.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, device="meta") + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + add.meta["val"] = torch.randn(4, device="meta") + graph.output((add,)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + _boxed_nop_preserve_node_meta(gm, None) + + assert add.meta.get("custom", {}).get("ap_must_save") is None + + +def test_save_all_partitioner_saves_ap_must_save_despite_prefer_recompute(): + fw, bw = _simple_partition( + { + "custom": {"ap_must_save": True}, + "recompute": CheckpointPolicy.PREFER_RECOMPUTE, + } + ) + + fw_outputs = next(n for n in fw.graph.nodes if n.op == "output").args[0] + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert "add_tensor" in [n.name for n in fw_outputs if isinstance(n, torch.fx.Node)] + assert "add_tensor" in bw_placeholders + + +def test_save_all_partitioner_honors_must_save(): + fw, bw = _simple_partition({"recompute": CheckpointPolicy.MUST_SAVE}) + + fw_outputs = next(n for n in fw.graph.nodes if n.op == "output").args[0] + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert "add_tensor" in [n.name for n in fw_outputs if isinstance(n, torch.fx.Node)] + assert "add_tensor" in bw_placeholders + + +def test_save_all_partitioner_does_not_save_must_recompute(): + fw, bw = _simple_partition( + { + "custom": {"ap_must_save": True}, + "recompute": CheckpointPolicy.MUST_RECOMPUTE, + } + ) + + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert "add_tensor" not in bw_placeholders + + +def test_save_all_partitioner_must_recompute_blocks_multi_output_save(): + """MUST_RECOMPUTE on a multi-output op blocks saving its getitem + children, even when ap_must_save is also set. Documents the invariant + that the first partitioner's recompute decision wins over its own + save tags (which can happen if both are set during graph passes).""" + fw, bw = _multi_output_partition( + parent_meta={ + "custom": {"ap_must_save": True}, + "recompute": CheckpointPolicy.MUST_RECOMPUTE, + }, + saved_indices=[0, 1, 2], + ) + + # No getitem children should appear in backward inputs + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert not any( + name.startswith("getitem") for name in bw_placeholders + ), f"MUST_RECOMPUTE should block multi-output save, got: {bw_placeholders}" + + +def test_save_all_partitioner_must_recompute_blocks_opaque_save(): + """MUST_RECOMPUTE blocks saving even on nodes that would otherwise be + saved as opaque. The _must_recompute check must run before is_opaque_node + so the first partitioner's recompute intent is honored uniformly.""" + # Build a fake graph with a node we'll force is_opaque_node to recognize. + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(4, device="meta") + tangent = graph.placeholder("tangents_1") + tangent.meta["val"] = torch.randn(4, device="meta") + opaque_like = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + opaque_like.meta["val"] = torch.randn(4, device="meta") + opaque_like.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + bwd = graph.call_function(torch.ops.aten.mul.Tensor, args=(opaque_like, tangent)) + bwd.meta["val"] = torch.randn(4, device="meta") + output = graph.output((opaque_like, bwd)) + output.meta["desc"] = [None, None] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + # Force is_opaque_node to return True for our op. _SaveAllPartitioner + # imports it from torch._functorch.partitioners, so patch there. + import torch._functorch.partitioners as partitioners + + original = partitioners.is_opaque_node + partitioners.is_opaque_node = lambda n: n.name == "add_tensor" + try: + fw, bw = _SaveAllPartitioner()( + gm, [torch.randn(4), torch.randn(4)], num_fwd_outputs=1 + ) + finally: + partitioners.is_opaque_node = original + + # add_tensor was tagged MUST_RECOMPUTE; even though is_opaque_node + # returned True, it should NOT be saved + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + assert "add_tensor" not in bw_placeholders, ( + f"MUST_RECOMPUTE on an opaque-classified node was ignored; " + f"backward placeholders: {bw_placeholders}" + ) + + +def test_save_all_partitioner_multi_output_with_non_getitem_user(): + """A multi-output op tagged ap_must_save must save its getitem children + (not the parent tuple) even when it has a non-getitem user. The save + logic must look at the node's value type, not just its users.""" + fw, bw = _multi_output_partition( + parent_meta={"custom": {"ap_must_save": True}}, + saved_indices=[0, 1, 2], + extra_parent_consumer=True, + ) + + # Backward should receive the getitem children, NOT the split tuple + bw_placeholders = [n for n in bw.graph.nodes if n.op == "placeholder"] + # No placeholder should have a tuple/list value (which would mean we + # saved the multi-output op directly) + for n in bw_placeholders: + val = n.meta.get("val") + assert not isinstance(val, (list, tuple)), ( + f"Backward got placeholder {n.name} with tuple/list val — multi-output op " + f"was saved directly instead of its getitem children" + ) + # And at least one getitem should be present (the save took effect) + getitem_names = [n.name for n in bw_placeholders if n.name.startswith("getitem")] + assert len(getitem_names) > 0, ( + f"Expected getitem children to be saved, got placeholders: " + f"{[n.name for n in bw_placeholders]}" + ) + + +def test_save_all_partitioner_replays_only_indexed_getitems(): + """When tag_forward records specific getitem indices on a multi-output + parent, only those indices should appear as saved backward inputs; + other indices should not be saved (and would be recomputed if needed). + + Build a graph where backward consumes BOTH getitem 0 and getitem 1, but + we only tag index 0 as ap_must_save. The partitioner should save + getitem 0 only. + """ + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(4, device="meta") + t0 = graph.placeholder("tangents_1") + t0.meta["val"] = torch.randn(2, device="meta") + t1 = graph.placeholder("tangents_2") + t1.meta["val"] = torch.randn(2, device="meta") + + # Multi-output op: split into 2 chunks. Mark only index 0 as ap_must_save. + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 2 + custom = split.meta.setdefault("custom", {}) + custom["ap_must_save"] = True + custom["ap_must_save_getitem_indices"] = [0] + + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + g1 = graph.call_function(operator.getitem, args=(split, 1)) + g1.meta["val"] = torch.randn(2, device="meta") + + # Forward outputs: a value that depends on both g0 and g1 (so both + # getitems are required in forward). + add_fw = graph.call_function(torch.ops.aten.add.Tensor, args=(g0, g1)) + add_fw.meta["val"] = torch.randn(2, device="meta") + + # Backward ops: independent muls using g0 and g1 respectively, so each + # getitem is a real backward dependency. + bwd0 = graph.call_function(torch.ops.aten.mul.Tensor, args=(g0, t0)) + bwd0.meta["val"] = torch.randn(2, device="meta") + bwd1 = graph.call_function(torch.ops.aten.mul.Tensor, args=(g1, t1)) + bwd1.meta["val"] = torch.randn(2, device="meta") + + output = graph.output((add_fw, bwd0, bwd1)) + output.meta["desc"] = [None, None, None] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + fw, bw = _SaveAllPartitioner()( + gm, + [torch.randn(4), torch.randn(2), torch.randn(2)], + num_fwd_outputs=1, + ) + + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + # getitem 0 should be a saved backward input (we tagged index 0) + assert ( + "getitem" in bw_placeholders + ), f"Expected getitem (index 0) to be saved; got {bw_placeholders}" + # getitem 1 should NOT be a saved input — it's not in the indices list + assert "getitem_1" not in bw_placeholders, ( + f"Expected getitem_1 (index 1) NOT to be saved (index restriction); " + f"got {bw_placeholders}" + ) + + +def test_save_all_partitioner_must_save_overrides_getitem_indices(): + """MUST_SAVE on a multi-output parent should save ALL getitem children + even if ap_must_save_getitem_indices restricts to a subset. MUST_SAVE + is a stronger directive than ap_must_save's index-specific replay.""" + graph = torch.fx.Graph() + x = graph.placeholder("primals_1") + x.meta["val"] = torch.randn(4, device="meta") + t0 = graph.placeholder("tangents_1") + t0.meta["val"] = torch.randn(2, device="meta") + t1 = graph.placeholder("tangents_2") + t1.meta["val"] = torch.randn(2, device="meta") + + split = graph.call_function(torch.ops.aten.split.Tensor, args=(x, 2)) + split.meta["val"] = [torch.randn(2, device="meta")] * 2 + # Both MUST_SAVE AND restricted indices. MUST_SAVE should win. + split.meta["recompute"] = CheckpointPolicy.MUST_SAVE + custom = split.meta.setdefault("custom", {}) + custom["ap_must_save"] = True + custom["ap_must_save_getitem_indices"] = [0] + + g0 = graph.call_function(operator.getitem, args=(split, 0)) + g0.meta["val"] = torch.randn(2, device="meta") + g1 = graph.call_function(operator.getitem, args=(split, 1)) + g1.meta["val"] = torch.randn(2, device="meta") + + add_fw = graph.call_function(torch.ops.aten.add.Tensor, args=(g0, g1)) + add_fw.meta["val"] = torch.randn(2, device="meta") + bwd0 = graph.call_function(torch.ops.aten.mul.Tensor, args=(g0, t0)) + bwd0.meta["val"] = torch.randn(2, device="meta") + bwd1 = graph.call_function(torch.ops.aten.mul.Tensor, args=(g1, t1)) + bwd1.meta["val"] = torch.randn(2, device="meta") + + output = graph.output((add_fw, bwd0, bwd1)) + output.meta["desc"] = [None, None, None] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + fw, bw = _SaveAllPartitioner()( + gm, + [torch.randn(4), torch.randn(2), torch.randn(2)], + num_fwd_outputs=1, + ) + + bw_placeholders = [n.name for n in bw.graph.nodes if n.op == "placeholder"] + # Both getitems should be saved (MUST_SAVE overrides the index restriction) + assert "getitem" in bw_placeholders, f"Expected getitem; got {bw_placeholders}" + assert "getitem_1" in bw_placeholders, ( + f"Expected getitem_1 even though indices=[0] (MUST_SAVE should override); " + f"got {bw_placeholders}" + ) + + +def test_preserve_node_meta_propagates_recompute_through_collectives(): + """preserve_node_meta correctly propagates MUST_RECOMPUTE through + collective ops (allgather, wait_tensor) — confirming the safety net + mechanism works. + """ + from torch.fx.experimental.proxy_tensor import make_fx + from torch.testing._internal.distributed.fake_pg import FakeStore + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + "fake", store=FakeStore(), rank=0, world_size=8 + ) + + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, device="meta") + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + add.meta["val"] = torch.randn(4, device="meta") + add.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + add.meta["ac_graph_id"] = 100000 + + ag = graph.call_function( + torch.ops._c10d_functional.all_gather_into_tensor.default, + args=(add, 2, "0"), + ) + ag.meta["val"] = torch.randn(8, device="meta") + ag.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + ag.meta["ac_graph_id"] = 100000 + + wt = graph.call_function(torch.ops._c10d_functional.wait_tensor.default, args=(ag,)) + wt.meta["val"] = torch.randn(8, device="meta") + wt.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + wt.meta["ac_graph_id"] = 100000 + + graph.output((wt,)) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + compiled = _boxed_nop_preserve_node_meta(gm, None) + fake_input = torch.randn(4, device="cpu") + new_gm = make_fx(compiled, tracing_mode="fake")([fake_input]) + + for node in new_gm.graph.nodes: + if node.op != "call_function": + continue + name = ( + node.target.__name__ + if hasattr(node.target, "__name__") + else str(node.target) + ) + if name in ( + "add.Tensor", + "all_gather_into_tensor.default", + "wait_tensor.default", + ): + assert ( + node.meta.get("recompute") == CheckpointPolicy.MUST_RECOMPUTE + ), f"{name} lost MUST_RECOMPUTE tag through preserve_node_meta" + + +def test_patch_partitioner_dce_allows_wait_tensor_elimination(): + """_patch_partitioner_dce overrides is_not_collective for wait_tensor, + and the override is properly reverted on exit.""" + import torch._functorch.partitioners as partitioners + + graph = torch.fx.Graph() + x = graph.placeholder("x") + wt = graph.call_function(torch.ops._c10d_functional.wait_tensor.default, args=(x,)) + graph.output((wt,)) + + original_result = partitioners.is_not_collective(wt) + + with _patch_partitioner_dce(): + patched_result = partitioners.is_not_collective(wt) + + # After exit, the original function is restored + restored_result = partitioners.is_not_collective(wt) + + # Inside the patch, wait_tensor goes through our shortcut + assert patched_result is False + # And the patch is properly restored + assert restored_result == original_result + + +# --------------------------------------------------------------------------- +# Integration tests for the partitioner behavior on a real model +# --------------------------------------------------------------------------- + + +@apply_cuda_patches +def test_save_all_partitioner_does_not_save_fsdp_wait_tensors( + parallel_mod_2d, device_mesh_2d +): + """The whole point of the machinery: with FSDP allgathers tagged + MUST_RECOMPUTE, _SaveAllPartitioner should NOT save their wait_tensor + outputs in the backward. + """ + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_2d + + captured = _capture_partitioner_call( + parallel_mod, + batch_size, + seqlen, + vocab_size, + device_mesh_2d, + enable_ac=False, + ) + + # FSDP wait_tensor outputs should NOT appear as saved activations + saved_wait_tensors = [ + n for n in captured["saved_activation_names"] if "wait_tensor" in n + ] + assert len(saved_wait_tensors) == 0, ( + f"FSDP wait_tensor outputs were saved (should be recomputed via " + f"FSDP prefetch): {saved_wait_tensors}" + ) + + # Sanity: MUST_RECOMPUTE tags survived into the joint graph + assert captured["allgather_recompute_tags"] > 0, ( + "Expected MUST_RECOMPUTE tags on all_gather nodes to survive " + "preserve_node_meta into the second compilation's joint graph" + ) + + # And ap_must_save tags should be present (sanity check on tag + # propagation — used to be a separate test). + assert captured["ap_must_save_count"] > 0, ( + "Expected ap_must_save tags from first compilation to survive into " + "second compilation's joint graph" + ) + activation_saves = [ + n for n in captured["saved_activation_names"] if not n.startswith("primals") + ] + assert len(activation_saves) > 0 + + +@apply_cuda_patches +def test_default_partitioner_diverges_from_save_all_partitioner( + parallel_mod_2d, device_mesh_2d +): + """The default min-cut partitioner produces a different save list than + _SaveAllPartitioner when AC is active. This is the motivating reason + _SaveAllPartitioner exists: min-cut + AC tags can save FSDP allgather + outputs (or other tensors that the first partitioner chose to recompute). + + We assert "differ" rather than checking for a specific wait_tensor name + because the exact divergence depends on model size, mesh shape, and the + AC stage budget. For the production LLaMA-3 8B config (32 layers, dim=4096, + 128 GPUs), the divergence manifests as ~1.2 GB of FSDP wait_tensor outputs + being saved per 4 layers — see [memory_gap_investigation memory note]. + """ + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_2d + + # Default partitioner with AC enabled in second compilation (bad case) + default_saves = _saved_names_from_default_compile( + parallel_mod, + batch_size, + seqlen, + vocab_size, + device_mesh_2d, + enable_ac=True, + ) + + # _SaveAllPartitioner with AC enabled in second compilation (our fix) + save_all_captured = _capture_partitioner_call( + parallel_mod, + batch_size, + seqlen, + vocab_size, + device_mesh_2d, + enable_ac=True, + ) + save_all_saves = save_all_captured["saved_activation_names"] + + # The two partitioners pick different things. Equal save count would + # only happen by coincidence — even when counts match the choice of + # tensors differs. + default_set = set(default_saves) + save_all_set = set(save_all_saves) + only_in_default = default_set - save_all_set + only_in_save_all = save_all_set - default_set + assert only_in_default or only_in_save_all, ( + f"Default partitioner and _SaveAllPartitioner produced identical " + f"saves ({sorted(default_set)}). If this is reproducible, the " + f"motivating divergence may have been fixed upstream and " + f"_SaveAllPartitioner could be reevaluated." + ) + + +@apply_cuda_patches +def test_save_all_partitioner_reproduces_first_partitioner_saves( + parallel_mod_2d, device_mesh_2d +): + """_SaveAllPartitioner's saved set should approximately match what the + first partitioner chose. The two operate on different graphs (the first + inside apply_placement, the second inside torch.compile), so we compare + by tensor shape/dtype histograms rather than node names — names will + differ across compilations but the underlying tensors should match. + """ + # First partitioner: capture the forward outputs beyond num_fwd_outputs. + first_saves = _capture_first_partitioner_saves(device_mesh_2d, n_layers=2) + + # Second partitioner: reuse the cached parallel_mod and capture the + # backward inputs (saved values). + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_2d + captured = _capture_partitioner_call( + parallel_mod, + batch_size, + seqlen, + vocab_size, + device_mesh_2d, + enable_ac=False, + ) + bw = captured["bw_module"] + second_saves = [] + for node in bw.graph.nodes: + if node.op != "placeholder": + continue + if isinstance(node.target, str) and ( + "tangent" in node.target or "primals" in node.target + ): + continue + second_saves.append(node) + + def _shape_sig(node): + val = node.meta.get("val") + if val is None or not hasattr(val, "shape"): + return None + return (tuple(val.shape), str(val.dtype)) + + # Counts should be close. Retracing through Dynamo can add or eliminate + # a small number of view/reshape nodes. + diff = abs(len(first_saves) - len(second_saves)) + assert diff <= 2, ( + f"First partitioner saved {len(first_saves)} values, but " + f"_SaveAllPartitioner saved {len(second_saves)} (diff {diff} > 2). " + f"They should match closely." + ) + + # Shape/dtype histograms should match closely too. + from collections import Counter + + first_shapes = Counter(_shape_sig(n) for n in first_saves) + second_shapes = Counter(_shape_sig(n) for n in second_saves) + # Drop entries with no shape (sym ints, opaque, etc.) + first_shapes.pop(None, None) + second_shapes.pop(None, None) + + # Symmetric difference should be small + diff_count = sum((first_shapes - second_shapes).values()) + sum( + (second_shapes - first_shapes).values() + ) + assert diff_count <= 4, ( + f"Shape histograms diverged too much (diff_count={diff_count}).\n" + f" Only in first: {first_shapes - second_shapes}\n" + f" Only in second: {second_shapes - first_shapes}" + ) + + +@apply_cuda_patches +def test_save_all_partitioner_compile_with_ac_enabled(parallel_mod_2d, device_mesh_2d): + """End-to-end smoke test: AutoParallel + torch.compile(autoparallel_backend) + with AC enabled. Exercises the full Inductor pipeline including AC joint + pass and codegen — kept around so a regression in the backend wiring + surfaces here even if the unit tests still pass.""" + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_2d + # AC enabled in the backend → joint_custom_pass adds PREFER_RECOMPUTE + backend = autoparallel_backend(enable_ac=True, overlap_scheduling=False) + + compiled = torch.compile(parallel_mod, backend=backend) + x = torch.randint( + 0, + vocab_size, + (batch_size // device_mesh_2d.shape[0], seqlen), + device="cuda", + ) + out = compiled(x) + out.backward(torch.randn_like(out)) + + +@apply_cuda_patches +def test_save_all_partitioner_compile_1d_mesh(parallel_mod_1d, device_mesh_1d): + """The partitioner works with a 1D (FSDP-only) mesh.""" + parallel_mod, batch_size, seqlen, vocab_size = parallel_mod_1d + backend = autoparallel_backend(enable_ac=False, overlap_scheduling=False) + + compiled = torch.compile(parallel_mod, backend=backend) + x = torch.randint( + 0, + vocab_size, + (batch_size // device_mesh_1d.shape[0], seqlen), + device="cuda", + ) + out = compiled(x) + out.backward(torch.randn_like(out))