-
Notifications
You must be signed in to change notification settings - Fork 24
[Feat] Add channels-last layout optimization pass for conv-heavy models #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jiahy0825
merged 3 commits into
SandAI-org:main
from
themistbeforedawn:feat/conv-channels-last-optimization-pass
Jun 23, 2026
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
177 changes: 177 additions & 0 deletions
177
magi_compiler/passes/piecewise_graph/conv_channels_last.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| # Copyright (c) 2026 SandAI. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Conv2d/Conv3d channels-last layout pass for the post-grad ATen graph. | ||
|
|
||
| Forces channels-last (NHWC for 4D, NDHWC for 5D) at every ``aten.convolution`` | ||
| boundary by graph rewriting only -- no patching of PyTorch internals. | ||
|
|
||
| Mechanism Under the Hood: | ||
| 1. **FX-Meta Stride Injection**: The pass inserts ``aten.clone`` nodes before | ||
| each conv input/weight and manually configures their ``node.meta["val"]`` | ||
| to carry channels-last FakeTensors. Because Inductor's clone lowering | ||
| ignores ``memory_format`` (a known PyTorch upstream TODO), the channels-last | ||
| signal is carried purely within the FX meta strides. | ||
| 2. **Inductor Constraint Co-design**: With ``layout_optimization=False`` set | ||
| by this pass, Inductor's pre-registered ``constrain_conv_to_fx_strides`` | ||
| (in ``torch/_inductor/kernel/conv.py``) fires. It reads our modified FX input | ||
| meta strides and triggers ``require_stride_order`` at the conv boundary. | ||
| 3. **Zero-Cost Strided Allocation**: The clone lowers to a Pointwise kernel | ||
| with ``FlexibleLayout``. Consequently, the ``require_stride_order`` constraint | ||
| is zero-cost: the upstream buffer is allocated directly in channels-last | ||
| layout without generating an extra transpose copy kernel. | ||
| 4. **cuDNN Memory-Format Probe**: With channels-last inputs safely matching | ||
| stride constraints, ``conv_layout()`` naturally infers a channels-last | ||
| cuDNN output ``FixedLayout``. | ||
|
|
||
| The pass only fires on static, conv-heavy graphs; dynamic-shape or conv-sparse | ||
| graphs are skipped (their channels-last transpose tiles badly / gains little). | ||
| """ | ||
|
|
||
| import torch | ||
| from torch import fx | ||
|
|
||
| from ...magi_depyf.timeline import emit_pass_lifecycle | ||
| from ...utils import magi_logger | ||
| from ..pass_base import MagiInductorPass | ||
|
|
||
| aten = torch.ops.aten | ||
|
|
||
|
|
||
| def _meta_val(node: fx.Node) -> torch.Tensor | None: | ||
| val = node.meta.get("val") if hasattr(node, "meta") else None | ||
| return val if isinstance(val, torch.Tensor) else None | ||
|
|
||
|
|
||
| # Single-input, layout-transparent ops the conv stride constraint can hoist through. | ||
| _HOISTABLE_OPS = (aten.constant_pad_nd.default,) | ||
|
|
||
|
|
||
| class ConvChannelsLastPass(MagiInductorPass): | ||
| """Make conv2d/conv3d inputs channels-last on the post-grad ATen graph. | ||
|
|
||
| If the conv input comes from a single-consumer, layout-transparent op | ||
| (e.g., ``constant_pad_nd``), the clone is hoisted above it and its meta | ||
| rewritten to channels-last. This keeps the pad kernel coalesced instead of | ||
| triggering an extra memory-bound NC(D)HW -> N(D)HWC transpose. | ||
| """ | ||
|
|
||
| inductor_config_keys_potentially_mutated_by_this_pass = ("layout_optimization",) | ||
|
|
||
| @emit_pass_lifecycle | ||
| def __call__(self, graph: fx.Graph) -> bool: | ||
| # Only rewrite static, conv-heavy graphs. Channels-last inserts an | ||
| # NC(D)HW->N(D)HWC transpose; under dynamic shapes Inductor tiles it | ||
| # badly, and on conv-sparse graphs the few cuDNN channels-last kernels | ||
| # don't pay for the extra copies. | ||
| # TODO: If tiling optimization is upgraded to support conv layout opt | ||
| # under dynamic shapes, we can remove the ``is_dynamic`` check. | ||
| if self.is_dynamic(graph) or not self.is_conv_heavy(graph): | ||
| return False | ||
|
|
||
| torch._inductor.config.layout_optimization = False | ||
|
|
||
| # (input node, memory_format) -> clone node, so a weight shared by | ||
| # several convs (or a tensor feeding several convs) is cloned once. | ||
| clone_cache: dict[tuple[fx.Node, torch.memory_format], fx.Node] = {} | ||
| num_hoisted = 0 | ||
|
|
||
| def channels_last_clone(inp: fx.Node, memory_format, insert_point) -> fx.Node | None: | ||
| key = (inp, memory_format) | ||
| cached = clone_cache.get(key) | ||
| if cached is not None: | ||
| return cached | ||
| inp_val = _meta_val(inp) | ||
| if inp_val is None: | ||
| return None | ||
| if inp_val.is_contiguous(memory_format=memory_format): | ||
| return None # already channels-last per meta | ||
| with graph.inserting_before(insert_point): | ||
| cl = graph.call_function(aten.clone.default, (inp,), {"memory_format": memory_format}) | ||
| cl.meta = {**inp.meta} | ||
| cl.meta["val"] = inp_val.clone(memory_format=memory_format) | ||
| clone_cache[key] = cl | ||
| return cl | ||
|
|
||
| def make_channels_last(node: fx.Node, memory_format, depth: int = 0) -> bool: | ||
| """Make ``node``'s FX meta channels-last; return True on success.""" | ||
| nonlocal num_hoisted | ||
| node_val = _meta_val(node) | ||
| if node_val is None: | ||
| return False | ||
| if node_val.is_contiguous(memory_format=memory_format): | ||
| return True # already channels-last per meta | ||
|
|
||
| # Hoist through single-consumer layout-transparent ops: rewrite this | ||
| # op's meta to channels-last and recurse on its input, so the | ||
| # transpose fuses with the upstream producer instead of the pad kernel. | ||
| if depth < 8 and node.op == "call_function" and node.target in _HOISTABLE_OPS and len(node.users) == 1: | ||
| src = node.args[0] | ||
| if isinstance(src, fx.Node): | ||
| if not make_channels_last(src, memory_format, depth + 1): | ||
| # Chain top: materialise the layout change here, above | ||
| # the hoistable op. | ||
| cl = channels_last_clone(src, memory_format, node) | ||
| if cl is None: | ||
| return False | ||
| node.replace_input_with(src, cl) | ||
| node.meta["val"] = node_val.clone(memory_format=memory_format) | ||
| num_hoisted += 1 | ||
| return True | ||
| return False | ||
|
|
||
| num_converted = 0 | ||
| for conv in list(graph.nodes): | ||
| if conv.op != "call_function" or conv.target != aten.convolution.default: | ||
| continue | ||
| x_val = _meta_val(conv.args[0]) | ||
| if x_val is None: | ||
| continue | ||
| if x_val.ndim == 4: | ||
| memory_format = torch.channels_last | ||
| elif x_val.ndim == 5: | ||
| memory_format = torch.channels_last_3d | ||
| else: | ||
| continue # conv1d etc.: leave untouched | ||
|
|
||
| new_args = list(conv.args) | ||
| changed = False | ||
| for idx in (0, 1): # x, weight | ||
| inp = new_args[idx] | ||
| if not isinstance(inp, fx.Node): | ||
| continue | ||
| # Try hoisting first (rewrites pad metas upstream in place). | ||
| if idx == 0 and make_channels_last(inp, memory_format): | ||
| inp_val = _meta_val(inp) | ||
| if inp_val is not None and inp_val.is_contiguous(memory_format=memory_format): | ||
| changed = True | ||
| continue | ||
| cl = channels_last_clone(inp, memory_format, conv) | ||
| if cl is not None: | ||
| new_args[idx] = cl | ||
| changed = True | ||
| if changed: | ||
| conv.args = tuple(new_args) | ||
| num_converted += 1 | ||
|
|
||
| if num_converted: | ||
| graph.lint() | ||
| magi_logger.info( | ||
| "ConvChannelsLastPass: routed %d forward conv(s) through channels-last clones " | ||
| "(%d clone node(s) inserted, %d pad meta(s) hoisted to channels-last)", | ||
| num_converted, | ||
| len(clone_cache), | ||
| num_hoisted, | ||
| ) | ||
| return (num_converted) > 0 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # Copyright (c) 2026 SandAI. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Decision-logic tests for ``ConvChannelsLastPass``. | ||
|
|
||
| When applicable, the pass rewrites static, conv-heavy graphs (by inserting | ||
| ``aten.clone`` layout-changing nodes) to trigger channels-last convolutions. | ||
| The binary ``enable_conv_channels_last`` config controls registration: | ||
|
|
||
| * ``True`` (default) -> register the pass; its internal heuristics then decide: | ||
| apply iff static shapes AND conv-heavy. | ||
| * ``False`` -> pass not registered at all. | ||
|
|
||
| These tests assert the registered pass's heuristic decision. The shared | ||
| base-class utilities (``is_dynamic``, ``is_conv_heavy``, config | ||
| snapshot/anti-leakage) are tested in ``test_magi_inductor_pass.py``; the | ||
| end-to-end speedup in ``tests/perf_tests/test_conv_channels_last_perf.py``. | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.fx as fx | ||
|
|
||
| from magi_compiler.config import PassConfig | ||
| from magi_compiler.passes.piecewise_graph.conv_channels_last import ConvChannelsLastPass | ||
|
|
||
| aten = torch.ops.aten | ||
|
|
||
|
|
||
| def _build_conv_graph(fake_mode, *, dynamic: bool, n_conv: int = 1, n_filler: int = 0) -> fx.Graph: | ||
| """Build a tiny conv3d graph for the channels-last decision logic. | ||
|
|
||
| ``dynamic`` makes the input placeholder's first dim a free symbol (drives | ||
| ``is_dynamic``). ``n_conv`` adds ``aten.convolution.default`` nodes (5-D | ||
| inputs/weights => conv3d, so the pass targets them). ``n_filler`` adds plain | ||
| ``relu`` nodes to inflate the node count (drives ``nnodes vs 300 * nconv``). | ||
|
|
||
| Inputs are kept contiguous (NCDHW) so the pass has a real layout change to | ||
| make: a successful rewrite inserts ``aten.clone`` (channels_last_3d) nodes. | ||
| """ | ||
| graph = fx.Graph() | ||
| if dynamic: | ||
| sym = fake_mode.shape_env.create_unbacked_symint() | ||
| with fake_mode: | ||
| x_val = torch.empty(sym, 8, 4, 4, 4) | ||
| else: | ||
| with fake_mode: | ||
| x_val = torch.empty(2, 8, 4, 4, 4) | ||
|
|
||
| x = graph.placeholder("x") | ||
| x.meta["val"] = x_val | ||
|
|
||
| with fake_mode: | ||
| weight_val = torch.empty(8, 8, 3, 3, 3) | ||
| out_val = torch.empty(2, 8, 4, 4, 4) | ||
|
|
||
| node = x | ||
| for c in range(n_conv): | ||
| weight = graph.placeholder(f"weight_{c}") | ||
| weight.meta["val"] = weight_val | ||
| node = graph.call_function(aten.convolution.default, args=(node, weight)) | ||
| node.meta["val"] = out_val | ||
| for _ in range(n_filler): | ||
| node = graph.call_function(aten.relu.default, args=(node,)) | ||
| node.meta["val"] = out_val | ||
| graph.output((node,)) | ||
| return graph | ||
|
|
||
|
|
||
| def _num_channels_last_clones(graph: fx.Graph) -> int: | ||
| """Count ``aten.clone`` nodes the pass inserts to force channels-last.""" | ||
| return sum(1 for n in graph.nodes if n.op == "call_function" and n.target == aten.clone.default) | ||
|
|
||
|
|
||
| def _run_pass(graph: fx.Graph) -> int: | ||
| """Run the pass on ``graph`` and return how many channels-last clones it inserted.""" | ||
| ConvChannelsLastPass()(graph) | ||
| return _num_channels_last_clones(graph) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("value", [True, False]) | ||
| def test_config_field_binary(value): | ||
| """enable_conv_channels_last accepts True/False (default True covered in test_magi_inductor_pass).""" | ||
| assert PassConfig(enable_conv_channels_last=value).enable_conv_channels_last is value | ||
|
|
||
|
|
||
| def test_auto_skips_dynamic(fake_mode): | ||
| """auto: a dynamic graph is skipped (no clones inserted).""" | ||
| graph = _build_conv_graph(fake_mode, dynamic=True, n_conv=1, n_filler=0) | ||
| assert _run_pass(graph) == 0 | ||
|
|
||
|
|
||
| def test_auto_skips_static_conv_sparse(fake_mode): | ||
| """auto: a static but conv-sparse graph (nnodes >= 300 * nconv, i.e. not is_conv_heavy) is skipped.""" | ||
| graph = _build_conv_graph(fake_mode, dynamic=False, n_conv=1, n_filler=320) | ||
| assert _run_pass(graph) == 0 | ||
|
|
||
|
|
||
| def test_auto_rewrites_static_conv_heavy(fake_mode): | ||
| """auto: a static, conv-heavy graph (nnodes < 300 * nconv, i.e. is_conv_heavy) gets rewritten.""" | ||
| graph = _build_conv_graph(fake_mode, dynamic=False, n_conv=1, n_filler=0) | ||
| assert _run_pass(graph) > 0 | ||
|
|
||
|
|
||
| def test_auto_skips_dynamic_conv_heavy(fake_mode): | ||
| """auto: dynamic dominates -- even a conv-heavy dynamic graph is skipped.""" | ||
| graph = _build_conv_graph(fake_mode, dynamic=True, n_conv=1, n_filler=0) | ||
| assert _run_pass(graph) == 0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.