diff --git a/magi_compiler/config.py b/magi_compiler/config.py index 0b180d1..44fbd03 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -70,6 +70,17 @@ class PassConfig(BaseModel): "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_SAGE_ATTN (1/0/true/false)." ), ) + enable_conv_channels_last: bool = Field( + True, + description=( + "Forces channels-last (NHWC/NDHWC) inputs/weights at conv boundaries " + "so cuDNN can select faster layout-optimized kernels. " + "True (default): register and let its internal heuristics decide whether to " + "apply (currently: static shapes AND conv-heavy graphs). " + "False: do not register the pass at all. " + "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_CONV_CHANNELS_LAST (1/0/true/false)." + ), + ) enable_nd_tiling_workaround: bool = Field( True, description=( @@ -91,7 +102,7 @@ class PassConfig(BaseModel): "TMA + WGMMA DualGemm. The pass is a no-op on older architectures " "regardless of this flag, but the flag still controls whether it " "is registered at all. " - "Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_MM_EPILOGUE_FUSION (1/0/true/false)." + "Settable via the MAGI_COMPILE_PASS_CONFIG__ENABLE_MM_EPILOGUE_FUSION env var." ), ) diff --git a/magi_compiler/passes/piecewise_graph/conv_channels_last.py b/magi_compiler/passes/piecewise_graph/conv_channels_last.py new file mode 100644 index 0000000..c447566 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/conv_channels_last.py @@ -0,0 +1,177 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conv2d/Conv3d channels-last layout pass for the post-grad ATen graph. + +Forces channels-last (NHWC for 4D, NDHWC for 5D) at every ``aten.convolution`` +boundary by graph rewriting only -- no patching of PyTorch internals. + +Mechanism Under the Hood: +1. **FX-Meta Stride Injection**: The pass inserts ``aten.clone`` nodes before + each conv input/weight and manually configures their ``node.meta["val"]`` + to carry channels-last FakeTensors. Because Inductor's clone lowering + ignores ``memory_format`` (a known PyTorch upstream TODO), the channels-last + signal is carried purely within the FX meta strides. +2. **Inductor Constraint Co-design**: With ``layout_optimization=False`` set + by this pass, Inductor's pre-registered ``constrain_conv_to_fx_strides`` + (in ``torch/_inductor/kernel/conv.py``) fires. It reads our modified FX input + meta strides and triggers ``require_stride_order`` at the conv boundary. +3. **Zero-Cost Strided Allocation**: The clone lowers to a Pointwise kernel + with ``FlexibleLayout``. Consequently, the ``require_stride_order`` constraint + is zero-cost: the upstream buffer is allocated directly in channels-last + layout without generating an extra transpose copy kernel. +4. **cuDNN Memory-Format Probe**: With channels-last inputs safely matching + stride constraints, ``conv_layout()`` naturally infers a channels-last + cuDNN output ``FixedLayout``. + +The pass only fires on static, conv-heavy graphs; dynamic-shape or conv-sparse +graphs are skipped (their channels-last transpose tiles badly / gains little). +""" + +import torch +from torch import fx + +from ...magi_depyf.timeline import emit_pass_lifecycle +from ...utils import magi_logger +from ..pass_base import MagiInductorPass + +aten = torch.ops.aten + + +def _meta_val(node: fx.Node) -> torch.Tensor | None: + val = node.meta.get("val") if hasattr(node, "meta") else None + return val if isinstance(val, torch.Tensor) else None + + +# Single-input, layout-transparent ops the conv stride constraint can hoist through. +_HOISTABLE_OPS = (aten.constant_pad_nd.default,) + + +class ConvChannelsLastPass(MagiInductorPass): + """Make conv2d/conv3d inputs channels-last on the post-grad ATen graph. + + If the conv input comes from a single-consumer, layout-transparent op + (e.g., ``constant_pad_nd``), the clone is hoisted above it and its meta + rewritten to channels-last. This keeps the pad kernel coalesced instead of + triggering an extra memory-bound NC(D)HW -> N(D)HWC transpose. + """ + + inductor_config_keys_potentially_mutated_by_this_pass = ("layout_optimization",) + + @emit_pass_lifecycle + def __call__(self, graph: fx.Graph) -> bool: + # Only rewrite static, conv-heavy graphs. Channels-last inserts an + # NC(D)HW->N(D)HWC transpose; under dynamic shapes Inductor tiles it + # badly, and on conv-sparse graphs the few cuDNN channels-last kernels + # don't pay for the extra copies. + # TODO: If tiling optimization is upgraded to support conv layout opt + # under dynamic shapes, we can remove the ``is_dynamic`` check. + if self.is_dynamic(graph) or not self.is_conv_heavy(graph): + return False + + torch._inductor.config.layout_optimization = False + + # (input node, memory_format) -> clone node, so a weight shared by + # several convs (or a tensor feeding several convs) is cloned once. + clone_cache: dict[tuple[fx.Node, torch.memory_format], fx.Node] = {} + num_hoisted = 0 + + def channels_last_clone(inp: fx.Node, memory_format, insert_point) -> fx.Node | None: + key = (inp, memory_format) + cached = clone_cache.get(key) + if cached is not None: + return cached + inp_val = _meta_val(inp) + if inp_val is None: + return None + if inp_val.is_contiguous(memory_format=memory_format): + return None # already channels-last per meta + with graph.inserting_before(insert_point): + cl = graph.call_function(aten.clone.default, (inp,), {"memory_format": memory_format}) + cl.meta = {**inp.meta} + cl.meta["val"] = inp_val.clone(memory_format=memory_format) + clone_cache[key] = cl + return cl + + def make_channels_last(node: fx.Node, memory_format, depth: int = 0) -> bool: + """Make ``node``'s FX meta channels-last; return True on success.""" + nonlocal num_hoisted + node_val = _meta_val(node) + if node_val is None: + return False + if node_val.is_contiguous(memory_format=memory_format): + return True # already channels-last per meta + + # Hoist through single-consumer layout-transparent ops: rewrite this + # op's meta to channels-last and recurse on its input, so the + # transpose fuses with the upstream producer instead of the pad kernel. + if depth < 8 and node.op == "call_function" and node.target in _HOISTABLE_OPS and len(node.users) == 1: + src = node.args[0] + if isinstance(src, fx.Node): + if not make_channels_last(src, memory_format, depth + 1): + # Chain top: materialise the layout change here, above + # the hoistable op. + cl = channels_last_clone(src, memory_format, node) + if cl is None: + return False + node.replace_input_with(src, cl) + node.meta["val"] = node_val.clone(memory_format=memory_format) + num_hoisted += 1 + return True + return False + + num_converted = 0 + for conv in list(graph.nodes): + if conv.op != "call_function" or conv.target != aten.convolution.default: + continue + x_val = _meta_val(conv.args[0]) + if x_val is None: + continue + if x_val.ndim == 4: + memory_format = torch.channels_last + elif x_val.ndim == 5: + memory_format = torch.channels_last_3d + else: + continue # conv1d etc.: leave untouched + + new_args = list(conv.args) + changed = False + for idx in (0, 1): # x, weight + inp = new_args[idx] + if not isinstance(inp, fx.Node): + continue + # Try hoisting first (rewrites pad metas upstream in place). + if idx == 0 and make_channels_last(inp, memory_format): + inp_val = _meta_val(inp) + if inp_val is not None and inp_val.is_contiguous(memory_format=memory_format): + changed = True + continue + cl = channels_last_clone(inp, memory_format, conv) + if cl is not None: + new_args[idx] = cl + changed = True + if changed: + conv.args = tuple(new_args) + num_converted += 1 + + if num_converted: + graph.lint() + magi_logger.info( + "ConvChannelsLastPass: routed %d forward conv(s) through channels-last clones " + "(%d clone node(s) inserted, %d pad meta(s) hoisted to channels-last)", + num_converted, + len(clone_cache), + num_hoisted, + ) + return (num_converted) > 0 diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index b087c62..b9b6d27 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -80,6 +80,11 @@ def __call__(self, graph: fx.Graph): def configure(self, pass_config: PassConfig): self.pass_config = pass_config + if pass_config.enable_conv_channels_last: + from .conv_channels_last import ConvChannelsLastPass + + self.add(ConvChannelsLastPass()) + if pass_config.enable_nd_tiling_workaround: from .nd_tiling_workaround import ND_TilingWorkaroundPass diff --git a/tests/feature_tests/test_conv_channels_last_switch.py b/tests/feature_tests/test_conv_channels_last_switch.py new file mode 100644 index 0000000..d758aeb --- /dev/null +++ b/tests/feature_tests/test_conv_channels_last_switch.py @@ -0,0 +1,119 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decision-logic tests for ``ConvChannelsLastPass``. + +When applicable, the pass rewrites static, conv-heavy graphs (by inserting +``aten.clone`` layout-changing nodes) to trigger channels-last convolutions. +The binary ``enable_conv_channels_last`` config controls registration: + + * ``True`` (default) -> register the pass; its internal heuristics then decide: + apply iff static shapes AND conv-heavy. + * ``False`` -> pass not registered at all. + +These tests assert the registered pass's heuristic decision. The shared +base-class utilities (``is_dynamic``, ``is_conv_heavy``, config +snapshot/anti-leakage) are tested in ``test_magi_inductor_pass.py``; the +end-to-end speedup in ``tests/perf_tests/test_conv_channels_last_perf.py``. +""" + +import pytest +import torch +import torch.fx as fx + +from magi_compiler.config import PassConfig +from magi_compiler.passes.piecewise_graph.conv_channels_last import ConvChannelsLastPass + +aten = torch.ops.aten + + +def _build_conv_graph(fake_mode, *, dynamic: bool, n_conv: int = 1, n_filler: int = 0) -> fx.Graph: + """Build a tiny conv3d graph for the channels-last decision logic. + + ``dynamic`` makes the input placeholder's first dim a free symbol (drives + ``is_dynamic``). ``n_conv`` adds ``aten.convolution.default`` nodes (5-D + inputs/weights => conv3d, so the pass targets them). ``n_filler`` adds plain + ``relu`` nodes to inflate the node count (drives ``nnodes vs 300 * nconv``). + + Inputs are kept contiguous (NCDHW) so the pass has a real layout change to + make: a successful rewrite inserts ``aten.clone`` (channels_last_3d) nodes. + """ + graph = fx.Graph() + if dynamic: + sym = fake_mode.shape_env.create_unbacked_symint() + with fake_mode: + x_val = torch.empty(sym, 8, 4, 4, 4) + else: + with fake_mode: + x_val = torch.empty(2, 8, 4, 4, 4) + + x = graph.placeholder("x") + x.meta["val"] = x_val + + with fake_mode: + weight_val = torch.empty(8, 8, 3, 3, 3) + out_val = torch.empty(2, 8, 4, 4, 4) + + node = x + for c in range(n_conv): + weight = graph.placeholder(f"weight_{c}") + weight.meta["val"] = weight_val + node = graph.call_function(aten.convolution.default, args=(node, weight)) + node.meta["val"] = out_val + for _ in range(n_filler): + node = graph.call_function(aten.relu.default, args=(node,)) + node.meta["val"] = out_val + graph.output((node,)) + return graph + + +def _num_channels_last_clones(graph: fx.Graph) -> int: + """Count ``aten.clone`` nodes the pass inserts to force channels-last.""" + return sum(1 for n in graph.nodes if n.op == "call_function" and n.target == aten.clone.default) + + +def _run_pass(graph: fx.Graph) -> int: + """Run the pass on ``graph`` and return how many channels-last clones it inserted.""" + ConvChannelsLastPass()(graph) + return _num_channels_last_clones(graph) + + +@pytest.mark.parametrize("value", [True, False]) +def test_config_field_binary(value): + """enable_conv_channels_last accepts True/False (default True covered in test_magi_inductor_pass).""" + assert PassConfig(enable_conv_channels_last=value).enable_conv_channels_last is value + + +def test_auto_skips_dynamic(fake_mode): + """auto: a dynamic graph is skipped (no clones inserted).""" + graph = _build_conv_graph(fake_mode, dynamic=True, n_conv=1, n_filler=0) + assert _run_pass(graph) == 0 + + +def test_auto_skips_static_conv_sparse(fake_mode): + """auto: a static but conv-sparse graph (nnodes >= 300 * nconv, i.e. not is_conv_heavy) is skipped.""" + graph = _build_conv_graph(fake_mode, dynamic=False, n_conv=1, n_filler=320) + assert _run_pass(graph) == 0 + + +def test_auto_rewrites_static_conv_heavy(fake_mode): + """auto: a static, conv-heavy graph (nnodes < 300 * nconv, i.e. is_conv_heavy) gets rewritten.""" + graph = _build_conv_graph(fake_mode, dynamic=False, n_conv=1, n_filler=0) + assert _run_pass(graph) > 0 + + +def test_auto_skips_dynamic_conv_heavy(fake_mode): + """auto: dynamic dominates -- even a conv-heavy dynamic graph is skipped.""" + graph = _build_conv_graph(fake_mode, dynamic=True, n_conv=1, n_filler=0) + assert _run_pass(graph) == 0 diff --git a/tests/feature_tests/test_magi_inductor_pass.py b/tests/feature_tests/test_magi_inductor_pass.py index 79d15da..2d5ecfa 100644 --- a/tests/feature_tests/test_magi_inductor_pass.py +++ b/tests/feature_tests/test_magi_inductor_pass.py @@ -166,21 +166,3 @@ def test_env_rejects_non_bool_strings(monkeypatch, env_value): monkeypatch.setenv(_ND_TILING_ENV, env_value) with pytest.raises(ValidationError): CompileConfig() - - -def test_config_binary_registration_mapping(): - """PostGradPassManager registers the pass iff enable_nd_tiling_workaround is True. - - - config = True --> pass is registered - - config = False --> pass is not registered at all - """ - from magi_compiler.config import PassConfig - from magi_compiler.passes.piecewise_graph.post_grad_pass_manager import PostGradPassManager - - pm_true = PostGradPassManager() - pm_true.configure(PassConfig(enable_nd_tiling_workaround=True)) - assert len(pm_true.passes) == 1 - - pm_false = PostGradPassManager() - pm_false.configure(PassConfig(enable_nd_tiling_workaround=False)) - assert len(pm_false.passes) == 0 diff --git a/tests/perf_tests/test_conv_channels_last_perf.py b/tests/perf_tests/test_conv_channels_last_perf.py new file mode 100644 index 0000000..129e9e4 --- /dev/null +++ b/tests/perf_tests/test_conv_channels_last_perf.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance test: conv channels-last layout pass under static shapes. + +Background +---------- +cuDNN's channels-last (NHWC/NDHWC) conv kernels beat contiguous NC(D)HW on +Ampere+. ``ConvChannelsLastPass`` rewrites the post-grad ATen graph so every +``aten.convolution`` (conv2d/conv3d) consumes channels-last inputs/weights, +letting cuDNN pick those kernels; it sets ``layout_optimization=False`` so layout +is owned entirely by the pass. The pass applies only when the post-grad graph is +static AND conv-heavy (the regime where channels-last pays off and its transpose +doesn't tile badly); see +``magi_compiler.passes.piecewise_graph.conv_channels_last.ConvChannelsLastPass``. + +This test exercises a WAN-2.2-VAE-decode-like workload (stacked 3D conv resblocks ++ spatial upsampling) compiled with **static shapes** — a static, conv-heavy graph +that triggers the pass — and checks that magi_compile beats vanilla +``torch.compile`` on that path. + +Real WAN 2.2 VAE decode (540p, static shapes) numbers that motivate this: + - with conv channels-last layout: 520ms -> 430ms / decode (**~1.2x speedup**) +This synthetic, weightless decoder reproduces this regime. The absolute ratio is +GPU-dependent, so CONV_CHANNELS_LAST_SPEEDUP_THRESHOLD is set to a conservative +lower bound of 1.20 that still proves a clear, non-noise win. +""" + +import pytest +import torch + +from magi_compiler import magi_compile +from tests.model_definition import VAEDecoderLike +from tests.perf_tests import cuda_benchmark, print_perf_comparison +from tests.perf_tests.utils import assert_magi_vs_torch + +# WAN 2.2 VAE 540p latent: [C, T, H, W]; compiled with static shapes here. +LATENT_C, LATENT_T, LATENT_H, LATENT_W = 48, 7, 34, 60 + +# magi_compile (channels-last on) vs vanilla torch.compile, both on the static path. +# Real 540p decode lands ~1.2x; assert a conservative lower bound (calibrated GPUs only). +CONV_CHANNELS_LAST_SPEEDUP_THRESHOLD = 1.20 + + +@pytest.fixture(scope="function") +def decoder_input(device): + return torch.randn(1, LATENT_C, LATENT_T, LATENT_H, LATENT_W, device=device, dtype=torch.bfloat16) + + +def _compile_decoder(device: torch.device): + def _patch(cfg): + # This decoder is static + conv-dense, so the pass's heuristics fire and + # channels-last is applied. The pass mutates layout_optimization only inside + # the compilation's config.patch scope, so nothing leaks to the torch baseline. + cfg.pass_config.enable_conv_channels_last = True + return cfg + + model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + # Empty dims => fully static; the pass forces channels-last without dynamic shapes. + return magi_compile(model, dynamic_arg_dims={"z": []}, config_patch=_patch) + + +def _compile_torch(device: torch.device): + model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + return torch.compile(model, fullgraph=True, backend="inductor") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_conv_channels_last_vs_torch_compile(device, decoder_input): + """Channels-last pass ON should beat vanilla torch.compile on the static path.""" + # Build isolated inputs so the three paths don't share state. + eager_input = decoder_input.clone() + magi_input = decoder_input.clone() + torch_input = decoder_input.clone() + + eager_model = VAEDecoderLike().to(device).to(torch.bfloat16).eval() + magi_compiled = _compile_decoder(device) + torch_compiled = _compile_torch(device) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_model(eager_input)) + torch_result = cuda_benchmark(lambda: torch_compiled(torch_input), compilation_warmup=3) + magi_result = cuda_benchmark(lambda: magi_compiled(magi_input), compilation_warmup=3) + + speedup = torch_result.median / magi_result.median + print_perf_comparison( + "Conv channels-last: magi_compile vs torch.compile (static shapes)", + eager_result, + magi_result, + torch_result, + extra_info=(f"latent=({LATENT_C}, {LATENT_T}, {LATENT_H}, {LATENT_W}) " f"speedup(torch/magi)={speedup:.2f}x"), + ) + + assert_magi_vs_torch( + speedup, torch_result, magi_result, label="Conv channels-last", threshold=CONV_CHANNELS_LAST_SPEEDUP_THRESHOLD + )