Skip to content

Refactor: multi-batch q_padded and simplify online softmax in scope2/scope12#93

Merged
zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
zhangqi-chen:refactor/qwen3-scope2-scope12-multibatch-simplify
Apr 10, 2026
Merged

Refactor: multi-batch q_padded and simplify online softmax in scope2/scope12#93
zhangqi-chen merged 1 commit intohw-native-sys:mainfrom
zhangqi-chen:refactor/qwen3-scope2-scope12-multibatch-simplify

Conversation

@zhangqi-chen
Copy link
Copy Markdown
Collaborator

@zhangqi-chen zhangqi-chen commented Apr 10, 2026

Summary

  • Hoist all_q_padded allocation outside the batch loop and index by batch offset for true multi-batch decode support in scope2 and scope12
  • Remove redundant zero-initialization workarounds for intermediate tensors (all_raw_scores, all_exp_padded, all_oi_tmp, all_cur_mi, all_cur_li)
  • Simplify online softmax accumulation (Stage 5) — initialize from first block via pl.slice instead of pl.full, loop from sb=1, and merge normalisation into the same incore block
  • Remove zero-init loops for q_proj/k_proj/v_proj in scope12

@gemini-code-assist
Copy link
Copy Markdown

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 10, 2026

📝 Walkthrough

Walkthrough

Two Qwen3 attention kernel files refactored to move Q padding outside per-batch loops with batch-dependent indexing, remove explicit projection zero-initialization, and simplify online softmax accumulation by initializing from precomputed block-0 outputs and accumulating incrementally.

Changes

Cohort / File(s) Summary
Qwen3 Decode Attention Kernels
examples/models/qwen3/qwen3_32b_decode_scope12.py, examples/models/qwen3/qwen3_32b_decode_scope2.py
Removed explicit zero-initialization for Q/K/V projections. Restructured all_q_padded construction to be per-program outside per-batch loops with shape [batch * total_q_groups * Q_HEAD_PAD, head_dim] and updated index offsets to include b * total_q_groups * Q_HEAD_PAD. Simplified Stage 5 online softmax accumulation by removing SB-batched pre-initialization and replacing with direct initialization from sb=0 slices followed by in-core accumulation over remaining blocks.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Poem

🐰 Batches now grouped with care so neat,
Q padding organized, per-batch, complete!
Softmax flows smooth from block-zero's embrace,
Accumulating grace through each attention space. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main changes: hoisting q_padded allocation for multi-batch support and simplifying online softmax logic in scope2/scope12.
Description check ✅ Passed The description directly relates to the changeset, detailing the hoisting of all_q_padded allocation, removal of zero-initialization workarounds, simplification of Stage 5 online softmax, and removal of zero-init loops in scope12.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

…cope2/scope12

- Hoist q_padded allocation outside the batch loop and index by batch offset,
  supporting true multi-batch decode
- Remove redundant zero-initialization workarounds for intermediate tensors
  (all_raw_scores, all_exp_padded, all_oi_tmp, all_cur_mi, all_cur_li)
- Simplify online softmax accumulation (Stage 5) by initializing from the
  first block via pl.slice instead of pl.full, then looping from sb=1
