Update: chunk Qwen3 decode scope1 projections#104
Update: chunk Qwen3 decode scope1 projections#104zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughChanged Scope‑1 tiling and loop structure in qwen3 decode: introduced a larger RMSNorm chunk size (512) for Scope‑1, and converted Q/K/V projection output-block loops to parallel/core-group scopes with chunked-loop optimization, relocating projection assembly into those scopes. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request increases the K_CHUNK size and optimizes the Q, K, and V projection stages by wrapping loops in a chunked_loop_optimizer and converting them to parallel loops. A review comment points out that the pl.parallel calls should include an explicit start index to ensure compatibility with the DSL and avoid potential runtime errors.
| for ob in pl.range(q_out_blocks): | ||
| q0 = ob * Q_OUT_CHUNK | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): | ||
| for ob in pl.parallel(q_out_blocks, chunk=4): |
There was a problem hiding this comment.
The pl.parallel function in this repository is consistently used with at least two positional arguments for the start and stop indices (e.g., pl.parallel(0, q_out_blocks, ...)), as seen in other model examples. Using a single argument may not be supported by the DSL and could lead to incorrect loop bounds or runtime errors.
| for ob in pl.parallel(q_out_blocks, chunk=4): | |
| for ob in pl.parallel(0, q_out_blocks, chunk=4): |
Apply the larger scope1 reduction chunk and chunked Q/KV projection loops to both the standalone scope1 example and the full decode example. In the full decode path, keep scope3 chunking at 128 by introducing a scope1-specific chunk constant, so only scope1 uses the wider 512-way reduction and projection tiling. Benchmarks on a2a3 device 1 show lower task counts and lower end-to-end runtime for both the scope1-only path and the full decode path.
|
|
||
| # Scope 1 tiling constants. | ||
| K_CHUNK = 128 | ||
| SCOPE1_K_CHUNK = 512 |
Summary
examples/models/qwen3/qwen3_32b_decode_scope1.pyand the scope1 section ofexamples/models/qwen3/qwen3_32b_decode.py.K_CHUNK = 128by introducing a scope1-specific chunk constant inqwen3_32b_decode.py.a2a3device1with runtime profiling:origin/mainwall time525.04 usand161tasks; updated branch wall time350.02 usand37tasks.a2a3device1with--max-seq --runtime-profiling: before change wall time3198.22 usand1503tasks; after change wall time3080.66 usand1379tasks.Related Issues