From 7daacc39789c57aa4a27e5a2a3af6fbd65c8a127 Mon Sep 17 00:00:00 2001 From: OceanLi <122793010+ohdearquant@users.noreply.github.com> Date: Sun, 24 May 2026 22:49:49 -0400 Subject: [PATCH] =?UTF-8?q?fix(inference):=20RoPE=20pairing=20=E2=80=94=20?= =?UTF-8?q?stride-half=20(not=20interleaved)=20closes=200.74=20PPL=20gap?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `apply_partial_rope` and Metal `partial_rope_interleaved` rotated consecutive pairs (2i, 2i+1) — the GPT-J / `traditional=True` convention. Qwen3.5 is trained with stride-half pairs (i, half+i) — HF transformers' `rotate_half` and MLX's `nn.RoPE(traditional=False)`. The comment "Qwen3.5 uses mrope_interleaved=true" was misread: `mrope_interleaved` in config controls multimodal-position section interleaving (M-RoPE for video/image tokens), not the 1-D text pairing convention. Empirically verified against MLX's nn.RoPE(traditional=False, rope_dim=64, base=1e7, position=5): stride-half matches with max-diff 8e-6; interleaved diverges with max-diff 67.5. WikiText-2 PPL on Qwen3.5-0.8B (window=512, stride=256, 2041 scored tokens): before: 16.6242 (lattice) vs 15.8580 (MLX) → +0.77 PPL gap after: 15.8870 (lattice) vs 15.8580 (MLX) → +0.029 PPL gap argmax agreement at pos 0: 0% before (lat=695 echoed input token, mlx=220) → 97.3% after across single 512-token window. The wrong RoPE scrambled positional information in all 6 full-attention layers of the 24-layer hybrid (75% GDN, 25% full attention) stack. The hybrid transformer's hidden state stopped transforming, producing the "embedding leakage" signature where logits peaked at the input token — diagnosed earlier this session, now explained. Files: - crates/inference/src/model/qwen35/forward.rs:391 — CPU apply_partial_rope - crates/inference/src/forward/metal_qwen35.rs:346 — Metal kernel (name kept for ABI continuity) - crates/inference/src/speculative.rs:1090 — mtp_apply_partial_rope - crates/inference/src/forward/metal_qwen35.rs golden snapshot — updated from stale -22.62 (pre-(1+gamma)) to the correct -45.24 derived value - crates/inference/src/forward/metal_qwen35.rs test inits — add missing `grammar: None` field so tests compile Tests: 843 pass, 0 fail. Clippy clean. Co-Authored-By: Claude Opus 4.7 --- crates/inference/src/forward/metal_qwen35.rs | 33 +++++++++++++------- crates/inference/src/model/qwen35/forward.rs | 15 +++++---- crates/inference/src/speculative.rs | 10 +++--- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/crates/inference/src/forward/metal_qwen35.rs b/crates/inference/src/forward/metal_qwen35.rs index 3dfd0ef2..8ee2d615 100644 --- a/crates/inference/src/forward/metal_qwen35.rs +++ b/crates/inference/src/forward/metal_qwen35.rs @@ -339,8 +339,9 @@ kernel void rms_norm_qwen35( } } -// ===== Interleaved Partial RoPE for Qwen3.5 full-attention layers ===== -// Pairs are (2i, 2i+1) for i in 0..half_rope_dim. +// ===== Stride-Half Partial RoPE for Qwen3.5 full-attention layers ===== +// Pairs are (i, half+i) for i in 0..half_rope_dim — HF rotate_half convention, +// matching MLX nn.RoPE(traditional=False). Kernel name kept for ABI continuity. // Only first rope_dim dimensions are rotated; the rest are untouched. // Operates on x[num_heads * head_dim] for a single position. kernel void partial_rope_interleaved( @@ -365,9 +366,9 @@ kernel void partial_rope_interleaved( float cos_val = cos_tab[cs_base + pair]; float sin_val = sin_tab[cs_base + pair]; - // Interleaved: rotate (2*pair, 2*pair+1) within each head - uint idx0 = base + 2 * pair; - uint idx1 = base + 2 * pair + 1; + // Stride-half: rotate (pair, half_rope_dim + pair) within each head + uint idx0 = base + pair; + uint idx1 = base + half_rope_dim + pair; float x0 = x[idx0]; float x1 = x[idx1]; x[idx0] = x0 * cos_val - x1 * sin_val; @@ -13554,17 +13555,23 @@ kernel void decode_attention_reference( let logits = state.forward_step(42, 0); assert_eq!(logits.len(), cfg.vocab_size); let actual = &logits[..10]; + // Math: token 42's embedding has a single non-zero element (=1.0). After + // input_layernorm (qwen35_rms_norm uses 1+gamma weights), the norm output + // is x * sqrt(hidden)/||x|| * (1 + gamma) = 1 * sqrt(512) * 2 = 45.2543. + // All attention and FFN weights are zero, so the residual stream stays at + // the input embedding. final_norm produces the same 45.2543 magnitude. + // Tied lm_head: logit[v] = embed[v, 0] * 45.2543 → -45.25, 0, 45.25 pattern. let expected = [ - -22.6216266_f32, + -45.243256_f32, 0.0, - 22.6216266, - -22.6216266, + 45.243256, + -45.243256, 0.0, - 22.6216266, - -22.6216266, + 45.243256, + -45.243256, 0.0, - 22.6216266, - -22.6216266, + 45.243256, + -45.243256, ]; let max_abs_diff = actual .iter() @@ -15024,6 +15031,7 @@ kernel void decode_attention_reference( stop_token_ids: vec![], enable_thinking: false, enable_mtp: Some(false), + grammar: None, }; let out = with_self_spec_env(|| { @@ -15118,6 +15126,7 @@ kernel void decode_attention_reference( stop_token_ids: vec![], enable_thinking: false, enable_mtp: Some(false), + grammar: None, }; let mut state = with_self_spec_env(|| { diff --git a/crates/inference/src/model/qwen35/forward.rs b/crates/inference/src/model/qwen35/forward.rs index fa5e4d94..472ac330 100644 --- a/crates/inference/src/model/qwen35/forward.rs +++ b/crates/inference/src/model/qwen35/forward.rs @@ -385,8 +385,11 @@ impl Qwen35Model { let _ = q_dim; // used by caller for gate_z } - /// Apply partial RoPE with INTERLEAVED pairing. - /// Qwen3.5 uses mrope_interleaved=true: pairs are (0,1), (2,3), (4,5), ... + /// Apply partial RoPE with STRIDE-HALF (HF rotate_half / MLX traditional=False) pairing. + /// Pairs are (i, half+i) for i in 0..half. The `mrope_interleaved` flag in config + /// applies only to multimodal-position interleaving (M-RoPE sections); for 1-D text + /// positions, Qwen3.5 uses standard stride-half pairing (verified empirically against + /// MLX nn.RoPE(traditional=False) — diff ~1e-5 with stride-half, ~70 with interleaved). /// Only the first `rope_dim` dimensions are rotated. pub(crate) fn apply_partial_rope( &self, @@ -399,10 +402,10 @@ impl Qwen35Model { for i in 0..half { let cos_val = self.rope.cos_at(base + i); let sin_val = self.rope.sin_at(base + i); - let x0 = head_vec[2 * i]; - let x1 = head_vec[2 * i + 1]; - head_vec[2 * i] = x0 * cos_val - x1 * sin_val; - head_vec[2 * i + 1] = x0 * sin_val + x1 * cos_val; + let x0 = head_vec[i]; + let x1 = head_vec[half + i]; + head_vec[i] = x0 * cos_val - x1 * sin_val; + head_vec[half + i] = x0 * sin_val + x1 * cos_val; } } diff --git a/crates/inference/src/speculative.rs b/crates/inference/src/speculative.rs index 6b3acca1..c799ce58 100644 --- a/crates/inference/src/speculative.rs +++ b/crates/inference/src/speculative.rs @@ -1086,7 +1086,7 @@ impl<'a> MtpVerifier<'a> { } } -/// Interleaved partial RoPE: rotate pairs (2i, 2i+1) for i in 0..rope_dim/2. +/// Stride-half partial RoPE: rotate pairs (i, half+i) for i in 0..rope_dim/2 — matches HF rotate_half / MLX traditional=False. fn mtp_apply_partial_rope( head_vec: &mut [f32], position: usize, @@ -1098,10 +1098,10 @@ fn mtp_apply_partial_rope( for i in 0..half { let cos_val = rope.cos_at(base + i); let sin_val = rope.sin_at(base + i); - let x0 = head_vec[2 * i]; - let x1 = head_vec[2 * i + 1]; - head_vec[2 * i] = x0 * cos_val - x1 * sin_val; - head_vec[2 * i + 1] = x0 * sin_val + x1 * cos_val; + let x0 = head_vec[i]; + let x1 = head_vec[half + i]; + head_vec[i] = x0 * cos_val - x1 * sin_val; + head_vec[half + i] = x0 * sin_val + x1 * cos_val; } }