Refactor: multi-batch q_padded and simplify online softmax in scope2/scope12#93
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
📝 WalkthroughWalkthroughTwo Qwen3 attention kernel files refactored to move Q padding outside per-batch loops with batch-dependent indexing, remove explicit projection zero-initialization, and simplify online softmax accumulation by initializing from precomputed block-0 outputs and accumulating incrementally. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…cope2/scope12 - Hoist q_padded allocation outside the batch loop and index by batch offset, supporting true multi-batch decode - Remove redundant zero-initialization workarounds for intermediate tensors (all_raw_scores, all_exp_padded, all_oi_tmp, all_cur_mi, all_cur_li) - Simplify online softmax accumulation (Stage 5) by initializing from the first block via pl.slice instead of pl.full, then looping from sb=1 - Remove separate Stage 6 — merge normalisation into the same incore block - Remove zero-init loops for q_proj/k_proj/v_proj in scope12
68198d9 to
92511d4
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_scope12.py`:
- Around line 164-172: The hoisted Q buffer all_q_padded is allocated for batch
* total_q_groups groups but the producer only writes the first group per KV
head; either fully materialize every (ki, qg) group or explicitly reject
multi-group configs. Fix by iterating/writing for every group: update the
producer loop that uses pl.range(batch * total_q_groups) / pl.assemble to cover
all q groups (use total_q_groups and q_per_kv or q_groups in the loop/index
arithmetic so each group's rows at idx * Q_HEAD_PAD + Q_HEAD_BATCH +
group_offset are filled), or if you prefer a simpler change add a guard in the
function (raise/exit) when total_q_groups != 1 (check total_q_groups, q_per_kv
vs Q_HEAD_BATCH) and document that only single-group is supported; ensure
references to all_q_padded, Q_HEAD_PAD, Q_HEAD_BATCH, total_q_groups, and the
pl.assemble write are updated accordingly.
In `@examples/models/qwen3/qwen3_32b_decode_scope2.py`:
- Around line 68-75: The loop that assembles all_q_padded only writes the first
Q_HEAD_BATCH rows per KV group, so when q_groups (num_heads // num_kv_heads) > 1
later gi slices will read uninitialized rows; update the assembly to either (A)
add an inner loop over qg and write into index (ki * q_groups + qg) * Q_HEAD_PAD
+ qi for every q group (ensure use of variables all_q_padded, Q_HEAD_PAD,
Q_HEAD_BATCH, total_q_groups, ki, qi, qg), or (B) explicitly assert/reject
configs where q_groups != 1 (e.g., check num_heads // num_kv_heads == 1) and
raise an error; apply the same fix pattern to the other occurrences mentioned
around the later assembly uses (the blocks at the other referenced locations).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: c2868c8a-7043-43e3-9aca-c1e986f1742c
📒 Files selected for processing (2)
examples/models/qwen3/qwen3_32b_decode_scope12.pyexamples/models/qwen3/qwen3_32b_decode_scope2.py
| # Padding q | ||
| all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16) | ||
| with pl.incore(): | ||
| for idx in pl.range(batch * total_q_groups): | ||
| all_q_padded = pl.assemble( | ||
| all_q_padded, | ||
| pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16), | ||
| [idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0], | ||
| ) |
There was a problem hiding this comment.
The hoisted Q buffer still only fills one group per KV head.
This buffer is allocated and consumed as total_q_groups, but the producer still writes rows as if q_groups == 1. With any non-default configuration where q_per_kv > Q_HEAD_BATCH, the later groups will read uninitialized all_q_padded rows. Please either materialize every (ki, qg) group here or add an explicit guard that only the single-group case is supported.
Also applies to: 224-225, 232-232
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_scope12.py` around lines 164 - 172,
The hoisted Q buffer all_q_padded is allocated for batch * total_q_groups groups
but the producer only writes the first group per KV head; either fully
materialize every (ki, qg) group or explicitly reject multi-group configs. Fix
by iterating/writing for every group: update the producer loop that uses
pl.range(batch * total_q_groups) / pl.assemble to cover all q groups (use
total_q_groups and q_per_kv or q_groups in the loop/index arithmetic so each
group's rows at idx * Q_HEAD_PAD + Q_HEAD_BATCH + group_offset are filled), or
if you prefer a simpler change add a guard in the function (raise/exit) when
total_q_groups != 1 (check total_q_groups, q_per_kv vs Q_HEAD_BATCH) and
document that only single-group is supported; ensure references to all_q_padded,
Q_HEAD_PAD, Q_HEAD_BATCH, total_q_groups, and the pl.assemble write are updated
accordingly.
| all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16) | ||
| with pl.incore(): | ||
| for idx in pl.range(batch * total_q_groups): | ||
| all_q_padded = pl.assemble( | ||
| all_q_padded, | ||
| pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16), | ||
| [idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0], | ||
| ) |
There was a problem hiding this comment.
Populate all Q groups or enforce q_groups == 1.
all_q_padded is now sized and sliced by total_q_groups, but these writes still only populate the first group for each KV head (... + ki * Q_HEAD_PAD + qi). If num_heads // num_kv_heads > Q_HEAD_BATCH, later gi slices for qg > 0 will read untouched rows and produce wrong attention. Either add a qg loop here and write to (ki * q_groups + qg) * Q_HEAD_PAD + qi, or reject unsupported configs explicitly.
Also applies to: 143-144, 151-151
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/models/qwen3/qwen3_32b_decode_scope2.py` around lines 68 - 75, The
loop that assembles all_q_padded only writes the first Q_HEAD_BATCH rows per KV
group, so when q_groups (num_heads // num_kv_heads) > 1 later gi slices will
read uninitialized rows; update the assembly to either (A) add an inner loop
over qg and write into index (ki * q_groups + qg) * Q_HEAD_PAD + qi for every q
group (ensure use of variables all_q_padded, Q_HEAD_PAD, Q_HEAD_BATCH,
total_q_groups, ki, qi, qg), or (B) explicitly assert/reject configs where
q_groups != 1 (e.g., check num_heads // num_kv_heads == 1) and raise an error;
apply the same fix pattern to the other occurrences mentioned around the later
assembly uses (the blocks at the other referenced locations).
Summary
all_q_paddedallocation outside the batch loop and index by batch offset for true multi-batch decode support in scope2 and scope12all_raw_scores,all_exp_padded,all_oi_tmp,all_cur_mi,all_cur_li)pl.sliceinstead ofpl.full, loop fromsb=1, and merge normalisation into the same incore blockq_proj/k_proj/v_projin scope12