diff --git a/crates/inference/src/forward/metal_qwen35.rs b/crates/inference/src/forward/metal_qwen35.rs index 3dfd0ef2..8ee2d615 100644 --- a/crates/inference/src/forward/metal_qwen35.rs +++ b/crates/inference/src/forward/metal_qwen35.rs @@ -339,8 +339,9 @@ kernel void rms_norm_qwen35( } } -// ===== Interleaved Partial RoPE for Qwen3.5 full-attention layers ===== -// Pairs are (2i, 2i+1) for i in 0..half_rope_dim. +// ===== Stride-Half Partial RoPE for Qwen3.5 full-attention layers ===== +// Pairs are (i, half+i) for i in 0..half_rope_dim — HF rotate_half convention, +// matching MLX nn.RoPE(traditional=False). Kernel name kept for ABI continuity. // Only first rope_dim dimensions are rotated; the rest are untouched. // Operates on x[num_heads * head_dim] for a single position. kernel void partial_rope_interleaved( @@ -365,9 +366,9 @@ kernel void partial_rope_interleaved( float cos_val = cos_tab[cs_base + pair]; float sin_val = sin_tab[cs_base + pair]; - // Interleaved: rotate (2*pair, 2*pair+1) within each head - uint idx0 = base + 2 * pair; - uint idx1 = base + 2 * pair + 1; + // Stride-half: rotate (pair, half_rope_dim + pair) within each head + uint idx0 = base + pair; + uint idx1 = base + half_rope_dim + pair; float x0 = x[idx0]; float x1 = x[idx1]; x[idx0] = x0 * cos_val - x1 * sin_val; @@ -13554,17 +13555,23 @@ kernel void decode_attention_reference( let logits = state.forward_step(42, 0); assert_eq!(logits.len(), cfg.vocab_size); let actual = &logits[..10]; + // Math: token 42's embedding has a single non-zero element (=1.0). After + // input_layernorm (qwen35_rms_norm uses 1+gamma weights), the norm output + // is x * sqrt(hidden)/||x|| * (1 + gamma) = 1 * sqrt(512) * 2 = 45.2543. + // All attention and FFN weights are zero, so the residual stream stays at + // the input embedding. final_norm produces the same 45.2543 magnitude. + // Tied lm_head: logit[v] = embed[v, 0] * 45.2543 → -45.25, 0, 45.25 pattern. let expected = [ - -22.6216266_f32, + -45.243256_f32, 0.0, - 22.6216266, - -22.6216266, + 45.243256, + -45.243256, 0.0, - 22.6216266, - -22.6216266, + 45.243256, + -45.243256, 0.0, - 22.6216266, - -22.6216266, + 45.243256, + -45.243256, ]; let max_abs_diff = actual .iter() @@ -15024,6 +15031,7 @@ kernel void decode_attention_reference( stop_token_ids: vec![], enable_thinking: false, enable_mtp: Some(false), + grammar: None, }; let out = with_self_spec_env(|| { @@ -15118,6 +15126,7 @@ kernel void decode_attention_reference( stop_token_ids: vec![], enable_thinking: false, enable_mtp: Some(false), + grammar: None, }; let mut state = with_self_spec_env(|| { diff --git a/crates/inference/src/model/qwen35/forward.rs b/crates/inference/src/model/qwen35/forward.rs index fa5e4d94..472ac330 100644 --- a/crates/inference/src/model/qwen35/forward.rs +++ b/crates/inference/src/model/qwen35/forward.rs @@ -385,8 +385,11 @@ impl Qwen35Model { let _ = q_dim; // used by caller for gate_z } - /// Apply partial RoPE with INTERLEAVED pairing. - /// Qwen3.5 uses mrope_interleaved=true: pairs are (0,1), (2,3), (4,5), ... + /// Apply partial RoPE with STRIDE-HALF (HF rotate_half / MLX traditional=False) pairing. + /// Pairs are (i, half+i) for i in 0..half. The `mrope_interleaved` flag in config + /// applies only to multimodal-position interleaving (M-RoPE sections); for 1-D text + /// positions, Qwen3.5 uses standard stride-half pairing (verified empirically against + /// MLX nn.RoPE(traditional=False) — diff ~1e-5 with stride-half, ~70 with interleaved). /// Only the first `rope_dim` dimensions are rotated. pub(crate) fn apply_partial_rope( &self, @@ -399,10 +402,10 @@ impl Qwen35Model { for i in 0..half { let cos_val = self.rope.cos_at(base + i); let sin_val = self.rope.sin_at(base + i); - let x0 = head_vec[2 * i]; - let x1 = head_vec[2 * i + 1]; - head_vec[2 * i] = x0 * cos_val - x1 * sin_val; - head_vec[2 * i + 1] = x0 * sin_val + x1 * cos_val; + let x0 = head_vec[i]; + let x1 = head_vec[half + i]; + head_vec[i] = x0 * cos_val - x1 * sin_val; + head_vec[half + i] = x0 * sin_val + x1 * cos_val; } } diff --git a/crates/inference/src/speculative.rs b/crates/inference/src/speculative.rs index 6b3acca1..c799ce58 100644 --- a/crates/inference/src/speculative.rs +++ b/crates/inference/src/speculative.rs @@ -1086,7 +1086,7 @@ impl<'a> MtpVerifier<'a> { } } -/// Interleaved partial RoPE: rotate pairs (2i, 2i+1) for i in 0..rope_dim/2. +/// Stride-half partial RoPE: rotate pairs (i, half+i) for i in 0..rope_dim/2 — matches HF rotate_half / MLX traditional=False. fn mtp_apply_partial_rope( head_vec: &mut [f32], position: usize, @@ -1098,10 +1098,10 @@ fn mtp_apply_partial_rope( for i in 0..half { let cos_val = rope.cos_at(base + i); let sin_val = rope.sin_at(base + i); - let x0 = head_vec[2 * i]; - let x1 = head_vec[2 * i + 1]; - head_vec[2 * i] = x0 * cos_val - x1 * sin_val; - head_vec[2 * i + 1] = x0 * sin_val + x1 * cos_val; + let x0 = head_vec[i]; + let x1 = head_vec[half + i]; + head_vec[i] = x0 * cos_val - x1 * sin_val; + head_vec[half + i] = x0 * sin_val + x1 * cos_val; } }