Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion magi_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ class PassConfig(BaseModel):
"Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_SAGE_ATTN (1/0/true/false)."
),
)
enable_conv_channels_last: bool = Field(
True,
description=(
"Forces channels-last (NHWC/NDHWC) inputs/weights at conv boundaries "
"so cuDNN can select faster layout-optimized kernels. "
"True (default): register and let its internal heuristics decide whether to "
"apply (currently: static shapes AND conv-heavy graphs). "
"False: do not register the pass at all. "
"Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_CONV_CHANNELS_LAST (1/0/true/false)."
),
)
enable_nd_tiling_workaround: bool = Field(
True,
description=(
Expand All @@ -91,7 +102,7 @@ class PassConfig(BaseModel):
"TMA + WGMMA DualGemm. The pass is a no-op on older architectures "
"regardless of this flag, but the flag still controls whether it "
"is registered at all. "
"Env var: MAGI_COMPILE_PASS_CONFIG__ENABLE_MM_EPILOGUE_FUSION (1/0/true/false)."
"Settable via the MAGI_COMPILE_PASS_CONFIG__ENABLE_MM_EPILOGUE_FUSION env var."
),
)

Expand Down
177 changes: 177 additions & 0 deletions magi_compiler/passes/piecewise_graph/conv_channels_last.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Conv2d/Conv3d channels-last layout pass for the post-grad ATen graph.

Forces channels-last (NHWC for 4D, NDHWC for 5D) at every ``aten.convolution``
boundary by graph rewriting only -- no patching of PyTorch internals.

Mechanism Under the Hood:
1. **FX-Meta Stride Injection**: The pass inserts ``aten.clone`` nodes before
each conv input/weight and manually configures their ``node.meta["val"]``
to carry channels-last FakeTensors. Because Inductor's clone lowering
ignores ``memory_format`` (a known PyTorch upstream TODO), the channels-last
signal is carried purely within the FX meta strides.
2. **Inductor Constraint Co-design**: With ``layout_optimization=False`` set
by this pass, Inductor's pre-registered ``constrain_conv_to_fx_strides``
(in ``torch/_inductor/kernel/conv.py``) fires. It reads our modified FX input
meta strides and triggers ``require_stride_order`` at the conv boundary.
3. **Zero-Cost Strided Allocation**: The clone lowers to a Pointwise kernel
with ``FlexibleLayout``. Consequently, the ``require_stride_order`` constraint
is zero-cost: the upstream buffer is allocated directly in channels-last
layout without generating an extra transpose copy kernel.
4. **cuDNN Memory-Format Probe**: With channels-last inputs safely matching
stride constraints, ``conv_layout()`` naturally infers a channels-last
cuDNN output ``FixedLayout``.

The pass only fires on static, conv-heavy graphs; dynamic-shape or conv-sparse
graphs are skipped (their channels-last transpose tiles badly / gains little).
"""

import torch
from torch import fx

from ...magi_depyf.timeline import emit_pass_lifecycle
from ...utils import magi_logger
from ..pass_base import MagiInductorPass

aten = torch.ops.aten


def _meta_val(node: fx.Node) -> torch.Tensor | None:
val = node.meta.get("val") if hasattr(node, "meta") else None
return val if isinstance(val, torch.Tensor) else None


# Single-input, layout-transparent ops the conv stride constraint can hoist through.
_HOISTABLE_OPS = (aten.constant_pad_nd.default,)


class ConvChannelsLastPass(MagiInductorPass):
"""Make conv2d/conv3d inputs channels-last on the post-grad ATen graph.

If the conv input comes from a single-consumer, layout-transparent op
(e.g., ``constant_pad_nd``), the clone is hoisted above it and its meta
rewritten to channels-last. This keeps the pad kernel coalesced instead of
triggering an extra memory-bound NC(D)HW -> N(D)HWC transpose.
"""

inductor_config_keys_potentially_mutated_by_this_pass = ("layout_optimization",)

@emit_pass_lifecycle
def __call__(self, graph: fx.Graph) -> bool:
# Only rewrite static, conv-heavy graphs. Channels-last inserts an
# NC(D)HW->N(D)HWC transpose; under dynamic shapes Inductor tiles it
# badly, and on conv-sparse graphs the few cuDNN channels-last kernels
# don't pay for the extra copies.
# TODO: If tiling optimization is upgraded to support conv layout opt
# under dynamic shapes, we can remove the ``is_dynamic`` check.
if self.is_dynamic(graph) or not self.is_conv_heavy(graph):
return False

torch._inductor.config.layout_optimization = False
Comment thread
themistbeforedawn marked this conversation as resolved.

# (input node, memory_format) -> clone node, so a weight shared by
# several convs (or a tensor feeding several convs) is cloned once.
clone_cache: dict[tuple[fx.Node, torch.memory_format], fx.Node] = {}
num_hoisted = 0

def channels_last_clone(inp: fx.Node, memory_format, insert_point) -> fx.Node | None:
key = (inp, memory_format)
cached = clone_cache.get(key)
if cached is not None:
return cached
inp_val = _meta_val(inp)
if inp_val is None:
return None
if inp_val.is_contiguous(memory_format=memory_format):
return None # already channels-last per meta
with graph.inserting_before(insert_point):
cl = graph.call_function(aten.clone.default, (inp,), {"memory_format": memory_format})
cl.meta = {**inp.meta}
cl.meta["val"] = inp_val.clone(memory_format=memory_format)
clone_cache[key] = cl
return cl

def make_channels_last(node: fx.Node, memory_format, depth: int = 0) -> bool:
"""Make ``node``'s FX meta channels-last; return True on success."""
nonlocal num_hoisted
node_val = _meta_val(node)
if node_val is None:
return False
if node_val.is_contiguous(memory_format=memory_format):
return True # already channels-last per meta

# Hoist through single-consumer layout-transparent ops: rewrite this
# op's meta to channels-last and recurse on its input, so the
# transpose fuses with the upstream producer instead of the pad kernel.
if depth < 8 and node.op == "call_function" and node.target in _HOISTABLE_OPS and len(node.users) == 1:
src = node.args[0]
if isinstance(src, fx.Node):
if not make_channels_last(src, memory_format, depth + 1):
# Chain top: materialise the layout change here, above
# the hoistable op.
cl = channels_last_clone(src, memory_format, node)
if cl is None:
return False
node.replace_input_with(src, cl)
node.meta["val"] = node_val.clone(memory_format=memory_format)
num_hoisted += 1
return True
return False

num_converted = 0
for conv in list(graph.nodes):
if conv.op != "call_function" or conv.target != aten.convolution.default:
continue
x_val = _meta_val(conv.args[0])
if x_val is None:
continue
if x_val.ndim == 4:
memory_format = torch.channels_last
elif x_val.ndim == 5:
memory_format = torch.channels_last_3d
else:
continue # conv1d etc.: leave untouched

new_args = list(conv.args)
changed = False
for idx in (0, 1): # x, weight
inp = new_args[idx]
if not isinstance(inp, fx.Node):
continue
# Try hoisting first (rewrites pad metas upstream in place).
if idx == 0 and make_channels_last(inp, memory_format):
inp_val = _meta_val(inp)
if inp_val is not None and inp_val.is_contiguous(memory_format=memory_format):
changed = True
continue
cl = channels_last_clone(inp, memory_format, conv)
if cl is not None:
new_args[idx] = cl
changed = True
if changed:
conv.args = tuple(new_args)
num_converted += 1

if num_converted:
graph.lint()
magi_logger.info(
"ConvChannelsLastPass: routed %d forward conv(s) through channels-last clones "
"(%d clone node(s) inserted, %d pad meta(s) hoisted to channels-last)",
num_converted,
len(clone_cache),
num_hoisted,
)
return (num_converted) > 0
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def __call__(self, graph: fx.Graph):
def configure(self, pass_config: PassConfig):
self.pass_config = pass_config

if pass_config.enable_conv_channels_last:
from .conv_channels_last import ConvChannelsLastPass

self.add(ConvChannelsLastPass())

if pass_config.enable_nd_tiling_workaround:
from .nd_tiling_workaround import ND_TilingWorkaroundPass

Expand Down
119 changes: 119 additions & 0 deletions tests/feature_tests/test_conv_channels_last_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Decision-logic tests for ``ConvChannelsLastPass``.

When applicable, the pass rewrites static, conv-heavy graphs (by inserting
``aten.clone`` layout-changing nodes) to trigger channels-last convolutions.
The binary ``enable_conv_channels_last`` config controls registration:

* ``True`` (default) -> register the pass; its internal heuristics then decide:
apply iff static shapes AND conv-heavy.
* ``False`` -> pass not registered at all.

These tests assert the registered pass's heuristic decision. The shared
base-class utilities (``is_dynamic``, ``is_conv_heavy``, config
snapshot/anti-leakage) are tested in ``test_magi_inductor_pass.py``; the
end-to-end speedup in ``tests/perf_tests/test_conv_channels_last_perf.py``.
"""

import pytest
import torch
import torch.fx as fx

from magi_compiler.config import PassConfig
from magi_compiler.passes.piecewise_graph.conv_channels_last import ConvChannelsLastPass

aten = torch.ops.aten


def _build_conv_graph(fake_mode, *, dynamic: bool, n_conv: int = 1, n_filler: int = 0) -> fx.Graph:
"""Build a tiny conv3d graph for the channels-last decision logic.

``dynamic`` makes the input placeholder's first dim a free symbol (drives
``is_dynamic``). ``n_conv`` adds ``aten.convolution.default`` nodes (5-D
inputs/weights => conv3d, so the pass targets them). ``n_filler`` adds plain
``relu`` nodes to inflate the node count (drives ``nnodes vs 300 * nconv``).

Inputs are kept contiguous (NCDHW) so the pass has a real layout change to
make: a successful rewrite inserts ``aten.clone`` (channels_last_3d) nodes.
"""
graph = fx.Graph()
if dynamic:
sym = fake_mode.shape_env.create_unbacked_symint()
with fake_mode:
x_val = torch.empty(sym, 8, 4, 4, 4)
else:
with fake_mode:
x_val = torch.empty(2, 8, 4, 4, 4)

x = graph.placeholder("x")
x.meta["val"] = x_val

with fake_mode:
weight_val = torch.empty(8, 8, 3, 3, 3)
out_val = torch.empty(2, 8, 4, 4, 4)

node = x
for c in range(n_conv):
weight = graph.placeholder(f"weight_{c}")
weight.meta["val"] = weight_val
node = graph.call_function(aten.convolution.default, args=(node, weight))
node.meta["val"] = out_val
for _ in range(n_filler):
node = graph.call_function(aten.relu.default, args=(node,))
node.meta["val"] = out_val
graph.output((node,))
return graph


def _num_channels_last_clones(graph: fx.Graph) -> int:
"""Count ``aten.clone`` nodes the pass inserts to force channels-last."""
return sum(1 for n in graph.nodes if n.op == "call_function" and n.target == aten.clone.default)


def _run_pass(graph: fx.Graph) -> int:
"""Run the pass on ``graph`` and return how many channels-last clones it inserted."""
ConvChannelsLastPass()(graph)
return _num_channels_last_clones(graph)


@pytest.mark.parametrize("value", [True, False])
def test_config_field_binary(value):
"""enable_conv_channels_last accepts True/False (default True covered in test_magi_inductor_pass)."""
assert PassConfig(enable_conv_channels_last=value).enable_conv_channels_last is value


def test_auto_skips_dynamic(fake_mode):
"""auto: a dynamic graph is skipped (no clones inserted)."""
graph = _build_conv_graph(fake_mode, dynamic=True, n_conv=1, n_filler=0)
assert _run_pass(graph) == 0


def test_auto_skips_static_conv_sparse(fake_mode):
"""auto: a static but conv-sparse graph (nnodes >= 300 * nconv, i.e. not is_conv_heavy) is skipped."""
graph = _build_conv_graph(fake_mode, dynamic=False, n_conv=1, n_filler=320)
assert _run_pass(graph) == 0


def test_auto_rewrites_static_conv_heavy(fake_mode):
"""auto: a static, conv-heavy graph (nnodes < 300 * nconv, i.e. is_conv_heavy) gets rewritten."""
graph = _build_conv_graph(fake_mode, dynamic=False, n_conv=1, n_filler=0)
assert _run_pass(graph) > 0


def test_auto_skips_dynamic_conv_heavy(fake_mode):
"""auto: dynamic dominates -- even a conv-heavy dynamic graph is skipped."""
graph = _build_conv_graph(fake_mode, dynamic=True, n_conv=1, n_filler=0)
assert _run_pass(graph) == 0
18 changes: 0 additions & 18 deletions tests/feature_tests/test_magi_inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,3 @@ def test_env_rejects_non_bool_strings(monkeypatch, env_value):
monkeypatch.setenv(_ND_TILING_ENV, env_value)
with pytest.raises(ValidationError):
CompileConfig()


def test_config_binary_registration_mapping():
"""PostGradPassManager registers the pass iff enable_nd_tiling_workaround is True.

- config = True --> pass is registered
- config = False --> pass is not registered at all
"""
from magi_compiler.config import PassConfig
from magi_compiler.passes.piecewise_graph.post_grad_pass_manager import PostGradPassManager

pm_true = PostGradPassManager()
pm_true.configure(PassConfig(enable_nd_tiling_workaround=True))
assert len(pm_true.passes) == 1

pm_false = PostGradPassManager()
pm_false.configure(PassConfig(enable_nd_tiling_workaround=False))
assert len(pm_false.passes) == 0
Loading