Skip to content

[Feat] Add Triton ND-tiling workaround for dynamic shapes & harden MagiInductorPass Base Class#35

Merged
jiahy0825 merged 5 commits into
SandAI-org:mainfrom
themistbeforedawn:feat/enable-nd-tiling-workaround-for-dynamic-shapes
Jun 23, 2026
Merged

[Feat] Add Triton ND-tiling workaround for dynamic shapes & harden MagiInductorPass Base Class#35
jiahy0825 merged 5 commits into
SandAI-org:mainfrom
themistbeforedawn:feat/enable-nd-tiling-workaround-for-dynamic-shapes

Conversation

@themistbeforedawn

@themistbeforedawn themistbeforedawn commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

Problem

On PyTorch < 2.11.0, Inductor's coalesce tiling bails out on symbolic numels under
dynamic shapes, so transpose/permute/channels-last pointwise kernels degrade to
untiled Grid1D — a big hit on memory-bound, conv-dense workloads like VAE decode.

Fix

Add ND_TilingWorkaroundPass, a post-grad pass that applies the Triton ND-tiling
workaround during compilation: prefer_nd_tiling=True, max_tiles=3,
tile_reductions=True.

Gating (fires iff all three hold)

  1. torch.__version__ < 2.11.0 (fixed upstream in 2.11.0)
  2. dynamic shapes (a placeholder carries free symbols)
  3. conv-heavy graph (nnodes < 300 * nconv)
    Config: pass_config.enable_nd_tiling_workaround: bool = True (False = don't
    register). Env: MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND.

Performance (WAN 2.2 VAE decode, 540p, dynamic H/W)

torch.compile magi_compile Speedup
770 ms 535 ms ~1.44x

Supporting change: harden MagiInductorPass

To make the workaround self-contained, the base class now lets a pass safely mutate
torch._inductor.config: a pass declares the keys it touches via
inductor_config_keys_potentially_mutated_by_this_pass, which are snapshotted and
restored through Inductor's config.patch after compilation — no global leakage.
Also adds reusable graph predicates is_dynamic and is_conv_heavy.

Tests

  • 22 logic cases (test_nd_tiling_workaround.py, test_magi_inductor_pass.py):
    gating, binary config → registration, env-var parsing, config snapshot/restore.
  • Perf (test_nd_tiling_perf_workaround.py): VAE-decode-like dynamic-H/W workload,
    magi_compile vs torch.compile; ~1.36x on H100, asserts >= 1.20x.

Comment thread magi_compiler/magi_backend/magi_backend.py Outdated
@themistbeforedawn themistbeforedawn force-pushed the feat/enable-nd-tiling-workaround-for-dynamic-shapes branch from f02b02a to 89f9bb8 Compare June 22, 2026 11:57
@themistbeforedawn themistbeforedawn force-pushed the feat/enable-nd-tiling-workaround-for-dynamic-shapes branch 2 times, most recently from cb04db9 to 5f907c2 Compare June 23, 2026 02:57
@themistbeforedawn themistbeforedawn force-pushed the feat/enable-nd-tiling-workaround-for-dynamic-shapes branch from 5f907c2 to 4134f71 Compare June 23, 2026 03:05

@jiahy0825 jiahy0825 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiahy0825 jiahy0825 merged commit 54f9836 into SandAI-org:main Jun 23, 2026
1 check passed
@themistbeforedawn themistbeforedawn changed the title [Feature] Enable nd tiling workaround for dynamic shapes [Feat] Add Triton ND-tiling workaround for dynamic shapes & harden MagiInductorPass Base Class Jun 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants