Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions crates/inference/src/forward/metal_qwen35.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,9 @@ kernel void rms_norm_qwen35(
}
}

// ===== Interleaved Partial RoPE for Qwen3.5 full-attention layers =====
// Pairs are (2i, 2i+1) for i in 0..half_rope_dim.
// ===== Stride-Half Partial RoPE for Qwen3.5 full-attention layers =====
// Pairs are (i, half+i) for i in 0..half_rope_dim — HF rotate_half convention,
// matching MLX nn.RoPE(traditional=False). Kernel name kept for ABI continuity.
// Only first rope_dim dimensions are rotated; the rest are untouched.
// Operates on x[num_heads * head_dim] for a single position.
kernel void partial_rope_interleaved(
Expand All @@ -365,9 +366,9 @@ kernel void partial_rope_interleaved(
float cos_val = cos_tab[cs_base + pair];
float sin_val = sin_tab[cs_base + pair];

// Interleaved: rotate (2*pair, 2*pair+1) within each head
uint idx0 = base + 2 * pair;
uint idx1 = base + 2 * pair + 1;
// Stride-half: rotate (pair, half_rope_dim + pair) within each head
uint idx0 = base + pair;
uint idx1 = base + half_rope_dim + pair;
float x0 = x[idx0];
float x1 = x[idx1];
x[idx0] = x0 * cos_val - x1 * sin_val;
Expand Down Expand Up @@ -13554,17 +13555,23 @@ kernel void decode_attention_reference(
let logits = state.forward_step(42, 0);
assert_eq!(logits.len(), cfg.vocab_size);
let actual = &logits[..10];
// Math: token 42's embedding has a single non-zero element (=1.0). After
// input_layernorm (qwen35_rms_norm uses 1+gamma weights), the norm output
// is x * sqrt(hidden)/||x|| * (1 + gamma) = 1 * sqrt(512) * 2 = 45.2543.
// All attention and FFN weights are zero, so the residual stream stays at
// the input embedding. final_norm produces the same 45.2543 magnitude.
// Tied lm_head: logit[v] = embed[v, 0] * 45.2543 → -45.25, 0, 45.25 pattern.
let expected = [
-22.6216266_f32,
-45.243256_f32,
0.0,
22.6216266,
-22.6216266,
45.243256,
-45.243256,
0.0,
22.6216266,
-22.6216266,
45.243256,
-45.243256,
0.0,
22.6216266,
-22.6216266,
45.243256,
-45.243256,
];
let max_abs_diff = actual
.iter()
Expand Down Expand Up @@ -15024,6 +15031,7 @@ kernel void decode_attention_reference(
stop_token_ids: vec![],
enable_thinking: false,
enable_mtp: Some(false),
grammar: None,
};

let out = with_self_spec_env(|| {
Expand Down Expand Up @@ -15118,6 +15126,7 @@ kernel void decode_attention_reference(
stop_token_ids: vec![],
enable_thinking: false,
enable_mtp: Some(false),
grammar: None,
};

let mut state = with_self_spec_env(|| {
Expand Down
15 changes: 9 additions & 6 deletions crates/inference/src/model/qwen35/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,11 @@ impl Qwen35Model {
let _ = q_dim; // used by caller for gate_z
}

/// Apply partial RoPE with INTERLEAVED pairing.
/// Qwen3.5 uses mrope_interleaved=true: pairs are (0,1), (2,3), (4,5), ...
/// Apply partial RoPE with STRIDE-HALF (HF rotate_half / MLX traditional=False) pairing.
/// Pairs are (i, half+i) for i in 0..half. The `mrope_interleaved` flag in config
/// applies only to multimodal-position interleaving (M-RoPE sections); for 1-D text
/// positions, Qwen3.5 uses standard stride-half pairing (verified empirically against
/// MLX nn.RoPE(traditional=False) — diff ~1e-5 with stride-half, ~70 with interleaved).
/// Only the first `rope_dim` dimensions are rotated.
pub(crate) fn apply_partial_rope(
&self,
Expand All @@ -399,10 +402,10 @@ impl Qwen35Model {
for i in 0..half {
let cos_val = self.rope.cos_at(base + i);
let sin_val = self.rope.sin_at(base + i);
let x0 = head_vec[2 * i];
let x1 = head_vec[2 * i + 1];
head_vec[2 * i] = x0 * cos_val - x1 * sin_val;
head_vec[2 * i + 1] = x0 * sin_val + x1 * cos_val;
let x0 = head_vec[i];
let x1 = head_vec[half + i];
head_vec[i] = x0 * cos_val - x1 * sin_val;
head_vec[half + i] = x0 * sin_val + x1 * cos_val;
}
}

Expand Down
10 changes: 5 additions & 5 deletions crates/inference/src/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,7 @@ impl<'a> MtpVerifier<'a> {
}
}

/// Interleaved partial RoPE: rotate pairs (2i, 2i+1) for i in 0..rope_dim/2.
/// Stride-half partial RoPE: rotate pairs (i, half+i) for i in 0..rope_dim/2 — matches HF rotate_half / MLX traditional=False.
fn mtp_apply_partial_rope(
head_vec: &mut [f32],
position: usize,
Expand All @@ -1098,10 +1098,10 @@ fn mtp_apply_partial_rope(
for i in 0..half {
let cos_val = rope.cos_at(base + i);
let sin_val = rope.sin_at(base + i);
let x0 = head_vec[2 * i];
let x1 = head_vec[2 * i + 1];
head_vec[2 * i] = x0 * cos_val - x1 * sin_val;
head_vec[2 * i + 1] = x0 * sin_val + x1 * cos_val;
let x0 = head_vec[i];
let x1 = head_vec[half + i];
head_vec[i] = x0 * cos_val - x1 * sin_val;
head_vec[half + i] = x0 * sin_val + x1 * cos_val;
}
}

Expand Down
Loading