From da11334c64ab60e4998f85995488f87e3c4c7316 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 16 Jun 2026 20:32:17 +0800 Subject: [PATCH 1/5] [Feature] Enable nd tiling workaround for dynamic shapes --- magi_compiler/config.py | 11 +- .../magi_backend/compile_artifacts.py | 11 ++ magi_compiler/magi_backend/magi_backend.py | 53 ++++- .../magi_backend/magi_compiler_base.py | 7 +- tests/feature_tests/test_dynamic_nd_tiling.py | 186 ++++++++++++++++++ .../perf_tests/test_dynamic_nd_tiling_perf.py | 142 +++++++++++++ 6 files changed, 406 insertions(+), 4 deletions(-) create mode 100644 tests/feature_tests/test_dynamic_nd_tiling.py create mode 100644 tests/perf_tests/test_dynamic_nd_tiling_perf.py diff --git a/magi_compiler/config.py b/magi_compiler/config.py index a303093..fe2e80f 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -16,7 +16,7 @@ import os from enum import Enum, unique from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, Optional import torch from pydantic import BaseModel, Field @@ -217,6 +217,15 @@ class CompileConfig(BaseSettings): enable_inductor_coordinate_descent_tuning: bool = Field( False, description="Enable Inductor coordinate_descent_tuning for kernel selection." ) + enable_dynamic_nd_tiling: Optional[bool] = Field( + None, + description=( + "Triton ND-tiling workaround (prefer_nd_tiling + max_tiles=3 + tile_reductions) " + "for Inductor's coalesce tiling bailing out under dynamic shapes. " + "Tri-state: None = auto (on for dynamic shapes on PyTorch < 2.11.0); True/False = force. " + "Settable via the MAGI_COMPILE_ENABLE_DYNAMIC_ND_TILING env var." + ), + ) compile_sizes: list[int] = Field( default_factory=list, description=( diff --git a/magi_compiler/magi_backend/compile_artifacts.py b/magi_compiler/magi_backend/compile_artifacts.py index 0af9438..7d84b83 100644 --- a/magi_compiler/magi_backend/compile_artifacts.py +++ b/magi_compiler/magi_backend/compile_artifacts.py @@ -171,12 +171,23 @@ def rebuild_backend(self) -> None: compile_inputs = [inp if inp is not None else placeholder_fake_values[i] for i, inp in enumerate(self.example_inputs)] fake_mode = detect_fake_mode(compile_inputs) + + def _has_symbolic_dims(t) -> bool: + shape = getattr(t, "shape", None) + if shape is None: + return False + return any(isinstance(s, torch.SymInt) for s in shape) + + is_dynamic = any(_has_symbolic_dims(inp) for inp in compile_inputs) + rebuilt_dynamic_arg_dims = {"__rebuilt__": [0]} if is_dynamic else None + magi_backend = MagiBackend( self.compile_config, model_idx=self.model_idx, model_tag=self.model_tag, traced_files=OrderedSet(self.traced_files), inductor_compile_config={}, + dynamic_arg_dims=rebuilt_dynamic_arg_dims, ) with tracing(TracingContext(fake_mode)): diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 0d010e3..bb3a7bf 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -14,6 +14,7 @@ import ast import dataclasses +import functools import pprint import time from collections.abc import Callable @@ -28,6 +29,7 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher from torch._guards import detect_fake_mode +from torch.torch_version import TorchVersion import magi_compiler.utils.envs as envs from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode, inductor_compile_config_hash, magi_cache_dump_path @@ -468,6 +470,17 @@ def call_module( return output +@functools.lru_cache(maxsize=1) +def _inductor_needs_nd_tiling_workaround() -> bool: + """Under dynamic shapes, Inductor's coalesce tiling analysis used to bail out on + symbolic numels, degrading transpose/permute/channels-last pointwise kernels to + untiled Grid1D. PyTorch >= 2.11.0 fixed this upstream, so on those versions + prefer_nd_tiling workaround is unnecessary (and would actively bypass the native + coalesce path). Gate purely on the PyTorch version. + """ + return TorchVersion(torch.__version__) < (2, 11, 0) + + class MagiBackend: """ The compilation backend for `torch.compile` with MagiCompiler. @@ -488,12 +501,14 @@ def __init__( model_tag: str, traced_files: "OrderedSet", inductor_compile_config: dict[str, Any], + dynamic_arg_dims: dict[str, Any] | None = None, ): self.compile_config = compile_config self.model_idx = model_idx self.model_tag = model_tag self.traced_files = traced_files self.inductor_compile_config = inductor_compile_config + self.dynamic_arg_dims = dynamic_arg_dims self._configure_custom_passes() self.compiler_manager: CompilerManager = CompilerManager(self.compile_config) self._called_once = False @@ -523,6 +538,35 @@ def _configure_custom_passes(self): self.inductor_compile_config[post_grad_key] = post_grad_pass_manager + # On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on + # symbolic numels, so dynamic-shape transpose/permute/channels-last kernels + # degrade to untiled Grid1D. Forcing prefer_nd_tiling restores ND tiling + # (WAN 2.2 VAE 540p decode: ~1.45x. + if self._should_enable_nd_tiling(): + self.inductor_compile_config["triton.prefer_nd_tiling"] = True + self.inductor_compile_config["triton.max_tiles"] = 3 + self.inductor_compile_config["triton.tile_reductions"] = True + + def _should_enable_nd_tiling(self) -> bool: + cfg = self.compile_config.enable_dynamic_nd_tiling + if cfg is not None: + return bool(cfg) + + return self._is_dynamic_compilation() and _inductor_needs_nd_tiling_workaround() + + def _is_dynamic_compilation(self) -> bool: + """Best-effort detection of dynamic-shape compilation from intent.""" + dims = self.dynamic_arg_dims + if not dims: + return False + for v in dims.values(): + if isinstance(v, (list, tuple)): + if len(v) > 0: + return True + elif v is not None: + return True + return False + def _init_cache(self) -> str: hash_key = compute_hash( [ @@ -646,7 +690,12 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> MagiSerializableFun def init_backend( - compile_config: CompileConfig, model_idx: int, model_tag: str, traced_files: "OrderedSet", inductor_config: dict[str, Any] + compile_config: CompileConfig, + model_idx: int, + model_tag: str, + traced_files: "OrderedSet", + inductor_config: dict[str, Any], + dynamic_arg_dims: dict[str, Any] | None = None, ) -> str | Callable: """ Initialize the backend based on CompileConfig. @@ -663,6 +712,6 @@ def init_backend( return compile_config.backend elif compile_config.compile_mode == CompileMode.MAGI_COMPILE: assert compile_config.backend in ["eager", "inductor"], f"Invalid backend for MagiCompiler: {compile_config.backend}" - return MagiBackend(compile_config, model_idx, model_tag, traced_files, inductor_config) + return MagiBackend(compile_config, model_idx, model_tag, traced_files, inductor_config, dynamic_arg_dims) else: raise ValueError(f"Invalid compile mode: {compile_config.compile_mode}") diff --git a/magi_compiler/magi_backend/magi_compiler_base.py b/magi_compiler/magi_backend/magi_compiler_base.py index 97fcdd7..c375604 100644 --- a/magi_compiler/magi_backend/magi_compiler_base.py +++ b/magi_compiler/magi_backend/magi_compiler_base.py @@ -124,7 +124,12 @@ def _ensure_compiled(self): if self.compiled_entry is not None: return backend = init_backend( - self.compile_config, self.model_idx, self.model_tag, self.traced_files, self.inductor_compile_config + self.compile_config, + self.model_idx, + self.model_tag, + self.traced_files, + self.inductor_compile_config, + self.dynamic_arg_dims, ) options = None if isinstance(backend, str) and backend == "inductor": diff --git a/tests/feature_tests/test_dynamic_nd_tiling.py b/tests/feature_tests/test_dynamic_nd_tiling.py new file mode 100644 index 0000000..ba1897f --- /dev/null +++ b/tests/feature_tests/test_dynamic_nd_tiling.py @@ -0,0 +1,186 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logic tests for the dynamic-shape Triton ND-tiling workaround. + +Covers the decision logic only (no GPU / benchmarking): + * dynamic-shape intent detection -> ``_is_dynamic_compilation`` + * PyTorch version gating -> ``_inductor_needs_nd_tiling_workaround`` + * full precedence (explicit config > auto) -> ``_should_enable_nd_tiling`` + * actual Inductor config injection -> ``_configure_custom_passes`` + +The end-to-end speedup is validated separately in +``tests/perf_tests/test_dynamic_nd_tiling_perf.py``. +""" + +import pytest + +from magi_compiler.config import CompileConfig, get_compile_config +from magi_compiler.magi_backend import magi_backend as mb +from magi_compiler.magi_backend.magi_backend import MagiBackend, _inductor_needs_nd_tiling_workaround + +ND_TILING_KEYS = ("triton.prefer_nd_tiling", "triton.max_tiles", "triton.tile_reductions") + + +def _make_backend(dynamic_arg_dims, *, enable_dynamic_nd_tiling=None): + """Build a MagiBackend without running the heavy __init__. + + We only need the attributes that the ND-tiling decision reads, so we + bypass __init__ (which would spin up a CompilerManager) and set them by hand. + """ + backend = MagiBackend.__new__(MagiBackend) + backend.compile_config = get_compile_config().model_copy(update={"enable_dynamic_nd_tiling": enable_dynamic_nd_tiling}) + backend.dynamic_arg_dims = dynamic_arg_dims + backend.inductor_compile_config = {} + return backend + + +@pytest.fixture +def clear_version_cache(): + """The version probe is lru_cached; clear it around tests that patch the version.""" + _inductor_needs_nd_tiling_workaround.cache_clear() + yield + _inductor_needs_nd_tiling_workaround.cache_clear() + + +@pytest.fixture +def no_env_override(monkeypatch): + """Ensure the env override for the config field is absent for auto-path tests.""" + monkeypatch.delenv("MAGI_COMPILE_ENABLE_DYNAMIC_ND_TILING", raising=False) + + +# ── Point 2: dynamic-shape intent detection ────────────────────────────── + + +@pytest.mark.parametrize( + "dynamic_arg_dims, expected", + [ + (None, False), + ({}, False), + ({"x": []}, False), + ({"x": [0]}, True), + ({"x": 0}, True), + ({"x": (1, 2)}, True), + ({"x": None}, False), + ({"x": [], "y": [0]}, True), + ({"x": None, "y": []}, False), + ], +) +def test_is_dynamic_compilation(dynamic_arg_dims, expected): + backend = _make_backend(dynamic_arg_dims) + assert backend._is_dynamic_compilation() is expected + + +# ── Point 3: PyTorch version gating ────────────────────────────────────── + + +@pytest.mark.parametrize( + "version, needs_workaround", + [ + ("2.9.1", True), + ("2.10.0", True), + ("2.10.5", True), + ("2.11.0", False), + ("2.11.0+cu124", False), + ("2.11.1", False), + ("2.12.0.dev20260101+gitabcdef", False), + ("3.0.0", False), + ], +) +def test_version_gating(monkeypatch, clear_version_cache, version, needs_workaround): + monkeypatch.setattr(mb.torch, "__version__", version) + assert _inductor_needs_nd_tiling_workaround() is needs_workaround + + +# ── Point 4: full precedence (explicit config > auto) ──────────────────── + + +def test_explicit_config_forces_on(monkeypatch): + """Explicit config True forces on even for a static (non-dynamic) compilation.""" + backend = _make_backend(dynamic_arg_dims=None, enable_dynamic_nd_tiling=True) + assert backend._should_enable_nd_tiling() is True + + +def test_explicit_config_forces_off(monkeypatch, clear_version_cache): + """Explicit config False forces off even when dynamic + buggy-version would auto-enable.""" + monkeypatch.setattr(mb.torch, "__version__", "2.9.1") + backend = _make_backend(dynamic_arg_dims={"x": [0]}, enable_dynamic_nd_tiling=False) + assert backend._should_enable_nd_tiling() is False + + +@pytest.mark.parametrize("env_val, expected", [("1", True), ("0", False)]) +def test_env_var_drives_config_field(monkeypatch, env_val, expected): + """MAGI_COMPILE_ENABLE_DYNAMIC_ND_TILING populates the config field directly.""" + monkeypatch.setenv("MAGI_COMPILE_ENABLE_DYNAMIC_ND_TILING", env_val) + assert CompileConfig().enable_dynamic_nd_tiling is expected + + +@pytest.mark.parametrize("explicit", [True, False]) +def test_explicit_config_overrides_auto(monkeypatch, clear_version_cache, no_env_override, explicit): + """Explicit config beats the auto path, regardless of dynamic/version state.""" + # Static + fixed version would auto-decide differently; explicit must win. + monkeypatch.setattr(mb.torch, "__version__", "2.11.0" if explicit else "2.9.1") + backend = _make_backend(dynamic_arg_dims={"x": [0]} if not explicit else None, enable_dynamic_nd_tiling=explicit) + assert backend._should_enable_nd_tiling() is explicit + + +def test_auto_enables_on_dynamic_and_buggy_version(monkeypatch, clear_version_cache, no_env_override): + monkeypatch.setattr(mb.torch, "__version__", "2.9.1") + backend = _make_backend(dynamic_arg_dims={"x": [0]}) + assert backend._should_enable_nd_tiling() is True + + +def test_auto_disables_on_static_shapes(monkeypatch, clear_version_cache, no_env_override): + monkeypatch.setattr(mb.torch, "__version__", "2.9.1") + backend = _make_backend(dynamic_arg_dims=None) + assert backend._should_enable_nd_tiling() is False + + +def test_auto_disables_on_fixed_version(monkeypatch, clear_version_cache, no_env_override): + """Dynamic shapes but PyTorch >= 2.11.0: native coalesce path handles it.""" + monkeypatch.setattr(mb.torch, "__version__", "2.11.0") + backend = _make_backend(dynamic_arg_dims={"x": [0]}) + assert backend._should_enable_nd_tiling() is False + + +# ── Point 5: actual Inductor config injection ──────────────────────────── +# +# We exercise the injection branch directly rather than the full +# ``_configure_custom_passes`` (which also wires up unrelated pass managers), +# so the test stays focused on the ND-tiling keys and free of pass-chain deps. + + +def _inject_nd_tiling(backend): + """Mirror the ND-tiling injection block of ``_configure_custom_passes``.""" + if backend._should_enable_nd_tiling(): + backend.inductor_compile_config["triton.prefer_nd_tiling"] = True + backend.inductor_compile_config["triton.max_tiles"] = 3 + backend.inductor_compile_config["triton.tile_reductions"] = True + + +def test_config_injected_when_enabled(monkeypatch, clear_version_cache, no_env_override): + monkeypatch.setattr(mb.torch, "__version__", "2.9.1") + backend = _make_backend(dynamic_arg_dims={"x": [0]}) + _inject_nd_tiling(backend) + assert backend.inductor_compile_config["triton.prefer_nd_tiling"] is True + assert backend.inductor_compile_config["triton.max_tiles"] == 3 + assert backend.inductor_compile_config["triton.tile_reductions"] is True + + +def test_config_absent_when_disabled(monkeypatch, clear_version_cache, no_env_override): + monkeypatch.setattr(mb.torch, "__version__", "2.11.0") + backend = _make_backend(dynamic_arg_dims={"x": [0]}) + _inject_nd_tiling(backend) + for key in ND_TILING_KEYS: + assert key not in backend.inductor_compile_config diff --git a/tests/perf_tests/test_dynamic_nd_tiling_perf.py b/tests/perf_tests/test_dynamic_nd_tiling_perf.py new file mode 100644 index 0000000..0f9c1b9 --- /dev/null +++ b/tests/perf_tests/test_dynamic_nd_tiling_perf.py @@ -0,0 +1,142 @@ +# Copyright (c) 2025 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance test: Triton ND-tiling workaround under dynamic shapes. + +Background +---------- +On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on symbolic +numels (``tiling_utils.extract_normalized_read_writes`` returns ``None``), so +transpose/permute/channels-last pointwise kernels in a dynamic-shape graph +degrade to untiled Grid1D. MagiCompiler works around this by auto-enabling +``triton.prefer_nd_tiling`` (+ ``max_tiles=3`` + ``tile_reductions``) for dynamic +compilation; see ``MagiBackend._should_enable_nd_tiling``. + +This test exercises a WAN-2.2-VAE-decode-like workload (stacked 3D conv resblocks ++ spatial upsampling) compiled with **dynamic H/W**, and checks that the +workaround is a net win versus turning it off on the *same* magi_compile path. + +Real WAN 2.2 VAE decode (540p, dynamic H/W) numbers that motivate this: + - with conv channels-last layout: 1.252s -> 542ms / decode (~2.3x) + - without conv channels-last: 770ms -> 535ms / decode (~1.44x) +This synthetic decoder (no weights, no conv channels-last pass) reproduces the +"~1.4x" regime; the absolute ratio is GPU-dependent so the strict assertion only +runs on calibrated GPUs. +""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from magi_compiler import magi_compile +from tests.perf_tests import cuda_benchmark, print_perf_comparison +from tests.perf_tests.utils import is_perf_calibrated_gpu + +# WAN 2.2 VAE 540p latent: [C, T, H, W]; dynamic dims are H and W. +LATENT_C, LATENT_T, LATENT_H, LATENT_W = 48, 7, 34, 60 +BASE_CHANNELS = 128 + +# nd_tiling(on) vs nd_tiling(off), both on the magi_compile dynamic path. +# Observed ~1.36x (off=2.209ms -> on=1.627ms) on H100; assert a conservative +# lower bound that still proves a clear, non-noise win. +ND_TILING_SPEEDUP_THRESHOLD = 1.20 + + +class _ResBlock3D(nn.Module): + def __init__(self, cin: int, cout: int): + super().__init__() + self.norm1 = nn.GroupNorm(32, cin) + self.conv1 = nn.Conv3d(cin, cout, 3, padding=1) + self.norm2 = nn.GroupNorm(32, cout) + self.conv2 = nn.Conv3d(cout, cout, 3, padding=1) + self.skip = nn.Conv3d(cin, cout, 1) if cin != cout else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.conv1(F.silu(self.norm1(x))) + h = self.conv2(F.silu(self.norm2(h))) + return h + self.skip(x) + + +class VAEDecoderLike(nn.Module): + """Stacked 3D conv resblocks + spatial upsampling, mimicking VAE decode.""" + + def __init__(self, zc: int = LATENT_C, base: int = BASE_CHANNELS): + super().__init__() + self.conv_in = nn.Conv3d(zc, base, 3, padding=1) + self.r1 = _ResBlock3D(base, base) + self.up1 = nn.Conv3d(base, base, 3, padding=1) + self.r2 = _ResBlock3D(base, base // 2) + self.up2 = nn.Conv3d(base // 2, base // 2, 3, padding=1) + self.r3 = _ResBlock3D(base // 2, base // 4) + self.norm_out = nn.GroupNorm(32, base // 4) + self.conv_out = nn.Conv3d(base // 4, 3, 3, padding=1) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + x = self.conv_in(z) + x = self.r1(x) + x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") + x = self.up1(x) + x = self.r2(x) + x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") + x = self.up2(x) + x = self.r3(x) + return self.conv_out(F.silu(self.norm_out(x))) + + +@pytest.fixture(scope="module") +def decoder_device(): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.fixture(scope="module") +def decoder_input(decoder_device): + return torch.randn(1, LATENT_C, LATENT_T, LATENT_H, LATENT_W, device=decoder_device, dtype=torch.bfloat16) + + +def _compile_decoder(device: torch.device, enable_nd_tiling: bool): + def _patch(cfg): + cfg.enable_dynamic_nd_tiling = enable_nd_tiling + return cfg + + model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + # Dynamic H, W (latent dims 3, 4) — the regime where coalesce tiling bails out. + return magi_compile(model, dynamic_arg_dims={"z": [3, 4]}, config_patch=_patch) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_nd_tiling_workaround_speedup(decoder_device, decoder_input): + """ND-tiling ON should beat ND-tiling OFF on the dynamic magi_compile path.""" + disabled = _compile_decoder(decoder_device, enable_nd_tiling=False) + enabled = _compile_decoder(decoder_device, enable_nd_tiling=True) + + with torch.no_grad(): + disabled_result = cuda_benchmark(lambda: disabled(decoder_input), compilation_warmup=3) + enabled_result = cuda_benchmark(lambda: enabled(decoder_input), compilation_warmup=3) + + speedup = disabled_result.median / enabled_result.median + print_perf_comparison( + "Dynamic ND-tiling: workaround ON vs OFF (magi_compile, dynamic H/W)", + disabled_result, + enabled_result, + extra_info=(f"latent=({LATENT_C}, {LATENT_T}, {LATENT_H}, {LATENT_W}) " f"speedup(off/on)={speedup:.2f}x"), + ) + + if not is_perf_calibrated_gpu(): + return + assert speedup >= ND_TILING_SPEEDUP_THRESHOLD, ( + f"ND-tiling workaround should be >= {ND_TILING_SPEEDUP_THRESHOLD:.2f}x faster than disabled " + f"under dynamic shapes. Got {speedup:.2f}x " + f"(disabled={disabled_result.median:.3f}ms, enabled={enabled_result.median:.3f}ms)" + ) From ba191dc623de15d4663aad0302e605cd89541757 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Thu, 18 Jun 2026 17:29:46 +0800 Subject: [PATCH 2/5] [Refactor] refactor codes --- .../magi_backend/compile_artifacts.py | 11 - magi_compiler/magi_backend/magi_backend.py | 76 ++--- .../magi_backend/magi_compiler_base.py | 7 +- tests/feature_tests/test_dynamic_nd_tiling.py | 287 +++++++++++------- 4 files changed, 204 insertions(+), 177 deletions(-) diff --git a/magi_compiler/magi_backend/compile_artifacts.py b/magi_compiler/magi_backend/compile_artifacts.py index 7d84b83..0af9438 100644 --- a/magi_compiler/magi_backend/compile_artifacts.py +++ b/magi_compiler/magi_backend/compile_artifacts.py @@ -171,23 +171,12 @@ def rebuild_backend(self) -> None: compile_inputs = [inp if inp is not None else placeholder_fake_values[i] for i, inp in enumerate(self.example_inputs)] fake_mode = detect_fake_mode(compile_inputs) - - def _has_symbolic_dims(t) -> bool: - shape = getattr(t, "shape", None) - if shape is None: - return False - return any(isinstance(s, torch.SymInt) for s in shape) - - is_dynamic = any(_has_symbolic_dims(inp) for inp in compile_inputs) - rebuilt_dynamic_arg_dims = {"__rebuilt__": [0]} if is_dynamic else None - magi_backend = MagiBackend( self.compile_config, model_idx=self.model_idx, model_tag=self.model_tag, traced_files=OrderedSet(self.traced_files), inductor_compile_config={}, - dynamic_arg_dims=rebuilt_dynamic_arg_dims, ) with tracing(TracingContext(fake_mode)): diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index bb3a7bf..90eade1 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -14,9 +14,9 @@ import ast import dataclasses -import functools import pprint import time +from collections import Counter from collections.abc import Callable from contextlib import contextmanager from pathlib import Path @@ -29,6 +29,7 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher from torch._guards import detect_fake_mode +from torch.fx.experimental.symbolic_shapes import has_free_symbols from torch.torch_version import TorchVersion import magi_compiler.utils.envs as envs @@ -46,6 +47,7 @@ from .piecewise_backend import PiecewiseBackend from .piecewise_compiler import CompilerInterface, EagerAdaptor, InductorStandaloneAdaptor +TORCH_VERSION = TorchVersion(torch.__version__) compilation_start_time: float = 0.0 @@ -470,17 +472,6 @@ def call_module( return output -@functools.lru_cache(maxsize=1) -def _inductor_needs_nd_tiling_workaround() -> bool: - """Under dynamic shapes, Inductor's coalesce tiling analysis used to bail out on - symbolic numels, degrading transpose/permute/channels-last pointwise kernels to - untiled Grid1D. PyTorch >= 2.11.0 fixed this upstream, so on those versions - prefer_nd_tiling workaround is unnecessary (and would actively bypass the native - coalesce path). Gate purely on the PyTorch version. - """ - return TorchVersion(torch.__version__) < (2, 11, 0) - - class MagiBackend: """ The compilation backend for `torch.compile` with MagiCompiler. @@ -501,14 +492,12 @@ def __init__( model_tag: str, traced_files: "OrderedSet", inductor_compile_config: dict[str, Any], - dynamic_arg_dims: dict[str, Any] | None = None, ): self.compile_config = compile_config self.model_idx = model_idx self.model_tag = model_tag self.traced_files = traced_files self.inductor_compile_config = inductor_compile_config - self.dynamic_arg_dims = dynamic_arg_dims self._configure_custom_passes() self.compiler_manager: CompilerManager = CompilerManager(self.compile_config) self._called_once = False @@ -538,35 +527,33 @@ def _configure_custom_passes(self): self.inductor_compile_config[post_grad_key] = post_grad_pass_manager - # On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on - # symbolic numels, so dynamic-shape transpose/permute/channels-last kernels - # degrade to untiled Grid1D. Forcing prefer_nd_tiling restores ND tiling - # (WAN 2.2 VAE 540p decode: ~1.45x. - if self._should_enable_nd_tiling(): + def _configure_custom_passes_by_graph_info(self, graph: fx.GraphModule, example_inputs) -> None: + """Configure custom passes based on the graph information.""" + # Check if the graph is dynamic + placeholder_vals = (n.meta.get("example_value") for n in graph.graph.nodes if n.op == "placeholder") + self.is_dynamic = any(v is not None and has_free_symbols(v) for v in (*placeholder_vals, *example_inputs)) + + # Count number of nodes + nnodes = len(list(graph.graph.nodes)) + conv_nodes = [n for n in graph.graph.nodes if n.target == torch.ops.aten.convolution.default] + nconv = len(conv_nodes) + Counter(n.args[1].meta["val"].dim() - 2 for n in conv_nodes) + # dim_counts[1] / dim_counts[2] / dim_counts[3] means number of conv1d/2d/3d + + if (self.compile_config.enable_dynamic_nd_tiling is True) or ( + self.compile_config.enable_dynamic_nd_tiling is None + and self.is_dynamic + and TORCH_VERSION < (2, 11, 0) + and nnodes > 300 * nconv + ): + # On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on + # symbolic numels, so dynamic-shape transpose/permute/channels-last kernels + # degrade to untiled Grid1D. Forcing prefer_nd_tiling restores ND tiling + # (WAN 2.2 VAE 540p decode: ~1.45x. self.inductor_compile_config["triton.prefer_nd_tiling"] = True self.inductor_compile_config["triton.max_tiles"] = 3 self.inductor_compile_config["triton.tile_reductions"] = True - def _should_enable_nd_tiling(self) -> bool: - cfg = self.compile_config.enable_dynamic_nd_tiling - if cfg is not None: - return bool(cfg) - - return self._is_dynamic_compilation() and _inductor_needs_nd_tiling_workaround() - - def _is_dynamic_compilation(self) -> bool: - """Best-effort detection of dynamic-shape compilation from intent.""" - dims = self.dynamic_arg_dims - if not dims: - return False - for v in dims.values(): - if isinstance(v, (list, tuple)): - if len(v) > 0: - return True - elif v is not None: - return True - return False - def _init_cache(self) -> str: hash_key = compute_hash( [ @@ -647,6 +634,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> MagiSerializableFun magi_logger.info("Dynamo traced files (for compilation cache):\n%s", "\n".join(self.traced_files)) compilation_counter.num_graphs_seen += 1 + self._configure_custom_passes_by_graph_info(graph, example_inputs) + self._init_cache() self.full_graph_pass_manager(graph) @@ -690,12 +679,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> MagiSerializableFun def init_backend( - compile_config: CompileConfig, - model_idx: int, - model_tag: str, - traced_files: "OrderedSet", - inductor_config: dict[str, Any], - dynamic_arg_dims: dict[str, Any] | None = None, + compile_config: CompileConfig, model_idx: int, model_tag: str, traced_files: "OrderedSet", inductor_config: dict[str, Any] ) -> str | Callable: """ Initialize the backend based on CompileConfig. @@ -712,6 +696,6 @@ def init_backend( return compile_config.backend elif compile_config.compile_mode == CompileMode.MAGI_COMPILE: assert compile_config.backend in ["eager", "inductor"], f"Invalid backend for MagiCompiler: {compile_config.backend}" - return MagiBackend(compile_config, model_idx, model_tag, traced_files, inductor_config, dynamic_arg_dims) + return MagiBackend(compile_config, model_idx, model_tag, traced_files, inductor_config) else: raise ValueError(f"Invalid compile mode: {compile_config.compile_mode}") diff --git a/magi_compiler/magi_backend/magi_compiler_base.py b/magi_compiler/magi_backend/magi_compiler_base.py index c375604..97fcdd7 100644 --- a/magi_compiler/magi_backend/magi_compiler_base.py +++ b/magi_compiler/magi_backend/magi_compiler_base.py @@ -124,12 +124,7 @@ def _ensure_compiled(self): if self.compiled_entry is not None: return backend = init_backend( - self.compile_config, - self.model_idx, - self.model_tag, - self.traced_files, - self.inductor_compile_config, - self.dynamic_arg_dims, + self.compile_config, self.model_idx, self.model_tag, self.traced_files, self.inductor_compile_config ) options = None if isinstance(backend, str) and backend == "inductor": diff --git a/tests/feature_tests/test_dynamic_nd_tiling.py b/tests/feature_tests/test_dynamic_nd_tiling.py index ba1897f..0ce8e98 100644 --- a/tests/feature_tests/test_dynamic_nd_tiling.py +++ b/tests/feature_tests/test_dynamic_nd_tiling.py @@ -14,109 +14,155 @@ """Logic tests for the dynamic-shape Triton ND-tiling workaround. -Covers the decision logic only (no GPU / benchmarking): - * dynamic-shape intent detection -> ``_is_dynamic_compilation`` - * PyTorch version gating -> ``_inductor_needs_nd_tiling_workaround`` - * full precedence (explicit config > auto) -> ``_should_enable_nd_tiling`` - * actual Inductor config injection -> ``_configure_custom_passes`` +All the decision logic lives in +``MagiBackend._configure_custom_passes_by_graph_info``, which (a) probes +``is_dynamic`` from the graph's free symbols and (b) injects the ND-tiling +Inductor config based on the tri-state ``enable_dynamic_nd_tiling`` config: + + * ``True`` -> force on + * ``False`` -> force off + * ``None`` -> auto: dynamic shapes AND PyTorch < 2.11.0 AND ``nnodes > 300 * nconv`` The end-to-end speedup is validated separately in ``tests/perf_tests/test_dynamic_nd_tiling_perf.py``. """ import pytest +import torch +import torch.fx as fx +from torch.torch_version import TorchVersion from magi_compiler.config import CompileConfig, get_compile_config from magi_compiler.magi_backend import magi_backend as mb -from magi_compiler.magi_backend.magi_backend import MagiBackend, _inductor_needs_nd_tiling_workaround +from magi_compiler.magi_backend.magi_backend import MagiBackend ND_TILING_KEYS = ("triton.prefer_nd_tiling", "triton.max_tiles", "triton.tile_reductions") -def _make_backend(dynamic_arg_dims, *, enable_dynamic_nd_tiling=None): +def _set_torch_version(monkeypatch, version): + """Patch the module-level parsed torch version used by the gating logic. + + ``TORCH_VERSION`` is parsed once at import time, so tests override the + parsed constant directly instead of patching ``torch.__version__``. + """ + monkeypatch.setattr(mb, "TORCH_VERSION", TorchVersion(version)) + + +def _make_backend(*, enable_dynamic_nd_tiling=None): """Build a MagiBackend without running the heavy __init__. - We only need the attributes that the ND-tiling decision reads, so we - bypass __init__ (which would spin up a CompilerManager) and set them by hand. + We bypass __init__ (which would spin up a CompilerManager) and set only the + attributes that ``_configure_custom_passes_by_graph_info`` reads. """ backend = MagiBackend.__new__(MagiBackend) backend.compile_config = get_compile_config().model_copy(update={"enable_dynamic_nd_tiling": enable_dynamic_nd_tiling}) - backend.dynamic_arg_dims = dynamic_arg_dims backend.inductor_compile_config = {} return backend @pytest.fixture -def clear_version_cache(): - """The version probe is lru_cached; clear it around tests that patch the version.""" - _inductor_needs_nd_tiling_workaround.cache_clear() - yield - _inductor_needs_nd_tiling_workaround.cache_clear() +def fake_mode(): + """A FakeTensorMode backed by a fresh ShapeEnv for symbolic shapes.""" + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + return FakeTensorMode(shape_env=ShapeEnv()) -@pytest.fixture -def no_env_override(monkeypatch): - """Ensure the env override for the config field is absent for auto-path tests.""" - monkeypatch.delenv("MAGI_COMPILE_ENABLE_DYNAMIC_ND_TILING", raising=False) +def _static_tensor(fake_mode): + """A FakeTensor with fully concrete (non-symbolic) dims.""" + with fake_mode: + return torch.empty(4, 8) + + +def _dynamic_tensor(fake_mode): + """A FakeTensor whose first dim is a free symbol (mimics a dynamic batch).""" + sym = fake_mode.shape_env.create_unbacked_symint() + with fake_mode: + return torch.empty(sym, 8) + + +def _symint(fake_mode): + """A bare (scalar) ``torch.SymInt`` carrying a free symbol.""" + return fake_mode.shape_env.create_unbacked_symint() + + +def _build_graph(fake_mode, *, placeholder_vals=(), n_conv=0, n_filler=0): + """Build a tiny fx.GraphModule for the decision logic. + + ``placeholder_vals`` populate placeholder ``meta["example_value"]`` (drive + ``is_dynamic``). ``n_conv`` adds ``aten.convolution.default`` call nodes + (each fed a weight placeholder carrying a 5-D ``meta["val"]`` so the + conv-dim bookkeeping works), and ``n_filler`` adds plain call nodes to + inflate the node count (drives the ``nnodes > 300 * nconv`` heuristic). + """ + graph = fx.Graph() + inputs = [] + for i, ev in enumerate(placeholder_vals): + node = graph.placeholder(f"arg_{i}") + node.meta["example_value"] = ev + inputs.append(node) + if not inputs: + inputs.append(graph.placeholder("arg_0")) + + with fake_mode: + weight_val = torch.empty(8, 8, 3, 3, 3) # 5-D weight => conv3d (dim()-2 == 3) + + x = inputs[0] + for c in range(n_conv): + weight = graph.placeholder(f"weight_{c}") + weight.meta["val"] = weight_val + x = graph.call_function(torch.ops.aten.convolution.default, args=(x, weight)) + for _ in range(n_filler): + x = graph.call_function(torch.ops.aten.relu.default, args=(x,)) + graph.output((x,)) + return fx.GraphModule(torch.nn.Module(), graph) -# ── Point 2: dynamic-shape intent detection ────────────────────────────── +# ── is_dynamic detection ───────────────────────────────────────────────── -@pytest.mark.parametrize( - "dynamic_arg_dims, expected", - [ - (None, False), - ({}, False), - ({"x": []}, False), - ({"x": [0]}, True), - ({"x": 0}, True), - ({"x": (1, 2)}, True), - ({"x": None}, False), - ({"x": [], "y": [0]}, True), - ({"x": None, "y": []}, False), - ], -) -def test_is_dynamic_compilation(dynamic_arg_dims, expected): - backend = _make_backend(dynamic_arg_dims) - assert backend._is_dynamic_compilation() is expected +def _is_dynamic(fake_mode, *, placeholder_vals=(), example_inputs=()): + backend = _make_backend() + gm = _build_graph(fake_mode, placeholder_vals=placeholder_vals) + backend._configure_custom_passes_by_graph_info(gm, list(example_inputs)) + return backend.is_dynamic -# ── Point 3: PyTorch version gating ────────────────────────────────────── +def test_is_dynamic_static(fake_mode): + static = _static_tensor(fake_mode) + assert _is_dynamic(fake_mode, placeholder_vals=[static], example_inputs=[static]) is False -@pytest.mark.parametrize( - "version, needs_workaround", - [ - ("2.9.1", True), - ("2.10.0", True), - ("2.10.5", True), - ("2.11.0", False), - ("2.11.0+cu124", False), - ("2.11.1", False), - ("2.12.0.dev20260101+gitabcdef", False), - ("3.0.0", False), - ], -) -def test_version_gating(monkeypatch, clear_version_cache, version, needs_workaround): - monkeypatch.setattr(mb.torch, "__version__", version) - assert _inductor_needs_nd_tiling_workaround() is needs_workaround +def test_is_dynamic_via_placeholder_symint(fake_mode): + """A symbolic dim on a placeholder example_value marks the compilation dynamic.""" + dynamic = _dynamic_tensor(fake_mode) + assert _is_dynamic(fake_mode, placeholder_vals=[dynamic], example_inputs=[]) is True -# ── Point 4: full precedence (explicit config > auto) ──────────────────── +def test_is_dynamic_via_example_inputs(fake_mode): + """A symbolic dim on an example input marks the compilation dynamic.""" + dynamic = _dynamic_tensor(fake_mode) + assert _is_dynamic(fake_mode, placeholder_vals=[], example_inputs=[dynamic]) is True -def test_explicit_config_forces_on(monkeypatch): - """Explicit config True forces on even for a static (non-dynamic) compilation.""" - backend = _make_backend(dynamic_arg_dims=None, enable_dynamic_nd_tiling=True) - assert backend._should_enable_nd_tiling() is True +def test_is_dynamic_ignores_none_placeholder(fake_mode): + """A missing/None placeholder example_value must not crash has_free_symbols.""" + static = _static_tensor(fake_mode) + assert _is_dynamic(fake_mode, placeholder_vals=[None, static], example_inputs=[static]) is False -def test_explicit_config_forces_off(monkeypatch, clear_version_cache): - """Explicit config False forces off even when dynamic + buggy-version would auto-enable.""" - monkeypatch.setattr(mb.torch, "__version__", "2.9.1") - backend = _make_backend(dynamic_arg_dims={"x": [0]}, enable_dynamic_nd_tiling=False) - assert backend._should_enable_nd_tiling() is False + +def test_is_dynamic_ignores_plain_int(fake_mode): + """Non-symbolic scalar inputs (plain ints) are treated as static.""" + assert _is_dynamic(fake_mode, placeholder_vals=[], example_inputs=[3, 8]) is False + + +def test_is_dynamic_via_bare_symint_input(fake_mode): + """A bare SymInt scalar input (no .shape) is correctly detected as dynamic.""" + assert _is_dynamic(fake_mode, placeholder_vals=[], example_inputs=[_symint(fake_mode)]) is True + + +# ── env var -> config field ────────────────────────────────────────────── @pytest.mark.parametrize("env_val, expected", [("1", True), ("0", False)]) @@ -126,61 +172,74 @@ def test_env_var_drives_config_field(monkeypatch, env_val, expected): assert CompileConfig().enable_dynamic_nd_tiling is expected -@pytest.mark.parametrize("explicit", [True, False]) -def test_explicit_config_overrides_auto(monkeypatch, clear_version_cache, no_env_override, explicit): - """Explicit config beats the auto path, regardless of dynamic/version state.""" - # Static + fixed version would auto-decide differently; explicit must win. - monkeypatch.setattr(mb.torch, "__version__", "2.11.0" if explicit else "2.9.1") - backend = _make_backend(dynamic_arg_dims={"x": [0]} if not explicit else None, enable_dynamic_nd_tiling=explicit) - assert backend._should_enable_nd_tiling() is explicit +# ── ND-tiling injection decision ───────────────────────────────────────── +# +# Injection requires either ``enable_dynamic_nd_tiling is True`` OR +# (``is None`` AND dynamic AND torch < 2.11.0 AND ``nnodes > 300 * nconv``). +# We build an auto-eligible graph (dynamic input + one conv + enough filler +# nodes so ``nnodes > 300 * nconv``) and then flip one condition per test. -def test_auto_enables_on_dynamic_and_buggy_version(monkeypatch, clear_version_cache, no_env_override): - monkeypatch.setattr(mb.torch, "__version__", "2.9.1") - backend = _make_backend(dynamic_arg_dims={"x": [0]}) - assert backend._should_enable_nd_tiling() is True +def _assert_injected(backend, injected): + if injected: + assert backend.inductor_compile_config["triton.prefer_nd_tiling"] is True + assert backend.inductor_compile_config["triton.max_tiles"] == 3 + assert backend.inductor_compile_config["triton.tile_reductions"] is True + else: + for key in ND_TILING_KEYS: + assert key not in backend.inductor_compile_config -def test_auto_disables_on_static_shapes(monkeypatch, clear_version_cache, no_env_override): - monkeypatch.setattr(mb.torch, "__version__", "2.9.1") - backend = _make_backend(dynamic_arg_dims=None) - assert backend._should_enable_nd_tiling() is False +def _auto_eligible_graph(fake_mode): + """Dynamic graph with 1 conv + plenty of filler nodes (nnodes > 300 * nconv).""" + return _build_graph(fake_mode, placeholder_vals=[_dynamic_tensor(fake_mode)], n_conv=1, n_filler=320) -def test_auto_disables_on_fixed_version(monkeypatch, clear_version_cache, no_env_override): - """Dynamic shapes but PyTorch >= 2.11.0: native coalesce path handles it.""" - monkeypatch.setattr(mb.torch, "__version__", "2.11.0") - backend = _make_backend(dynamic_arg_dims={"x": [0]}) - assert backend._should_enable_nd_tiling() is False +def test_force_on_injects_even_when_static(monkeypatch, fake_mode): + """enable_dynamic_nd_tiling=True forces injection regardless of graph/version.""" + _set_torch_version(monkeypatch, "2.11.0") # a "fixed" version the auto path would skip + backend = _make_backend(enable_dynamic_nd_tiling=True) + gm = _build_graph(fake_mode, placeholder_vals=[_static_tensor(fake_mode)], n_conv=1, n_filler=0) + backend._configure_custom_passes_by_graph_info(gm, []) + _assert_injected(backend, True) -# ── Point 5: actual Inductor config injection ──────────────────────────── -# -# We exercise the injection branch directly rather than the full -# ``_configure_custom_passes`` (which also wires up unrelated pass managers), -# so the test stays focused on the ND-tiling keys and free of pass-chain deps. - - -def _inject_nd_tiling(backend): - """Mirror the ND-tiling injection block of ``_configure_custom_passes``.""" - if backend._should_enable_nd_tiling(): - backend.inductor_compile_config["triton.prefer_nd_tiling"] = True - backend.inductor_compile_config["triton.max_tiles"] = 3 - backend.inductor_compile_config["triton.tile_reductions"] = True - - -def test_config_injected_when_enabled(monkeypatch, clear_version_cache, no_env_override): - monkeypatch.setattr(mb.torch, "__version__", "2.9.1") - backend = _make_backend(dynamic_arg_dims={"x": [0]}) - _inject_nd_tiling(backend) - assert backend.inductor_compile_config["triton.prefer_nd_tiling"] is True - assert backend.inductor_compile_config["triton.max_tiles"] == 3 - assert backend.inductor_compile_config["triton.tile_reductions"] is True - - -def test_config_absent_when_disabled(monkeypatch, clear_version_cache, no_env_override): - monkeypatch.setattr(mb.torch, "__version__", "2.11.0") - backend = _make_backend(dynamic_arg_dims={"x": [0]}) - _inject_nd_tiling(backend) - for key in ND_TILING_KEYS: - assert key not in backend.inductor_compile_config +def test_force_off_skips_even_when_auto_eligible(monkeypatch, fake_mode): + """enable_dynamic_nd_tiling=False skips injection even when auto would enable.""" + _set_torch_version(monkeypatch, "2.9.1") + backend = _make_backend(enable_dynamic_nd_tiling=False) + backend._configure_custom_passes_by_graph_info(_auto_eligible_graph(fake_mode), []) + _assert_injected(backend, False) + + +def test_auto_injects_when_all_conditions_met(monkeypatch, fake_mode): + _set_torch_version(monkeypatch, "2.9.1") + backend = _make_backend(enable_dynamic_nd_tiling=None) + backend._configure_custom_passes_by_graph_info(_auto_eligible_graph(fake_mode), []) + _assert_injected(backend, True) + + +def test_auto_skips_on_static_shapes(monkeypatch, fake_mode): + _set_torch_version(monkeypatch, "2.9.1") + backend = _make_backend(enable_dynamic_nd_tiling=None) + gm = _build_graph(fake_mode, placeholder_vals=[_static_tensor(fake_mode)], n_conv=0, n_filler=5) + backend._configure_custom_passes_by_graph_info(gm, []) + _assert_injected(backend, False) + + +def test_auto_skips_on_fixed_version(monkeypatch, fake_mode): + """Dynamic shapes but PyTorch >= 2.11.0: native coalesce path handles it.""" + _set_torch_version(monkeypatch, "2.11.0") + backend = _make_backend(enable_dynamic_nd_tiling=None) + backend._configure_custom_passes_by_graph_info(_auto_eligible_graph(fake_mode), []) + _assert_injected(backend, False) + + +def test_auto_skips_when_graph_too_conv_dense(monkeypatch, fake_mode): + """``nnodes <= 300 * nconv`` (conv-dense graph): the heuristic bails out.""" + _set_torch_version(monkeypatch, "2.9.1") + backend = _make_backend(enable_dynamic_nd_tiling=None) + # 1 conv + few filler nodes => nnodes well under 300 * 1 = 300. + gm = _build_graph(fake_mode, placeholder_vals=[_dynamic_tensor(fake_mode)], n_conv=1, n_filler=5) + backend._configure_custom_passes_by_graph_info(gm, []) + _assert_injected(backend, False) From 7b22962cd6802c244c0cfa599143d99d8c0c8e8b Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 22 Jun 2026 12:01:36 +0800 Subject: [PATCH 3/5] [chores] fix ci --- .github/workflows/integration_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 59db755..2c281ea 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -14,7 +14,7 @@ jobs: integration_test: name: Integration Test runs-on: [self-hosted, magi-compiler] - timeout-minutes: 30 + timeout-minutes: 40 env: http_proxy: ${{ secrets.HTTP_PROXY }} https_proxy: ${{ secrets.HTTPS_PROXY }} From 7568198bccbc021de3d2328cb0a82f1b3eb07258 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 22 Jun 2026 19:52:22 +0800 Subject: [PATCH 4/5] [Feat] Add a snapshot mechanism for inductor configs potentially mutated by post grad passes --- magi_compiler/config.py | 22 +- magi_compiler/magi_backend/magi_backend.py | 33 +-- magi_compiler/passes/pass_base/__init__.py | 4 +- .../passes/pass_base/magi_inductor_pass.py | 64 +++++ .../piecewise_graph/nd_tiling_workaround.py | 48 ++++ .../piecewise_graph/post_grad_pass_manager.py | 17 +- tests/feature_tests/conftest.py | 74 +++++ tests/feature_tests/test_dynamic_nd_tiling.py | 268 +++++------------- .../feature_tests/test_magi_inductor_pass.py | 183 ++++++++++++ tests/model_definition.py | 43 +++ .../perf_tests/test_dynamic_nd_tiling_perf.py | 109 +++---- 11 files changed, 558 insertions(+), 307 deletions(-) create mode 100644 magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py create mode 100644 tests/feature_tests/conftest.py create mode 100644 tests/feature_tests/test_magi_inductor_pass.py diff --git a/magi_compiler/config.py b/magi_compiler/config.py index fe2e80f..7e793ee 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -64,6 +64,15 @@ class PassConfig(BaseModel): # TODO: Add sequence parallelism pass and async TP pass. # TODO: Add Ulysses overlap pass. enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.") + enable_nd_tiling_workaround: Optional[bool] = Field( + None, + description=( + "Triton ND-tiling workaround (prefer_nd_tiling + max_tiles=3 + tile_reductions) " + "for Inductor's coalesce tiling bailing out under dynamic shapes. " + "Tri-state: None = auto (decided by the Pass's internal heuristics, if any); True/False = force. " + "See MagiInductorPass.__init__ for how this maps to the Pass's force_on flag." + ), + ) enable_mm_epilogue_fusion: bool = Field( False, description=( @@ -171,6 +180,10 @@ class CompileConfig(BaseSettings): model_config = SettingsConfigDict( env_prefix="MAGI_COMPILE_", + # Nested sub-configs (e.g. pass_config, offload_config) are reachable via + # ``MAGI_COMPILE___`` env vars, e.g. + # ``MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND=1``. + env_nested_delimiter="__", populate_by_name=True, cli_parse_args=True, cli_ignore_unknown_args=True, @@ -217,15 +230,6 @@ class CompileConfig(BaseSettings): enable_inductor_coordinate_descent_tuning: bool = Field( False, description="Enable Inductor coordinate_descent_tuning for kernel selection." ) - enable_dynamic_nd_tiling: Optional[bool] = Field( - None, - description=( - "Triton ND-tiling workaround (prefer_nd_tiling + max_tiles=3 + tile_reductions) " - "for Inductor's coalesce tiling bailing out under dynamic shapes. " - "Tri-state: None = auto (on for dynamic shapes on PyTorch < 2.11.0); True/False = force. " - "Settable via the MAGI_COMPILE_ENABLE_DYNAMIC_ND_TILING env var." - ), - ) compile_sizes: list[int] = Field( default_factory=list, description=( diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 90eade1..9569bb2 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -16,7 +16,6 @@ import dataclasses import pprint import time -from collections import Counter from collections.abc import Callable from contextlib import contextmanager from pathlib import Path @@ -29,8 +28,6 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher from torch._guards import detect_fake_mode -from torch.fx.experimental.symbolic_shapes import has_free_symbols -from torch.torch_version import TorchVersion import magi_compiler.utils.envs as envs from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode, inductor_compile_config_hash, magi_cache_dump_path @@ -47,7 +44,6 @@ from .piecewise_backend import PiecewiseBackend from .piecewise_compiler import CompilerInterface, EagerAdaptor, InductorStandaloneAdaptor -TORCH_VERSION = TorchVersion(torch.__version__) compilation_start_time: float = 0.0 @@ -527,32 +523,7 @@ def _configure_custom_passes(self): self.inductor_compile_config[post_grad_key] = post_grad_pass_manager - def _configure_custom_passes_by_graph_info(self, graph: fx.GraphModule, example_inputs) -> None: - """Configure custom passes based on the graph information.""" - # Check if the graph is dynamic - placeholder_vals = (n.meta.get("example_value") for n in graph.graph.nodes if n.op == "placeholder") - self.is_dynamic = any(v is not None and has_free_symbols(v) for v in (*placeholder_vals, *example_inputs)) - - # Count number of nodes - nnodes = len(list(graph.graph.nodes)) - conv_nodes = [n for n in graph.graph.nodes if n.target == torch.ops.aten.convolution.default] - nconv = len(conv_nodes) - Counter(n.args[1].meta["val"].dim() - 2 for n in conv_nodes) - # dim_counts[1] / dim_counts[2] / dim_counts[3] means number of conv1d/2d/3d - - if (self.compile_config.enable_dynamic_nd_tiling is True) or ( - self.compile_config.enable_dynamic_nd_tiling is None - and self.is_dynamic - and TORCH_VERSION < (2, 11, 0) - and nnodes > 300 * nconv - ): - # On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on - # symbolic numels, so dynamic-shape transpose/permute/channels-last kernels - # degrade to untiled Grid1D. Forcing prefer_nd_tiling restores ND tiling - # (WAN 2.2 VAE 540p decode: ~1.45x. - self.inductor_compile_config["triton.prefer_nd_tiling"] = True - self.inductor_compile_config["triton.max_tiles"] = 3 - self.inductor_compile_config["triton.tile_reductions"] = True + post_grad_pass_manager.snapshot_original_inductor_configs(self.inductor_compile_config) def _init_cache(self) -> str: hash_key = compute_hash( @@ -634,8 +605,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> MagiSerializableFun magi_logger.info("Dynamo traced files (for compilation cache):\n%s", "\n".join(self.traced_files)) compilation_counter.num_graphs_seen += 1 - self._configure_custom_passes_by_graph_info(graph, example_inputs) - self._init_cache() self.full_graph_pass_manager(graph) diff --git a/magi_compiler/passes/pass_base/__init__.py b/magi_compiler/passes/pass_base/__init__.py index 17dd699..8919555 100644 --- a/magi_compiler/passes/pass_base/__init__.py +++ b/magi_compiler/passes/pass_base/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .inductor_pass import InductorPass -from .magi_inductor_pass import MagiInductorPass +from .magi_inductor_pass import MagiInductorPass, snapshot_original_inductor_configs from .pass_context import get_pass_context, pass_context -__all__ = ["InductorPass", "pass_context", "get_pass_context", "MagiInductorPass"] +__all__ = ["InductorPass", "pass_context", "get_pass_context", "MagiInductorPass", "snapshot_original_inductor_configs"] diff --git a/magi_compiler/passes/pass_base/magi_inductor_pass.py b/magi_compiler/passes/pass_base/magi_inductor_pass.py index b01e955..fd5d5ae 100644 --- a/magi_compiler/passes/pass_base/magi_inductor_pass.py +++ b/magi_compiler/passes/pass_base/magi_inductor_pass.py @@ -12,10 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +from collections.abc import Iterable + +import torch +from torch.fx.experimental.symbolic_shapes import has_free_symbols + from .inductor_pass import InductorPass +DEFAULT_CONV_HEAVY_THRESHOLD = 300 + + class MagiInductorPass(InductorPass): """ Base class for inductor passes. """ + + # If a pass needs to modify any Inductor configuration (``torch._inductor.config``), + # it **MUST** declare all affected config keys here. The declared keys will be + # automatically snapshotted and restored after the subgraph compilation ends + # to prevent global leakage. + # + # Note: Only passes running in ``PostGradPassManager`` are allowed to mutate Inductor + # configs. Passes running in ``FullGraphPassManager`` **MUST NOT** modify them, as full-graph + # passes do not trigger compilation and cannot be patched/isolated. + inductor_config_keys_potentially_mutated_by_this_pass: tuple[str, ...] = () + + def __init__(self, force_on: bool = False): + """ + Initialize the pass. + + ### Switch Mapping + The external tri-state config is mapped to this bi-state `force_on` by the Pass Manager: + + User Config (Tri-state) --> Pass Manager Action --> Pass `force_on` (Bi-state) + ----------------------- ------------------- -------------------------- + False Skip registration (Not instantiated) + True Add pass force_on = True (Bypass heuristics) + None (Auto) Add pass force_on = False (Run heuristics, if any) + + :param force_on: If True, force enable the pass, bypassing any auto-detection heuristics. + If False, run in auto mode (relying on pass-specific heuristics). + """ + super().__init__() + self.force_on = force_on + + def is_dynamic(self, graph: torch.fx.Graph) -> bool: + """Determine if the graph has dynamic shapes by checking if any placeholder carries free symbols.""" + placeholder_vals = (n.meta.get("val", n.meta.get("example_value")) for n in graph.nodes if n.op == "placeholder") + return any(v is not None and has_free_symbols(v) for v in placeholder_vals) + + def is_conv_heavy(self, graph: torch.fx.Graph, threshold: int = DEFAULT_CONV_HEAVY_THRESHOLD) -> bool: + """Determine if the graph is convolution-heavy (dense in convolutions).""" + nnodes = len(list(graph.nodes)) + nconv = sum(1 for n in graph.nodes if n.target == torch.ops.aten.convolution.default) + return nnodes < threshold * nconv + + +def snapshot_original_inductor_configs(passes: Iterable, inductor_compile_config: dict) -> None: + """Snapshot the original values of global Inductor configs that passes potentially mutate. + + The captured original values are stored in ``inductor_compile_config``. When ``standalone_compile`` + calls ``compile_fx``, it automatically passes this config as ``config_patches`` to Inductor's + ``config.patch`` context manager. No matter what values the passes temporarily set these fields to + during compilation, they will be safely restored to their pre-compilation state on scope exit. + """ + cfg = torch._inductor.config + for pass_ in passes: + for key in getattr(pass_, "inductor_config_keys_potentially_mutated_by_this_pass", ()): + snapshot = functools.reduce(getattr, key.split("."), cfg) + inductor_compile_config.setdefault(key, snapshot) diff --git a/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py b/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py new file mode 100644 index 0000000..54dfe2d --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch.torch_version import TorchVersion + +from ...magi_depyf.timeline import emit_pass_lifecycle +from ..pass_base import MagiInductorPass + + +class ND_TilingWorkaroundPass(MagiInductorPass): + inductor_config_keys_potentially_mutated_by_this_pass = ( + "triton.prefer_nd_tiling", + "triton.max_tiles", + "triton.tile_reductions", + ) + + def __init__(self, force_on: bool = False, torch_version: TorchVersion = None): + super().__init__(force_on=force_on) + self.torch_version = torch_version if torch_version is not None else TorchVersion(torch.__version__) + + @emit_pass_lifecycle + def __call__(self, graph: torch.fx.Graph): + if not self.force_on: + if self.torch_version >= (2, 11, 0) or not self.is_dynamic(graph) or self.is_conv_heavy(graph): + return False + + # On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on + # symbolic numels, so dynamic-shape transpose/permute/channels-last kernels + # degrade to untiled Grid1D. Forcing prefer_nd_tiling restores ND tiling + # (WAN 2.2 VAE 540p decode: ~1.45x). + torch._inductor.config.triton.prefer_nd_tiling = True + torch._inductor.config.triton.max_tiles = 3 + torch._inductor.config.triton.tile_reductions = True diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index 3323ea3..b2c0e7f 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -14,13 +14,15 @@ import functools +import torch from torch import fx as fx from torch._inductor.custom_graph_pass import CustomGraphPass +from torch.torch_version import TorchVersion from ...config import PassConfig, get_compile_config from ...utils import magi_logger, set_env_var from ...utils.envs import MAGI_PATTERN_MATCH_DEBUG -from ..pass_base import InductorPass, get_pass_context +from ..pass_base import InductorPass, get_pass_context, snapshot_original_inductor_configs from .fix_functionalization import FixFunctionalizationPass from .post_cleanup import PostCleanupPass @@ -80,6 +82,15 @@ def __call__(self, graph: fx.Graph): def configure(self, pass_config: PassConfig): self.pass_config = pass_config + if pass_config.enable_nd_tiling_workaround != False: + from .nd_tiling_workaround import ND_TilingWorkaroundPass + + self.add( + ND_TilingWorkaroundPass( + force_on=pass_config.enable_nd_tiling_workaround == True, torch_version=TorchVersion(torch.__version__) + ) + ) + if pass_config.enable_mm_epilogue_fusion: compile_config = get_compile_config() if compile_config.has_cutlass: @@ -100,6 +111,10 @@ def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) self.passes.append(pass_) + def snapshot_original_inductor_configs(self, inductor_compile_config: dict) -> None: + """Snapshot original values of global Inductor configs potentially mutated by passes.""" + snapshot_original_inductor_configs(self.passes, inductor_compile_config) + def uuid(self): """ The PostGradPassManager is set as a custom pass in the Inductor and diff --git a/tests/feature_tests/conftest.py b/tests/feature_tests/conftest.py new file mode 100644 index 0000000..fe1b26c --- /dev/null +++ b/tests/feature_tests/conftest.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for Inductor-pass feature tests. + +``fake_mode`` plus the tensor/graph builders below are reused by both +``test_magi_inductor_pass.py`` (base-class utilities) and +``test_dynamic_nd_tiling.py`` (ND_TilingWorkaroundPass decision logic). +""" + +import pytest +import torch +import torch.fx as fx + + +@pytest.fixture +def fake_mode(): + """A FakeTensorMode backed by a fresh ShapeEnv for symbolic shapes.""" + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + return FakeTensorMode(shape_env=ShapeEnv()) + + +def static_tensor(fake_mode): + """A FakeTensor with fully concrete (non-symbolic) dims.""" + with fake_mode: + return torch.empty(4, 8) + + +def dynamic_tensor(fake_mode): + """A FakeTensor whose first dim is a free symbol (mimics a dynamic batch).""" + sym = fake_mode.shape_env.create_unbacked_symint() + with fake_mode: + return torch.empty(sym, 8) + + +def build_graph_module(fake_mode, *, placeholder_vals=(), placeholder_meta_key="example_value", n_conv=0, n_filler=0): + """Build a tiny fx.GraphModule for exercising pass decision logic. + + ``placeholder_vals`` populate each placeholder's ``meta[placeholder_meta_key]`` + (drives ``is_dynamic``; ``MagiInductorPass.is_dynamic`` reads ``meta["val"]`` + first and falls back to ``meta["example_value"]``). ``n_conv`` adds + ``aten.convolution.default`` call nodes and ``n_filler`` adds plain call + nodes to inflate the node count (drives the ``is_conv_heavy`` heuristic). + """ + graph = fx.Graph() + inputs = [] + for i, ev in enumerate(placeholder_vals): + node = graph.placeholder(f"arg_{i}") + node.meta[placeholder_meta_key] = ev + inputs.append(node) + if not inputs: + inputs.append(graph.placeholder("arg_0")) + + x = inputs[0] + for c in range(n_conv): + weight = graph.placeholder(f"weight_{c}") + x = graph.call_function(torch.ops.aten.convolution.default, args=(x, weight)) + for _ in range(n_filler): + x = graph.call_function(torch.ops.aten.relu.default, args=(x,)) + graph.output((x,)) + return fx.GraphModule(torch.nn.Module(), graph) diff --git a/tests/feature_tests/test_dynamic_nd_tiling.py b/tests/feature_tests/test_dynamic_nd_tiling.py index 0ce8e98..e543b80 100644 --- a/tests/feature_tests/test_dynamic_nd_tiling.py +++ b/tests/feature_tests/test_dynamic_nd_tiling.py @@ -12,234 +12,116 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Logic tests for the dynamic-shape Triton ND-tiling workaround. +"""Decision-logic tests for ``ND_TilingWorkaroundPass``. -All the decision logic lives in -``MagiBackend._configure_custom_passes_by_graph_info``, which (a) probes -``is_dynamic`` from the graph's free symbols and (b) injects the ND-tiling -Inductor config based on the tri-state ``enable_dynamic_nd_tiling`` config: +When applicable, the pass flips three ``torch._inductor.config`` triton keys +(``prefer_nd_tiling`` / ``max_tiles`` / ``tile_reductions``) ON. Whether it does +so is driven by the ``enable_nd_tiling_workaround`` config: - * ``True`` -> force on - * ``False`` -> force off - * ``None`` -> auto: dynamic shapes AND PyTorch < 2.11.0 AND ``nnodes > 300 * nconv`` + * ``True`` -> force on, skip heuristics + * ``False`` -> pass not registered at all + * ``None`` -> auto: on iff dynamic shapes AND PyTorch < 2.11.0 AND not conv-heavy -The end-to-end speedup is validated separately in +These tests assert that mapping. The shared base-class utilities (``is_dynamic``, +``is_conv_heavy``, config snapshot/anti-leakage) are tested in +``test_magi_inductor_pass.py``; the end-to-end speedup in ``tests/perf_tests/test_dynamic_nd_tiling_perf.py``. """ import pytest import torch -import torch.fx as fx from torch.torch_version import TorchVersion -from magi_compiler.config import CompileConfig, get_compile_config -from magi_compiler.magi_backend import magi_backend as mb -from magi_compiler.magi_backend.magi_backend import MagiBackend +from magi_compiler.config import PassConfig +from magi_compiler.passes.piecewise_graph.nd_tiling_workaround import ND_TilingWorkaroundPass +from tests.feature_tests.conftest import build_graph_module, dynamic_tensor, static_tensor -ND_TILING_KEYS = ("triton.prefer_nd_tiling", "triton.max_tiles", "triton.tile_reductions") +@pytest.fixture(autouse=True) +def _restore_inductor_config(): + """Snapshot/restore the three triton keys around every test. -def _set_torch_version(monkeypatch, version): - """Patch the module-level parsed torch version used by the gating logic. - - ``TORCH_VERSION`` is parsed once at import time, so tests override the - parsed constant directly instead of patching ``torch.__version__``. - """ - monkeypatch.setattr(mb, "TORCH_VERSION", TorchVersion(version)) - - -def _make_backend(*, enable_dynamic_nd_tiling=None): - """Build a MagiBackend without running the heavy __init__. - - We bypass __init__ (which would spin up a CompilerManager) and set only the - attributes that ``_configure_custom_passes_by_graph_info`` reads. + The pass mutates the process-global ``torch._inductor.config`` directly, so + without this fixture one test could leak into the next. """ - backend = MagiBackend.__new__(MagiBackend) - backend.compile_config = get_compile_config().model_copy(update={"enable_dynamic_nd_tiling": enable_dynamic_nd_tiling}) - backend.inductor_compile_config = {} - return backend - - -@pytest.fixture -def fake_mode(): - """A FakeTensorMode backed by a fresh ShapeEnv for symbolic shapes.""" - from torch._subclasses.fake_tensor import FakeTensorMode - from torch.fx.experimental.symbolic_shapes import ShapeEnv - - return FakeTensorMode(shape_env=ShapeEnv()) - - -def _static_tensor(fake_mode): - """A FakeTensor with fully concrete (non-symbolic) dims.""" - with fake_mode: - return torch.empty(4, 8) - - -def _dynamic_tensor(fake_mode): - """A FakeTensor whose first dim is a free symbol (mimics a dynamic batch).""" - sym = fake_mode.shape_env.create_unbacked_symint() - with fake_mode: - return torch.empty(sym, 8) - - -def _symint(fake_mode): - """A bare (scalar) ``torch.SymInt`` carrying a free symbol.""" - return fake_mode.shape_env.create_unbacked_symint() - - -def _build_graph(fake_mode, *, placeholder_vals=(), n_conv=0, n_filler=0): - """Build a tiny fx.GraphModule for the decision logic. - - ``placeholder_vals`` populate placeholder ``meta["example_value"]`` (drive - ``is_dynamic``). ``n_conv`` adds ``aten.convolution.default`` call nodes - (each fed a weight placeholder carrying a 5-D ``meta["val"]`` so the - conv-dim bookkeeping works), and ``n_filler`` adds plain call nodes to - inflate the node count (drives the ``nnodes > 300 * nconv`` heuristic). - """ - graph = fx.Graph() - inputs = [] - for i, ev in enumerate(placeholder_vals): - node = graph.placeholder(f"arg_{i}") - node.meta["example_value"] = ev - inputs.append(node) - if not inputs: - inputs.append(graph.placeholder("arg_0")) - - with fake_mode: - weight_val = torch.empty(8, 8, 3, 3, 3) # 5-D weight => conv3d (dim()-2 == 3) - - x = inputs[0] - for c in range(n_conv): - weight = graph.placeholder(f"weight_{c}") - weight.meta["val"] = weight_val - x = graph.call_function(torch.ops.aten.convolution.default, args=(x, weight)) - for _ in range(n_filler): - x = graph.call_function(torch.ops.aten.relu.default, args=(x,)) - graph.output((x,)) - return fx.GraphModule(torch.nn.Module(), graph) - + cfg = torch._inductor.config + saved = (cfg.triton.prefer_nd_tiling, cfg.triton.max_tiles, cfg.triton.tile_reductions) + try: + yield + finally: + cfg.triton.prefer_nd_tiling, cfg.triton.max_tiles, cfg.triton.tile_reductions = saved -# ── is_dynamic detection ───────────────────────────────────────────────── +def _make_pass(*, force_on=False, version="2.9.1"): + return ND_TilingWorkaroundPass(force_on=force_on, torch_version=TorchVersion(version)) -def _is_dynamic(fake_mode, *, placeholder_vals=(), example_inputs=()): - backend = _make_backend() - gm = _build_graph(fake_mode, placeholder_vals=placeholder_vals) - backend._configure_custom_passes_by_graph_info(gm, list(example_inputs)) - return backend.is_dynamic - -def test_is_dynamic_static(fake_mode): - static = _static_tensor(fake_mode) - assert _is_dynamic(fake_mode, placeholder_vals=[static], example_inputs=[static]) is False - - -def test_is_dynamic_via_placeholder_symint(fake_mode): - """A symbolic dim on a placeholder example_value marks the compilation dynamic.""" - dynamic = _dynamic_tensor(fake_mode) - assert _is_dynamic(fake_mode, placeholder_vals=[dynamic], example_inputs=[]) is True - - -def test_is_dynamic_via_example_inputs(fake_mode): - """A symbolic dim on an example input marks the compilation dynamic.""" - dynamic = _dynamic_tensor(fake_mode) - assert _is_dynamic(fake_mode, placeholder_vals=[], example_inputs=[dynamic]) is True - - -def test_is_dynamic_ignores_none_placeholder(fake_mode): - """A missing/None placeholder example_value must not crash has_free_symbols.""" - static = _static_tensor(fake_mode) - assert _is_dynamic(fake_mode, placeholder_vals=[None, static], example_inputs=[static]) is False +def _assert_injected(injected): + cfg = torch._inductor.config + if injected: + assert cfg.triton.prefer_nd_tiling is True + assert cfg.triton.max_tiles == 3 + assert cfg.triton.tile_reductions is True + else: + assert cfg.triton.prefer_nd_tiling is False + assert cfg.triton.max_tiles is None + assert cfg.triton.tile_reductions is False -def test_is_dynamic_ignores_plain_int(fake_mode): - """Non-symbolic scalar inputs (plain ints) are treated as static.""" - assert _is_dynamic(fake_mode, placeholder_vals=[], example_inputs=[3, 8]) is False +def _auto_eligible_graph(fake_mode): + """Dynamic graph with 1 conv + plenty of filler nodes (nnodes >= 300 * nconv).""" + return build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], n_conv=1, n_filler=320) -def test_is_dynamic_via_bare_symint_input(fake_mode): - """A bare SymInt scalar input (no .shape) is correctly detected as dynamic.""" - assert _is_dynamic(fake_mode, placeholder_vals=[], example_inputs=[_symint(fake_mode)]) is True +# ── config field tri-state ─────────────────────────────────────────────── -# ── env var -> config field ────────────────────────────────────────────── +@pytest.mark.parametrize("value", [True, False, None]) +def test_config_field_tristate(value): + """enable_nd_tiling_workaround accepts True/False/None.""" + assert PassConfig(enable_nd_tiling_workaround=value).enable_nd_tiling_workaround is value -@pytest.mark.parametrize("env_val, expected", [("1", True), ("0", False)]) -def test_env_var_drives_config_field(monkeypatch, env_val, expected): - """MAGI_COMPILE_ENABLE_DYNAMIC_ND_TILING populates the config field directly.""" - monkeypatch.setenv("MAGI_COMPILE_ENABLE_DYNAMIC_ND_TILING", env_val) - assert CompileConfig().enable_dynamic_nd_tiling is expected +def test_config_field_default_is_none(): + assert PassConfig().enable_nd_tiling_workaround is None # ── ND-tiling injection decision ───────────────────────────────────────── -# -# Injection requires either ``enable_dynamic_nd_tiling is True`` OR -# (``is None`` AND dynamic AND torch < 2.11.0 AND ``nnodes > 300 * nconv``). -# We build an auto-eligible graph (dynamic input + one conv + enough filler -# nodes so ``nnodes > 300 * nconv``) and then flip one condition per test. - - -def _assert_injected(backend, injected): - if injected: - assert backend.inductor_compile_config["triton.prefer_nd_tiling"] is True - assert backend.inductor_compile_config["triton.max_tiles"] == 3 - assert backend.inductor_compile_config["triton.tile_reductions"] is True - else: - for key in ND_TILING_KEYS: - assert key not in backend.inductor_compile_config - - -def _auto_eligible_graph(fake_mode): - """Dynamic graph with 1 conv + plenty of filler nodes (nnodes > 300 * nconv).""" - return _build_graph(fake_mode, placeholder_vals=[_dynamic_tensor(fake_mode)], n_conv=1, n_filler=320) - - -def test_force_on_injects_even_when_static(monkeypatch, fake_mode): - """enable_dynamic_nd_tiling=True forces injection regardless of graph/version.""" - _set_torch_version(monkeypatch, "2.11.0") # a "fixed" version the auto path would skip - backend = _make_backend(enable_dynamic_nd_tiling=True) - gm = _build_graph(fake_mode, placeholder_vals=[_static_tensor(fake_mode)], n_conv=1, n_filler=0) - backend._configure_custom_passes_by_graph_info(gm, []) - _assert_injected(backend, True) -def test_force_off_skips_even_when_auto_eligible(monkeypatch, fake_mode): - """enable_dynamic_nd_tiling=False skips injection even when auto would enable.""" - _set_torch_version(monkeypatch, "2.9.1") - backend = _make_backend(enable_dynamic_nd_tiling=False) - backend._configure_custom_passes_by_graph_info(_auto_eligible_graph(fake_mode), []) - _assert_injected(backend, False) +def test_force_on_injects_even_when_static_and_fixed_version(fake_mode): + """force_on=True flips the config regardless of graph/version.""" + pass_ = _make_pass(force_on=True, version="2.11.0") + gm = build_graph_module(fake_mode, placeholder_vals=[static_tensor(fake_mode)], n_conv=1, n_filler=0) + pass_(gm.graph) + _assert_injected(True) -def test_auto_injects_when_all_conditions_met(monkeypatch, fake_mode): - _set_torch_version(monkeypatch, "2.9.1") - backend = _make_backend(enable_dynamic_nd_tiling=None) - backend._configure_custom_passes_by_graph_info(_auto_eligible_graph(fake_mode), []) - _assert_injected(backend, True) +def test_auto_injects_when_all_conditions_met(fake_mode): + pass_ = _make_pass(version="2.9.1") + gm = _auto_eligible_graph(fake_mode) + pass_(gm.graph) + _assert_injected(True) -def test_auto_skips_on_static_shapes(monkeypatch, fake_mode): - _set_torch_version(monkeypatch, "2.9.1") - backend = _make_backend(enable_dynamic_nd_tiling=None) - gm = _build_graph(fake_mode, placeholder_vals=[_static_tensor(fake_mode)], n_conv=0, n_filler=5) - backend._configure_custom_passes_by_graph_info(gm, []) - _assert_injected(backend, False) +def test_auto_skips_on_static_shapes(fake_mode): + pass_ = _make_pass(version="2.9.1") + gm = build_graph_module(fake_mode, placeholder_vals=[static_tensor(fake_mode)], n_conv=0, n_filler=5) + pass_(gm.graph) + _assert_injected(False) -def test_auto_skips_on_fixed_version(monkeypatch, fake_mode): +def test_auto_skips_on_fixed_version(fake_mode): """Dynamic shapes but PyTorch >= 2.11.0: native coalesce path handles it.""" - _set_torch_version(monkeypatch, "2.11.0") - backend = _make_backend(enable_dynamic_nd_tiling=None) - backend._configure_custom_passes_by_graph_info(_auto_eligible_graph(fake_mode), []) - _assert_injected(backend, False) - - -def test_auto_skips_when_graph_too_conv_dense(monkeypatch, fake_mode): - """``nnodes <= 300 * nconv`` (conv-dense graph): the heuristic bails out.""" - _set_torch_version(monkeypatch, "2.9.1") - backend = _make_backend(enable_dynamic_nd_tiling=None) - # 1 conv + few filler nodes => nnodes well under 300 * 1 = 300. - gm = _build_graph(fake_mode, placeholder_vals=[_dynamic_tensor(fake_mode)], n_conv=1, n_filler=5) - backend._configure_custom_passes_by_graph_info(gm, []) - _assert_injected(backend, False) + pass_ = _make_pass(version="2.11.0") + gm = _auto_eligible_graph(fake_mode) + pass_(gm.graph) + _assert_injected(False) + + +def test_auto_skips_when_graph_too_conv_dense(fake_mode): + """``nnodes < 300 * nconv`` (conv-dense graph): the heuristic bails out.""" + pass_ = _make_pass(version="2.9.1") + gm = build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], n_conv=1, n_filler=5) + pass_(gm.graph) + _assert_injected(False) diff --git a/tests/feature_tests/test_magi_inductor_pass.py b/tests/feature_tests/test_magi_inductor_pass.py new file mode 100644 index 0000000..4ac05bb --- /dev/null +++ b/tests/feature_tests/test_magi_inductor_pass.py @@ -0,0 +1,183 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the base class MagiInductorPass and its helper utilities.""" + +import torch + +from magi_compiler.config import CompileConfig +from magi_compiler.passes.pass_base import MagiInductorPass, snapshot_original_inductor_configs +from tests.feature_tests.conftest import build_graph_module, dynamic_tensor, static_tensor + + +class DummyPass(MagiInductorPass): + inductor_config_keys_potentially_mutated_by_this_pass = ("triton.prefer_nd_tiling", "triton.max_tiles") + + def __call__(self, graph: torch.fx.Graph): + # Mutate the configs during execution + torch._inductor.config.triton.prefer_nd_tiling = True + torch._inductor.config.triton.max_tiles = 3 + + +class UndeclaredPass(MagiInductorPass): + def __call__(self, graph: torch.fx.Graph): + pass + + +def test_is_dynamic(fake_mode): + """is_dynamic flags graphs whose placeholders carry free symbols.""" + pass_ = DummyPass() + + # Static graph + gm_static = build_graph_module(fake_mode, placeholder_vals=[static_tensor(fake_mode)]) + assert not pass_.is_dynamic(gm_static.graph) + + # Dynamic graph + gm_dynamic = build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)]) + assert pass_.is_dynamic(gm_dynamic.graph) + + +def test_is_dynamic_reads_val_meta_key(fake_mode): + """is_dynamic prefers meta["val"], falling back to meta["example_value"].""" + pass_ = DummyPass() + + # Symbol carried under meta["val"] (the preferred key) must be detected. + gm_val = build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], placeholder_meta_key="val") + assert pass_.is_dynamic(gm_val.graph) + + # A placeholder with neither meta key set is treated as static (not dynamic). + gm_no_meta = build_graph_module(fake_mode) + assert not pass_.is_dynamic(gm_no_meta.graph) + + +def test_is_conv_heavy(fake_mode): + """is_conv_heavy flags graphs whose node count is dense relative to convs.""" + pass_ = DummyPass() + + # 1 conv, 50 filler nodes -> nnodes = 53 (1 input + 1 weight + 1 conv + 50 relu). + # threshold = 300, threshold * nconv = 300. nnodes < 300 -> is_conv_heavy is True. + gm_heavy = build_graph_module(fake_mode, n_conv=1, n_filler=50) + assert pass_.is_conv_heavy(gm_heavy.graph, threshold=300) + + # 1 conv, 320 filler nodes -> nnodes = 323. + # threshold = 300, threshold * nconv = 300. nnodes >= 300 -> is_conv_heavy is False. + gm_light = build_graph_module(fake_mode, n_conv=1, n_filler=320) + assert not pass_.is_conv_heavy(gm_light.graph, threshold=300) + + +def test_is_conv_heavy_zero_conv(fake_mode): + """A graph with no convolutions is never conv-heavy (threshold * 0 == 0).""" + pass_ = DummyPass() + gm_no_conv = build_graph_module(fake_mode, n_conv=0, n_filler=10) + assert not pass_.is_conv_heavy(gm_no_conv.graph, threshold=300) + + +def test_snapshot_original_inductor_configs(): + """Verify that snapshot_original_inductor_configs snapshots declared keys correctly.""" + cfg = {} + pass_ = DummyPass() + snapshot_original_inductor_configs([pass_], cfg) + + # Check that declared keys are snapshotted + for key in pass_.inductor_config_keys_potentially_mutated_by_this_pass: + assert key in cfg + + # setdefault: a value already set by the user/upstream is preserved + cfg_with_preset = {"triton.prefer_nd_tiling": True} + snapshot_original_inductor_configs([pass_], cfg_with_preset) + assert cfg_with_preset["triton.prefer_nd_tiling"] is True + + # A pass declaring an empty tuple contributes no anchors + cfg_empty = {} + snapshot_original_inductor_configs([UndeclaredPass()], cfg_empty) + assert cfg_empty == {} + + +def test_snapshot_prevents_global_leakage(fake_mode): + """Verify that snapshot_original_inductor_configs prevents global config leakage when used with config.patch.""" + cfg = torch._inductor.config + + # Snapshot original states to verify restoration later + orig_prefer_nd_tiling = cfg.triton.prefer_nd_tiling + orig_max_tiles = cfg.triton.max_tiles + + # Ensure they are currently at their default values (or at least we know what they are) + assert orig_prefer_nd_tiling is False + assert orig_max_tiles is None + + pass_ = DummyPass() + inductor_compile_config = {} + snapshot_original_inductor_configs([pass_], inductor_compile_config) + + # Simulate the compilation scope with config.patch + with cfg.patch(inductor_compile_config): + gm = build_graph_module(fake_mode) + pass_(gm.graph) + + # Inside the scope, the config should be mutated by the pass + assert cfg.triton.prefer_nd_tiling is True + assert cfg.triton.max_tiles == 3 + + # Outside the scope, the config must be restored to its original values, preventing leakage + assert cfg.triton.prefer_nd_tiling == orig_prefer_nd_tiling + assert cfg.triton.max_tiles == orig_max_tiles + + +def test_env_nested_delimiter_config_parsing(monkeypatch): + """Verify that nested sub-configs can be overridden via double-underscore environment variables.""" + # Simulate setting the environment variable + monkeypatch.setenv("MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND", "1") + + config = CompileConfig() + assert config.pass_config.enable_nd_tiling_workaround is True + + # Test setting to False ("0") + monkeypatch.setenv("MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND", "0") + config = CompileConfig() + assert config.pass_config.enable_nd_tiling_workaround is False + + +def test_tristate_compatibility_with_bistate_passes(): + """Verify that the tri-state configuration is fully compatible with bi-state passes. + + If a pass does not implement None/auto mode, it behaves the same as False when force_on is False. + """ + # 1. Test a pass that DOES NOT implement any auto heuristics (e.g. UndeclaredPass) + # When force_on is False (mapped from config = None/Auto), it should behave the same as False. + pass_bistate = UndeclaredPass(force_on=False) + assert not pass_bistate.force_on + + # 2. Verify PostGradPassManager's mapping of tri-state config to bi-state force_on: + # - config = True --> force_on = True + # - config = None --> force_on = False (runs heuristics, if any) + # - config = False --> pass is not registered at all + from magi_compiler.config import PassConfig + from magi_compiler.passes.piecewise_graph.post_grad_pass_manager import PostGradPassManager + + # Case A: config = True + pm_true = PostGradPassManager() + pm_true.configure(PassConfig(enable_nd_tiling_workaround=True)) + assert len(pm_true.passes) == 1 + assert pm_true.passes[0].force_on is True + + # Case B: config = None (Auto) + pm_none = PostGradPassManager() + pm_none.configure(PassConfig(enable_nd_tiling_workaround=None)) + assert len(pm_none.passes) == 1 + assert pm_none.passes[0].force_on is False + + # Case C: config = False + pm_false = PostGradPassManager() + pm_false.configure(PassConfig(enable_nd_tiling_workaround=False)) + assert len(pm_false.passes) == 0 diff --git a/tests/model_definition.py b/tests/model_definition.py index 8686f5b..ae6c04f 100644 --- a/tests/model_definition.py +++ b/tests/model_definition.py @@ -398,6 +398,49 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None) return self.lm_head(x) +class ResBlock3D(nn.Module): + """3D conv residual block (GroupNorm + SiLU + Conv3d, ×2) with a skip path.""" + + def __init__(self, cin: int, cout: int): + super().__init__() + self.norm1 = nn.GroupNorm(32, cin) + self.conv1 = nn.Conv3d(cin, cout, 3, padding=1) + self.norm2 = nn.GroupNorm(32, cout) + self.conv2 = nn.Conv3d(cout, cout, 3, padding=1) + self.skip = nn.Conv3d(cin, cout, 1) if cin != cout else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.conv1(F.silu(self.norm1(x))) + h = self.conv2(F.silu(self.norm2(h))) + return h + self.skip(x) + + +class VAEDecoderLike(nn.Module): + """Stacked 3D conv resblocks + spatial upsampling, mimicking a VAE decoder.""" + + def __init__(self, zc: int = 48, base: int = 128): + super().__init__() + self.conv_in = nn.Conv3d(zc, base, 3, padding=1) + self.r1 = ResBlock3D(base, base) + self.up1 = nn.Conv3d(base, base, 3, padding=1) + self.r2 = ResBlock3D(base, base // 2) + self.up2 = nn.Conv3d(base // 2, base // 2, 3, padding=1) + self.r3 = ResBlock3D(base // 2, base // 4) + self.norm_out = nn.GroupNorm(32, base // 4) + self.conv_out = nn.Conv3d(base // 4, 3, 3, padding=1) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + x = self.conv_in(z) + x = self.r1(x) + x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") + x = self.up1(x) + x = self.r2(x) + x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") + x = self.up2(x) + x = self.r3(x) + return self.conv_out(F.silu(self.norm_out(x))) + + def create_transformer_model(config: TransformerConfig, device: torch.device) -> Transformer: """Create Transformer model diff --git a/tests/perf_tests/test_dynamic_nd_tiling_perf.py b/tests/perf_tests/test_dynamic_nd_tiling_perf.py index 0f9c1b9..fde3eba 100644 --- a/tests/perf_tests/test_dynamic_nd_tiling_perf.py +++ b/tests/perf_tests/test_dynamic_nd_tiling_perf.py @@ -21,80 +21,37 @@ transpose/permute/channels-last pointwise kernels in a dynamic-shape graph degrade to untiled Grid1D. MagiCompiler works around this by auto-enabling ``triton.prefer_nd_tiling`` (+ ``max_tiles=3`` + ``tile_reductions``) for dynamic -compilation; see ``MagiBackend._should_enable_nd_tiling``. +compilation; see ``magi_compiler.passes.piecewise_graph.nd_tiling_workaround.ND_TilingWorkaroundPass``. This test exercises a WAN-2.2-VAE-decode-like workload (stacked 3D conv resblocks -+ spatial upsampling) compiled with **dynamic H/W**, and checks that the -workaround is a net win versus turning it off on the *same* magi_compile path. ++ spatial upsampling) compiled with **dynamic H/W**, and checks that magi_compile +(with the workaround on) beats vanilla ``torch.compile`` on that path. Real WAN 2.2 VAE decode (540p, dynamic H/W) numbers that motivate this: - with conv channels-last layout: 1.252s -> 542ms / decode (~2.3x) - without conv channels-last: 770ms -> 535ms / decode (~1.44x) This synthetic decoder (no weights, no conv channels-last pass) reproduces the -"~1.4x" regime; the absolute ratio is GPU-dependent so the strict assertion only -runs on calibrated GPUs. +"~1.4x" regime. The absolute ratio is GPU-dependent, so ND_TILING_SPEEDUP_THRESHOLD +is set to a conservative lower bound that still proves a clear, non-noise win. """ import pytest import torch -import torch.nn as nn -import torch.nn.functional as F from magi_compiler import magi_compile +from tests.model_definition import VAEDecoderLike from tests.perf_tests import cuda_benchmark, print_perf_comparison -from tests.perf_tests.utils import is_perf_calibrated_gpu +from tests.perf_tests.utils import assert_magi_vs_torch # WAN 2.2 VAE 540p latent: [C, T, H, W]; dynamic dims are H and W. LATENT_C, LATENT_T, LATENT_H, LATENT_W = 48, 7, 34, 60 -BASE_CHANNELS = 128 -# nd_tiling(on) vs nd_tiling(off), both on the magi_compile dynamic path. -# Observed ~1.36x (off=2.209ms -> on=1.627ms) on H100; assert a conservative +# magi_compile (workaround on) vs vanilla torch.compile, both on the dynamic path. +# Observed ~1.36x (torch=2.20ms -> magi=1.63ms) on H100; assert a conservative # lower bound that still proves a clear, non-noise win. ND_TILING_SPEEDUP_THRESHOLD = 1.20 -class _ResBlock3D(nn.Module): - def __init__(self, cin: int, cout: int): - super().__init__() - self.norm1 = nn.GroupNorm(32, cin) - self.conv1 = nn.Conv3d(cin, cout, 3, padding=1) - self.norm2 = nn.GroupNorm(32, cout) - self.conv2 = nn.Conv3d(cout, cout, 3, padding=1) - self.skip = nn.Conv3d(cin, cout, 1) if cin != cout else nn.Identity() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self.conv1(F.silu(self.norm1(x))) - h = self.conv2(F.silu(self.norm2(h))) - return h + self.skip(x) - - -class VAEDecoderLike(nn.Module): - """Stacked 3D conv resblocks + spatial upsampling, mimicking VAE decode.""" - - def __init__(self, zc: int = LATENT_C, base: int = BASE_CHANNELS): - super().__init__() - self.conv_in = nn.Conv3d(zc, base, 3, padding=1) - self.r1 = _ResBlock3D(base, base) - self.up1 = nn.Conv3d(base, base, 3, padding=1) - self.r2 = _ResBlock3D(base, base // 2) - self.up2 = nn.Conv3d(base // 2, base // 2, 3, padding=1) - self.r3 = _ResBlock3D(base // 2, base // 4) - self.norm_out = nn.GroupNorm(32, base // 4) - self.conv_out = nn.Conv3d(base // 4, 3, 3, padding=1) - - def forward(self, z: torch.Tensor) -> torch.Tensor: - x = self.conv_in(z) - x = self.r1(x) - x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") - x = self.up1(x) - x = self.r2(x) - x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") - x = self.up2(x) - x = self.r3(x) - return self.conv_out(F.silu(self.norm_out(x))) - - @pytest.fixture(scope="module") def decoder_device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -105,9 +62,9 @@ def decoder_input(decoder_device): return torch.randn(1, LATENT_C, LATENT_T, LATENT_H, LATENT_W, device=decoder_device, dtype=torch.bfloat16) -def _compile_decoder(device: torch.device, enable_nd_tiling: bool): +def _compile_decoder(device: torch.device): def _patch(cfg): - cfg.enable_dynamic_nd_tiling = enable_nd_tiling + cfg.pass_config.enable_nd_tiling_workaround = True return cfg model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() @@ -115,28 +72,40 @@ def _patch(cfg): return magi_compile(model, dynamic_arg_dims={"z": [3, 4]}, config_patch=_patch) +def _compile_torch(device: torch.device): + model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + return torch.compile(model, backend="inductor") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") def test_nd_tiling_workaround_speedup(decoder_device, decoder_input): - """ND-tiling ON should beat ND-tiling OFF on the dynamic magi_compile path.""" - disabled = _compile_decoder(decoder_device, enable_nd_tiling=False) - enabled = _compile_decoder(decoder_device, enable_nd_tiling=True) + """ND-tiling ON should beat vanilla torch.compile on the dynamic path.""" + # Build isolated inputs to prevent dynamic shape marking leakage + eager_input = decoder_input.clone() + magi_input = decoder_input.clone() + torch_input = decoder_input.clone() + + # Explicitly mark dynamic dimensions for the vanilla torch.compile environment + torch._dynamo.mark_dynamic(torch_input, [3, 4]) + + eager_model = VAEDecoderLike().to(decoder_device).to(torch.bfloat16).eval() + magi_compiled = _compile_decoder(decoder_device) + torch_compiled = _compile_torch(decoder_device) with torch.no_grad(): - disabled_result = cuda_benchmark(lambda: disabled(decoder_input), compilation_warmup=3) - enabled_result = cuda_benchmark(lambda: enabled(decoder_input), compilation_warmup=3) + eager_result = cuda_benchmark(lambda: eager_model(eager_input)) + torch_result = cuda_benchmark(lambda: torch_compiled(torch_input), compilation_warmup=3) + magi_result = cuda_benchmark(lambda: magi_compiled(magi_input), compilation_warmup=3) - speedup = disabled_result.median / enabled_result.median + speedup = torch_result.median / magi_result.median print_perf_comparison( - "Dynamic ND-tiling: workaround ON vs OFF (magi_compile, dynamic H/W)", - disabled_result, - enabled_result, - extra_info=(f"latent=({LATENT_C}, {LATENT_T}, {LATENT_H}, {LATENT_W}) " f"speedup(off/on)={speedup:.2f}x"), + "Dynamic ND-tiling: magi_compile vs torch.compile (dynamic H/W)", + eager_result, + magi_result, + torch_result, + extra_info=(f"latent=({LATENT_C}, {LATENT_T}, {LATENT_H}, {LATENT_W}) " f"speedup(torch/magi)={speedup:.2f}x"), ) - if not is_perf_calibrated_gpu(): - return - assert speedup >= ND_TILING_SPEEDUP_THRESHOLD, ( - f"ND-tiling workaround should be >= {ND_TILING_SPEEDUP_THRESHOLD:.2f}x faster than disabled " - f"under dynamic shapes. Got {speedup:.2f}x " - f"(disabled={disabled_result.median:.3f}ms, enabled={enabled_result.median:.3f}ms)" + assert_magi_vs_torch( + speedup, torch_result, magi_result, label="Dynamic ND-tiling workaround", threshold=ND_TILING_SPEEDUP_THRESHOLD ) From 4134f7102e960d1fd17d3ab3effcc3c58a1674fd Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 22 Jun 2026 22:55:47 +0800 Subject: [PATCH 5/5] [Refactor] Delete tri-state gating logic --- magi_compiler/config.py | 23 ++++-- .../passes/pass_base/magi_inductor_pass.py | 20 ----- .../piecewise_graph/nd_tiling_workaround.py | 11 ++- .../piecewise_graph/post_grad_pass_manager.py | 10 +-- .../feature_tests/test_magi_inductor_pass.py | 61 +++++++-------- ...tiling.py => test_nd_tiling_workaround.py} | 74 ++++++++----------- ...f.py => test_nd_tiling_perf_workaround.py} | 35 ++++----- 7 files changed, 105 insertions(+), 129 deletions(-) rename tests/feature_tests/{test_dynamic_nd_tiling.py => test_nd_tiling_workaround.py} (57%) rename tests/perf_tests/{test_dynamic_nd_tiling_perf.py => test_nd_tiling_perf_workaround.py} (78%) diff --git a/magi_compiler/config.py b/magi_compiler/config.py index 7e793ee..0b180d1 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -16,7 +16,7 @@ import os from enum import Enum, unique from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal import torch from pydantic import BaseModel, Field @@ -63,14 +63,22 @@ class PassConfig(BaseModel): # TODO: Add no-op elimination pass. # TODO: Add sequence parallelism pass and async TP pass. # TODO: Add Ulysses overlap pass. - enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.") - enable_nd_tiling_workaround: Optional[bool] = Field( - None, + enable_sage_attn: bool = Field( + False, + description=( + "Whether to replace flash attention with sage attention. " + "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_SAGE_ATTN (1/0/true/false)." + ), + ) + enable_nd_tiling_workaround: bool = Field( + True, description=( "Triton ND-tiling workaround (prefer_nd_tiling + max_tiles=3 + tile_reductions) " "for Inductor's coalesce tiling bailing out under dynamic shapes. " - "Tri-state: None = auto (decided by the Pass's internal heuristics, if any); True/False = force. " - "See MagiInductorPass.__init__ for how this maps to the Pass's force_on flag." + "True (default): register the pass and let its internal heuristics decide whether to " + "apply (currently: torch < 2.11.0 AND dynamic shapes AND conv-heavy). " + "False: do not register the pass at all. " + "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND (1/0/true/false)." ), ) enable_mm_epilogue_fusion: bool = Field( @@ -82,7 +90,8 @@ class PassConfig(BaseModel): "(sm_90) the swiglu sub-path additionally uses the native Sm90 " "TMA + WGMMA DualGemm. The pass is a no-op on older architectures " "regardless of this flag, but the flag still controls whether it " - "is registered at all." + "is registered at all. " + "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_MM_EPILOGUE_FUSION (1/0/true/false)." ), ) diff --git a/magi_compiler/passes/pass_base/magi_inductor_pass.py b/magi_compiler/passes/pass_base/magi_inductor_pass.py index fd5d5ae..7f337dd 100644 --- a/magi_compiler/passes/pass_base/magi_inductor_pass.py +++ b/magi_compiler/passes/pass_base/magi_inductor_pass.py @@ -20,7 +20,6 @@ from .inductor_pass import InductorPass - DEFAULT_CONV_HEAVY_THRESHOLD = 300 @@ -39,25 +38,6 @@ class MagiInductorPass(InductorPass): # passes do not trigger compilation and cannot be patched/isolated. inductor_config_keys_potentially_mutated_by_this_pass: tuple[str, ...] = () - def __init__(self, force_on: bool = False): - """ - Initialize the pass. - - ### Switch Mapping - The external tri-state config is mapped to this bi-state `force_on` by the Pass Manager: - - User Config (Tri-state) --> Pass Manager Action --> Pass `force_on` (Bi-state) - ----------------------- ------------------- -------------------------- - False Skip registration (Not instantiated) - True Add pass force_on = True (Bypass heuristics) - None (Auto) Add pass force_on = False (Run heuristics, if any) - - :param force_on: If True, force enable the pass, bypassing any auto-detection heuristics. - If False, run in auto mode (relying on pass-specific heuristics). - """ - super().__init__() - self.force_on = force_on - def is_dynamic(self, graph: torch.fx.Graph) -> bool: """Determine if the graph has dynamic shapes by checking if any placeholder carries free symbols.""" placeholder_vals = (n.meta.get("val", n.meta.get("example_value")) for n in graph.nodes if n.op == "placeholder") diff --git a/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py b/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py index 54dfe2d..568ce00 100644 --- a/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py +++ b/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py @@ -29,15 +29,14 @@ class ND_TilingWorkaroundPass(MagiInductorPass): "triton.tile_reductions", ) - def __init__(self, force_on: bool = False, torch_version: TorchVersion = None): - super().__init__(force_on=force_on) - self.torch_version = torch_version if torch_version is not None else TorchVersion(torch.__version__) + def __init__(self): + super().__init__() + self.is_target_torch_version = TorchVersion(torch.__version__) < (2, 11, 0) @emit_pass_lifecycle def __call__(self, graph: torch.fx.Graph): - if not self.force_on: - if self.torch_version >= (2, 11, 0) or not self.is_dynamic(graph) or self.is_conv_heavy(graph): - return False + if not self.is_target_torch_version or not self.is_dynamic(graph) or not self.is_conv_heavy(graph): + return False # On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on # symbolic numels, so dynamic-shape transpose/permute/channels-last kernels diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index b2c0e7f..b087c62 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -14,10 +14,8 @@ import functools -import torch from torch import fx as fx from torch._inductor.custom_graph_pass import CustomGraphPass -from torch.torch_version import TorchVersion from ...config import PassConfig, get_compile_config from ...utils import magi_logger, set_env_var @@ -82,14 +80,10 @@ def __call__(self, graph: fx.Graph): def configure(self, pass_config: PassConfig): self.pass_config = pass_config - if pass_config.enable_nd_tiling_workaround != False: + if pass_config.enable_nd_tiling_workaround: from .nd_tiling_workaround import ND_TilingWorkaroundPass - self.add( - ND_TilingWorkaroundPass( - force_on=pass_config.enable_nd_tiling_workaround == True, torch_version=TorchVersion(torch.__version__) - ) - ) + self.add(ND_TilingWorkaroundPass()) if pass_config.enable_mm_epilogue_fusion: compile_config = get_compile_config() diff --git a/tests/feature_tests/test_magi_inductor_pass.py b/tests/feature_tests/test_magi_inductor_pass.py index 4ac05bb..79d15da 100644 --- a/tests/feature_tests/test_magi_inductor_pass.py +++ b/tests/feature_tests/test_magi_inductor_pass.py @@ -14,12 +14,16 @@ """Unit tests for the base class MagiInductorPass and its helper utilities.""" +import pytest import torch +from pydantic import ValidationError from magi_compiler.config import CompileConfig from magi_compiler.passes.pass_base import MagiInductorPass, snapshot_original_inductor_configs from tests.feature_tests.conftest import build_graph_module, dynamic_tensor, static_tensor +_ND_TILING_ENV = "MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND" + class DummyPass(MagiInductorPass): inductor_config_keys_potentially_mutated_by_this_pass = ("triton.prefer_nd_tiling", "triton.max_tiles") @@ -134,50 +138,49 @@ def test_snapshot_prevents_global_leakage(fake_mode): assert cfg.triton.max_tiles == orig_max_tiles -def test_env_nested_delimiter_config_parsing(monkeypatch): - """Verify that nested sub-configs can be overridden via double-underscore environment variables.""" - # Simulate setting the environment variable - monkeypatch.setenv("MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND", "1") +@pytest.mark.parametrize("env_value, expected", [("1", True), ("true", True), ("0", False), ("false", False)]) +def test_env_nested_delimiter_config_parsing(monkeypatch, env_value, expected): + """A nested sub-config field is overridable via the MAGI_COMPILE___ env var. + pydantic parses the bool field here, so only truthy/falsy literals + (1/0/true/false) are accepted as on/off. + """ + monkeypatch.setenv(_ND_TILING_ENV, env_value) config = CompileConfig() - assert config.pass_config.enable_nd_tiling_workaround is True + assert config.pass_config.enable_nd_tiling_workaround is expected - # Test setting to False ("0") - monkeypatch.setenv("MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND", "0") + +def test_env_unset_defaults_to_true(monkeypatch): + """When the env var is unset, the binary field defaults to True.""" + monkeypatch.delenv(_ND_TILING_ENV, raising=False) config = CompileConfig() - assert config.pass_config.enable_nd_tiling_workaround is False + assert config.pass_config.enable_nd_tiling_workaround is True -def test_tristate_compatibility_with_bistate_passes(): - """Verify that the tri-state configuration is fully compatible with bi-state passes. +@pytest.mark.parametrize("env_value", ["none", "null", "maybe", ""]) +def test_env_rejects_non_bool_strings(monkeypatch, env_value): + """The field is a bool, so non-bool strings raise a ValidationError. - If a pass does not implement None/auto mode, it behaves the same as False when force_on is False. + Only 1/0/true/false round-trip; everything else is rejected. + """ + monkeypatch.setenv(_ND_TILING_ENV, env_value) + with pytest.raises(ValidationError): + CompileConfig() + + +def test_config_binary_registration_mapping(): + """PostGradPassManager registers the pass iff enable_nd_tiling_workaround is True. + + - config = True --> pass is registered + - config = False --> pass is not registered at all """ - # 1. Test a pass that DOES NOT implement any auto heuristics (e.g. UndeclaredPass) - # When force_on is False (mapped from config = None/Auto), it should behave the same as False. - pass_bistate = UndeclaredPass(force_on=False) - assert not pass_bistate.force_on - - # 2. Verify PostGradPassManager's mapping of tri-state config to bi-state force_on: - # - config = True --> force_on = True - # - config = None --> force_on = False (runs heuristics, if any) - # - config = False --> pass is not registered at all from magi_compiler.config import PassConfig from magi_compiler.passes.piecewise_graph.post_grad_pass_manager import PostGradPassManager - # Case A: config = True pm_true = PostGradPassManager() pm_true.configure(PassConfig(enable_nd_tiling_workaround=True)) assert len(pm_true.passes) == 1 - assert pm_true.passes[0].force_on is True - - # Case B: config = None (Auto) - pm_none = PostGradPassManager() - pm_none.configure(PassConfig(enable_nd_tiling_workaround=None)) - assert len(pm_none.passes) == 1 - assert pm_none.passes[0].force_on is False - # Case C: config = False pm_false = PostGradPassManager() pm_false.configure(PassConfig(enable_nd_tiling_workaround=False)) assert len(pm_false.passes) == 0 diff --git a/tests/feature_tests/test_dynamic_nd_tiling.py b/tests/feature_tests/test_nd_tiling_workaround.py similarity index 57% rename from tests/feature_tests/test_dynamic_nd_tiling.py rename to tests/feature_tests/test_nd_tiling_workaround.py index e543b80..411edbc 100644 --- a/tests/feature_tests/test_dynamic_nd_tiling.py +++ b/tests/feature_tests/test_nd_tiling_workaround.py @@ -15,22 +15,21 @@ """Decision-logic tests for ``ND_TilingWorkaroundPass``. When applicable, the pass flips three ``torch._inductor.config`` triton keys -(``prefer_nd_tiling`` / ``max_tiles`` / ``tile_reductions``) ON. Whether it does -so is driven by the ``enable_nd_tiling_workaround`` config: +(``prefer_nd_tiling`` / ``max_tiles`` / ``tile_reductions``) ON. The binary +``enable_nd_tiling_workaround`` config controls registration: - * ``True`` -> force on, skip heuristics - * ``False`` -> pass not registered at all - * ``None`` -> auto: on iff dynamic shapes AND PyTorch < 2.11.0 AND not conv-heavy + * ``True`` (default) -> register the pass; its internal heuristics then decide: + apply iff dynamic shapes AND PyTorch < 2.11.0 AND conv-heavy. + * ``False`` -> pass not registered at all. -These tests assert that mapping. The shared base-class utilities (``is_dynamic``, -``is_conv_heavy``, config snapshot/anti-leakage) are tested in -``test_magi_inductor_pass.py``; the end-to-end speedup in -``tests/perf_tests/test_dynamic_nd_tiling_perf.py``. +These tests assert the registered pass's heuristic decision. The shared +base-class utilities (``is_dynamic``, ``is_conv_heavy``, config +snapshot/anti-leakage) are tested in ``test_magi_inductor_pass.py``; the +end-to-end speedup in ``tests/perf_tests/test_nd_tiling_perf_workaround.py``. """ import pytest import torch -from torch.torch_version import TorchVersion from magi_compiler.config import PassConfig from magi_compiler.passes.piecewise_graph.nd_tiling_workaround import ND_TilingWorkaroundPass @@ -52,8 +51,17 @@ def _restore_inductor_config(): cfg.triton.prefer_nd_tiling, cfg.triton.max_tiles, cfg.triton.tile_reductions = saved -def _make_pass(*, force_on=False, version="2.9.1"): - return ND_TilingWorkaroundPass(force_on=force_on, torch_version=TorchVersion(version)) +def _make_pass(*, is_target_torch_version=True): + """Build the pass and pin its version gate. + + The pass reads ``torch.__version__`` at construction and caches whether it is + a target version (< 2.11.0) in ``is_target_torch_version``. Tests override that + cached flag directly so the version branch is exercised regardless of the + installed torch. + """ + pass_ = ND_TilingWorkaroundPass() + pass_.is_target_torch_version = is_target_torch_version + return pass_ def _assert_injected(injected): @@ -69,43 +77,25 @@ def _assert_injected(injected): def _auto_eligible_graph(fake_mode): - """Dynamic graph with 1 conv + plenty of filler nodes (nnodes >= 300 * nconv).""" - return build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], n_conv=1, n_filler=320) - + """Dynamic, conv-heavy graph (nnodes < 300 * nconv): the workaround applies.""" + return build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], n_conv=1, n_filler=5) -# ── config field tri-state ─────────────────────────────────────────────── - -@pytest.mark.parametrize("value", [True, False, None]) -def test_config_field_tristate(value): - """enable_nd_tiling_workaround accepts True/False/None.""" +@pytest.mark.parametrize("value", [True, False]) +def test_config_field_binary(value): + """enable_nd_tiling_workaround accepts True/False (default True covered in test_magi_inductor_pass).""" assert PassConfig(enable_nd_tiling_workaround=value).enable_nd_tiling_workaround is value -def test_config_field_default_is_none(): - assert PassConfig().enable_nd_tiling_workaround is None - - -# ── ND-tiling injection decision ───────────────────────────────────────── - - -def test_force_on_injects_even_when_static_and_fixed_version(fake_mode): - """force_on=True flips the config regardless of graph/version.""" - pass_ = _make_pass(force_on=True, version="2.11.0") - gm = build_graph_module(fake_mode, placeholder_vals=[static_tensor(fake_mode)], n_conv=1, n_filler=0) - pass_(gm.graph) - _assert_injected(True) - - def test_auto_injects_when_all_conditions_met(fake_mode): - pass_ = _make_pass(version="2.9.1") + pass_ = _make_pass(is_target_torch_version=True) gm = _auto_eligible_graph(fake_mode) pass_(gm.graph) _assert_injected(True) def test_auto_skips_on_static_shapes(fake_mode): - pass_ = _make_pass(version="2.9.1") + pass_ = _make_pass(is_target_torch_version=True) gm = build_graph_module(fake_mode, placeholder_vals=[static_tensor(fake_mode)], n_conv=0, n_filler=5) pass_(gm.graph) _assert_injected(False) @@ -113,15 +103,15 @@ def test_auto_skips_on_static_shapes(fake_mode): def test_auto_skips_on_fixed_version(fake_mode): """Dynamic shapes but PyTorch >= 2.11.0: native coalesce path handles it.""" - pass_ = _make_pass(version="2.11.0") + pass_ = _make_pass(is_target_torch_version=False) gm = _auto_eligible_graph(fake_mode) pass_(gm.graph) _assert_injected(False) -def test_auto_skips_when_graph_too_conv_dense(fake_mode): - """``nnodes < 300 * nconv`` (conv-dense graph): the heuristic bails out.""" - pass_ = _make_pass(version="2.9.1") - gm = build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], n_conv=1, n_filler=5) +def test_auto_skips_when_graph_not_conv_heavy(fake_mode): + """``nnodes >= 300 * nconv`` (conv-sparse graph): low conv ratio, ND-tiling gives little, so skip.""" + pass_ = _make_pass(is_target_torch_version=True) + gm = build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], n_conv=1, n_filler=320) pass_(gm.graph) _assert_injected(False) diff --git a/tests/perf_tests/test_dynamic_nd_tiling_perf.py b/tests/perf_tests/test_nd_tiling_perf_workaround.py similarity index 78% rename from tests/perf_tests/test_dynamic_nd_tiling_perf.py rename to tests/perf_tests/test_nd_tiling_perf_workaround.py index fde3eba..e08260f 100644 --- a/tests/perf_tests/test_dynamic_nd_tiling_perf.py +++ b/tests/perf_tests/test_nd_tiling_perf_workaround.py @@ -19,13 +19,16 @@ On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on symbolic numels (``tiling_utils.extract_normalized_read_writes`` returns ``None``), so transpose/permute/channels-last pointwise kernels in a dynamic-shape graph -degrade to untiled Grid1D. MagiCompiler works around this by auto-enabling -``triton.prefer_nd_tiling`` (+ ``max_tiles=3`` + ``tile_reductions``) for dynamic -compilation; see ``magi_compiler.passes.piecewise_graph.nd_tiling_workaround.ND_TilingWorkaroundPass``. +degrade to untiled Grid1D. ``ND_TilingWorkaroundPass`` works around this by +enabling ``triton.prefer_nd_tiling`` (+ ``max_tiles=3`` + ``tile_reductions``) +when the post-grad graph is dynamic AND conv-heavy (the regime where the +degraded kernels dominate); see +``magi_compiler.passes.piecewise_graph.nd_tiling_workaround.ND_TilingWorkaroundPass``. This test exercises a WAN-2.2-VAE-decode-like workload (stacked 3D conv resblocks -+ spatial upsampling) compiled with **dynamic H/W**, and checks that magi_compile -(with the workaround on) beats vanilla ``torch.compile`` on that path. ++ spatial upsampling) compiled with **dynamic H/W** — a dynamic, conv-heavy graph +that triggers the pass — and checks that magi_compile beats vanilla +``torch.compile`` on that path. Real WAN 2.2 VAE decode (540p, dynamic H/W) numbers that motivate this: - with conv channels-last layout: 1.252s -> 542ms / decode (~2.3x) @@ -52,18 +55,16 @@ ND_TILING_SPEEDUP_THRESHOLD = 1.20 -@pytest.fixture(scope="module") -def decoder_device(): - return torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -@pytest.fixture(scope="module") -def decoder_input(decoder_device): - return torch.randn(1, LATENT_C, LATENT_T, LATENT_H, LATENT_W, device=decoder_device, dtype=torch.bfloat16) +@pytest.fixture(scope="function") +def decoder_input(device): + return torch.randn(1, LATENT_C, LATENT_T, LATENT_H, LATENT_W, device=device, dtype=torch.bfloat16) def _compile_decoder(device: torch.device): def _patch(cfg): + # This decoder is dynamic + conv-heavy, so the pass's heuristics fire and + # the workaround is applied. The pass mutates triton configs only inside the + # compilation's config.patch scope, so nothing leaks to the torch baseline. cfg.pass_config.enable_nd_tiling_workaround = True return cfg @@ -78,7 +79,7 @@ def _compile_torch(device: torch.device): @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") -def test_nd_tiling_workaround_speedup(decoder_device, decoder_input): +def test_nd_tiling_workaround_speedup(device, decoder_input): """ND-tiling ON should beat vanilla torch.compile on the dynamic path.""" # Build isolated inputs to prevent dynamic shape marking leakage eager_input = decoder_input.clone() @@ -88,9 +89,9 @@ def test_nd_tiling_workaround_speedup(decoder_device, decoder_input): # Explicitly mark dynamic dimensions for the vanilla torch.compile environment torch._dynamo.mark_dynamic(torch_input, [3, 4]) - eager_model = VAEDecoderLike().to(decoder_device).to(torch.bfloat16).eval() - magi_compiled = _compile_decoder(decoder_device) - torch_compiled = _compile_torch(decoder_device) + eager_model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + magi_compiled = _compile_decoder(device) + torch_compiled = _compile_torch(device) with torch.no_grad(): eager_result = cuda_benchmark(lambda: eager_model(eager_input))