From 66416015226b81e26dac876ce467917e72955344 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 14:22:35 +0000 Subject: [PATCH 1/6] Preserve DeviceMesh identity in expand_rule's op_schema deepcopy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit expand_rule used copy.deepcopy(op_schema_) to snapshot the schema before mutating it. DeviceMesh has no __deepcopy__, so deepcopy went through __getstate__/__setstate__ and produced a fresh DeviceMesh with the same logical content but an empty _flatten_mapping cache. The DTensorSpecs returned from expand_rule carried these duplicates, which propagated into the sharding solution. apply_placement's pre-warming in _apply_placement_common only populates the user mesh's cache. When _optimize_same_nd_sharding_as_1d inside make_fx hit a duplicate mesh, _flatten() cache-missed and dispatched as_strided on the real rank_map — failing FakeTensorMode's non-fake-input check. Which solution the solver picked depended on the cost model, so the failure surfaced on g5/A10G CI but not on H100. Fix: _deepcopy_preserving_mesh pre-seeds copy.deepcopy's memo with DeviceMesh identity mappings so duplicates aren't produced. Adds a regression test that asserts every spec mesh's root has a warm _flatten cache after apply_placement. Authored with Claude. --- autoparallel/shardings/propagation_rules.py | 45 +++++++- tests/test_mesh_identity.py | 113 ++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 tests/test_mesh_identity.py diff --git a/autoparallel/shardings/propagation_rules.py b/autoparallel/shardings/propagation_rules.py index 27a20049..d054f345 100644 --- a/autoparallel/shardings/propagation_rules.py +++ b/autoparallel/shardings/propagation_rules.py @@ -51,6 +51,49 @@ logger = logging.getLogger(__name__) + +def _deepcopy_preserving_mesh(obj): + """Like copy.deepcopy, but reuses DeviceMesh instances instead of + duplicating them. + + DeviceMesh carries process-group state (rank maps, _flatten_mapping + cache, backend overrides) that is logically shared across all callers. + Deep-copying it produces a fresh object with an empty _flatten_mapping + that misses the cache populated on the original, which later forces + DeviceMesh._flatten to dispatch as_strided on the rank_map inside + make_fx — failing FakeTensorMode's non-fake-input assertion. + + We pre-populate copy.deepcopy's memo dict with identity mappings for + every DeviceMesh reachable from obj. deepcopy returns existing entries + from memo without recursing, so the same DeviceMesh instances appear + in the copy. + """ + from torch.distributed.device_mesh import DeviceMesh + + memo: dict = {} + stack = [obj] + seen: set[int] = set() + while stack: + x = stack.pop() + if id(x) in seen: + continue + seen.add(id(x)) + if isinstance(x, DeviceMesh): + memo[id(x)] = x + continue + if isinstance(x, (list, tuple, set, frozenset)): + stack.extend(x) + elif isinstance(x, dict): + stack.extend(x.values()) + elif hasattr(x, "__dict__"): + stack.extend(x.__dict__.values()) + elif hasattr(x, "__slots__"): + for slot in x.__slots__: + if hasattr(x, slot): + stack.append(getattr(x, slot)) + return copy.deepcopy(obj, memo) + + # TODO: move this to PyTorch dim_maps[torch.t] = lambda input: dim_transpose(input.ndim, -2, -1) @@ -829,7 +872,7 @@ def expand_rule(mesh, op_schema_): from torch._subclasses.fake_tensor import unset_fake_temporarily with unset_fake_temporarily(): - op_schema = copy.deepcopy(op_schema_) + op_schema = _deepcopy_preserving_mesh(op_schema_) input_strat = op_schema.args_schema[0] orig_shape = input_strat.strategies[0].output_specs.tensor_meta.shape dest_shape = op_schema.args_schema[1] diff --git a/tests/test_mesh_identity.py b/tests/test_mesh_identity.py new file mode 100644 index 00000000..cbd7b705 --- /dev/null +++ b/tests/test_mesh_identity.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Regression test: DeviceMesh duplicates introduced by +``copy.deepcopy(op_schema_)`` in propagation rules used to trigger an +``as_strided``-inside-FakeTensorMode failure during ``apply_placement``. + +Background: ``copy.deepcopy(op_schema)`` inside ``expand_rule`` produces a +fresh DeviceMesh object with an empty ``_flatten_mapping``. When the solver +picks a redistribution that calls ``mesh._flatten()`` on the duplicate, +``_create_flatten_mesh`` runs uncached, dispatching ``as_strided`` on the +rank_map — and FakeTensorMode rejects the non-fake tensor input. + +Fix lives in ``autoparallel/shardings/propagation_rules.py`` as +``_deepcopy_preserving_mesh``: pre-seeds copy.deepcopy's memo with +DeviceMesh identity mappings so the deepcopy reuses the original meshes. + +This test asserts the property we actually care about: every DeviceMesh +referenced by the sharding solution has a populated ``_flatten_mapping`` +on its root, so a subsequent ``_flatten()`` call inside ``make_fx`` hits +the cache instead of dispatching. + +We use the Transformer model because it triggers ``expand_rule`` (a +simpler model wouldn't exercise that propagation rule). +""" + +import torch +from conftest import apply_cuda_patches +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel + + +@apply_cuda_patches +def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): + """After ``apply_placement``'s pre-warming, every spec mesh's root + must have the default flattened mesh cached. Otherwise a subsequent + ``_flatten()`` call inside ``make_fx`` triggers ``as_strided`` on the + rank_map and FakeTensorMode rejects it (the original CI failure). + """ + vocab_size = 1024 + seqlen = 128 + batch_size = 2 * device_mesh_2d.shape[0] + + with torch.device("meta"): + model = Transformer( + TransformerModelArgs( + dim=256, + n_layers=2, + n_heads=8, + n_kv_heads=2, + ffn_dim_multiplier=1.3, + multiple_of=64, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + ) + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + device_mesh_2d, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=False) + # apply_placement pre-warms the user mesh's _flatten cache so + # subsequent _flatten() calls inside make_fx hit the cache. + autop.apply_placement(sharding_placement) + + # Collect every distinct spec mesh from the solution + spec_meshes: dict[int, object] = {} + for strategy in sharding_placement.values(): + specs = [] + if hasattr(strategy, "output_specs"): + o = strategy.output_specs + specs.extend(o if isinstance(o, (list, tuple)) else [o]) + if hasattr(strategy, "input_specs"): + specs.extend(strategy.input_specs or []) + for s in specs: + if s is None: + continue + m = getattr(s, "mesh", None) + if m is None: + continue + spec_meshes[id(m)] = m + + cold = [] + for mid, m in spec_meshes.items(): + if m.ndim == 1: + # 1D meshes: _flatten() short-circuits to self without dispatch + continue + root = m._get_root_mesh() + default_name = "_".join(m._mesh_dim_names) + if default_name not in root._flatten_mapping: + cold.append((mid, m._mesh_dim_names, list(root._flatten_mapping))) + + assert not cold, ( + f"After apply_placement, {len(cold)} spec mesh(es) still have a " + f"cold _flatten_mapping for their default name. A subsequent " + f"_flatten() call inside make_fx will dispatch as_strided and " + f"fail FakeTensorMode's non-fake-input check. Details " + f"(id, dim_names, root cache keys): {cold}. See " + f"_deepcopy_preserving_mesh in propagation_rules.py." + ) From 6785c98b9de16df3599c7df6cf94cd233b463839 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 15:10:23 +0000 Subject: [PATCH 2/6] Update to use CUDA 13.0 --- .github/workflows/test_cuda.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 6cd42626..31857c35 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -13,13 +13,13 @@ concurrency: jobs: test-cuda-single-gpu: - name: Test CUDA Single GPU (cuda12.6-py3.12) + name: Test CUDA Single GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -29,18 +29,18 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . pytest tests examples-cuda-single-gpu: - name: Examples CUDA Single GPU (cuda12.6-py3.12) + name: Examples CUDA Single GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -50,7 +50,7 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . run_timed() { local start=$SECONDS; "$@"; echo "$* : $((SECONDS - start))s" >> /tmp/timings.txt; } run_timed python examples/example_autoparallel.py @@ -60,13 +60,13 @@ jobs: cat /tmp/timings.txt test-cuda-multi-gpu: - name: Test CUDA Multi GPU (cuda12.6-py3.12) + name: Test CUDA Multi GPU (cuda13.0-py3.12) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: linux.g5.12xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.6" + gpu-arch-version: "13.0" submodules: recursive script: | conda create --yes --quiet --name py312 python=3.12 @@ -76,7 +76,7 @@ jobs: pip install --quiet -r requirements-test.txt # For some reason the spec above isnt working pip uninstall -y torch - pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . python -m pytest tests/test_dcp_roundtrip.py -v torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py From ac4d2b30c39d8b543bd05905e99c190b2f40a667 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 17:36:13 +0000 Subject: [PATCH 3/6] Temp test --- .github/workflows/test_cuda.yml | 11 ++ tests/diagnose_mesh_identity_ci.py | 227 +++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+) create mode 100644 tests/diagnose_mesh_identity_ci.py diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 31857c35..3834e437 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -31,7 +31,18 @@ jobs: pip uninstall -y torch pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . + # Run tests but always emit the mesh-identity diagnostic afterward + # so CI captures which mesh triggers the as_strided dispatch on the + # real A10G hardware (see tests/diagnose_mesh_identity_ci.py). + set +e pytest tests + pytest_status=$? + set -e + echo "============================================================" + echo " mesh-identity diagnostic (always runs)" + echo "============================================================" + python tests/diagnose_mesh_identity_ci.py || true + exit $pytest_status examples-cuda-single-gpu: name: Examples CUDA Single GPU (cuda13.0-py3.12) diff --git a/tests/diagnose_mesh_identity_ci.py b/tests/diagnose_mesh_identity_ci.py new file mode 100644 index 00000000..1839f31b --- /dev/null +++ b/tests/diagnose_mesh_identity_ci.py @@ -0,0 +1,227 @@ +"""Diagnose CI failure for test_mesh_identity.py. + +Wraps DeviceMesh._flatten and DeviceMesh._create_flatten_mesh to log +every call. Mirrors the exact path the failing test takes — no extra +config patches, no joint_custom_pass. + +Run on CI immediately after the failure to see which mesh triggered +the as_strided dispatch and why the cache lookup missed. +""" + +import os +import sys +import traceback +from typing import Any + +sys.path.insert(0, os.path.join(os.path.dirname(__file__))) + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.testing._internal.distributed.fake_pg import FakeStore + +# Mirror the apply_cuda_patches in conftest (H100, capability 9.0) +from unittest.mock import patch + +_PATCHES: list[Any] = [ + patch("torch.cuda.device_count", lambda: 8), + patch("torch.cuda.get_device_name", lambda *a, **k: "H100"), + patch("torch.cuda.get_device_capability", lambda *a, **k: (9, 0)), + patch( + "torch.cuda.get_device_properties", + lambda *a, **k: type( + "Props", + (), + { + "major": 9, + "minor": 0, + "name": "H100", + "total_memory": 80 * 1024**3, + "multi_processor_count": 132, + }, + )(), + ), +] +for p in _PATCHES: + p.start() + + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel + +_log_lines: list[str] = [] + + +def _log(msg: str) -> None: + _log_lines.append(msg) + print(msg, flush=True) + + +def _summarize(m) -> str: + if m is None: + return "None" + dim_names = getattr(m, "_mesh_dim_names", None) + ndim = getattr(m, "ndim", "?") + root = m._get_root_mesh() if hasattr(m, "_get_root_mesh") else None + is_root = root is m + cache_keys = ( + list(root._flatten_mapping.keys()) + if root is not None and hasattr(root, "_flatten_mapping") + else "" + ) + return ( + f"mesh_id={id(m):#x} ndim={ndim} dim_names={dim_names} " + f"is_root={is_root} root_id={id(root):#x} root_cache={cache_keys}" + ) + + +def _short_traceback() -> str: + stack = traceback.extract_stack() + interesting = [ + f" {f.filename}:{f.lineno} in {f.name}" + for f in stack + if "device_mesh" not in f.filename + and os.path.basename(__file__) not in f.filename + ] + return "\n".join(interesting[-8:]) + + +_orig_flatten = DeviceMesh._flatten +_orig_create = DeviceMesh._create_flatten_mesh + +_count = {"flatten": 0, "create": 0, "miss": 0} + + +def _wrapped_flatten(self, mesh_dim_name=None, backend_override=None): + _count["flatten"] += 1 + n = _count["flatten"] + requested = mesh_dim_name or ( + "_".join(self._mesh_dim_names) if self._mesh_dim_names else "" + ) + _log(f"\n[_flatten #{n}] CALL on {_summarize(self)} requested={requested!r}") + _log(f"[_flatten #{n}] call site:\n{_short_traceback()}") + try: + result = _orig_flatten(self, mesh_dim_name, backend_override) + _log(f"[_flatten #{n}] OK → {_summarize(result)}") + return result + except Exception as e: + _log(f"[_flatten #{n}] RAISED: {type(e).__name__}: {e}") + raise + + +def _wrapped_create(self, mesh_dim_name, backend_override=(None, None)): + _count["create"] += 1 + n = _count["create"] + root = self._get_root_mesh() + cache_hit = mesh_dim_name in root._flatten_mapping + if not cache_hit: + _count["miss"] += 1 + _log( + f" [_create_flatten_mesh #{n}] *** CACHE MISS *** " + f"name={mesh_dim_name!r} root_id={id(root):#x} " + f"root_cache={list(root._flatten_mapping)}" + ) + else: + _log( + f" [_create_flatten_mesh #{n}] cache hit name={mesh_dim_name!r} " + f"root_id={id(root):#x}" + ) + return _orig_create(self, mesh_dim_name, backend_override) + + +DeviceMesh._flatten = _wrapped_flatten # type: ignore[method-assign] +DeviceMesh._create_flatten_mesh = _wrapped_create # type: ignore[method-assign] + + +def main() -> None: + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + "fake", store=FakeStore(), rank=0, world_size=256 + ) + + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", (32, 8), mesh_dim_names=("dp", "tp") + ) + _log(f"\n=== USER MESH: {_summarize(mesh)} ===\n") + + vocab_size = 1024 + seqlen = 128 + batch_size = 2 * mesh.shape[0] + + with torch.device("meta"): + model = Transformer( + TransformerModelArgs( + dim=256, + n_layers=2, + n_heads=8, + n_kv_heads=2, + ffn_dim_multiplier=1.3, + multiple_of=64, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + ) + + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + mesh, + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=False) + + seen: dict[int, Any] = {} + _log("\n=== MESHES IN SHARDING SOLUTION ===") + for strategy in sharding_placement.values(): + specs: list[Any] = [] + if hasattr(strategy, "output_specs"): + o = strategy.output_specs + specs.extend(o if isinstance(o, (list, tuple)) else [o]) + if hasattr(strategy, "input_specs"): + specs.extend(strategy.input_specs or []) + for s in specs: + if s is None: + continue + m = getattr(s, "mesh", None) + if m is None: + continue + if id(m) not in seen: + seen[id(m)] = m + _log(f" spec mesh: {_summarize(m)}") + _log(f"=== TOTAL UNIQUE SPEC MESHES: {len(seen)} ===\n") + + _log("\n=== ENTERING apply_placement ===\n") + try: + autop.apply_placement(sharding_placement) + _log("\n=== apply_placement SUCCEEDED ===\n") + except Exception as e: + _log(f"\n=== apply_placement FAILED: {type(e).__name__}: {e} ===\n") + raise + + +if __name__ == "__main__": + try: + main() + finally: + _log(f"\n=== TOTAL _flatten calls: {_count['flatten']} ===") + _log(f"=== TOTAL _create_flatten_mesh calls: {_count['create']} ===") + _log(f"=== TOTAL _create_flatten_mesh CACHE MISSES: {_count['miss']} ===") + if _count["miss"] > 1: + _log( + "\n*** SMOKING GUN: cache missed more than once. " + "First miss is the pre-warming. Any further miss is the " + "duplicate mesh triggering the as_strided dispatch. " + "Search above for '*** CACHE MISS ***'." + ) + log_path = os.environ.get( + "FLATTEN_DIAG_LOG", "mesh_identity_diagnosis.log" + ) + with open(log_path, "w") as f: + f.write("\n".join(_log_lines)) + print(f"\nFull diagnostic log written to {log_path}") From 5734c507192e7307bc62fb70d198a0f1588009b9 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 18:07:05 +0000 Subject: [PATCH 4/6] Debug --- .github/workflows/test_cuda.yml | 8 +- tests/test_diagnose_mesh_identity_pytest.py | 186 ++++++++++++++++++++ 2 files changed, 192 insertions(+), 2 deletions(-) create mode 100644 tests/test_diagnose_mesh_identity_pytest.py diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 3834e437..133b4957 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -35,13 +35,17 @@ jobs: # so CI captures which mesh triggers the as_strided dispatch on the # real A10G hardware (see tests/diagnose_mesh_identity_ci.py). set +e - pytest tests + pytest tests --ignore=tests/test_diagnose_mesh_identity_pytest.py pytest_status=$? set -e echo "============================================================" - echo " mesh-identity diagnostic (always runs)" + echo " mesh-identity diagnostic (standalone script)" echo "============================================================" python tests/diagnose_mesh_identity_ci.py || true + echo "============================================================" + echo " mesh-identity diagnostic (pytest mode, after other tests)" + echo "============================================================" + pytest -s tests/test_diagnose_mesh_identity_pytest.py -v || true exit $pytest_status examples-cuda-single-gpu: diff --git a/tests/test_diagnose_mesh_identity_pytest.py b/tests/test_diagnose_mesh_identity_pytest.py new file mode 100644 index 00000000..04a5212f --- /dev/null +++ b/tests/test_diagnose_mesh_identity_pytest.py @@ -0,0 +1,186 @@ +"""Diagnose CI mesh-identity failure FROM WITHIN PYTEST. + +Adds extensive instrumentation around the actual failing test to expose +state at every flatten call. Crucial difference from +tests/diagnose_mesh_identity_ci.py: this runs as a pytest test, so +fixtures and prior-test pollution are in effect — matching the actual +CI failure conditions. +""" + +import os +import traceback +from typing import Any + +import torch +from conftest import apply_cuda_patches +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel + + +def _summarize(m) -> str: + if m is None: + return "None" + dim_names = getattr(m, "_mesh_dim_names", None) + ndim = getattr(m, "ndim", "?") + root = m._get_root_mesh() if hasattr(m, "_get_root_mesh") else None + cache_keys = ( + list(root._flatten_mapping.keys()) + if root is not None and hasattr(root, "_flatten_mapping") + else "" + ) + return ( + f"mesh_id={id(m):#x} ndim={ndim} dim_names={dim_names} " + f"is_root={root is m} root_id={id(root):#x} root_cache={cache_keys}" + ) + + +def _short_traceback() -> str: + stack = traceback.extract_stack() + interesting = [ + f" {f.filename}:{f.lineno} in {f.name}" + for f in stack + if "device_mesh" not in f.filename + and "diagnose_mesh_identity_pytest" not in f.filename + ] + return "\n".join(interesting[-10:]) + + +_state: dict[str, Any] = { + "log": [], + "count_flatten": 0, + "count_create": 0, + "count_miss": 0, + "installed": False, +} + + +def _log(msg: str) -> None: + _state["log"].append(msg) + print(msg, flush=True) + + +def _install_hooks(): + if _state["installed"]: + return + _state["installed"] = True + + _orig_flatten = DeviceMesh._flatten + _orig_create = DeviceMesh._create_flatten_mesh + + def _wrapped_flatten(self, mesh_dim_name=None, backend_override=None): + _state["count_flatten"] += 1 + n = _state["count_flatten"] + requested = mesh_dim_name or ( + "_".join(self._mesh_dim_names) if self._mesh_dim_names else "" + ) + _log(f"\n[_flatten #{n}] CALL on {_summarize(self)} requested={requested!r}") + _log(f"[_flatten #{n}] call site:\n{_short_traceback()}") + try: + result = _orig_flatten(self, mesh_dim_name, backend_override) + _log(f"[_flatten #{n}] OK → {_summarize(result)}") + return result + except Exception as e: + _log(f"[_flatten #{n}] RAISED: {type(e).__name__}: {e}") + raise + + def _wrapped_create(self, mesh_dim_name, backend_override=(None, None)): + _state["count_create"] += 1 + n = _state["count_create"] + root = self._get_root_mesh() + cache_hit = mesh_dim_name in root._flatten_mapping + if not cache_hit: + _state["count_miss"] += 1 + _log( + f" [_create_flatten_mesh #{n}] *** CACHE MISS *** " + f"name={mesh_dim_name!r} root_id={id(root):#x} " + f"root_cache={list(root._flatten_mapping)}" + ) + else: + _log( + f" [_create_flatten_mesh #{n}] cache hit name={mesh_dim_name!r} " + f"root_id={id(root):#x}" + ) + return _orig_create(self, mesh_dim_name, backend_override) + + DeviceMesh._flatten = _wrapped_flatten # type: ignore[method-assign] + DeviceMesh._create_flatten_mesh = _wrapped_create # type: ignore[method-assign] + + +@apply_cuda_patches +def test_diagnose_mesh_identity_pytest(device_mesh_2d): + """Run the same scenario as test_mesh_identity.py but with instrumentation.""" + _install_hooks() + mesh = device_mesh_2d + _log(f"\n=== USER MESH (from fixture): {_summarize(mesh)} ===\n") + + vocab_size = 1024 + seqlen = 128 + batch_size = 2 * mesh.shape[0] + + with torch.device("meta"): + model = Transformer( + TransformerModelArgs( + dim=256, + n_layers=2, + n_heads=8, + n_kv_heads=2, + ffn_dim_multiplier=1.3, + multiple_of=64, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + ) + + try: + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + mesh, + MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + ), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=False) + + seen: dict[int, Any] = {} + _log("\n=== MESHES IN SHARDING SOLUTION ===") + for strategy in sharding_placement.values(): + specs: list[Any] = [] + if hasattr(strategy, "output_specs"): + o = strategy.output_specs + specs.extend(o if isinstance(o, (list, tuple)) else [o]) + if hasattr(strategy, "input_specs"): + specs.extend(strategy.input_specs or []) + for s in specs: + if s is None: + continue + m = getattr(s, "mesh", None) + if m is None: + continue + if id(m) not in seen: + seen[id(m)] = m + _log(f" spec mesh: {_summarize(m)}") + _log(f"=== TOTAL UNIQUE SPEC MESHES: {len(seen)} ===\n") + + _log("\n=== ENTERING apply_placement ===\n") + autop.apply_placement(sharding_placement) + _log("\n=== apply_placement SUCCEEDED ===\n") + finally: + _log(f"\n=== TOTAL _flatten calls: {_state['count_flatten']} ===") + _log(f"=== TOTAL _create_flatten_mesh calls: {_state['count_create']} ===") + _log(f"=== TOTAL _create_flatten_mesh CACHE MISSES: {_state['count_miss']} ===") + log_path = os.environ.get( + "MESH_IDENTITY_PYTEST_LOG", "mesh_identity_pytest.log" + ) + with open(log_path, "w") as f: + f.write("\n".join(_state["log"])) + print(f"\nFull diagnostic log written to {log_path}", flush=True) From 33caa9671c921257ec3ef7dfe4d3cdccbb3e8fb6 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jun 2026 18:59:42 +0000 Subject: [PATCH 5/6] Debug --- tests/test_mesh_identity.py | 164 +++++++++++++++++++++++++++++++++--- 1 file changed, 150 insertions(+), 14 deletions(-) diff --git a/tests/test_mesh_identity.py b/tests/test_mesh_identity.py index cbd7b705..7b5e9952 100644 --- a/tests/test_mesh_identity.py +++ b/tests/test_mesh_identity.py @@ -26,8 +26,11 @@ simpler model wouldn't exercise that propagation rule). """ +import traceback + import torch from conftest import apply_cuda_patches +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard @@ -35,6 +38,88 @@ from autoparallel.api import AutoParallel +def _summarize(m) -> str: + if m is None: + return "None" + dim_names = getattr(m, "_mesh_dim_names", None) + ndim = getattr(m, "ndim", "?") + root = m._get_root_mesh() if hasattr(m, "_get_root_mesh") else None + cache_keys = ( + list(root._flatten_mapping.keys()) + if root is not None and hasattr(root, "_flatten_mapping") + else "" + ) + return ( + f"mesh_id={id(m):#x} ndim={ndim} dim_names={dim_names} " + f"is_root={root is m} root_id={id(root):#x} root_cache={cache_keys}" + ) + + +def _short_traceback() -> str: + stack = traceback.extract_stack() + interesting = [ + f" {f.filename}:{f.lineno} in {f.name}" + for f in stack + if "device_mesh" not in f.filename and "test_mesh_identity" not in f.filename + ] + return "\n".join(interesting[-10:]) + + +_diag = {"flatten": 0, "create": 0, "miss": 0, "installed": False} + + +def _install_diag_hooks(): + if _diag["installed"]: + return + _diag["installed"] = True + + _orig_flatten = DeviceMesh._flatten + _orig_create = DeviceMesh._create_flatten_mesh + + def _wrapped_flatten(self, mesh_dim_name=None, backend_override=None): + _diag["flatten"] += 1 + n = _diag["flatten"] + requested = mesh_dim_name or ( + "_".join(self._mesh_dim_names) if self._mesh_dim_names else "" + ) + print( + f"\n[_flatten #{n}] CALL on {_summarize(self)} requested={requested!r}", + flush=True, + ) + print(f"[_flatten #{n}] call site:\n{_short_traceback()}", flush=True) + try: + result = _orig_flatten(self, mesh_dim_name, backend_override) + print(f"[_flatten #{n}] OK → {_summarize(result)}", flush=True) + return result + except Exception as e: + print(f"[_flatten #{n}] RAISED: {type(e).__name__}: {e}", flush=True) + raise + + def _wrapped_create(self, mesh_dim_name, backend_override=(None, None)): + _diag["create"] += 1 + n = _diag["create"] + root = self._get_root_mesh() + cache_hit = mesh_dim_name in root._flatten_mapping + if not cache_hit: + _diag["miss"] += 1 + print( + f" [_create_flatten_mesh #{n}] *** CACHE MISS *** " + f"name={mesh_dim_name!r} root_id={id(root):#x} " + f"root_cache={list(root._flatten_mapping)}", + flush=True, + ) + else: + print( + f" [_create_flatten_mesh #{n}] cache hit name={mesh_dim_name!r} " + f"root_id={id(root):#x}", + flush=True, + ) + return _orig_create(self, mesh_dim_name, backend_override) + + DeviceMesh._flatten = _wrapped_flatten # type: ignore[method-assign] + DeviceMesh._create_flatten_mesh = _wrapped_create # type: ignore[method-assign] + + @apply_cuda_patches def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): """After ``apply_placement``'s pre-warming, every spec mesh's root @@ -42,6 +127,27 @@ def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): ``_flatten()`` call inside ``make_fx`` triggers ``as_strided`` on the rank_map and FakeTensorMode rejects it (the original CI failure). """ + _install_diag_hooks() + print( + f"\n=== USER MESH (from fixture): {_summarize(device_mesh_2d)} ===\n", + flush=True, + ) + + # Clear caches that might carry duplicated meshes from prior tests + # (Dynamo guard cache, DTensor propagation lru_cache, etc.). + torch._dynamo.reset() + try: + from torch.distributed.tensor._api import DTensor + + if hasattr(DTensor, "_op_dispatcher"): + sp = DTensor._op_dispatcher.sharding_propagator + sp._propagate_tensor_meta_cached.cache_clear() + sp.op_strategy_funcs # ensure attr exists + if hasattr(sp, "op_to_rules_lru"): + sp.op_to_rules_lru.cache_clear() + except Exception as e: + print(f"cache clear note: {e}", flush=True) + vocab_size = 1024 seqlen = 128 batch_size = 2 * device_mesh_2d.shape[0] @@ -61,20 +167,50 @@ def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): ) ) - with AutoParallel( - model, - lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), - device_mesh_2d, - MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), - repeated_subgraphs=True, - ) as autop: - autop.add_parameter_memory_constraint(low=None, high=None) - autop.add_input_constraints([(Shard(0), Replicate())]) - autop.add_output_constraints([(Shard(0), Shard(2))]) - sharding_placement = autop.optimize_placement(verbose=False) - # apply_placement pre-warms the user mesh's _flatten cache so - # subsequent _flatten() calls inside make_fx hit the cache. - autop.apply_placement(sharding_placement) + try: + with AutoParallel( + model, + lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), + device_mesh_2d, + MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + ), + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=False) + + seen: dict[int, object] = {} + print("\n=== MESHES IN SHARDING SOLUTION ===", flush=True) + for strategy in sharding_placement.values(): + specs = [] + if hasattr(strategy, "output_specs"): + o = strategy.output_specs + specs.extend(o if isinstance(o, (list, tuple)) else [o]) + if hasattr(strategy, "input_specs"): + specs.extend(strategy.input_specs or []) + for s in specs: + if s is None: + continue + m = getattr(s, "mesh", None) + if m is None: + continue + if id(m) not in seen: + seen[id(m)] = m + print(f" spec mesh: {_summarize(m)}", flush=True) + print(f"=== TOTAL UNIQUE SPEC MESHES: {len(seen)} ===\n", flush=True) + + # apply_placement pre-warms the user mesh's _flatten cache so + # subsequent _flatten() calls inside make_fx hit the cache. + autop.apply_placement(sharding_placement) + finally: + print( + f"\n=== TOTAL _flatten={_diag['flatten']} " + f"create={_diag['create']} MISSES={_diag['miss']} ===", + flush=True, + ) # Collect every distinct spec mesh from the solution spec_meshes: dict[int, object] = {} From 88deb04bfad14881bc0843c5ca86e943a0206e0d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 11 Jun 2026 07:32:24 +0000 Subject: [PATCH 6/6] Debug --- .github/workflows/test_cuda.yml | 24 +++++------------ tests/test_mesh_identity.py | 48 +++++++++++++++++++++++++++------ 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 133b4957..a6b2e6f9 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -31,22 +31,12 @@ jobs: pip uninstall -y torch pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . - # Run tests but always emit the mesh-identity diagnostic afterward - # so CI captures which mesh triggers the as_strided dispatch on the - # real A10G hardware (see tests/diagnose_mesh_identity_ci.py). - set +e - pytest tests --ignore=tests/test_diagnose_mesh_identity_pytest.py - pytest_status=$? - set -e - echo "============================================================" - echo " mesh-identity diagnostic (standalone script)" - echo "============================================================" - python tests/diagnose_mesh_identity_ci.py || true - echo "============================================================" - echo " mesh-identity diagnostic (pytest mode, after other tests)" - echo "============================================================" - pytest -s tests/test_diagnose_mesh_identity_pytest.py -v || true - exit $pytest_status + # Run test_mesh_identity in a separate invocation FIRST to test + # if prior-test pollution is causing the as_strided failure. + echo "=== Running test_mesh_identity.py in isolation ===" + pytest -s tests/test_mesh_identity.py -v + echo "=== Running full test suite ===" + pytest -s tests --ignore=tests/test_diagnose_mesh_identity_pytest.py --ignore=tests/test_mesh_identity.py examples-cuda-single-gpu: name: Examples CUDA Single GPU (cuda13.0-py3.12) @@ -93,5 +83,5 @@ jobs: pip uninstall -y torch pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130 pip install --quiet . - python -m pytest tests/test_dcp_roundtrip.py -v + python examples/example_dcp.py torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py diff --git a/tests/test_mesh_identity.py b/tests/test_mesh_identity.py index 7b5e9952..1135279c 100644 --- a/tests/test_mesh_identity.py +++ b/tests/test_mesh_identity.py @@ -36,8 +36,6 @@ from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel - - def _summarize(m) -> str: if m is None: return "None" @@ -60,12 +58,20 @@ def _short_traceback() -> str: interesting = [ f" {f.filename}:{f.lineno} in {f.name}" for f in stack - if "device_mesh" not in f.filename and "test_mesh_identity" not in f.filename + if "device_mesh" not in f.filename + and "test_mesh_identity" not in f.filename ] return "\n".join(interesting[-10:]) -_diag = {"flatten": 0, "create": 0, "miss": 0, "installed": False} +_diag = { + "flatten": 0, + "create": 0, + "miss": 0, + "installed": False, + # id(mesh) → traceback string for every DeviceMesh constructed + "init_sites": {}, +} def _install_diag_hooks(): @@ -73,9 +79,23 @@ def _install_diag_hooks(): return _diag["installed"] = True + _orig_init = DeviceMesh.__init__ _orig_flatten = DeviceMesh._flatten _orig_create = DeviceMesh._create_flatten_mesh + def _wrapped_init(self, *args, **kwargs): + _orig_init(self, *args, **kwargs) + stack = traceback.extract_stack() + sites = [ + f" {f.filename}:{f.lineno} in {f.name}" + for f in stack + if "device_mesh" not in f.filename + and "test_mesh_identity" not in f.filename + and "/pluggy/" not in f.filename + and "/_pytest/" not in f.filename + ] + _diag["init_sites"][id(self)] = "\n".join(sites[-10:]) + def _wrapped_flatten(self, mesh_dim_name=None, backend_override=None): _diag["flatten"] += 1 n = _diag["flatten"] @@ -92,7 +112,9 @@ def _wrapped_flatten(self, mesh_dim_name=None, backend_override=None): print(f"[_flatten #{n}] OK → {_summarize(result)}", flush=True) return result except Exception as e: - print(f"[_flatten #{n}] RAISED: {type(e).__name__}: {e}", flush=True) + print( + f"[_flatten #{n}] RAISED: {type(e).__name__}: {e}", flush=True + ) raise def _wrapped_create(self, mesh_dim_name, backend_override=(None, None)): @@ -116,10 +138,16 @@ def _wrapped_create(self, mesh_dim_name, backend_override=(None, None)): ) return _orig_create(self, mesh_dim_name, backend_override) + DeviceMesh.__init__ = _wrapped_init # type: ignore[method-assign] DeviceMesh._flatten = _wrapped_flatten # type: ignore[method-assign] DeviceMesh._create_flatten_mesh = _wrapped_create # type: ignore[method-assign] +# Install hooks at import time so we capture meshes constructed by +# fixtures and prior tests (which can leak duplicates into our solution). +_install_diag_hooks() + + @apply_cuda_patches def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): """After ``apply_placement``'s pre-warming, every spec mesh's root @@ -172,9 +200,7 @@ def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): model, lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"), device_mesh_2d, - MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 - ), + MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), repeated_subgraphs=True, ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) @@ -200,6 +226,12 @@ def test_sharding_solution_meshes_have_warm_flatten_cache(device_mesh_2d): if id(m) not in seen: seen[id(m)] = m print(f" spec mesh: {_summarize(m)}", flush=True) + if id(m) in _diag["init_sites"]: + print( + f" spec mesh CONSTRUCTION SITE:\n" + f"{_diag['init_sites'][id(m)]}", + flush=True, + ) print(f"=== TOTAL UNIQUE SPEC MESHES: {len(seen)} ===\n", flush=True) # apply_placement pre-warms the user mesh's _flatten cache so