Skip to content

[Feat] Add channels-last layout optimization pass for conv-heavy models#36

Merged
jiahy0825 merged 3 commits into
SandAI-org:mainfrom
themistbeforedawn:feat/conv-channels-last-optimization-pass
Jun 23, 2026
Merged

[Feat] Add channels-last layout optimization pass for conv-heavy models#36
jiahy0825 merged 3 commits into
SandAI-org:mainfrom
themistbeforedawn:feat/conv-channels-last-optimization-pass

Conversation

@themistbeforedawn

@themistbeforedawn themistbeforedawn commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

Motivation

cuDNN's channels-last (NHWC/NDHWC) conv kernels are much faster than contiguous
NC(D)HW on Ampere+, but activations are stored contiguous by default, so cuDNN
pays an internal NCHW⇄NHWC conversion around every conv. Inductor only hoists
this for conv2d (ndim == 2 / the 4D len(...) == 4 gate); 5D conv3d has
no native channels-last path
, so conv3d-dense models (e.g. VAE decode) keep
paying that per-conv cost.

What this PR adds

ConvChannelsLastPass: an opt-in post-grad ATen pass that brings channels-last
to conv3d (and conv2d) by graph rewriting only — no patching of PyTorch
internals. It sets layout_optimization=False and owns layout itself:

  • Inserts aten.clone(memory_format=channels_last(_3d)) on each conv
    input/weight and marks the clone's meta["val"] channels-last. The clone
    lowering ignores memory_format (a TODO in lowering.py), so the signal
    lives purely in the FX meta strides — which constrain_conv_to_fx_strides
    then reads to pin the conv channels-last, so cuDNN skips its internal
    conversions.
  • The clone lowers to a FlexibleLayout Pointwise, so the stride freeze is
    zero-cost: the buffer is allocated channels-last directly and fuses into the
    neighboring elementwise kernel
    (silu/groupnorm) instead of adding a copy.
  • Shared inputs/weights convert once (clone_cache); the conversion is hoisted
    through constant_pad_nd to fuse with the upstream producer.

Gating

pass_config.enable_conv_channels_last is binary:

  • True (default): Registered; its internal heuristics decide at runtime whether to apply (fires only on static, conv-heavy graphs).
  • False: Off (not registered at all).

Performance (WAN 2.2 VAE decode, 540p, static)

torch.compile magi_compile Speedup
520 ms 430 ms ~1.2x

Tests

  • Logic (test_conv_channels_last_switch.py): Verifies the corrected binary gating (static conv-heavy rewrites; dynamic or conv-sparse skips) and configuration registration. Refactored to leverage shared conftest fixtures and clean out fragile, pass-unrelated integration logic.
  • Perf (test_conv_channels_last_perf.py): Uses static, conv-heavy VAE-decode-like workload; achieves 1.22x speedup over vanilla torch.compile (asserts >= 1.20x). Leverages centralized VAEDecoderLike and scoped config_patch to prevent baseline config leakage.

Comment thread magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py Outdated
Comment thread tests/perf_tests/test_conv_channels_last_perf.py Outdated
Comment thread magi_compiler/passes/piecewise_graph/conv_channels_last.py Outdated
Comment thread magi_compiler/passes/piecewise_graph/conv_channels_last.py
@themistbeforedawn themistbeforedawn force-pushed the feat/conv-channels-last-optimization-pass branch from 8cb57dd to 9a1e5c8 Compare June 23, 2026 08:58

@jiahy0825 jiahy0825 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiahy0825 jiahy0825 merged commit c497615 into SandAI-org:main Jun 23, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants