Skip to content

fix(test): chain DFlash verify/replay produce valid logits at all positions#346

Open
davide221 wants to merge 1 commit into
mainfrom
fix/dflash-chain-verify-mask-argmax
Open

fix(test): chain DFlash verify/replay produce valid logits at all positions#346
davide221 wants to merge 1 commit into
mainfrom
fix/dflash-chain-verify-mask-argmax

Conversation

@davide221
Copy link
Copy Markdown
Contributor

Summary

The chain (non-DDTree) DFlash decode path in test_dflash aborts after step 0 on Qwen3.6 and generates 0 tokens. Running the same prompt with --ddtree works, so the bug is specific to the chain verify/replay. This fixes it; chain mode now decodes correctly.

Root cause

Two pre-existing bugs in server/test/test_dflash.cpp, both independent of the DDTree path:

  1. Causal-mask stride mismatch (the main bug). The batched verify and the legacy replay call build_causal_mask(...) without kv_pad_override, so the mask buffer is strided by align_up(win_len, kq_stride_pad), while the mask tensor is allocated with stride align_up(max_ctx + n_tokens, kq_stride_pad). Only query row 0 lands at the correct offset; rows 1.. read an unwritten region, so attention (and therefore logits) is zeroed for every verify/replay position > 0. DDTree already passes the override; the chain path did not. Empirically, logits maxabs was p0=26.5, p1..=0.0; with the override, p1..=26.x.

  2. Broken GPU argmax read. The chain verify read sg.argmax_tokens, whose CUDA ggml_argmax returns -1 after position 0 even when logits are valid (the same defect fixed for DDTree in 1b3882d). Switched to reading full sg.logits + CPU argmax_f32 per position.

With (1) alone the verify is correct but the replay still corrupts last_tok; both are needed.

Validation

RTX 3090 (sm_86) / CUDA 12, Qwen3.6-27B Q4_K_M target + 3.6 DFlash draft, on main (f59f2a3):

  • Before: [step 0] accept_n=2 bonus=-1, then silent abort, 0 tokens.
  • After: 7 draft steps, accepted=41/112 (36.6%/step), ~50 tok/s, coherent output identical to the DDTree/server path through the deterministic prefix (diverges only post-EOS under temp=1.0 sampling).

Notes

This is the chain test-harness path; the production server uses DDTree (unaffected). Issue #259 is a different (V100/Volta MMA) problem.

🧙 Built with WOZCODE

…itions

The chain (non-DDTree) decode path in test_dflash aborted after step 0 on
Qwen3.6 and generated 0 tokens; running with --ddtree worked, so the bug was
specific to the chain verify/replay. Two pre-existing root causes, both
independent of the DDTree path:

1. build_causal_mask for the batched verify and the legacy replay was called
   without kv_pad_override, so the mask buffer was strided by
   align_up(win_len) while the mask tensor was allocated with stride
   align_up(max_ctx + n_tokens). Only query row 0 landed at the right offset;
   rows 1.. read an unwritten region, zeroing attention (and logits) for every
   verify/replay position > 0. DDTree already passed this; the chain did not.

2. The chain verify read sg.argmax_tokens, whose CUDA ggml_argmax returns -1
   after position 0 even when logits are valid (same defect fixed for DDTree
   in 1b3882d). Switched to reading full logits + CPU argmax per position.

Verified on main (f59f2a3), RTX 3090 / Qwen3.6-27B Q4_K_M: 7 steps,
accepted=41/112 (36.6%/step), coherent output matching the DDTree path.

Co-Authored-By: WOZCODE <contact@withwoz.com>
Copy link
Copy Markdown
Contributor

@cubic-dev-ai cubic-dev-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issues found across 1 file

Re-trigger cubic

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant