Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request primarily focuses on upgrading the GLM-4.5 Attention and Qwen3-next Gated Delta Rule modules to be compatible with PyPTO 3.0. This migration involves rewriting the core logic of these attention mechanisms within the new PyPTO framework. Additionally, minor optimizations and configuration updates were applied to the Deepseek V3.2 decode_back module, likely to align with the new PyPTO 3.0 standards and improve performance. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
📝 WalkthroughWalkthroughA pull request introducing new attention kernel implementations for GLM-4.5 and Qwen3-next models with comprehensive test suites, along with parameter and backend adjustments to the existing Deepseek-V3.2 backward pass module. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant GLMAttention as GLM Attention Program
participant QKVTensors as Q/K/V Cache
participant Compute as Softmax & Matmul
participant Output as Output Buffer
Client->>GLMAttention: invoke glm_flash_attention()
GLMAttention->>QKVTensors: load query tiles (s1_cfg)
QKVTensors-->>GLMAttention: query blocks
loop For each KV block via block_table
GLMAttention->>QKVTensors: fetch K/V from block indices
QKVTensors-->>GLMAttention: K/V block data
GLMAttention->>Compute: compute Q·K^T in FP32
Compute->>Compute: track per-tile max for stability
Compute->>Compute: apply softmax (log-sum-exp update)
Compute->>GLMAttention: attention probs
GLMAttention->>Compute: multiply probs × V (FP32)
Compute-->>GLMAttention: weighted values
end
GLMAttention->>Compute: normalize by softmax denom
Compute->>Compute: cast FP32 → BF16
GLMAttention->>Output: assemble results into attn_out
sequenceDiagram
participant Client
participant ChunkGatedDeltaRule as Gated Delta Rule Program
participant Inputs as Query/Key/Value/Beta/Gate
participant PreAttn as Pre-Attention Compute
participant InverseBlock as Inverse Transform
participant Recurrent as Recurrent State Update
participant StateOut as Output & State
Client->>ChunkGatedDeltaRule: invoke chunk_gated_delta_rule()
loop For each batch segment
loop For each chunk (l_cfg tiles)
ChunkGatedDeltaRule->>Inputs: slice Q/K/V/Beta/Gate
Inputs-->>ChunkGatedDeltaRule: chunk tensors
ChunkGatedDeltaRule->>PreAttn: L2-normalize Q/K
PreAttn->>PreAttn: compute cumulative gate & decay
PreAttn->>PreAttn: build attention blocks
PreAttn-->>ChunkGatedDeltaRule: intermediates
ChunkGatedDeltaRule->>InverseBlock: blockwise inverse transform
InverseBlock-->>ChunkGatedDeltaRule: inverse result
ChunkGatedDeltaRule->>Recurrent: compute value outputs
Recurrent->>Recurrent: update state via matmul
Recurrent-->>ChunkGatedDeltaRule: chunk output & new state
ChunkGatedDeltaRule->>StateOut: assemble into core_attn_out
StateOut->>StateOut: propagate state to next chunk
end
end
ChunkGatedDeltaRule->>StateOut: finalize last_state_data
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. 📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request migrates GLM and Qwen3-next models to PTO 3.0 and includes some adjustments to the DeepSeek V3.2 decode back implementation. The new model implementations introduce several critical and high-severity bugs related to tensor dimension mismatches and incorrect test configurations. Additionally, a comment in the DeepSeek V3.2 file contradicts the actual code changes regarding tile sizes. Addressing these issues is crucial for the correctness and effectiveness of the migrated code and its tests.
| attn_inv_prev = pl.slice(attn_inv, [i, min_length], [0, 0]) | ||
|
|
||
| row = pl.slice(attn, [1, min_length], [i, 0]) | ||
|
|
||
| col = pl.slice(attn_t, [i, 1], [0, i]) | ||
| prod = pl.matmul(attn_inv_prev, col) | ||
| prod_2d = pl.reshape(prod, [i, 1]) | ||
| attn_update = pl.add(row, pl.transpose(prod_2d, 0, 1)) | ||
|
|
There was a problem hiding this comment.
There appears to be a critical dimension mismatch in the inverse_pto_min_length function. In the loop, attn_inv_prev is sliced as pl.slice(attn_inv, [i, min_length], [0, 0]), which results in a tensor of shape [i, min_length]. The col tensor is sliced as pl.slice(attn_t, [i, 1], [0, i]), resulting in [i, 1]. The pl.matmul(attn_inv_prev, col) operation will produce a tensor of shape [min_length, 1]. Subsequently, pl.reshape(prod, [i, 1]) will fail if min_length is not equal to i. This indicates a fundamental issue in the matrix inversion logic.
|
|
||
| attn_inv = pl.assemble(attn_inv, attn_update, [i, 0]) | ||
|
|
||
| result = pl.add(attn_inv, eye) |
There was a problem hiding this comment.
The eye parameter is defined as pl.Tensor[[16, l_cfg], pl.FP32], while attn_inv is [min_length, min_length]. The pl.add(attn_inv, eye) operation requires compatible tensor shapes. If min_length is l_cfg // 8, and l_cfg is 128, then min_length is 16. In this case, attn_inv would be [16, 16], but eye is [16, 128], leading to a dimension mismatch. The eye tensor should have the shape [min_length, min_length] to correctly perform element-wise addition.
| eye_block = pl.slice(eye, [min_length, min_length], [0, min_length * i]) | ||
| inv_block = inverse_pto_min_length(block, eye_block, min_length) |
There was a problem hiding this comment.
The eye_block is sliced from eye using pl.slice(eye, [min_length, min_length], [0, min_length * i]). Given that eye has a shape of [16, l_cfg], and min_length is l_cfg // 8, this slice can go out of bounds for the second dimension if i > 0. For example, if l_cfg is 128, min_length is 16. For i=1, the offset min_length * i would be 16, and the slice attempts to read [16, 16] from eye starting at [0, 16]. However, for i=2, the offset would be 32, which is still within l_cfg=128. But the eye tensor itself is [16, l_cfg], and eye_block is expected to be [min_length, min_length]. This implies eye should be an identity matrix of [l_cfg, l_cfg] or [min_length, l_cfg] and the slicing needs to be adjusted to provide the correct identity sub-matrix.
| result = run( | ||
| program=program, | ||
| tensor_specs=tensor_specs, | ||
| golden=None, |
There was a problem hiding this comment.
The run function is called with golden=None, but golden_out is computed and available from build_tensor_specs. This means the test is not actually comparing the program's output against the golden reference, rendering the test ineffective for correctness validation. Please pass golden_out to the golden parameter.
| golden=None, | |
| golden=golden_out, |
| zeros_16 = pl.create_tensor([16, 16], dtype=pl.FP32) | ||
| zeros_16 = pl.mul(zeros_16, 0.0) |
There was a problem hiding this comment.
The zeros_16 tensor is created with a fixed size [16, 16]. However, m_len (which is min_length) can vary depending on l_cfg. If min_length is not 16, this fixed-size tensor will cause incorrect assembly or runtime errors. The zeros_XX tensors should be dynamically sized based on the current m_len to ensure correctness and flexibility.
| zeros_16 = pl.create_tensor([16, 16], dtype=pl.FP32) | |
| zeros_16 = pl.mul(zeros_16, 0.0) | |
| zeros_16 = pl.create_tensor([m_len, m_len], dtype=pl.FP32) |
| zeros_64 = pl.create_tensor([64, 64], dtype=pl.FP32) | ||
| zeros_64 = pl.mul(zeros_64, 0.0) |
There was a problem hiding this comment.
Similar to zeros_16 and zeros_32, zeros_64 is created with a fixed size [64, 64]. This is problematic as m_len is now min_length * 4. The tensor should be dynamically sized to [m_len, m_len] to match the block dimensions.
| zeros_64 = pl.create_tensor([64, 64], dtype=pl.FP32) | |
| zeros_64 = pl.mul(zeros_64, 0.0) | |
| zeros_64 = pl.create_tensor([m_len, m_len], dtype=pl.FP32) |
| result = run( | ||
| program=program, | ||
| tensor_specs=tensor_specs, | ||
| golden=None, |
There was a problem hiding this comment.
The run function is called with golden=None, but golden_data is computed and available from gen_data. This means the test is not actually comparing the program's output against the golden reference, rendering the test ineffective for correctness validation. Please pass golden_data to the golden parameter.
| golden=None, | |
| golden=golden_data, |
| vj_assemble = pl.create_tensor([s2_tile_cfg, q_d_cfg], dtype=pl.FP32) | ||
| vj_assemble = pl.mul(vj_assemble, 0.0) | ||
|
|
||
| for i in range(block_num): | ||
| block_idx = pl.tensor.read(block_table, [b_idx, idx + i]) | ||
| block_idx_valid = pl.max(block_idx, 0) | ||
| v_offset = block_idx_valid * block_size_cfg | ||
| vj_block = pl.slice(v, [block_size_cfg, 1, q_d_cfg], [v_offset, n2_idx, 0]) | ||
| vj_block_fp32 = pl.cast(vj_block, target_type=pl.FP32) | ||
| vj_block_2d = pl.reshape(vj_block_fp32, [block_size_cfg, q_d_cfg]) | ||
| vj_assemble = pl.assemble(vj_assemble, vj_block_2d, [i * block_size_cfg, 0]) |
There was a problem hiding this comment.
The block of code responsible for assembling vj_assemble is duplicated here and in the else branch (lines 166-176). This duplication can be avoided by moving this logic outside the if/else block, as all necessary variables (b_idx, idx, block_num, block_table, block_size_cfg, v, n2_idx, q_d_cfg, actual_s2_tile) are available before the conditional statement. This improves maintainability and reduces code redundancy.
| vj_assemble = pl.create_tensor([s2_tile_cfg, q_d_cfg], dtype=pl.FP32) | ||
| vj_assemble = pl.mul(vj_assemble, 0.0) | ||
|
|
||
| for i in range(block_num): | ||
| block_idx = pl.tensor.read(block_table, [b_idx, idx + i]) | ||
| block_idx_valid = pl.max(block_idx, 0) | ||
| v_offset = block_idx_valid * block_size_cfg | ||
| vj_block = pl.slice(v, [block_size_cfg, 1, q_d_cfg], [v_offset, n2_idx, 0]) | ||
| vj_block_fp32 = pl.cast(vj_block, target_type=pl.FP32) | ||
| vj_block_2d = pl.reshape(vj_block_fp32, [block_size_cfg, q_d_cfg]) | ||
| vj_assemble = pl.assemble(vj_assemble, vj_block_2d, [i * block_size_cfg, 0]) |
| Q_OUT_CHUNK = 64 | ||
| MLP_OUT_CHUNK = 64 |
There was a problem hiding this comment.
The comment on lines 37-38 states, "Increase tile sizes to encourage larger mixed-kernel fusion regions." However, the changes to Q_OUT_CHUNK (from 128 to 64) and MLP_OUT_CHUNK (from 512 to 64) represent a decrease in tile sizes. This creates a contradiction between the comment and the code. Please update the comment to accurately reflect the intent of these changes, or justify why decreasing tile sizes is beneficial in this context.
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/deepseek_v3_2/deepseek_v3_2_decode_back.py (1)
221-246:⚠️ Potential issue | 🟡 MinorRemove misleading
work_dirreferences from print statements.The
work_dirvariable is computed at lines 221-222 but is not passed toRunConfig(lines 228-236). However, the print statements at lines 240, 244, and 246 still referencework_diras if it contains the output location for generated kernels and pass dumps.Since
work_diris not used by therun()function, these messages are misleading to users. Either remove thework_dircomputation and update the print messages to remove the path reference, or verify thatwork_dirshould be passed toRunConfigfor the Ascend950 backend (note thatqwen3_32b_decode.py, which also uses Ascend950, omits these print statements entirely).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v3_2/deepseek_v3_2_decode_back.py` around lines 221 - 246, The code computes work_dir but never passes it into run()/RunConfig, yet print statements reference it; remove the work_dir computation and any references to it in the prints (update messages in the result handling around result.passed/result.error) so they no longer claim a filesystem path, or alternatively if the Ascend950 backend should receive a dump directory, pass work_dir into RunConfig (e.g., add an appropriate dump_dir/dump_root argument) and ensure dump_passes/BackendType.Ascend950 use it; update the prints accordingly to reference only that RunConfig field or omit the path.
🧹 Nitpick comments (3)
examples/deepseek_v3_2/deepseek_v3_2_decode_back.py (1)
37-41: Comment is now inconsistent with the chunk values.The comment states "Increase tile sizes to encourage larger mixed-kernel fusion regions" but
Q_OUT_CHUNKandMLP_OUT_CHUNKhave been decreased (64 vs. previous 128 and 512, respectively). The MLP_OUT_CHUNK reduction from 512 to 64 is an 8x decrease, which will significantly increaseMLP_OUT_BLOCKSloop iterations.Consider updating the comment to reflect the new tuning rationale, or verify these values are correct for the Ascend950 backend target.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v3_2/deepseek_v3_2_decode_back.py` around lines 37 - 41, The comment above the chunk constants is now misleading: it says "Increase tile sizes..." while K_CHUNK is larger but Q_OUT_CHUNK and MLP_OUT_CHUNK were reduced (Q_OUT_CHUNK=64, MLP_OUT_CHUNK=64), which impacts MLP_OUT_BLOCKS loops; update the comment to reflect the actual tuning rationale or confirm these reduced values are intentional for the Ascend950 backend. Locate the constants K_CHUNK, Q_OUT_CHUNK, and MLP_OUT_CHUNK in the file (deepseek_v3_2_decode_back.py) and either revise the comment to describe that Q_OUT_CHUNK and MLP_OUT_CHUNK were decreased to favor X (e.g., memory/locality/compute balance on Ascend950) or validate and revert to previous sizes if the 64 values are incorrect.examples/custom/qwen3_next/qwen3_next_gated_delta_rule.py (1)
346-349: Error handling could mask real failures.Checking
"code_runner" in result.erroris brittle—if an unrelated error message happens to contain "code_runner", it would be treated as a successful compilation. Consider checking for a specific error type or code instead.♻️ Suggested improvement
- if not result.passed and result.error and "code_runner" in result.error: + if not result.passed and result.error: + if "code_runner not found" in result.error or "code_runner unavailable" in result.error: + print("Result: COMPILE OK — device run skipped (code_runner not found).") + print(f" Generated kernels/orchestration: {work_dir}") + return resultOr better, if
pypto.runtimeprovides a specific exception type for missing runtime, catch that instead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/custom/qwen3_next/qwen3_next_gated_delta_rule.py` around lines 346 - 349, The current check treats any message containing "code_runner" in result.error as a benign missing-runtime case; instead, change the logic to detect the missing runtime explicitly by using a structured indicator or exception: if pypto.runtime exposes a specific exception or error code (e.g., MissingRuntimeError or a result.error_code like "MISSING_CODE_RUNNER"), check for that (using isinstance(result.error, MissingRuntimeError) or result.error_code == "MISSING_CODE_RUNNER") rather than substring matching; alternatively, wrap the invocation that produces result in a try/except that catches the specific runtime-missing exception and handle that path (printing "COMPILE OK — device run skipped" and returning) while letting other errors surface normally; update the branch that references result.passed/result.error and work_dir accordingly to use this explicit check.examples/custom/qwen3_next/gated_delta_rule_impl.py (1)
20-21: Unused module-level constants.
LandDare defined but never referenced—function parameters shadow them. Consider removing or documenting their intended use.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/custom/qwen3_next/gated_delta_rule_impl.py` around lines 20 - 21, The module-level constants L and D are declared but never used (they're shadowed by function parameters), so remove them or wire them into the implementation—either delete the unused L and D definitions, or use them as default values for the relevant function parameters (e.g., in the functions that accept length/depth args) or document their intended purpose; update references in gated_delta_rule_impl.py to rely on the parameter names or the module constants consistently so no shadowing occurs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/custom/glm_v4_5/glm_attention_test.py`:
- Around line 120-145: The golden_attention implementation assumes S1=1 and
nkv=1 and overwrites the entire head axis each iteration; update it to honor
runtime s1 and nkv by slicing per-head/group regions instead of replacing the
whole vector: compute group = nq // nkv (already present), index q, k_cache_bsnd
and v_cache_bsnd using the head slice q_bs = q[i * S1 + j,
n2_idx*group:(n2_idx+1)*group, :] and similarly for k_bs and v_bs (use the same
seq_len for k/v as before), perform the matmuls on these slices, and assign the
result into the corresponding slice of attention_output[i * S1 + j,
n2_idx*group:(n2_idx+1)*group, :] so multiple nkv iterations accumulate into
distinct head regions rather than overwriting the entire head axis.
- Around line 199-223: The test computes a reference output named golden_out but
calls run(...) with golden=None, skipping numerical verification; update the run
invocation (the call to run in this file) to pass golden=golden_out instead of
None so the RunConfig-based rtol/atol checks are exercised and numerical
correctness is validated.
In `@examples/custom/glm_v4_5/glm_attention.py`:
- Around line 69-71: Validate that tiling parameters divide evenly before using
computed counts: check that group (computed from nq_cfg//nkv_cfg) is divisible
by g_tile_cfg and that s2_cfg is aligned to s2_tile_cfg and that s2_tile_cfg is
divisible by block_size_cfg (or that block_num computed from
s2_tile_cfg//block_size_cfg does not under/overrun block_table); if any check
fails, raise/return an explicit error (or adjust the loop bounds) so downstream
loops that use g_loop and block_num cannot read/write tails out-of-bounds. Add
these checks near where group, g_loop and block_num are computed (referencing
nq_cfg, nkv_cfg, g_tile_cfg, s2_cfg, s2_tile_cfg, block_size_cfg and
block_table) and apply same validation in the other similar sections noted
(lines ~113-120, 140-147, 169-176).
In `@examples/custom/qwen3_next/gated_delta_rule_impl.py`:
- Around line 95-168: The inverse_pto function incorrectly uses hardcoded zero
tensors (zeros_16, zeros_32, zeros_64) which only work when min_length==16;
replace those with dynamically-sized zero tensors created from the current m_len
at each stage (i.e., create zeros of shape [m_len, m_len] before the 2x2
assemble step, and similarly use [m_len*2, m_len*2] or appropriate sizes for
higher-level assembles) so the pl.assemble calls in inverse_pto and the blocks
built in attn_inv_4_blocks, attn_inv_2_blocks and final attn_inv always match l
and min_length; update references to zeros_16/32/64 to use these computed zero
tensors where inv_block assembly places a zero sub-block.
- Line 40: The function pre_attn has an unused parameter l; remove l from the
parameter list of pre_attn (def pre_attn(gate_view, key_view_2d, beta_view,
tril, mask)) and update all call sites that currently pass a value for l to stop
passing it (or alternatively, if l is required for logic, incorporate it into
the function body where relevant); locate the function by name pre_attn and
adjust its callers so signatures match.
- Around line 221-231: The function build_chunk_gated_delta_rule_program is
missing a sequence-length parameter: add a new parameter t, create t_cfg
alongside b_cfg/nqk_cfg/nv_cfg/d_cfg/l_cfg (e.g., t_cfg = Var(t, "t_cfg") or
equivalent pattern used in this file), and replace all tensor dimension string
literals "T" in the type annotations for query, key, value, beta, gate, and
core_attn_out with t_cfg so those tensors use the concrete sequence-length
dimension; update the function signature and any uses of those annotations to
reference t_cfg instead of "T".
In `@examples/custom/qwen3_next/qwen3_next_gated_delta_rule.py`:
- Around line 331-344: golden_data produced by gen_data() is never used — modify
the run invocation or add post-run checks to validate outputs: either pass a
golden callback or keep golden=None but after run(...) inspect result (the
result object returned by run) and compare result.outputs["core_attn_out"] and
result.outputs["last_state_data"] against golden_data entries using
torch.testing.assert_close with rtol=1e-3 and atol=1e-3; ensure any necessary
shape/transpose adjustments (e.g., transpose(0,1) on
golden_data["core_attn_out"]) before comparison and only run assertions if
result.passed and result.outputs exist.
- Line 207: Remove the unused tuple unpacking "b, n, s, d = value.shape" in
qwen3_next_gated_delta_rule.py (it’s unused and dimensions are taken from
key.shape later); delete that line so the function no longer creates unused
variables b, n, s, d.
- Line 315: The top-level import "from gated_delta_rule_impl import
build_chunk_gated_delta_rule_program" can fail when the script is executed from
a different CWD; update the import to be package-relative or otherwise robust:
either convert to a package-relative import (e.g., use a relative import from
the correct package) or modify sys.path at startup to include the module's
directory before importing, ensuring the symbol
build_chunk_gated_delta_rule_program is still imported from the correct module;
locate the import statement and replace it with one of these approaches so the
module resolves reliably at runtime.
---
Outside diff comments:
In `@examples/deepseek_v3_2/deepseek_v3_2_decode_back.py`:
- Around line 221-246: The code computes work_dir but never passes it into
run()/RunConfig, yet print statements reference it; remove the work_dir
computation and any references to it in the prints (update messages in the
result handling around result.passed/result.error) so they no longer claim a
filesystem path, or alternatively if the Ascend950 backend should receive a dump
directory, pass work_dir into RunConfig (e.g., add an appropriate
dump_dir/dump_root argument) and ensure dump_passes/BackendType.Ascend950 use
it; update the prints accordingly to reference only that RunConfig field or omit
the path.
---
Nitpick comments:
In `@examples/custom/qwen3_next/gated_delta_rule_impl.py`:
- Around line 20-21: The module-level constants L and D are declared but never
used (they're shadowed by function parameters), so remove them or wire them into
the implementation—either delete the unused L and D definitions, or use them as
default values for the relevant function parameters (e.g., in the functions that
accept length/depth args) or document their intended purpose; update references
in gated_delta_rule_impl.py to rely on the parameter names or the module
constants consistently so no shadowing occurs.
In `@examples/custom/qwen3_next/qwen3_next_gated_delta_rule.py`:
- Around line 346-349: The current check treats any message containing
"code_runner" in result.error as a benign missing-runtime case; instead, change
the logic to detect the missing runtime explicitly by using a structured
indicator or exception: if pypto.runtime exposes a specific exception or error
code (e.g., MissingRuntimeError or a result.error_code like
"MISSING_CODE_RUNNER"), check for that (using isinstance(result.error,
MissingRuntimeError) or result.error_code == "MISSING_CODE_RUNNER") rather than
substring matching; alternatively, wrap the invocation that produces result in a
try/except that catches the specific runtime-missing exception and handle that
path (printing "COMPILE OK — device run skipped" and returning) while letting
other errors surface normally; update the branch that references
result.passed/result.error and work_dir accordingly to use this explicit check.
In `@examples/deepseek_v3_2/deepseek_v3_2_decode_back.py`:
- Around line 37-41: The comment above the chunk constants is now misleading: it
says "Increase tile sizes..." while K_CHUNK is larger but Q_OUT_CHUNK and
MLP_OUT_CHUNK were reduced (Q_OUT_CHUNK=64, MLP_OUT_CHUNK=64), which impacts
MLP_OUT_BLOCKS loops; update the comment to reflect the actual tuning rationale
or confirm these reduced values are intentional for the Ascend950 backend.
Locate the constants K_CHUNK, Q_OUT_CHUNK, and MLP_OUT_CHUNK in the file
(deepseek_v3_2_decode_back.py) and either revise the comment to describe that
Q_OUT_CHUNK and MLP_OUT_CHUNK were decreased to favor X (e.g.,
memory/locality/compute balance on Ascend950) or validate and revert to previous
sizes if the 64 values are incorrect.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 688a86b9-5baf-481c-a3ca-580c6e270063
📒 Files selected for processing (5)
examples/custom/glm_v4_5/glm_attention.pyexamples/custom/glm_v4_5/glm_attention_test.pyexamples/custom/qwen3_next/gated_delta_rule_impl.pyexamples/custom/qwen3_next/qwen3_next_gated_delta_rule.pyexamples/deepseek_v3_2/deepseek_v3_2_decode_back.py
| def golden_attention(q, k_cache_bsnd, v_cache_bsnd, kv_cache_actual_seq, softmax_scale): | ||
| b = q.shape[0] // S1 | ||
| nq = q.shape[1] | ||
| d = q.shape[2] | ||
| nkv = k_cache_bsnd.shape[2] | ||
| group = nq // nkv | ||
|
|
||
| attention_output = torch.zeros_like(q) | ||
|
|
||
| for i in range(b): | ||
| for j in range(S1): | ||
| for n2_idx in range(nkv): | ||
| kv_seq_len = kv_cache_actual_seq[i].item() | ||
| seq_len = kv_seq_len - S1 + 1 + j | ||
| q_bs = q[i * S1 + j] | ||
| k_bs = k_cache_bsnd[i, :seq_len, n2_idx:n2_idx + 1].reshape(seq_len, d) | ||
| v_bs = v_cache_bsnd[i, :seq_len, n2_idx:n2_idx + 1].reshape(seq_len, d) | ||
|
|
||
| qk_bmm_res = torch.matmul(q_bs, k_bs.transpose(1, 0)) | ||
| qk_ele_res = qk_bmm_res * softmax_scale | ||
| softmax_res, _, _ = softmax(qk_ele_res, True) | ||
| bmm2_res = torch.matmul(softmax_res, v_bs) | ||
|
|
||
| attention_output[i * S1 + j] = bmm2_res | ||
|
|
||
| return attention_output |
There was a problem hiding this comment.
The golden path only matches the default s1=1 / nkv=1 case.
S1 is hardcoded into the batch/sequence arithmetic, and each n2_idx iteration overwrites the whole head axis instead of just its group slice. compile_and_run() exposes s1 and nkv as parameters, so any non-default invocation gets a broken reference.
🧪 Make the golden reference follow the runtime parameters
def golden_attention(q, k_cache_bsnd, v_cache_bsnd, kv_cache_actual_seq, softmax_scale):
- b = q.shape[0] // S1
+ b = len(kv_cache_actual_seq)
+ s1 = q.shape[0] // b
nq = q.shape[1]
d = q.shape[2]
nkv = k_cache_bsnd.shape[2]
group = nq // nkv
@@
- for j in range(S1):
+ for j in range(s1):
for n2_idx in range(nkv):
kv_seq_len = kv_cache_actual_seq[i].item()
- seq_len = kv_seq_len - S1 + 1 + j
- q_bs = q[i * S1 + j]
+ seq_len = kv_seq_len - s1 + 1 + j
+ head_lo = n2_idx * group
+ head_hi = head_lo + group
+ q_bs = q[i * s1 + j, head_lo:head_hi]
k_bs = k_cache_bsnd[i, :seq_len, n2_idx:n2_idx + 1].reshape(seq_len, d)
v_bs = v_cache_bsnd[i, :seq_len, n2_idx:n2_idx + 1].reshape(seq_len, d)
@@
- attention_output[i * S1 + j] = bmm2_res
+ attention_output[i * s1 + j, head_lo:head_hi] = bmm2_res📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def golden_attention(q, k_cache_bsnd, v_cache_bsnd, kv_cache_actual_seq, softmax_scale): | |
| b = q.shape[0] // S1 | |
| nq = q.shape[1] | |
| d = q.shape[2] | |
| nkv = k_cache_bsnd.shape[2] | |
| group = nq // nkv | |
| attention_output = torch.zeros_like(q) | |
| for i in range(b): | |
| for j in range(S1): | |
| for n2_idx in range(nkv): | |
| kv_seq_len = kv_cache_actual_seq[i].item() | |
| seq_len = kv_seq_len - S1 + 1 + j | |
| q_bs = q[i * S1 + j] | |
| k_bs = k_cache_bsnd[i, :seq_len, n2_idx:n2_idx + 1].reshape(seq_len, d) | |
| v_bs = v_cache_bsnd[i, :seq_len, n2_idx:n2_idx + 1].reshape(seq_len, d) | |
| qk_bmm_res = torch.matmul(q_bs, k_bs.transpose(1, 0)) | |
| qk_ele_res = qk_bmm_res * softmax_scale | |
| softmax_res, _, _ = softmax(qk_ele_res, True) | |
| bmm2_res = torch.matmul(softmax_res, v_bs) | |
| attention_output[i * S1 + j] = bmm2_res | |
| return attention_output | |
| def golden_attention(q, k_cache_bsnd, v_cache_bsnd, kv_cache_actual_seq, softmax_scale): | |
| b = len(kv_cache_actual_seq) | |
| s1 = q.shape[0] // b | |
| nq = q.shape[1] | |
| d = q.shape[2] | |
| nkv = k_cache_bsnd.shape[2] | |
| group = nq // nkv | |
| attention_output = torch.zeros_like(q) | |
| for i in range(b): | |
| for j in range(s1): | |
| for n2_idx in range(nkv): | |
| kv_seq_len = kv_cache_actual_seq[i].item() | |
| seq_len = kv_seq_len - s1 + 1 + j | |
| head_lo = n2_idx * group | |
| head_hi = head_lo + group | |
| q_bs = q[i * s1 + j, head_lo:head_hi] | |
| k_bs = k_cache_bsnd[i, :seq_len, n2_idx:n2_idx + 1].reshape(seq_len, d) | |
| v_bs = v_cache_bsnd[i, :seq_len, n2_idx:n2_idx + 1].reshape(seq_len, d) | |
| qk_bmm_res = torch.matmul(q_bs, k_bs.transpose(1, 0)) | |
| qk_ele_res = qk_bmm_res * softmax_scale | |
| softmax_res, _, _ = softmax(qk_ele_res, True) | |
| bmm2_res = torch.matmul(softmax_res, v_bs) | |
| attention_output[i * s1 + j, head_lo:head_hi] = bmm2_res | |
| return attention_output |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/glm_v4_5/glm_attention_test.py` around lines 120 - 145, The
golden_attention implementation assumes S1=1 and nkv=1 and overwrites the entire
head axis each iteration; update it to honor runtime s1 and nkv by slicing
per-head/group regions instead of replacing the whole vector: compute group = nq
// nkv (already present), index q, k_cache_bsnd and v_cache_bsnd using the head
slice q_bs = q[i * S1 + j, n2_idx*group:(n2_idx+1)*group, :] and similarly for
k_bs and v_bs (use the same seq_len for k/v as before), perform the matmuls on
these slices, and assign the result into the corresponding slice of
attention_output[i * S1 + j, n2_idx*group:(n2_idx+1)*group, :] so multiple nkv
iterations accumulate into distinct head regions rather than overwriting the
entire head axis.
| tensor_specs, golden_out = build_tensor_specs( | ||
| batch, s1, s2, q_d, nq, nkv, block_size | ||
| ) | ||
|
|
||
| program = build_glm_attention_program( | ||
| batch=batch, s1=s1, s2=s2, q_d=q_d, nq=nq, nkv=nkv, block_size=block_size | ||
| ) | ||
|
|
||
| if work_dir is None: | ||
| work_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "glm_attention_dump")) | ||
|
|
||
| result = run( | ||
| program=program, | ||
| tensor_specs=tensor_specs, | ||
| golden=None, | ||
| config=RunConfig( | ||
| platform=platform, | ||
| device_id=device_id, | ||
| rtol=2e-2, | ||
| atol=2e-2, | ||
| strategy=OptimizationStrategy.Default, | ||
| dump_passes=dump_passes, | ||
| backend_type=BackendType.Ascend950, | ||
| ), | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Verify how `pypto.runtime.run` consumes the `golden` argument.
# Expected: the signature/implementation should show whether `golden`
# is compared directly or wrapped, and nearby call sites should reveal
# the shape/container other examples pass in.
rg -n -C4 --type=py '\bdef run\s*\(' pypto
rg -n -C4 --type=py '\bgolden\b' pypto
rg -n -C3 --type=py 'golden\s*=' pypto examplesRepository: hw-native-sys/pypto-lib
Length of output: 112
🏁 Script executed:
git ls-files | grep -E "^pypto|^examples" | head -30Repository: hw-native-sys/pypto-lib
Length of output: 677
🏁 Script executed:
find . -name "*.py" -path "*/runtime*" | head -20Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
find . -name "glm_attention_test.py" -exec dirname {} \;Repository: hw-native-sys/pypto-lib
Length of output: 93
🏁 Script executed:
git ls-files | grep -E "^[^/]*\.py$|^pypto|^src" | head -50Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
find . -name "runtime.py" -o -name "__init__.py" | grep pypto | head -20Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
rg -l "def run\s*\(" --type=py | head -10Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
ls -laRepository: hw-native-sys/pypto-lib
Length of output: 776
🏁 Script executed:
find . -type f -name "*.py" | head -20Repository: hw-native-sys/pypto-lib
Length of output: 770
🏁 Script executed:
grep -r "def run" --include="*.py" | head -20Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
grep -r "import.*run\|from.*run" --include="*.py" examples/custom/glm_v4_5/glm_attention_test.pyRepository: hw-native-sys/pypto-lib
Length of output: 119
🏁 Script executed:
cat -n examples/custom/glm_v4_5/glm_attention_test.py | head -50Repository: hw-native-sys/pypto-lib
Length of output: 2107
🏁 Script executed:
grep -r "def run" examples/custom/glm_v4_5/ --include="*.py"Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
cat -n examples/custom/glm_v4_5/glm_attention_test.py | sed -n '190,230p'Repository: hw-native-sys/pypto-lib
Length of output: 1648
🏁 Script executed:
rg -n "build_tensor_specs|build_glm_attention_program" examples/custom/glm_v4_5/ --type=pyRepository: hw-native-sys/pypto-lib
Length of output: 561
🏁 Script executed:
cat -n examples/custom/glm_v4_5/glm_attention_test.py | sed -n '1,100p'Repository: hw-native-sys/pypto-lib
Length of output: 4451
🏁 Script executed:
cat -n examples/custom/glm_v4_5/glm_attention_test.py | sed -n '148,189p'Repository: hw-native-sys/pypto-lib
Length of output: 2010
🏁 Script executed:
grep -n "return.*tensor_specs.*golden_out\|return.*golden_out" examples/custom/glm_v4_5/glm_attention_test.pyRepository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
rg -n "def run\(" --type=py -A 10 | head -40Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
find . -type d -name "pypto"Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
git log --oneline -- examples/custom/glm_v4_5/glm_attention_test.py | head -5Repository: hw-native-sys/pypto-lib
Length of output: 111
🏁 Script executed:
rg -n "golden=" examples/ --type=py -B 2 -A 2Repository: hw-native-sys/pypto-lib
Length of output: 4297
🏁 Script executed:
cat -n examples/hello_world.py | sed -n '75,105p'Repository: hw-native-sys/pypto-lib
Length of output: 1101
🏁 Script executed:
cat -n examples/matmul.py | sed -n '90,125p'Repository: hw-native-sys/pypto-lib
Length of output: 1483
🏁 Script executed:
rg -n "golden_attention|golden_hello_world|golden_matmul" examples/ --type=py -B 2Repository: hw-native-sys/pypto-lib
Length of output: 1177
🏁 Script executed:
cat -n examples/hello_world.py | sed -n '65,75p'Repository: hw-native-sys/pypto-lib
Length of output: 410
🏁 Script executed:
cat -n examples/matmul.py | sed -n '80,90p'Repository: hw-native-sys/pypto-lib
Length of output: 396
🏁 Script executed:
cat -n examples/custom/glm_v4_5/glm_attention_test.py | sed -n '120,147p'Repository: hw-native-sys/pypto-lib
Length of output: 1305
🏁 Script executed:
cat -n examples/hello_world.py | sed -n '45,70p'Repository: hw-native-sys/pypto-lib
Length of output: 783
🏁 Script executed:
cat -n examples/matmul.py | sed -n '60,85p'Repository: hw-native-sys/pypto-lib
Length of output: 801
Pass the pre-computed golden output to enable numerical correctness verification.
golden_out is computed at lines 199-201, but run() is called with golden=None at line 213. This skips numerical output verification and leaves the test as a compile/device smoke test that cannot catch numerical regressions.
Fix
- golden=None,
+ golden=golden_out,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| tensor_specs, golden_out = build_tensor_specs( | |
| batch, s1, s2, q_d, nq, nkv, block_size | |
| ) | |
| program = build_glm_attention_program( | |
| batch=batch, s1=s1, s2=s2, q_d=q_d, nq=nq, nkv=nkv, block_size=block_size | |
| ) | |
| if work_dir is None: | |
| work_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "glm_attention_dump")) | |
| result = run( | |
| program=program, | |
| tensor_specs=tensor_specs, | |
| golden=None, | |
| config=RunConfig( | |
| platform=platform, | |
| device_id=device_id, | |
| rtol=2e-2, | |
| atol=2e-2, | |
| strategy=OptimizationStrategy.Default, | |
| dump_passes=dump_passes, | |
| backend_type=BackendType.Ascend950, | |
| ), | |
| ) | |
| tensor_specs, golden_out = build_tensor_specs( | |
| batch, s1, s2, q_d, nq, nkv, block_size | |
| ) | |
| program = build_glm_attention_program( | |
| batch=batch, s1=s1, s2=s2, q_d=q_d, nq=nq, nkv=nkv, block_size=block_size | |
| ) | |
| if work_dir is None: | |
| work_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "glm_attention_dump")) | |
| result = run( | |
| program=program, | |
| tensor_specs=tensor_specs, | |
| golden=golden_out, | |
| config=RunConfig( | |
| platform=platform, | |
| device_id=device_id, | |
| rtol=2e-2, | |
| atol=2e-2, | |
| strategy=OptimizationStrategy.Default, | |
| dump_passes=dump_passes, | |
| backend_type=BackendType.Ascend950, | |
| ), | |
| ) |
🧰 Tools
🪛 Ruff (0.15.6)
[warning] 199-199: Unpacked variable golden_out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/glm_v4_5/glm_attention_test.py` around lines 199 - 223, The
test computes a reference output named golden_out but calls run(...) with
golden=None, skipping numerical verification; update the run invocation (the
call to run in this file) to pass golden=golden_out instead of None so the
RunConfig-based rtol/atol checks are exercised and numerical correctness is
validated.
| group = nq_cfg // nkv_cfg | ||
| g_loop = nq_cfg // nkv_cfg // g_tile_cfg | ||
| block_num = s2_tile_cfg // block_size_cfg |
There was a problem hiding this comment.
Validate the tiling divisibility before using g_loop and block_num.
These counts are computed with floor division, but the loops below assume exact divisibility. If group % g_tile_cfg != 0, tail heads are never written; if s2_cfg is not s2_tile_cfg-aligned, the last tile still reads block_num entries and can step past block_table (for example s2=640, s2_tile=512, block_size=128).
🔧 Fail fast on unsupported parameter combinations
softmax_scale = q_d_cfg ** -0.5
+ if nq_cfg % nkv_cfg != 0:
+ raise ValueError("nq must be divisible by nkv")
group = nq_cfg // nkv_cfg
+ if group % g_tile_cfg != 0:
+ raise ValueError("g_tile must divide nq // nkv")
+ if s2_tile_cfg % block_size_cfg != 0:
+ raise ValueError("s2_tile must be a multiple of block_size")
+ if s2_cfg % s2_tile_cfg != 0:
+ raise ValueError("s2 must be a multiple of s2_tile")
g_loop = nq_cfg // nkv_cfg // g_tile_cfg
block_num = s2_tile_cfg // block_size_cfgAlso applies to: 113-120, 140-147, 169-176
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/glm_v4_5/glm_attention.py` around lines 69 - 71, Validate
that tiling parameters divide evenly before using computed counts: check that
group (computed from nq_cfg//nkv_cfg) is divisible by g_tile_cfg and that s2_cfg
is aligned to s2_tile_cfg and that s2_tile_cfg is divisible by block_size_cfg
(or that block_num computed from s2_tile_cfg//block_size_cfg does not
under/overrun block_table); if any check fails, raise/return an explicit error
(or adjust the loop bounds) so downstream loops that use g_loop and block_num
cannot read/write tails out-of-bounds. Add these checks near where group, g_loop
and block_num are computed (referencing nq_cfg, nkv_cfg, g_tile_cfg, s2_cfg,
s2_tile_cfg, block_size_cfg and block_table) and apply same validation in the
other similar sections noted (lines ~113-120, 140-147, 169-176).
| return query_norm, key_norm | ||
|
|
||
|
|
||
| def pre_attn(gate_view, key_view_2d, beta_view, tril, mask, l): |
There was a problem hiding this comment.
Unused parameter l.
The parameter l is declared but never used in the function body. Either remove it or use it where appropriate.
🔧 Proposed fix
-def pre_attn(gate_view, key_view_2d, beta_view, tril, mask, l):
+def pre_attn(gate_view, key_view_2d, beta_view, tril, mask):Also update the call site at line 263:
- gate_cum, decay_mask, a_block, key_beta = pre_attn(
- gate_view, key_norm, beta_view, tril_mask, mask, l_cfg
- )
+ gate_cum, decay_mask, a_block, key_beta = pre_attn(
+ gate_view, key_norm, beta_view, tril_mask, mask
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def pre_attn(gate_view, key_view_2d, beta_view, tril, mask, l): | |
| def pre_attn(gate_view, key_view_2d, beta_view, tril, mask): |
🧰 Tools
🪛 Ruff (0.15.6)
[error] 40-40: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/qwen3_next/gated_delta_rule_impl.py` at line 40, The function
pre_attn has an unused parameter l; remove l from the parameter list of pre_attn
(def pre_attn(gate_view, key_view_2d, beta_view, tril, mask)) and update all
call sites that currently pass a value for l to stop passing it (or
alternatively, if l is required for logic, incorporate it into the function body
where relevant); locate the function by name pre_attn and adjust its callers so
signatures match.
| def inverse_pto(attn, eye, l): | ||
| min_length = l // 8 | ||
|
|
||
| attn_8_8_blocks = [] | ||
| for i in range(8): | ||
| block = pl.slice(attn, [min_length, min_length], [min_length * i, min_length * i]) | ||
| attn_8_8_blocks.append(block) | ||
|
|
||
| attn_inv_8_blocks = [] | ||
| for i in range(8): | ||
| block = attn_8_8_blocks[i] | ||
| eye_block = pl.slice(eye, [min_length, min_length], [0, min_length * i]) | ||
| inv_block = inverse_pto_min_length(block, eye_block, min_length) | ||
| attn_inv_8_blocks.append(inv_block) | ||
|
|
||
| m_len = min_length | ||
|
|
||
| zeros_16 = pl.create_tensor([16, 16], dtype=pl.FP32) | ||
| zeros_16 = pl.mul(zeros_16, 0.0) | ||
|
|
||
| attn_inv_4_blocks = [] | ||
| for i in range(4): | ||
| inv_1 = attn_inv_8_blocks[i * 2] | ||
| inv_2 = attn_inv_8_blocks[i * 2 + 1] | ||
|
|
||
| a_21 = pl.slice(attn, [m_len, m_len], [m_len * (i * 2 + 1), m_len * (i * 2)]) | ||
| temp1 = pl.matmul(inv_2, a_21) | ||
| inv_21 = pl.matmul(temp1, inv_1) | ||
|
|
||
| inv_block = pl.create_tensor([m_len * 2, m_len * 2], dtype=pl.FP32) | ||
| inv_block = pl.assemble(inv_block, inv_1, [0, 0]) | ||
| inv_block = pl.assemble(inv_block, zeros_16, [0, m_len]) | ||
| inv_block = pl.assemble(inv_block, inv_21, [m_len, 0]) | ||
| inv_block = pl.assemble(inv_block, inv_2, [m_len, m_len]) | ||
| attn_inv_4_blocks.append(inv_block) | ||
|
|
||
| zeros_32 = pl.create_tensor([32, 32], dtype=pl.FP32) | ||
| zeros_32 = pl.mul(zeros_32, 0.0) | ||
|
|
||
| m_len = min_length * 2 | ||
| attn_inv_2_blocks = [] | ||
| for i in range(2): | ||
| inv_1 = attn_inv_4_blocks[i * 2] | ||
| inv_2 = attn_inv_4_blocks[i * 2 + 1] | ||
|
|
||
| a_21 = pl.slice(attn, [m_len, m_len], [m_len * (i * 2 + 1), m_len * (i * 2)]) | ||
| temp1 = pl.matmul(inv_2, a_21) | ||
| inv_21 = pl.matmul(temp1, inv_1) | ||
|
|
||
| inv_block = pl.create_tensor([m_len * 2, m_len * 2], dtype=pl.FP32) | ||
| inv_block = pl.assemble(inv_block, inv_1, [0, 0]) | ||
| inv_block = pl.assemble(inv_block, zeros_32, [0, m_len]) | ||
| inv_block = pl.assemble(inv_block, inv_21, [m_len, 0]) | ||
| inv_block = pl.assemble(inv_block, inv_2, [m_len, m_len]) | ||
| attn_inv_2_blocks.append(inv_block) | ||
|
|
||
| zeros_64 = pl.create_tensor([64, 64], dtype=pl.FP32) | ||
| zeros_64 = pl.mul(zeros_64, 0.0) | ||
|
|
||
| m_len = min_length * 4 | ||
| inv_1 = attn_inv_2_blocks[0] | ||
| inv_2 = attn_inv_2_blocks[1] | ||
|
|
||
| a_21 = pl.slice(attn, [m_len, m_len], [m_len, 0]) | ||
| temp1 = pl.matmul(inv_2, a_21) | ||
| inv_21 = pl.matmul(temp1, inv_1) | ||
|
|
||
| attn_inv = pl.create_tensor([l, l], dtype=pl.FP32) | ||
| attn_inv = pl.assemble(attn_inv, inv_1, [0, 0]) | ||
| attn_inv = pl.assemble(attn_inv, zeros_64, [0, m_len]) | ||
| attn_inv = pl.assemble(attn_inv, inv_21, [m_len, 0]) | ||
| attn_inv = pl.assemble(attn_inv, inv_2, [m_len, m_len]) | ||
|
|
||
| return attn_inv |
There was a problem hiding this comment.
Hardcoded block sizes break for l != 128.
The function takes l as a parameter but hardcodes zeros_16, zeros_32, and zeros_64 tensors. These dimensions only match when l = 128 (where min_length = 16). For any other value of l, the pl.assemble calls will produce incorrect results or fail due to dimension mismatches.
For example, if l = 256, then min_length = 32, but zeros_16 at line 126 would be assembled into a 64×64 block expecting a 32×32 zero region.
🐛 Proposed fix: dynamically size zeros blocks
m_len = min_length
- zeros_16 = pl.create_tensor([16, 16], dtype=pl.FP32)
- zeros_16 = pl.mul(zeros_16, 0.0)
+ zeros_m = pl.create_tensor([m_len, m_len], dtype=pl.FP32)
+ zeros_m = pl.mul(zeros_m, 0.0)
attn_inv_4_blocks = []
for i in range(4):
inv_1 = attn_inv_8_blocks[i * 2]
inv_2 = attn_inv_8_blocks[i * 2 + 1]
a_21 = pl.slice(attn, [m_len, m_len], [m_len * (i * 2 + 1), m_len * (i * 2)])
temp1 = pl.matmul(inv_2, a_21)
inv_21 = pl.matmul(temp1, inv_1)
inv_block = pl.create_tensor([m_len * 2, m_len * 2], dtype=pl.FP32)
inv_block = pl.assemble(inv_block, inv_1, [0, 0])
- inv_block = pl.assemble(inv_block, zeros_16, [0, m_len])
+ inv_block = pl.assemble(inv_block, zeros_m, [0, m_len])
inv_block = pl.assemble(inv_block, inv_21, [m_len, 0])
inv_block = pl.assemble(inv_block, inv_2, [m_len, m_len])
attn_inv_4_blocks.append(inv_block)Apply similar changes for zeros_32 and zeros_64, sizing them dynamically based on m_len at each stage.
🧰 Tools
🪛 Ruff (0.15.6)
[error] 95-95: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/qwen3_next/gated_delta_rule_impl.py` around lines 95 - 168,
The inverse_pto function incorrectly uses hardcoded zero tensors (zeros_16,
zeros_32, zeros_64) which only work when min_length==16; replace those with
dynamically-sized zero tensors created from the current m_len at each stage
(i.e., create zeros of shape [m_len, m_len] before the 2x2 assemble step, and
similarly use [m_len*2, m_len*2] or appropriate sizes for higher-level
assembles) so the pl.assemble calls in inverse_pto and the blocks built in
attn_inv_4_blocks, attn_inv_2_blocks and final attn_inv always match l and
min_length; update references to zeros_16/32/64 to use these computed zero
tensors where inv_block assembly places a zero sub-block.
| query: pl.Tensor[["T", nqk_cfg, d_cfg], pl.FP32], | ||
| key: pl.Tensor[["T", nqk_cfg, d_cfg], pl.FP32], | ||
| value: pl.Tensor[["T", nv_cfg, d_cfg], pl.FP32], | ||
| beta: pl.Tensor[["T", nv_cfg], pl.FP32], | ||
| gate: pl.Tensor[["T", nv_cfg], pl.FP32], | ||
| states: pl.Tensor[[b_cfg, nv_cfg, d_cfg, d_cfg], pl.FP32], | ||
| mask: pl.Tensor[[l_cfg, l_cfg], pl.FP32], | ||
| tril_mask: pl.Tensor[[l_cfg, l_cfg], pl.FP32], | ||
| eye: pl.Tensor[[16, l_cfg], pl.FP32], | ||
| act_seq_len: pl.Tensor[[b_cfg + 1], pl.INT32], | ||
| core_attn_out: pl.Tensor[["T", nv_cfg, d_cfg], pl.FP32], |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for similar tensor annotation patterns in the codebase
rg -n 'pl\.Tensor\[\[.*"[A-Z]"' --type py -C2Repository: hw-native-sys/pypto-lib
Length of output: 1598
🏁 Script executed:
# Search for other tensor annotation patterns to understand the convention
rg -n 'pl\.Tensor\[\[' --type py -A 0 | head -30Repository: hw-native-sys/pypto-lib
Length of output: 3026
🏁 Script executed:
# Look for any documentation or comments about symbolic dimensions
rg -n 'symbolic|dynamic.*dimension|variable.*length|"T"' --type py -i -C 2Repository: hw-native-sys/pypto-lib
Length of output: 5054
🏁 Script executed:
# Check the pypto import and see if there are type hints or definitions
rg -n 'import.*pl|from.*pypto' --type py -B 2 -A 2Repository: hw-native-sys/pypto-lib
Length of output: 12944
🏁 Script executed:
# Look for any configuration or dimension definitions that might explain the "T"
rg -n 'nqk_cfg|nv_cfg|d_cfg|b_cfg|l_cfg' --type py -B 1 -A 1 | head -40Repository: hw-native-sys/pypto-lib
Length of output: 3694
🏁 Script executed:
# Look at the complete function definition
sed -n '208,240p' examples/custom/qwen3_next/gated_delta_rule_impl.pyRepository: hw-native-sys/pypto-lib
Length of output: 1452
🏁 Script executed:
# Search for calls to chunk_gated_delta_rule to see what parameters are passed
rg -n 'chunk_gated_delta_rule' --type py -B 2 -A 5Repository: hw-native-sys/pypto-lib
Length of output: 10874
🏁 Script executed:
# Check if there's a t_cfg or similar variable used in the context around this function
sed -n '200,250p' examples/custom/qwen3_next/gated_delta_rule_impl.py | cat -nRepository: hw-native-sys/pypto-lib
Length of output: 2695
Add missing sequence length dimension parameter to build_chunk_gated_delta_rule_program.
The tensor type annotations use "T" as a string literal (lines 221-225, 231), but this is inconsistent with the codebase pattern where all tensor dimensions are Python variables or expressions (e.g., b_cfg, nqk_cfg, d_cfg, l_cfg). The function currently accepts 5 parameters (b, nqk, nv, d, l) and none map to a sequence length dimension. Add a sequence length parameter t to the function signature, convert it to t_cfg like the other dimensions, and replace all "T" string literals with t_cfg in the tensor annotations.
🧰 Tools
🪛 Ruff (0.15.6)
[error] 221-221: Undefined name T
(F821)
[error] 222-222: Undefined name T
(F821)
[error] 223-223: Undefined name T
(F821)
[error] 224-224: Undefined name T
(F821)
[error] 225-225: Undefined name T
(F821)
[error] 231-231: Undefined name T
(F821)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/qwen3_next/gated_delta_rule_impl.py` around lines 221 - 231,
The function build_chunk_gated_delta_rule_program is missing a sequence-length
parameter: add a new parameter t, create t_cfg alongside
b_cfg/nqk_cfg/nv_cfg/d_cfg/l_cfg (e.g., t_cfg = Var(t, "t_cfg") or equivalent
pattern used in this file), and replace all tensor dimension string literals "T"
in the type annotations for query, key, value, beta, gate, and core_attn_out
with t_cfg so those tensors use the concrete sequence-length dimension; update
the function signature and any uses of those annotations to reference t_cfg
instead of "T".
| output_final_state = kwargs.get("output_final_state") | ||
| use_qk_l2norm_in_kernel = kwargs.get("use_qk_l2norm_in_kernel") | ||
|
|
||
| b, n, s, d = value.shape |
There was a problem hiding this comment.
Remove unused unpacking.
The variables b, n, s, d are unpacked but never used. The needed dimensions are correctly extracted from key.shape at line 214.
🔧 Proposed fix
- b, n, s, d = value.shape
+ # Dimensions extracted from key.shape belowOr simply remove the line entirely.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| b, n, s, d = value.shape | |
| # Dimensions extracted from key.shape below |
🧰 Tools
🪛 Ruff (0.15.6)
[warning] 207-207: Unpacked variable b is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 207-207: Unpacked variable n is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 207-207: Unpacked variable s is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 207-207: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/qwen3_next/qwen3_next_gated_delta_rule.py` at line 207,
Remove the unused tuple unpacking "b, n, s, d = value.shape" in
qwen3_next_gated_delta_rule.py (it’s unused and dimensions are taken from
key.shape later); delete that line so the function no longer creates unused
variables b, n, s, d.
| work_dir: str | None = None, | ||
| dump_passes: bool = True, | ||
| ): | ||
| from gated_delta_rule_impl import build_chunk_gated_delta_rule_program |
There was a problem hiding this comment.
Relative import may fail at runtime.
The import from gated_delta_rule_impl import ... assumes the current working directory contains gated_delta_rule_impl.py. This will fail if the script is run from a different directory. Consider using an absolute import or adjusting sys.path.
🔧 Proposed fix
+import sys
+import os
+sys.path.insert(0, os.path.dirname(__file__))
from gated_delta_rule_impl import build_chunk_gated_delta_rule_programOr use a package-relative import:
-from gated_delta_rule_impl import build_chunk_gated_delta_rule_program
+from .gated_delta_rule_impl import build_chunk_gated_delta_rule_program🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/qwen3_next/qwen3_next_gated_delta_rule.py` at line 315, The
top-level import "from gated_delta_rule_impl import
build_chunk_gated_delta_rule_program" can fail when the script is executed from
a different CWD; update the import to be package-relative or otherwise robust:
either convert to a package-relative import (e.g., use a relative import from
the correct package) or modify sys.path at startup to include the module's
directory before importing, ensuring the symbol
build_chunk_gated_delta_rule_program is still imported from the correct module;
locate the import statement and replace it with one of these approaches so the
module resolves reliably at runtime.
| result = run( | ||
| program=program, | ||
| tensor_specs=tensor_specs, | ||
| golden=None, | ||
| config=RunConfig( | ||
| platform=platform, | ||
| device_id=device_id, | ||
| rtol=1e-3, | ||
| atol=1e-3, | ||
| strategy=OptimizationStrategy.Default, | ||
| dump_passes=dump_passes, | ||
| backend_type=BackendType.Ascend950, | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Golden outputs are computed but never used for validation.
The golden_data computed at line 317 via gen_data() is not passed to run(). With golden=None, the runtime will not validate correctness against the expected outputs. If validation is intended, pass a golden callback or compare results afterward.
🐛 Proposed fix: validate outputs against golden
Add post-run validation:
result = run(
program=program,
tensor_specs=tensor_specs,
golden=None, # or pass a golden callback
config=RunConfig(...),
)
# Validate against golden if run succeeded
if result.passed and result.outputs:
core_attn_result = result.outputs.get("core_attn_out")
last_state_result = result.outputs.get("last_state_data")
if core_attn_result is not None:
torch.testing.assert_close(
core_attn_result,
golden_data["core_attn_out"].transpose(0, 1), # adjust shape if needed
rtol=1e-3, atol=1e-3
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/custom/qwen3_next/qwen3_next_gated_delta_rule.py` around lines 331 -
344, golden_data produced by gen_data() is never used — modify the run
invocation or add post-run checks to validate outputs: either pass a golden
callback or keep golden=None but after run(...) inspect result (the result
object returned by run) and compare result.outputs["core_attn_out"] and
result.outputs["last_state_data"] against golden_data entries using
torch.testing.assert_close with rtol=1e-3 and atol=1e-3; ensure any necessary
shape/transpose adjustments (e.g., transpose(0,1) on
golden_data["core_attn_out"]) before comparison and only run assertions if
result.passed and result.outputs exist.
migrate glm and qwen3 next to pto3.0
Summary by CodeRabbit
New Features
Chores