- Remove separate Stage 6 — merge normalisation into the same incore block
- Remove zero-init loops for q_proj/k_proj/v_proj in scope12
@zhangqi-chen zhangqi-chen force-pushed the refactor/qwen3-scope2-scope12-multibatch-simplify branch from 68198d9 to 92511d4 Compare April 10, 2026 03:45
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/models/qwen3/qwen3_32b_decode_scope12.py`:
- Around line 164-172: The hoisted Q buffer all_q_padded is allocated for batch
* total_q_groups groups but the producer only writes the first group per KV
head; either fully materialize every (ki, qg) group or explicitly reject
multi-group configs. Fix by iterating/writing for every group: update the
producer loop that uses pl.range(batch * total_q_groups) / pl.assemble to cover
all q groups (use total_q_groups and q_per_kv or q_groups in the loop/index
arithmetic so each group's rows at idx * Q_HEAD_PAD + Q_HEAD_BATCH +
group_offset are filled), or if you prefer a simpler change add a guard in the
function (raise/exit) when total_q_groups != 1 (check total_q_groups, q_per_kv
vs Q_HEAD_BATCH) and document that only single-group is supported; ensure
references to all_q_padded, Q_HEAD_PAD, Q_HEAD_BATCH, total_q_groups, and the
pl.assemble write are updated accordingly.

In `@examples/models/qwen3/qwen3_32b_decode_scope2.py`:
- Around line 68-75: The loop that assembles all_q_padded only writes the first
Q_HEAD_BATCH rows per KV group, so when q_groups (num_heads // num_kv_heads) > 1
later gi slices will read uninitialized rows; update the assembly to either (A)
add an inner loop over qg and write into index (ki * q_groups + qg) * Q_HEAD_PAD
+ qi for every q group (ensure use of variables all_q_padded, Q_HEAD_PAD,
Q_HEAD_BATCH, total_q_groups, ki, qi, qg), or (B) explicitly assert/reject
configs where q_groups != 1 (e.g., check num_heads // num_kv_heads == 1) and
raise an error; apply the same fix pattern to the other occurrences mentioned
around the later assembly uses (the blocks at the other referenced locations).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: c2868c8a-7043-43e3-9aca-c1e986f1742c

📥 Commits

Reviewing files that changed from the base of the PR and between 4f729f3 and 92511d4.

📒 Files selected for processing (2)
  • examples/models/qwen3/qwen3_32b_decode_scope12.py
  • examples/models/qwen3/qwen3_32b_decode_scope2.py

Comment on lines +164 to +172
# Padding q
all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16)
with pl.incore():
for idx in pl.range(batch * total_q_groups):
all_q_padded = pl.assemble(
all_q_padded,
pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0],
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The hoisted Q buffer still only fills one group per KV head.

This buffer is allocated and consumed as total_q_groups, but the producer still writes rows as if q_groups == 1. With any non-default configuration where q_per_kv > Q_HEAD_BATCH, the later groups will read uninitialized all_q_padded rows. Please either materialize every (ki, qg) group here or add an explicit guard that only the single-group case is supported.

Also applies to: 224-225, 232-232

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope12.py` around lines 164 - 172,
The hoisted Q buffer all_q_padded is allocated for batch * total_q_groups groups
but the producer only writes the first group per KV head; either fully
materialize every (ki, qg) group or explicitly reject multi-group configs. Fix
by iterating/writing for every group: update the producer loop that uses
pl.range(batch * total_q_groups) / pl.assemble to cover all q groups (use
total_q_groups and q_per_kv or q_groups in the loop/index arithmetic so each
group's rows at idx * Q_HEAD_PAD + Q_HEAD_BATCH + group_offset are filled), or
if you prefer a simpler change add a guard in the function (raise/exit) when
total_q_groups != 1 (check total_q_groups, q_per_kv vs Q_HEAD_BATCH) and
document that only single-group is supported; ensure references to all_q_padded,
Q_HEAD_PAD, Q_HEAD_BATCH, total_q_groups, and the pl.assemble write are updated
accordingly.

Comment on lines +68 to +75
all_q_padded = pl.create_tensor([batch * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16)
with pl.incore():
for idx in pl.range(batch * total_q_groups):
all_q_padded = pl.assemble(
all_q_padded,
pl.cast(pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16),
[idx * Q_HEAD_PAD + Q_HEAD_BATCH, 0],
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Populate all Q groups or enforce q_groups == 1.

all_q_padded is now sized and sliced by total_q_groups, but these writes still only populate the first group for each KV head (... + ki * Q_HEAD_PAD + qi). If num_heads // num_kv_heads > Q_HEAD_BATCH, later gi slices for qg > 0 will read untouched rows and produce wrong attention. Either add a qg loop here and write to (ki * q_groups + qg) * Q_HEAD_PAD + qi, or reject unsupported configs explicitly.

Also applies to: 143-144, 151-151

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/qwen3/qwen3_32b_decode_scope2.py` around lines 68 - 75, The
loop that assembles all_q_padded only writes the first Q_HEAD_BATCH rows per KV
group, so when q_groups (num_heads // num_kv_heads) > 1 later gi slices will
read uninitialized rows; update the assembly to either (A) add an inner loop
over qg and write into index (ki * q_groups + qg) * Q_HEAD_PAD + qi for every q
group (ensure use of variables all_q_padded, Q_HEAD_PAD, Q_HEAD_BATCH,
total_q_groups, ki, qi, qg), or (B) explicitly assert/reject configs where
q_groups != 1 (e.g., check num_heads // num_kv_heads == 1) and raise an error;
apply the same fix pattern to the other occurrences mentioned around the later
assembly uses (the blocks at the other referenced locations).

@zhangqi-chen zhangqi-chen merged commit de3b89d into hw-native-sys:main Apr 10, 2026
5 checks passed
@zhangqi-chen zhangqi-chen deleted the refactor/qwen3-scope2-scope12-multibatch-simplify branch April 10, 2026 03:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant