diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 59db755..2c281ea 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -14,7 +14,7 @@ jobs: integration_test: name: Integration Test runs-on: [self-hosted, magi-compiler] - timeout-minutes: 30 + timeout-minutes: 40 env: http_proxy: ${{ secrets.HTTP_PROXY }} https_proxy: ${{ secrets.HTTPS_PROXY }} diff --git a/magi_compiler/config.py b/magi_compiler/config.py index a303093..0b180d1 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -63,7 +63,24 @@ class PassConfig(BaseModel): # TODO: Add no-op elimination pass. # TODO: Add sequence parallelism pass and async TP pass. # TODO: Add Ulysses overlap pass. - enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.") + enable_sage_attn: bool = Field( + False, + description=( + "Whether to replace flash attention with sage attention. " + "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_SAGE_ATTN (1/0/true/false)." + ), + ) + enable_nd_tiling_workaround: bool = Field( + True, + description=( + "Triton ND-tiling workaround (prefer_nd_tiling + max_tiles=3 + tile_reductions) " + "for Inductor's coalesce tiling bailing out under dynamic shapes. " + "True (default): register the pass and let its internal heuristics decide whether to " + "apply (currently: torch < 2.11.0 AND dynamic shapes AND conv-heavy). " + "False: do not register the pass at all. " + "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND (1/0/true/false)." + ), + ) enable_mm_epilogue_fusion: bool = Field( False, description=( @@ -73,7 +90,8 @@ class PassConfig(BaseModel): "(sm_90) the swiglu sub-path additionally uses the native Sm90 " "TMA + WGMMA DualGemm. The pass is a no-op on older architectures " "regardless of this flag, but the flag still controls whether it " - "is registered at all." + "is registered at all. " + "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_MM_EPILOGUE_FUSION (1/0/true/false)." ), ) @@ -171,6 +189,10 @@ class CompileConfig(BaseSettings): model_config = SettingsConfigDict( env_prefix="MAGI_COMPILE_", + # Nested sub-configs (e.g. pass_config, offload_config) are reachable via + # ``MAGI_COMPILE___`` env vars, e.g. + # ``MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND=1``. + env_nested_delimiter="__", populate_by_name=True, cli_parse_args=True, cli_ignore_unknown_args=True, diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 0d010e3..9569bb2 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -523,6 +523,8 @@ def _configure_custom_passes(self): self.inductor_compile_config[post_grad_key] = post_grad_pass_manager + post_grad_pass_manager.snapshot_original_inductor_configs(self.inductor_compile_config) + def _init_cache(self) -> str: hash_key = compute_hash( [ diff --git a/magi_compiler/passes/pass_base/__init__.py b/magi_compiler/passes/pass_base/__init__.py index 17dd699..8919555 100644 --- a/magi_compiler/passes/pass_base/__init__.py +++ b/magi_compiler/passes/pass_base/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .inductor_pass import InductorPass -from .magi_inductor_pass import MagiInductorPass +from .magi_inductor_pass import MagiInductorPass, snapshot_original_inductor_configs from .pass_context import get_pass_context, pass_context -__all__ = ["InductorPass", "pass_context", "get_pass_context", "MagiInductorPass"] +__all__ = ["InductorPass", "pass_context", "get_pass_context", "MagiInductorPass", "snapshot_original_inductor_configs"] diff --git a/magi_compiler/passes/pass_base/magi_inductor_pass.py b/magi_compiler/passes/pass_base/magi_inductor_pass.py index b01e955..7f337dd 100644 --- a/magi_compiler/passes/pass_base/magi_inductor_pass.py +++ b/magi_compiler/passes/pass_base/magi_inductor_pass.py @@ -12,10 +12,54 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +from collections.abc import Iterable + +import torch +from torch.fx.experimental.symbolic_shapes import has_free_symbols + from .inductor_pass import InductorPass +DEFAULT_CONV_HEAVY_THRESHOLD = 300 + class MagiInductorPass(InductorPass): """ Base class for inductor passes. """ + + # If a pass needs to modify any Inductor configuration (``torch._inductor.config``), + # it **MUST** declare all affected config keys here. The declared keys will be + # automatically snapshotted and restored after the subgraph compilation ends + # to prevent global leakage. + # + # Note: Only passes running in ``PostGradPassManager`` are allowed to mutate Inductor + # configs. Passes running in ``FullGraphPassManager`` **MUST NOT** modify them, as full-graph + # passes do not trigger compilation and cannot be patched/isolated. + inductor_config_keys_potentially_mutated_by_this_pass: tuple[str, ...] = () + + def is_dynamic(self, graph: torch.fx.Graph) -> bool: + """Determine if the graph has dynamic shapes by checking if any placeholder carries free symbols.""" + placeholder_vals = (n.meta.get("val", n.meta.get("example_value")) for n in graph.nodes if n.op == "placeholder") + return any(v is not None and has_free_symbols(v) for v in placeholder_vals) + + def is_conv_heavy(self, graph: torch.fx.Graph, threshold: int = DEFAULT_CONV_HEAVY_THRESHOLD) -> bool: + """Determine if the graph is convolution-heavy (dense in convolutions).""" + nnodes = len(list(graph.nodes)) + nconv = sum(1 for n in graph.nodes if n.target == torch.ops.aten.convolution.default) + return nnodes < threshold * nconv + + +def snapshot_original_inductor_configs(passes: Iterable, inductor_compile_config: dict) -> None: + """Snapshot the original values of global Inductor configs that passes potentially mutate. + + The captured original values are stored in ``inductor_compile_config``. When ``standalone_compile`` + calls ``compile_fx``, it automatically passes this config as ``config_patches`` to Inductor's + ``config.patch`` context manager. No matter what values the passes temporarily set these fields to + during compilation, they will be safely restored to their pre-compilation state on scope exit. + """ + cfg = torch._inductor.config + for pass_ in passes: + for key in getattr(pass_, "inductor_config_keys_potentially_mutated_by_this_pass", ()): + snapshot = functools.reduce(getattr, key.split("."), cfg) + inductor_compile_config.setdefault(key, snapshot) diff --git a/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py b/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py new file mode 100644 index 0000000..568ce00 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/nd_tiling_workaround.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch.torch_version import TorchVersion + +from ...magi_depyf.timeline import emit_pass_lifecycle +from ..pass_base import MagiInductorPass + + +class ND_TilingWorkaroundPass(MagiInductorPass): + inductor_config_keys_potentially_mutated_by_this_pass = ( + "triton.prefer_nd_tiling", + "triton.max_tiles", + "triton.tile_reductions", + ) + + def __init__(self): + super().__init__() + self.is_target_torch_version = TorchVersion(torch.__version__) < (2, 11, 0) + + @emit_pass_lifecycle + def __call__(self, graph: torch.fx.Graph): + if not self.is_target_torch_version or not self.is_dynamic(graph) or not self.is_conv_heavy(graph): + return False + + # On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on + # symbolic numels, so dynamic-shape transpose/permute/channels-last kernels + # degrade to untiled Grid1D. Forcing prefer_nd_tiling restores ND tiling + # (WAN 2.2 VAE 540p decode: ~1.45x). + torch._inductor.config.triton.prefer_nd_tiling = True + torch._inductor.config.triton.max_tiles = 3 + torch._inductor.config.triton.tile_reductions = True diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index 3323ea3..b087c62 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -20,7 +20,7 @@ from ...config import PassConfig, get_compile_config from ...utils import magi_logger, set_env_var from ...utils.envs import MAGI_PATTERN_MATCH_DEBUG -from ..pass_base import InductorPass, get_pass_context +from ..pass_base import InductorPass, get_pass_context, snapshot_original_inductor_configs from .fix_functionalization import FixFunctionalizationPass from .post_cleanup import PostCleanupPass @@ -80,6 +80,11 @@ def __call__(self, graph: fx.Graph): def configure(self, pass_config: PassConfig): self.pass_config = pass_config + if pass_config.enable_nd_tiling_workaround: + from .nd_tiling_workaround import ND_TilingWorkaroundPass + + self.add(ND_TilingWorkaroundPass()) + if pass_config.enable_mm_epilogue_fusion: compile_config = get_compile_config() if compile_config.has_cutlass: @@ -100,6 +105,10 @@ def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) self.passes.append(pass_) + def snapshot_original_inductor_configs(self, inductor_compile_config: dict) -> None: + """Snapshot original values of global Inductor configs potentially mutated by passes.""" + snapshot_original_inductor_configs(self.passes, inductor_compile_config) + def uuid(self): """ The PostGradPassManager is set as a custom pass in the Inductor and diff --git a/tests/feature_tests/conftest.py b/tests/feature_tests/conftest.py new file mode 100644 index 0000000..fe1b26c --- /dev/null +++ b/tests/feature_tests/conftest.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for Inductor-pass feature tests. + +``fake_mode`` plus the tensor/graph builders below are reused by both +``test_magi_inductor_pass.py`` (base-class utilities) and +``test_dynamic_nd_tiling.py`` (ND_TilingWorkaroundPass decision logic). +""" + +import pytest +import torch +import torch.fx as fx + + +@pytest.fixture +def fake_mode(): + """A FakeTensorMode backed by a fresh ShapeEnv for symbolic shapes.""" + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + return FakeTensorMode(shape_env=ShapeEnv()) + + +def static_tensor(fake_mode): + """A FakeTensor with fully concrete (non-symbolic) dims.""" + with fake_mode: + return torch.empty(4, 8) + + +def dynamic_tensor(fake_mode): + """A FakeTensor whose first dim is a free symbol (mimics a dynamic batch).""" + sym = fake_mode.shape_env.create_unbacked_symint() + with fake_mode: + return torch.empty(sym, 8) + + +def build_graph_module(fake_mode, *, placeholder_vals=(), placeholder_meta_key="example_value", n_conv=0, n_filler=0): + """Build a tiny fx.GraphModule for exercising pass decision logic. + + ``placeholder_vals`` populate each placeholder's ``meta[placeholder_meta_key]`` + (drives ``is_dynamic``; ``MagiInductorPass.is_dynamic`` reads ``meta["val"]`` + first and falls back to ``meta["example_value"]``). ``n_conv`` adds + ``aten.convolution.default`` call nodes and ``n_filler`` adds plain call + nodes to inflate the node count (drives the ``is_conv_heavy`` heuristic). + """ + graph = fx.Graph() + inputs = [] + for i, ev in enumerate(placeholder_vals): + node = graph.placeholder(f"arg_{i}") + node.meta[placeholder_meta_key] = ev + inputs.append(node) + if not inputs: + inputs.append(graph.placeholder("arg_0")) + + x = inputs[0] + for c in range(n_conv): + weight = graph.placeholder(f"weight_{c}") + x = graph.call_function(torch.ops.aten.convolution.default, args=(x, weight)) + for _ in range(n_filler): + x = graph.call_function(torch.ops.aten.relu.default, args=(x,)) + graph.output((x,)) + return fx.GraphModule(torch.nn.Module(), graph) diff --git a/tests/feature_tests/test_magi_inductor_pass.py b/tests/feature_tests/test_magi_inductor_pass.py new file mode 100644 index 0000000..79d15da --- /dev/null +++ b/tests/feature_tests/test_magi_inductor_pass.py @@ -0,0 +1,186 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the base class MagiInductorPass and its helper utilities.""" + +import pytest +import torch +from pydantic import ValidationError + +from magi_compiler.config import CompileConfig +from magi_compiler.passes.pass_base import MagiInductorPass, snapshot_original_inductor_configs +from tests.feature_tests.conftest import build_graph_module, dynamic_tensor, static_tensor + +_ND_TILING_ENV = "MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND" + + +class DummyPass(MagiInductorPass): + inductor_config_keys_potentially_mutated_by_this_pass = ("triton.prefer_nd_tiling", "triton.max_tiles") + + def __call__(self, graph: torch.fx.Graph): + # Mutate the configs during execution + torch._inductor.config.triton.prefer_nd_tiling = True + torch._inductor.config.triton.max_tiles = 3 + + +class UndeclaredPass(MagiInductorPass): + def __call__(self, graph: torch.fx.Graph): + pass + + +def test_is_dynamic(fake_mode): + """is_dynamic flags graphs whose placeholders carry free symbols.""" + pass_ = DummyPass() + + # Static graph + gm_static = build_graph_module(fake_mode, placeholder_vals=[static_tensor(fake_mode)]) + assert not pass_.is_dynamic(gm_static.graph) + + # Dynamic graph + gm_dynamic = build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)]) + assert pass_.is_dynamic(gm_dynamic.graph) + + +def test_is_dynamic_reads_val_meta_key(fake_mode): + """is_dynamic prefers meta["val"], falling back to meta["example_value"].""" + pass_ = DummyPass() + + # Symbol carried under meta["val"] (the preferred key) must be detected. + gm_val = build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], placeholder_meta_key="val") + assert pass_.is_dynamic(gm_val.graph) + + # A placeholder with neither meta key set is treated as static (not dynamic). + gm_no_meta = build_graph_module(fake_mode) + assert not pass_.is_dynamic(gm_no_meta.graph) + + +def test_is_conv_heavy(fake_mode): + """is_conv_heavy flags graphs whose node count is dense relative to convs.""" + pass_ = DummyPass() + + # 1 conv, 50 filler nodes -> nnodes = 53 (1 input + 1 weight + 1 conv + 50 relu). + # threshold = 300, threshold * nconv = 300. nnodes < 300 -> is_conv_heavy is True. + gm_heavy = build_graph_module(fake_mode, n_conv=1, n_filler=50) + assert pass_.is_conv_heavy(gm_heavy.graph, threshold=300) + + # 1 conv, 320 filler nodes -> nnodes = 323. + # threshold = 300, threshold * nconv = 300. nnodes >= 300 -> is_conv_heavy is False. + gm_light = build_graph_module(fake_mode, n_conv=1, n_filler=320) + assert not pass_.is_conv_heavy(gm_light.graph, threshold=300) + + +def test_is_conv_heavy_zero_conv(fake_mode): + """A graph with no convolutions is never conv-heavy (threshold * 0 == 0).""" + pass_ = DummyPass() + gm_no_conv = build_graph_module(fake_mode, n_conv=0, n_filler=10) + assert not pass_.is_conv_heavy(gm_no_conv.graph, threshold=300) + + +def test_snapshot_original_inductor_configs(): + """Verify that snapshot_original_inductor_configs snapshots declared keys correctly.""" + cfg = {} + pass_ = DummyPass() + snapshot_original_inductor_configs([pass_], cfg) + + # Check that declared keys are snapshotted + for key in pass_.inductor_config_keys_potentially_mutated_by_this_pass: + assert key in cfg + + # setdefault: a value already set by the user/upstream is preserved + cfg_with_preset = {"triton.prefer_nd_tiling": True} + snapshot_original_inductor_configs([pass_], cfg_with_preset) + assert cfg_with_preset["triton.prefer_nd_tiling"] is True + + # A pass declaring an empty tuple contributes no anchors + cfg_empty = {} + snapshot_original_inductor_configs([UndeclaredPass()], cfg_empty) + assert cfg_empty == {} + + +def test_snapshot_prevents_global_leakage(fake_mode): + """Verify that snapshot_original_inductor_configs prevents global config leakage when used with config.patch.""" + cfg = torch._inductor.config + + # Snapshot original states to verify restoration later + orig_prefer_nd_tiling = cfg.triton.prefer_nd_tiling + orig_max_tiles = cfg.triton.max_tiles + + # Ensure they are currently at their default values (or at least we know what they are) + assert orig_prefer_nd_tiling is False + assert orig_max_tiles is None + + pass_ = DummyPass() + inductor_compile_config = {} + snapshot_original_inductor_configs([pass_], inductor_compile_config) + + # Simulate the compilation scope with config.patch + with cfg.patch(inductor_compile_config): + gm = build_graph_module(fake_mode) + pass_(gm.graph) + + # Inside the scope, the config should be mutated by the pass + assert cfg.triton.prefer_nd_tiling is True + assert cfg.triton.max_tiles == 3 + + # Outside the scope, the config must be restored to its original values, preventing leakage + assert cfg.triton.prefer_nd_tiling == orig_prefer_nd_tiling + assert cfg.triton.max_tiles == orig_max_tiles + + +@pytest.mark.parametrize("env_value, expected", [("1", True), ("true", True), ("0", False), ("false", False)]) +def test_env_nested_delimiter_config_parsing(monkeypatch, env_value, expected): + """A nested sub-config field is overridable via the MAGI_COMPILE___ env var. + + pydantic parses the bool field here, so only truthy/falsy literals + (1/0/true/false) are accepted as on/off. + """ + monkeypatch.setenv(_ND_TILING_ENV, env_value) + config = CompileConfig() + assert config.pass_config.enable_nd_tiling_workaround is expected + + +def test_env_unset_defaults_to_true(monkeypatch): + """When the env var is unset, the binary field defaults to True.""" + monkeypatch.delenv(_ND_TILING_ENV, raising=False) + config = CompileConfig() + assert config.pass_config.enable_nd_tiling_workaround is True + + +@pytest.mark.parametrize("env_value", ["none", "null", "maybe", ""]) +def test_env_rejects_non_bool_strings(monkeypatch, env_value): + """The field is a bool, so non-bool strings raise a ValidationError. + + Only 1/0/true/false round-trip; everything else is rejected. + """ + monkeypatch.setenv(_ND_TILING_ENV, env_value) + with pytest.raises(ValidationError): + CompileConfig() + + +def test_config_binary_registration_mapping(): + """PostGradPassManager registers the pass iff enable_nd_tiling_workaround is True. + + - config = True --> pass is registered + - config = False --> pass is not registered at all + """ + from magi_compiler.config import PassConfig + from magi_compiler.passes.piecewise_graph.post_grad_pass_manager import PostGradPassManager + + pm_true = PostGradPassManager() + pm_true.configure(PassConfig(enable_nd_tiling_workaround=True)) + assert len(pm_true.passes) == 1 + + pm_false = PostGradPassManager() + pm_false.configure(PassConfig(enable_nd_tiling_workaround=False)) + assert len(pm_false.passes) == 0 diff --git a/tests/feature_tests/test_nd_tiling_workaround.py b/tests/feature_tests/test_nd_tiling_workaround.py new file mode 100644 index 0000000..411edbc --- /dev/null +++ b/tests/feature_tests/test_nd_tiling_workaround.py @@ -0,0 +1,117 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decision-logic tests for ``ND_TilingWorkaroundPass``. + +When applicable, the pass flips three ``torch._inductor.config`` triton keys +(``prefer_nd_tiling`` / ``max_tiles`` / ``tile_reductions``) ON. The binary +``enable_nd_tiling_workaround`` config controls registration: + + * ``True`` (default) -> register the pass; its internal heuristics then decide: + apply iff dynamic shapes AND PyTorch < 2.11.0 AND conv-heavy. + * ``False`` -> pass not registered at all. + +These tests assert the registered pass's heuristic decision. The shared +base-class utilities (``is_dynamic``, ``is_conv_heavy``, config +snapshot/anti-leakage) are tested in ``test_magi_inductor_pass.py``; the +end-to-end speedup in ``tests/perf_tests/test_nd_tiling_perf_workaround.py``. +""" + +import pytest +import torch + +from magi_compiler.config import PassConfig +from magi_compiler.passes.piecewise_graph.nd_tiling_workaround import ND_TilingWorkaroundPass +from tests.feature_tests.conftest import build_graph_module, dynamic_tensor, static_tensor + + +@pytest.fixture(autouse=True) +def _restore_inductor_config(): + """Snapshot/restore the three triton keys around every test. + + The pass mutates the process-global ``torch._inductor.config`` directly, so + without this fixture one test could leak into the next. + """ + cfg = torch._inductor.config + saved = (cfg.triton.prefer_nd_tiling, cfg.triton.max_tiles, cfg.triton.tile_reductions) + try: + yield + finally: + cfg.triton.prefer_nd_tiling, cfg.triton.max_tiles, cfg.triton.tile_reductions = saved + + +def _make_pass(*, is_target_torch_version=True): + """Build the pass and pin its version gate. + + The pass reads ``torch.__version__`` at construction and caches whether it is + a target version (< 2.11.0) in ``is_target_torch_version``. Tests override that + cached flag directly so the version branch is exercised regardless of the + installed torch. + """ + pass_ = ND_TilingWorkaroundPass() + pass_.is_target_torch_version = is_target_torch_version + return pass_ + + +def _assert_injected(injected): + cfg = torch._inductor.config + if injected: + assert cfg.triton.prefer_nd_tiling is True + assert cfg.triton.max_tiles == 3 + assert cfg.triton.tile_reductions is True + else: + assert cfg.triton.prefer_nd_tiling is False + assert cfg.triton.max_tiles is None + assert cfg.triton.tile_reductions is False + + +def _auto_eligible_graph(fake_mode): + """Dynamic, conv-heavy graph (nnodes < 300 * nconv): the workaround applies.""" + return build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], n_conv=1, n_filler=5) + + +@pytest.mark.parametrize("value", [True, False]) +def test_config_field_binary(value): + """enable_nd_tiling_workaround accepts True/False (default True covered in test_magi_inductor_pass).""" + assert PassConfig(enable_nd_tiling_workaround=value).enable_nd_tiling_workaround is value + + +def test_auto_injects_when_all_conditions_met(fake_mode): + pass_ = _make_pass(is_target_torch_version=True) + gm = _auto_eligible_graph(fake_mode) + pass_(gm.graph) + _assert_injected(True) + + +def test_auto_skips_on_static_shapes(fake_mode): + pass_ = _make_pass(is_target_torch_version=True) + gm = build_graph_module(fake_mode, placeholder_vals=[static_tensor(fake_mode)], n_conv=0, n_filler=5) + pass_(gm.graph) + _assert_injected(False) + + +def test_auto_skips_on_fixed_version(fake_mode): + """Dynamic shapes but PyTorch >= 2.11.0: native coalesce path handles it.""" + pass_ = _make_pass(is_target_torch_version=False) + gm = _auto_eligible_graph(fake_mode) + pass_(gm.graph) + _assert_injected(False) + + +def test_auto_skips_when_graph_not_conv_heavy(fake_mode): + """``nnodes >= 300 * nconv`` (conv-sparse graph): low conv ratio, ND-tiling gives little, so skip.""" + pass_ = _make_pass(is_target_torch_version=True) + gm = build_graph_module(fake_mode, placeholder_vals=[dynamic_tensor(fake_mode)], n_conv=1, n_filler=320) + pass_(gm.graph) + _assert_injected(False) diff --git a/tests/model_definition.py b/tests/model_definition.py index 8686f5b..ae6c04f 100644 --- a/tests/model_definition.py +++ b/tests/model_definition.py @@ -398,6 +398,49 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None) return self.lm_head(x) +class ResBlock3D(nn.Module): + """3D conv residual block (GroupNorm + SiLU + Conv3d, ×2) with a skip path.""" + + def __init__(self, cin: int, cout: int): + super().__init__() + self.norm1 = nn.GroupNorm(32, cin) + self.conv1 = nn.Conv3d(cin, cout, 3, padding=1) + self.norm2 = nn.GroupNorm(32, cout) + self.conv2 = nn.Conv3d(cout, cout, 3, padding=1) + self.skip = nn.Conv3d(cin, cout, 1) if cin != cout else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.conv1(F.silu(self.norm1(x))) + h = self.conv2(F.silu(self.norm2(h))) + return h + self.skip(x) + + +class VAEDecoderLike(nn.Module): + """Stacked 3D conv resblocks + spatial upsampling, mimicking a VAE decoder.""" + + def __init__(self, zc: int = 48, base: int = 128): + super().__init__() + self.conv_in = nn.Conv3d(zc, base, 3, padding=1) + self.r1 = ResBlock3D(base, base) + self.up1 = nn.Conv3d(base, base, 3, padding=1) + self.r2 = ResBlock3D(base, base // 2) + self.up2 = nn.Conv3d(base // 2, base // 2, 3, padding=1) + self.r3 = ResBlock3D(base // 2, base // 4) + self.norm_out = nn.GroupNorm(32, base // 4) + self.conv_out = nn.Conv3d(base // 4, 3, 3, padding=1) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + x = self.conv_in(z) + x = self.r1(x) + x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") + x = self.up1(x) + x = self.r2(x) + x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") + x = self.up2(x) + x = self.r3(x) + return self.conv_out(F.silu(self.norm_out(x))) + + def create_transformer_model(config: TransformerConfig, device: torch.device) -> Transformer: """Create Transformer model diff --git a/tests/perf_tests/test_nd_tiling_perf_workaround.py b/tests/perf_tests/test_nd_tiling_perf_workaround.py new file mode 100644 index 0000000..e08260f --- /dev/null +++ b/tests/perf_tests/test_nd_tiling_perf_workaround.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance test: Triton ND-tiling workaround under dynamic shapes. + +Background +---------- +On PyTorch < 2.11.0, Inductor's coalesce tiling analysis bails out on symbolic +numels (``tiling_utils.extract_normalized_read_writes`` returns ``None``), so +transpose/permute/channels-last pointwise kernels in a dynamic-shape graph +degrade to untiled Grid1D. ``ND_TilingWorkaroundPass`` works around this by +enabling ``triton.prefer_nd_tiling`` (+ ``max_tiles=3`` + ``tile_reductions``) +when the post-grad graph is dynamic AND conv-heavy (the regime where the +degraded kernels dominate); see +``magi_compiler.passes.piecewise_graph.nd_tiling_workaround.ND_TilingWorkaroundPass``. + +This test exercises a WAN-2.2-VAE-decode-like workload (stacked 3D conv resblocks ++ spatial upsampling) compiled with **dynamic H/W** — a dynamic, conv-heavy graph +that triggers the pass — and checks that magi_compile beats vanilla +``torch.compile`` on that path. + +Real WAN 2.2 VAE decode (540p, dynamic H/W) numbers that motivate this: + - with conv channels-last layout: 1.252s -> 542ms / decode (~2.3x) + - without conv channels-last: 770ms -> 535ms / decode (~1.44x) +This synthetic decoder (no weights, no conv channels-last pass) reproduces the +"~1.4x" regime. The absolute ratio is GPU-dependent, so ND_TILING_SPEEDUP_THRESHOLD +is set to a conservative lower bound that still proves a clear, non-noise win. +""" + +import pytest +import torch + +from magi_compiler import magi_compile +from tests.model_definition import VAEDecoderLike +from tests.perf_tests import cuda_benchmark, print_perf_comparison +from tests.perf_tests.utils import assert_magi_vs_torch + +# WAN 2.2 VAE 540p latent: [C, T, H, W]; dynamic dims are H and W. +LATENT_C, LATENT_T, LATENT_H, LATENT_W = 48, 7, 34, 60 + +# magi_compile (workaround on) vs vanilla torch.compile, both on the dynamic path. +# Observed ~1.36x (torch=2.20ms -> magi=1.63ms) on H100; assert a conservative +# lower bound that still proves a clear, non-noise win. +ND_TILING_SPEEDUP_THRESHOLD = 1.20 + + +@pytest.fixture(scope="function") +def decoder_input(device): + return torch.randn(1, LATENT_C, LATENT_T, LATENT_H, LATENT_W, device=device, dtype=torch.bfloat16) + + +def _compile_decoder(device: torch.device): + def _patch(cfg): + # This decoder is dynamic + conv-heavy, so the pass's heuristics fire and + # the workaround is applied. The pass mutates triton configs only inside the + # compilation's config.patch scope, so nothing leaks to the torch baseline. + cfg.pass_config.enable_nd_tiling_workaround = True + return cfg + + model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + # Dynamic H, W (latent dims 3, 4) — the regime where coalesce tiling bails out. + return magi_compile(model, dynamic_arg_dims={"z": [3, 4]}, config_patch=_patch) + + +def _compile_torch(device: torch.device): + model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + return torch.compile(model, backend="inductor") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_nd_tiling_workaround_speedup(device, decoder_input): + """ND-tiling ON should beat vanilla torch.compile on the dynamic path.""" + # Build isolated inputs to prevent dynamic shape marking leakage + eager_input = decoder_input.clone() + magi_input = decoder_input.clone() + torch_input = decoder_input.clone() + + # Explicitly mark dynamic dimensions for the vanilla torch.compile environment + torch._dynamo.mark_dynamic(torch_input, [3, 4]) + + eager_model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + magi_compiled = _compile_decoder(device) + torch_compiled = _compile_torch(device) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_model(eager_input)) + torch_result = cuda_benchmark(lambda: torch_compiled(torch_input), compilation_warmup=3) + magi_result = cuda_benchmark(lambda: magi_compiled(magi_input), compilation_warmup=3) + + speedup = torch_result.median / magi_result.median + print_perf_comparison( + "Dynamic ND-tiling: magi_compile vs torch.compile (dynamic H/W)", + eager_result, + magi_result, + torch_result, + extra_info=(f"latent=({LATENT_C}, {LATENT_T}, {LATENT_H}, {LATENT_W}) " f"speedup(torch/magi)={speedup:.2f}x"), + ) + + assert_magi_vs_torch( + speedup, torch_result, magi_result, label="Dynamic ND-tiling workaround", threshold=ND_TILING_SPEEDUP_THRESHOLD + )