[Feat] Add Triton ND-tiling workaround for dynamic shapes & harden MagiInductorPass Base Class#35
Merged
jiahy0825 merged 5 commits intoJun 23, 2026
Conversation
jiahy0825
reviewed
Jun 20, 2026
f02b02a to
89f9bb8
Compare
…ted by post grad passes
cb04db9 to
5f907c2
Compare
5f907c2 to
4134f71
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
🗂️ PR Category
📝 Description
Problem
On PyTorch < 2.11.0, Inductor's coalesce tiling bails out on symbolic numels under
dynamic shapes, so transpose/permute/channels-last pointwise kernels degrade to
untiled Grid1D — a big hit on memory-bound, conv-dense workloads like VAE decode.
Fix
Add
ND_TilingWorkaroundPass, a post-grad pass that applies the Triton ND-tilingworkaround during compilation:
prefer_nd_tiling=True,max_tiles=3,tile_reductions=True.Gating (fires iff all three hold)
torch.__version__ < 2.11.0(fixed upstream in 2.11.0)nnodes < 300 * nconv)Config:
pass_config.enable_nd_tiling_workaround: bool = True(False= don'tregister). Env:
MAGI_COMPILE_PASS_CONFIG__ENABLE_ND_TILING_WORKAROUND.Performance (WAN 2.2 VAE decode, 540p, dynamic H/W)
Supporting change: harden
MagiInductorPassTo make the workaround self-contained, the base class now lets a pass safely mutate
torch._inductor.config: a pass declares the keys it touches viainductor_config_keys_potentially_mutated_by_this_pass, which are snapshotted andrestored through Inductor's
config.patchafter compilation — no global leakage.Also adds reusable graph predicates
is_dynamicandis_conv_heavy.Tests
test_nd_tiling_workaround.py,test_magi_inductor_pass.py):gating, binary config → registration, env-var parsing, config snapshot/restore.
test_nd_tiling_perf_workaround.py): VAE-decode-like dynamic-H/W workload,magi_compilevstorch.compile; ~1.36x on H100, asserts>= 1.20x.