Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,70 +13,85 @@

jobs:
test-cuda-single-gpu:
name: Test CUDA Single GPU (cuda12.6-py3.12)
name: Test CUDA Single GPU (cuda13.0-py3.12)
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 60
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.6"
gpu-arch-version: "13.0"
submodules: recursive
script: |
conda create --yes --quiet --name py312 python=3.12
source $(conda info --base)/etc/profile.d/conda.sh
conda activate py312

pip install --quiet -r requirements-test.txt
# For some reason the spec above isnt working
pip uninstall -y torch
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130
pip install --quiet .
pytest tests
# Run tests but always emit the mesh-identity diagnostic afterward
# so CI captures which mesh triggers the as_strided dispatch on the
# real A10G hardware (see tests/diagnose_mesh_identity_ci.py).
set +e
pytest tests --ignore=tests/test_diagnose_mesh_identity_pytest.py
pytest_status=$?
set -e
echo "============================================================"
echo " mesh-identity diagnostic (standalone script)"
echo "============================================================"
python tests/diagnose_mesh_identity_ci.py || true
echo "============================================================"
echo " mesh-identity diagnostic (pytest mode, after other tests)"
echo "============================================================"
pytest -s tests/test_diagnose_mesh_identity_pytest.py -v || true
exit $pytest_status

examples-cuda-single-gpu:

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium test

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {}
name: Examples CUDA Single GPU (cuda12.6-py3.12)
name: Examples CUDA Single GPU (cuda13.0-py3.12)
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 60
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.6"
gpu-arch-version: "13.0"
submodules: recursive
script: |
conda create --yes --quiet --name py312 python=3.12
source $(conda info --base)/etc/profile.d/conda.sh
conda activate py312

pip install --quiet -r requirements-test.txt
# For some reason the spec above isnt working
pip uninstall -y torch
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130
pip install --quiet .
run_timed() { local start=$SECONDS; "$@"; echo "$* : $((SECONDS - start))s" >> /tmp/timings.txt; }
run_timed python examples/example_autoparallel.py
run_timed python examples/example_llama3.py
run_timed python examples/example_local_map.py
echo "========== Timings =========="
cat /tmp/timings.txt

test-cuda-multi-gpu:

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium test

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {}
name: Test CUDA Multi GPU (cuda12.6-py3.12)
name: Test CUDA Multi GPU (cuda13.0-py3.12)
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 60
runner: linux.g5.12xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.6"
gpu-arch-version: "13.0"
submodules: recursive
script: |
conda create --yes --quiet --name py312 python=3.12
source $(conda info --base)/etc/profile.d/conda.sh
conda activate py312

pip install --quiet -r requirements-test.txt
# For some reason the spec above isnt working
pip uninstall -y torch
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130
pip install --quiet .
python -m pytest tests/test_dcp_roundtrip.py -v
torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium test

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {}
45 changes: 44 additions & 1 deletion autoparallel/shardings/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,49 @@

logger = logging.getLogger(__name__)


def _deepcopy_preserving_mesh(obj):
"""Like copy.deepcopy, but reuses DeviceMesh instances instead of
duplicating them.

DeviceMesh carries process-group state (rank maps, _flatten_mapping
cache, backend overrides) that is logically shared across all callers.
Deep-copying it produces a fresh object with an empty _flatten_mapping
that misses the cache populated on the original, which later forces
DeviceMesh._flatten to dispatch as_strided on the rank_map inside
make_fx — failing FakeTensorMode's non-fake-input assertion.

We pre-populate copy.deepcopy's memo dict with identity mappings for
every DeviceMesh reachable from obj. deepcopy returns existing entries
from memo without recursing, so the same DeviceMesh instances appear
in the copy.
"""
from torch.distributed.device_mesh import DeviceMesh

memo: dict = {}
stack = [obj]
seen: set[int] = set()
while stack:
x = stack.pop()
if id(x) in seen:
continue
seen.add(id(x))
if isinstance(x, DeviceMesh):
memo[id(x)] = x
continue
if isinstance(x, (list, tuple, set, frozenset)):
stack.extend(x)
elif isinstance(x, dict):
stack.extend(x.values())
elif hasattr(x, "__dict__"):
stack.extend(x.__dict__.values())
elif hasattr(x, "__slots__"):
for slot in x.__slots__:
if hasattr(x, slot):
stack.append(getattr(x, slot))
return copy.deepcopy(obj, memo)


# TODO: move this to PyTorch
dim_maps[torch.t] = lambda input: dim_transpose(input.ndim, -2, -1)

Expand Down Expand Up @@ -829,7 +872,7 @@ def expand_rule(mesh, op_schema_):
from torch._subclasses.fake_tensor import unset_fake_temporarily

with unset_fake_temporarily():
op_schema = copy.deepcopy(op_schema_)
op_schema = _deepcopy_preserving_mesh(op_schema_)
input_strat = op_schema.args_schema[0]
orig_shape = input_strat.strategies[0].output_specs.tensor_meta.shape
dest_shape = op_schema.args_schema[1]
Expand Down
227 changes: 227 additions & 0 deletions tests/diagnose_mesh_identity_ci.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""Diagnose CI failure for test_mesh_identity.py.

Wraps DeviceMesh._flatten and DeviceMesh._create_flatten_mesh to log
every call. Mirrors the exact path the failing test takes — no extra
config patches, no joint_custom_pass.

Run on CI immediately after the failure to see which mesh triggered
the as_strided dispatch and why the cache lookup missed.
"""

import os
import sys
import traceback
from typing import Any

sys.path.insert(0, os.path.join(os.path.dirname(__file__)))

import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore

# Mirror the apply_cuda_patches in conftest (H100, capability 9.0)
from unittest.mock import patch

_PATCHES: list[Any] = [
patch("torch.cuda.device_count", lambda: 8),
patch("torch.cuda.get_device_name", lambda *a, **k: "H100"),
patch("torch.cuda.get_device_capability", lambda *a, **k: (9, 0)),
patch(
"torch.cuda.get_device_properties",
lambda *a, **k: type(
"Props",
(),
{
"major": 9,
"minor": 0,
"name": "H100",
"total_memory": 80 * 1024**3,
"multi_processor_count": 132,
},
)(),
),
]
for p in _PATCHES:
p.start()


from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
from autoparallel.api import AutoParallel

_log_lines: list[str] = []


def _log(msg: str) -> None:
_log_lines.append(msg)
print(msg, flush=True)


def _summarize(m) -> str:
if m is None:
return "None"
dim_names = getattr(m, "_mesh_dim_names", None)
ndim = getattr(m, "ndim", "?")
root = m._get_root_mesh() if hasattr(m, "_get_root_mesh") else None
is_root = root is m
cache_keys = (
list(root._flatten_mapping.keys())
if root is not None and hasattr(root, "_flatten_mapping")
else "<none>"
)
return (
f"mesh_id={id(m):#x} ndim={ndim} dim_names={dim_names} "
f"is_root={is_root} root_id={id(root):#x} root_cache={cache_keys}"
)


def _short_traceback() -> str:
stack = traceback.extract_stack()
interesting = [
f" {f.filename}:{f.lineno} in {f.name}"
for f in stack
if "device_mesh" not in f.filename
and os.path.basename(__file__) not in f.filename
]
return "\n".join(interesting[-8:])


_orig_flatten = DeviceMesh._flatten
_orig_create = DeviceMesh._create_flatten_mesh

_count = {"flatten": 0, "create": 0, "miss": 0}


def _wrapped_flatten(self, mesh_dim_name=None, backend_override=None):
_count["flatten"] += 1
n = _count["flatten"]
requested = mesh_dim_name or (
"_".join(self._mesh_dim_names) if self._mesh_dim_names else "<unnamed>"
)
_log(f"\n[_flatten #{n}] CALL on {_summarize(self)} requested={requested!r}")
_log(f"[_flatten #{n}] call site:\n{_short_traceback()}")
try:
result = _orig_flatten(self, mesh_dim_name, backend_override)
_log(f"[_flatten #{n}] OK → {_summarize(result)}")
return result
except Exception as e:
_log(f"[_flatten #{n}] RAISED: {type(e).__name__}: {e}")
raise


def _wrapped_create(self, mesh_dim_name, backend_override=(None, None)):
_count["create"] += 1
n = _count["create"]
root = self._get_root_mesh()
cache_hit = mesh_dim_name in root._flatten_mapping
if not cache_hit:
_count["miss"] += 1
_log(
f" [_create_flatten_mesh #{n}] *** CACHE MISS *** "
f"name={mesh_dim_name!r} root_id={id(root):#x} "
f"root_cache={list(root._flatten_mapping)}"
)
else:
_log(
f" [_create_flatten_mesh #{n}] cache hit name={mesh_dim_name!r} "
f"root_id={id(root):#x}"
)
return _orig_create(self, mesh_dim_name, backend_override)


DeviceMesh._flatten = _wrapped_flatten # type: ignore[method-assign]
DeviceMesh._create_flatten_mesh = _wrapped_create # type: ignore[method-assign]


def main() -> None:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"fake", store=FakeStore(), rank=0, world_size=256
)

mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", (32, 8), mesh_dim_names=("dp", "tp")
)
_log(f"\n=== USER MESH: {_summarize(mesh)} ===\n")

vocab_size = 1024
seqlen = 128
batch_size = 2 * mesh.shape[0]

with torch.device("meta"):
model = Transformer(
TransformerModelArgs(
dim=256,
n_layers=2,
n_heads=8,
n_kv_heads=2,
ffn_dim_multiplier=1.3,
multiple_of=64,
rope_theta=500000,
vocab_size=vocab_size,
max_seq_len=seqlen,
)
)

with AutoParallel(
model,
lambda: torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda"),
mesh,
MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
repeated_subgraphs=True,
) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)
autop.add_input_constraints([(Shard(0), Replicate())])
autop.add_output_constraints([(Shard(0), Shard(2))])
sharding_placement = autop.optimize_placement(verbose=False)

seen: dict[int, Any] = {}
_log("\n=== MESHES IN SHARDING SOLUTION ===")
for strategy in sharding_placement.values():
specs: list[Any] = []
if hasattr(strategy, "output_specs"):
o = strategy.output_specs
specs.extend(o if isinstance(o, (list, tuple)) else [o])
if hasattr(strategy, "input_specs"):
specs.extend(strategy.input_specs or [])
for s in specs:
if s is None:
continue
m = getattr(s, "mesh", None)
if m is None:
continue
if id(m) not in seen:
seen[id(m)] = m
_log(f" spec mesh: {_summarize(m)}")
_log(f"=== TOTAL UNIQUE SPEC MESHES: {len(seen)} ===\n")

_log("\n=== ENTERING apply_placement ===\n")
try:
autop.apply_placement(sharding_placement)
_log("\n=== apply_placement SUCCEEDED ===\n")
except Exception as e:
_log(f"\n=== apply_placement FAILED: {type(e).__name__}: {e} ===\n")
raise


if __name__ == "__main__":
try:
main()
finally:
_log(f"\n=== TOTAL _flatten calls: {_count['flatten']} ===")
_log(f"=== TOTAL _create_flatten_mesh calls: {_count['create']} ===")
_log(f"=== TOTAL _create_flatten_mesh CACHE MISSES: {_count['miss']} ===")
if _count["miss"] > 1:
_log(
"\n*** SMOKING GUN: cache missed more than once. "
"First miss is the pre-warming. Any further miss is the "
"duplicate mesh triggering the as_strided dispatch. "
"Search above for '*** CACHE MISS ***'."
)
log_path = os.environ.get(
"FLATTEN_DIAG_LOG", "mesh_identity_diagnosis.log"
)
with open(log_path, "w") as f:
f.write("\n".join(_log_lines))
print(f"\nFull diagnostic log written to {log_path}")
Loading
Loading