From 447d05234349e61c5da60c95c88b065680a61fec Mon Sep 17 00:00:00 2001 From: Sergei Blinov Date: Tue, 21 Apr 2026 13:43:33 +0200 Subject: [PATCH] fix: support quantized MLX safetensors in extract-index Unsloth Gemma 4 MLX checkpoints store packed U32 weights with separate scales and biases, so extraction skipped embeddings and failed on Apple Silicon. Dequantize those tensors during loading and streaming extraction so quantized MLX models can build vindexes. --- .../larql-models/src/loading/safetensors.rs | 188 +++++++-- crates/larql-models/src/quant/mlx_affine.rs | 204 ++++++++++ crates/larql-models/src/quant/mod.rs | 3 +- crates/larql-vindex/src/extract/streaming.rs | 326 +++++++++++---- crates/larql-vindex/src/format/quant/mod.rs | 3 +- crates/larql-vindex/src/format/weights.rs | 377 +++++++++++++++--- 6 files changed, 934 insertions(+), 167 deletions(-) create mode 100644 crates/larql-models/src/quant/mlx_affine.rs diff --git a/crates/larql-models/src/loading/safetensors.rs b/crates/larql-models/src/loading/safetensors.rs index 0212cfe6..edce25c8 100644 --- a/crates/larql-models/src/loading/safetensors.rs +++ b/crates/larql-models/src/loading/safetensors.rs @@ -8,8 +8,8 @@ use std::path::{Path, PathBuf}; use ndarray::Array2; -use crate::weights::ModelWeights; use crate::detect::ModelError; +use crate::weights::ModelWeights; /// Load model weights from a directory or file. /// @@ -43,16 +43,17 @@ pub fn load_model_dir(path: impl AsRef) -> Result = std::fs::read_dir(path)? .filter_map(|e| e.ok()) @@ -97,7 +98,9 @@ pub fn load_model_dir(path: impl AsRef) -> Result d, Err(_) => continue, @@ -108,18 +111,48 @@ pub fn load_model_dir(path: impl AsRef) -> Result { vectors.insert(key, data); } + 1 => { + vectors.insert(key, data); + } _ => {} } } } else { // Standard float path for (name, view) in st.tensors() { + if name.ends_with(".scales") || name.ends_with(".biases") { + continue; + } let key = normalize_key(&name, prefixes); let shape = view.shape(); - let data = match tensor_to_f32(&view) { - Ok(d) => d, - Err(_) => continue, + let data = match view.dtype() { + safetensors::Dtype::U32 if shape.len() == 2 => { + match dequantize_mlx_affine_tensor(&st, &name, &view, mlx_affine_group_size) + { + Ok((data, _, cols)) => { + // Replace packed width with dequantized width. + let mut dequant_shape = shape.to_vec(); + dequant_shape[1] = cols; + match dequant_shape.len() { + 2 => { + let arr = Array2::from_shape_vec( + (dequant_shape[0], dequant_shape[1]), + data, + ) + .map_err(|e| ModelError::Parse(e.to_string()))?; + tensors.insert(key, arr.into_shared()); + continue; + } + _ => continue, + } + } + Err(_) => continue, + } + } + _ => match tensor_to_f32(&view) { + Ok(d) => d, + Err(_) => continue, + }, }; match shape.len() { 2 => { @@ -127,9 +160,13 @@ pub fn load_model_dir(path: impl AsRef) -> Result { vectors.insert(key, data); } + 1 => { + vectors.insert(key, data); + } // 0D scalar tensors (e.g., layer_scalar) → store as 1-element vector - 0 => { vectors.insert(key, data); } + 0 => { + vectors.insert(key, data); + } _ => {} } } @@ -167,6 +204,19 @@ pub fn load_model_dir(path: impl AsRef) -> Result Option { + let config_path = model_dir.join("config.json"); + let text = std::fs::read_to_string(config_path).ok()?; + let json: serde_json::Value = serde_json::from_str(&text).ok()?; + + json.get("quantization") + .or_else(|| json.get("quantization_config")) + .and_then(|q| q.get("group_size")) + .and_then(|v| v.as_u64()) + .map(|v| v as usize) +} + /// Resolve a HuggingFace model ID or path to a local directory or GGUF file. pub fn resolve_model_path(model: &str) -> Result { let path = PathBuf::from(model); @@ -191,11 +241,17 @@ pub fn resolve_model_path(model: &str) -> Result { if let Ok(entries) = std::fs::read_dir(&hf_cache) { for entry in entries.flatten() { let p = entry.path(); - if !p.is_dir() { continue; } + if !p.is_dir() { + continue; + } // Prefer snapshot with safetensors files - let has_st = std::fs::read_dir(&p).ok().map(|rd| { - rd.flatten().any(|e| e.path().extension().is_some_and(|ext| ext == "safetensors")) - }).unwrap_or(false); + let has_st = std::fs::read_dir(&p) + .ok() + .map(|rd| { + rd.flatten() + .any(|e| e.path().extension().is_some_and(|ext| ext == "safetensors")) + }) + .unwrap_or(false); if has_st { return Ok(p); } @@ -239,20 +295,26 @@ fn dequantize_mxfp4_experts( ) -> Result<(), ModelError> { // Find all gate_up_proj_blocks tensors (one per layer) for name in tensor_names { - if !name.ends_with(".gate_up_proj_blocks") { continue; } + if !name.ends_with(".gate_up_proj_blocks") { + continue; + } let scales_name = name.replace("_blocks", "_scales"); let down_blocks_name = name.replace("gate_up_proj_blocks", "down_proj_blocks"); let down_scales_name = name.replace("gate_up_proj_blocks", "down_proj_scales"); // Get tensor views - let blocks_view = st.tensor(name) + let blocks_view = st + .tensor(name) .map_err(|e| ModelError::Parse(format!("MXFP4 blocks: {e}")))?; - let scales_view = st.tensor(&scales_name) + let scales_view = st + .tensor(&scales_name) .map_err(|e| ModelError::Parse(format!("MXFP4 scales: {e}")))?; let shape = blocks_view.shape(); - if shape.len() != 4 { continue; } + if shape.len() != 4 { + continue; + } let num_experts = shape[0]; let out_features = shape[1]; // 2*hidden for gate_up, hidden for down @@ -262,8 +324,11 @@ fn dequantize_mxfp4_experts( // Dequantize gate_up (fused: first half = gate, second half = up) let expert_data = crate::quant::mxfp4::dequantize_all_experts( - blocks_view.data(), scales_view.data(), - num_experts, out_features, groups, + blocks_view.data(), + scales_view.data(), + num_experts, + out_features, + groups, ); // Extract layer number from key @@ -280,12 +345,18 @@ fn dequantize_mxfp4_experts( let gate_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w1.weight"); let up_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w3.weight"); - tensors.insert(gate_key, + tensors.insert( + gate_key, Array2::from_shape_vec((half, in_features), gate_data) - .map_err(|e| ModelError::Parse(e.to_string()))?.into_shared()); - tensors.insert(up_key, + .map_err(|e| ModelError::Parse(e.to_string()))? + .into_shared(), + ); + tensors.insert( + up_key, Array2::from_shape_vec((half, in_features), up_data) - .map_err(|e| ModelError::Parse(e.to_string()))?.into_shared()); + .map_err(|e| ModelError::Parse(e.to_string()))? + .into_shared(), + ); } // Dequantize down projection @@ -297,14 +368,21 @@ fn dequantize_mxfp4_experts( let down_in = down_groups * 32; let down_experts = crate::quant::mxfp4::dequantize_all_experts( - db.data(), ds.data(), num_experts, down_out, down_groups, + db.data(), + ds.data(), + num_experts, + down_out, + down_groups, ); for (e, data) in down_experts.iter().enumerate() { let down_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w2.weight"); - tensors.insert(down_key, + tensors.insert( + down_key, Array2::from_shape_vec((down_out, down_in), data.clone()) - .map_err(|e| ModelError::Parse(e.to_string()))?.into_shared()); + .map_err(|e| ModelError::Parse(e.to_string()))? + .into_shared(), + ); } } } @@ -316,9 +394,12 @@ fn dequantize_mxfp4_experts( let s = router_view.shape(); if s.len() == 2 { let router_key = format!("{layer_prefix}.block_sparse_moe.gate.weight"); - tensors.insert(router_key, + tensors.insert( + router_key, Array2::from_shape_vec((s[0], s[1]), data) - .map_err(|e| ModelError::Parse(e.to_string()))?.into_shared()); + .map_err(|e| ModelError::Parse(e.to_string()))? + .into_shared(), + ); } } } @@ -336,6 +417,53 @@ fn normalize_key(key: &str, prefixes: &[&str]) -> String { key.to_string() } +fn dequantize_mlx_affine_tensor( + st: &safetensors::SafeTensors, + name: &str, + view: &safetensors::tensor::TensorView<'_>, + group_size: Option, +) -> Result<(Vec, usize, usize), ModelError> { + let group_size = group_size.ok_or_else(|| { + ModelError::Parse(format!( + "missing MLX affine group_size for quantized tensor: {name}" + )) + })?; + + let shape = view.shape(); + if shape.len() != 2 { + return Err(ModelError::UnsupportedDtype(format!("{:?}", view.dtype()))); + } + + let stem = name + .strip_suffix(".weight") + .ok_or_else(|| ModelError::UnsupportedDtype(format!("{:?}", view.dtype())))?; + let scales_name = format!("{stem}.scales"); + let biases_name = format!("{stem}.biases"); + + let scales_view = st + .tensor(&scales_name) + .map_err(|e| ModelError::Parse(format!("MLX affine scales {scales_name}: {e}")))?; + let biases_view = st.tensor(&biases_name).ok(); + + let scales = tensor_to_f32(&scales_view)?; + let biases = match biases_view { + Some(biases_view) => Some(tensor_to_f32(&biases_view)?), + None => None, + }; + + let (data, cols) = crate::quant::mlx_affine::dequantize_u32_matrix_bytes( + view.data(), + shape[0], + shape[1], + &scales, + biases.as_deref(), + group_size, + ) + .map_err(ModelError::Parse)?; + + Ok((data, shape[0], cols)) +} + fn tensor_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result, ModelError> { use crate::quant::half; match view.dtype() { diff --git a/crates/larql-models/src/quant/mlx_affine.rs b/crates/larql-models/src/quant/mlx_affine.rs new file mode 100644 index 00000000..4eedcc18 --- /dev/null +++ b/crates/larql-models/src/quant/mlx_affine.rs @@ -0,0 +1,204 @@ +//! MLX affine quantization — packed U32 weights + per-group scales/biases. + +/// Infer bits-per-weight from packed width and grouping. +pub fn infer_bits(packed_cols: usize, groups: usize, group_size: usize) -> Result { + let packed_bits = packed_cols + .checked_mul(32) + .ok_or_else(|| "packed column count overflow".to_string())?; + let cols = groups + .checked_mul(group_size) + .ok_or_else(|| "group shape overflow".to_string())?; + + if cols == 0 || packed_bits % cols != 0 { + return Err(format!( + "cannot infer MLX affine bits: packed_cols={packed_cols}, groups={groups}, group_size={group_size}" + )); + } + + let bits = packed_bits / cols; + if bits == 0 || bits > 32 { + return Err(format!("invalid MLX affine bit width: {bits}")); + } + Ok(bits) +} + +/// Dequantize an MLX affine quantized 2D weight matrix stored as packed U32. +/// +/// Layout matches `mlx.core.quantize(..., mode="affine")`: +/// - packed codes are concatenated little-endian within each row +/// - `scales` and `biases` are per-row, per-group +/// - dequantization is `bias + scale * code` +pub fn dequantize_u32_matrix_bytes( + packed_bytes: &[u8], + rows: usize, + packed_cols: usize, + scales: &[f32], + biases: Option<&[f32]>, + group_size: usize, +) -> Result<(Vec, usize), String> { + if packed_bytes.len() % 4 != 0 { + return Err("MLX affine packed weights must be U32-aligned".to_string()); + } + + let packed: Vec = packed_bytes + .chunks_exact(4) + .map(|b| u32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(); + + if packed.len() != rows.saturating_mul(packed_cols) { + return Err(format!( + "MLX affine packed size mismatch: got {} words, expected {}", + packed.len(), + rows.saturating_mul(packed_cols) + )); + } + + if rows == 0 { + return Ok((Vec::new(), 0)); + } + + if scales.len() % rows != 0 { + return Err(format!( + "MLX affine scales shape mismatch: {} values for {rows} rows", + scales.len() + )); + } + let groups = scales.len() / rows; + let cols = groups + .checked_mul(group_size) + .ok_or_else(|| "MLX affine output shape overflow".to_string())?; + let bits = infer_bits(packed_cols, groups, group_size)?; + + if let Some(biases) = biases { + if biases.len() != scales.len() { + return Err(format!( + "MLX affine biases shape mismatch: {} values, expected {}", + biases.len(), + scales.len() + )); + } + } + + let mask = if bits == 32 { + u64::MAX + } else { + (1u64 << bits) - 1 + }; + + let mut out = vec![0.0; rows * cols]; + + for row in 0..rows { + let packed_row = &packed[row * packed_cols..(row + 1) * packed_cols]; + let row_scales = &scales[row * groups..(row + 1) * groups]; + let row_biases = biases.map(|b| &b[row * groups..(row + 1) * groups]); + + let mut out_col = 0usize; + let mut acc = 0u64; + let mut acc_bits = 0usize; + + for &word in packed_row { + acc |= (word as u64) << acc_bits; + acc_bits += 32; + + while acc_bits >= bits && out_col < cols { + let group = out_col / group_size; + let code = (acc & mask) as f32; + let scale = row_scales[group]; + let bias = row_biases.map(|b| b[group]).unwrap_or(0.0); + out[row * cols + out_col] = bias + scale * code; + acc >>= bits; + acc_bits -= bits; + out_col += 1; + } + } + + if out_col != cols { + return Err(format!( + "MLX affine unpack ended early: row {row}, decoded {out_col}/{cols} values" + )); + } + } + + Ok((out, cols)) +} + +#[cfg(test)] +mod tests { + use super::dequantize_u32_matrix_bytes; + + fn pack_codes(codes: &[u32], bits: usize) -> Vec { + let mut out = Vec::new(); + let mut acc = 0u64; + let mut acc_bits = 0usize; + + for &code in codes { + acc |= (code as u64) << acc_bits; + acc_bits += bits; + + while acc_bits >= 32 { + out.push((acc & 0xFFFF_FFFF) as u32); + acc >>= 32; + acc_bits -= 32; + } + } + + if acc_bits > 0 { + out.push(acc as u32); + } + + out + } + + fn approx_eq(a: &[f32], b: &[f32]) { + assert_eq!(a.len(), b.len()); + for (idx, (&lhs, &rhs)) in a.iter().zip(b.iter()).enumerate() { + assert!( + (lhs - rhs).abs() < 1e-6, + "mismatch at {idx}: {lhs} vs {rhs}" + ); + } + } + + #[test] + fn dequantizes_affine_u32_for_multiple_bit_widths() { + for bits in [4usize, 5, 6, 8] { + let rows = 2usize; + let groups = 2usize; + let group_size = 64usize; + let cols = groups * group_size; + let max_code = (1u32 << bits) - 1; + + let scales = vec![0.25, -0.5, 1.5, -0.125]; + let biases = vec![1.0, -3.0, 2.5, 7.0]; + + let mut packed = Vec::new(); + let mut expected = Vec::new(); + + for row in 0..rows { + let codes: Vec = (0..cols) + .map(|col| ((row * cols + col) as u32) & max_code) + .collect(); + packed.extend(pack_codes(&codes, bits)); + + for (col, code) in codes.into_iter().enumerate() { + let group = row * groups + (col / group_size); + expected.push(biases[group] + scales[group] * code as f32); + } + } + + let packed_bytes: Vec = packed.iter().flat_map(|w| w.to_le_bytes()).collect(); + let (actual, actual_cols) = dequantize_u32_matrix_bytes( + &packed_bytes, + rows, + cols * bits / 32, + &scales, + Some(&biases), + group_size, + ) + .unwrap(); + + assert_eq!(actual_cols, cols); + approx_eq(&actual, &expected); + } + } +} diff --git a/crates/larql-models/src/quant/mod.rs b/crates/larql-models/src/quant/mod.rs index dacb8bb1..fe1ca700 100644 --- a/crates/larql-models/src/quant/mod.rs +++ b/crates/larql-models/src/quant/mod.rs @@ -8,6 +8,7 @@ //! This module handles data format encoding/decoding only. //! Compute operations (matvec, vecmat, GPU shaders) are in `larql-compute`. -pub mod half; pub mod ggml; +pub mod half; +pub mod mlx_affine; pub mod mxfp4; diff --git a/crates/larql-vindex/src/extract/streaming.rs b/crates/larql-vindex/src/extract/streaming.rs index 7378859a..1dac6b20 100644 --- a/crates/larql-vindex/src/extract/streaming.rs +++ b/crates/larql-vindex/src/extract/streaming.rs @@ -44,6 +44,8 @@ pub fn build_vindex_streaming( .map_err(|e| VindexError::Parse(e.to_string()))?; let prefixes = arch.key_prefixes_to_strip(); let cfg = arch.config(); + let mlx_affine_group_size = + larql_models::loading::safetensors::mlx_affine_group_size_from_config(model_dir); let num_layers = cfg.num_layers; let hidden_size = cfg.hidden_size; @@ -75,18 +77,24 @@ pub fn build_vindex_streaming( } callbacks.on_stage("loading"); - eprintln!(" Streaming mode: {} safetensors shards (mmap'd, not loaded)", st_files.len()); + eprintln!( + " Streaming mode: {} safetensors shards (mmap'd, not loaded)", + st_files.len() + ); // (shards vec was for an earlier design — tensor_index + shard_mmaps is the actual approach) // SAFETY: We need to hold both the mmap and the SafeTensors that borrows from it. // We use a two-phase approach: first mmap all files, then deserialize. // The mmaps are kept alive in `shard_mmaps` for the lifetime of the function. - let shard_mmaps: Vec = st_files.iter().map(|path| { - let file = std::fs::File::open(path).unwrap(); - let mmap = unsafe { memmap2::Mmap::map(&file).unwrap() }; - MmapShard { _file: file, mmap } - }).collect(); + let shard_mmaps: Vec = st_files + .iter() + .map(|path| { + let file = std::fs::File::open(path).unwrap(); + let mmap = unsafe { memmap2::Mmap::map(&file).unwrap() }; + MmapShard { _file: file, mmap } + }) + .collect(); // Build a tensor index: key → (shard_idx, tensor_name) // We need to find which shard contains each tensor. @@ -123,18 +131,21 @@ pub fn build_vindex_streaming( let blocks_key = arch.packed_gate_up_blocks_key(layer).unwrap_or_default(); let scales_key = arch.packed_gate_up_scales_key(layer).unwrap_or_default(); - if let (Some(blocks_info), Some(scales_info)) = ( - tensor_index.get(&blocks_key), - tensor_index.get(&scales_key), - ) { - let blocks_st = safetensors::SafeTensors::deserialize(&shard_mmaps[blocks_info.0].mmap) + if let (Some(blocks_info), Some(scales_info)) = + (tensor_index.get(&blocks_key), tensor_index.get(&scales_key)) + { + let blocks_st = + safetensors::SafeTensors::deserialize(&shard_mmaps[blocks_info.0].mmap) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let scales_st = + safetensors::SafeTensors::deserialize(&shard_mmaps[scales_info.0].mmap) + .map_err(|e| VindexError::Parse(e.to_string()))?; + + let blocks_view = blocks_st + .tensor(&blocks_info.1) .map_err(|e| VindexError::Parse(e.to_string()))?; - let scales_st = safetensors::SafeTensors::deserialize(&shard_mmaps[scales_info.0].mmap) - .map_err(|e| VindexError::Parse(e.to_string()))?; - - let blocks_view = blocks_st.tensor(&blocks_info.1) - .map_err(|e| VindexError::Parse(e.to_string()))?; - let scales_view = scales_st.tensor(&scales_info.1) + let scales_view = scales_st + .tensor(&scales_info.1) .map_err(|e| VindexError::Parse(e.to_string()))?; let shape = blocks_view.shape(); @@ -145,7 +156,11 @@ pub fn build_vindex_streaming( let half = out_features / 2; // gate portion let experts = crate::format::quant::mxfp4::dequantize_all_experts( - blocks_view.data(), scales_view.data(), n_exp, out_features, groups, + blocks_view.data(), + scales_view.data(), + n_exp, + out_features, + groups, ); let mut total_features = 0usize; @@ -160,7 +175,10 @@ pub fn build_vindex_streaming( if total_features > 0 { layer_infos.push(VindexLayerInfo { - layer, num_features: total_features, offset, length: layer_bytes, + layer, + num_features: total_features, + offset, + length: layer_bytes, num_experts: Some(n_exp), num_features_per_expert: Some(half), }); @@ -179,7 +197,12 @@ pub fn build_vindex_streaming( None => continue, }; - if let Some(tensor) = get_tensor_f32(&shard_mmaps, &tensor_index, &gate_key)? { + if let Some(tensor) = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &gate_key, + mlx_affine_group_size, + )? { features_per_expert = tensor.shape()[0]; total_features += features_per_expert; let data = tensor.as_slice().unwrap(); @@ -189,7 +212,10 @@ pub fn build_vindex_streaming( if total_features > 0 { layer_infos.push(VindexLayerInfo { - layer, num_features: total_features, offset, length: layer_bytes, + layer, + num_features: total_features, + offset, + length: layer_bytes, num_experts: Some(n_experts), num_features_per_expert: Some(features_per_expert), }); @@ -198,13 +224,22 @@ pub fn build_vindex_streaming( } else { // Dense: single gate matrix per layer let gate_key = normalize_key(&arch.ffn_gate_key(layer), prefixes); - if let Some(tensor) = get_tensor_f32(&shard_mmaps, &tensor_index, &gate_key)? { + if let Some(tensor) = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &gate_key, + mlx_affine_group_size, + )? { let num_features = tensor.shape()[0]; let data = tensor.as_slice().unwrap(); let length = write_floats(&mut gate_file, data, dtype)?; layer_infos.push(VindexLayerInfo { - layer, num_features, offset, length, - num_experts: None, num_features_per_expert: None, + layer, + num_features, + offset, + length, + num_experts: None, + num_features_per_expert: None, }); offset += length; } @@ -222,11 +257,17 @@ pub fn build_vindex_streaming( let mut router_file = BufWriter::new(std::fs::File::create(&router_path)?); for layer in 0..num_layers { - let router_key = arch.moe_router_key(layer) + let router_key = arch + .moe_router_key(layer) .map(|k| normalize_key(&k, prefixes)) .unwrap_or_default(); - if let Some(tensor) = get_tensor_f32(&shard_mmaps, &tensor_index, &router_key)? { + if let Some(tensor) = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &router_key, + mlx_affine_group_size, + )? { let data = tensor.as_slice().unwrap(); let bytes = crate::config::dtype::encode_floats(data, dtype); router_file.write_all(&bytes)?; @@ -234,7 +275,12 @@ pub fn build_vindex_streaming( // Also try router bias let bias_key = router_key.replace(".weight", ".bias"); - if let Some(tensor) = get_tensor_f32(&shard_mmaps, &tensor_index, &bias_key)? { + if let Some(tensor) = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &bias_key, + mlx_affine_group_size, + )? { let data = tensor.as_slice().unwrap(); let bytes = crate::config::dtype::encode_floats(data, dtype); // Write bias after weight for each layer @@ -248,8 +294,13 @@ pub fn build_vindex_streaming( // ── 2. Embeddings ── callbacks.on_stage("embeddings"); let embed_key = normalize_key(arch.embed_key(), prefixes); - let embed = get_tensor_f32(&shard_mmaps, &tensor_index, &embed_key)? - .ok_or_else(|| VindexError::MissingTensor(embed_key.clone()))?; + let embed = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &embed_key, + mlx_affine_group_size, + )? + .ok_or_else(|| VindexError::MissingTensor(embed_key.clone()))?; let vocab_size = embed.shape()[0]; let embed_data = embed.as_slice().unwrap(); let embed_bytes = crate::config::dtype::encode_floats(embed_data, dtype); @@ -261,44 +312,61 @@ pub fn build_vindex_streaming( let mut all_down_meta: Vec>>> = vec![None; num_layers]; // Build whole-word vocab once - let (_ww_ids, _ww_embed) = super::build::build_whole_word_vocab(tokenizer, &embed, vocab_size, hidden_size); + let (_ww_ids, _ww_embed) = + super::build::build_whole_word_vocab(tokenizer, &embed, vocab_size, hidden_size); for (layer, layer_down_meta) in all_down_meta.iter_mut().enumerate().take(num_layers) { callbacks.on_layer_start("down", layer, num_layers); let start = std::time::Instant::now(); // Get down matrices for this layer - let down_matrices: Vec> = if expert_format == larql_models::ExpertFormat::PackedMxfp4 { + let down_matrices: Vec> = if expert_format + == larql_models::ExpertFormat::PackedMxfp4 + { // MXFP4: dequantize down_proj_blocks let blocks_key = arch.packed_down_blocks_key(layer).unwrap_or_default(); let scales_key = arch.packed_down_scales_key(layer).unwrap_or_default(); - if let (Some(bi), Some(si)) = (tensor_index.get(&blocks_key), tensor_index.get(&scales_key)) { + if let (Some(bi), Some(si)) = + (tensor_index.get(&blocks_key), tensor_index.get(&scales_key)) + { let bst = safetensors::SafeTensors::deserialize(&shard_mmaps[bi.0].mmap) .map_err(|e| VindexError::Parse(e.to_string()))?; let sst = safetensors::SafeTensors::deserialize(&shard_mmaps[si.0].mmap) .map_err(|e| VindexError::Parse(e.to_string()))?; - let bv = bst.tensor(&bi.1).map_err(|e| VindexError::Parse(e.to_string()))?; - let sv = sst.tensor(&si.1).map_err(|e| VindexError::Parse(e.to_string()))?; + let bv = bst + .tensor(&bi.1) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let sv = sst + .tensor(&si.1) + .map_err(|e| VindexError::Parse(e.to_string()))?; let shape = bv.shape(); let n_exp = shape[0]; let out_features = shape[1]; let groups = shape[2]; let in_features = groups * 32; let experts = crate::format::quant::mxfp4::dequantize_all_experts( - bv.data(), sv.data(), n_exp, out_features, groups, + bv.data(), + sv.data(), + n_exp, + out_features, + groups, ); - experts.into_iter().map(|data| { - Array2::from_shape_vec((out_features, in_features), data).unwrap() - }).collect() + experts + .into_iter() + .map(|data| Array2::from_shape_vec((out_features, in_features), data).unwrap()) + .collect() } else { - callbacks.on_layer_done("down", layer, 0.0); continue; + callbacks.on_layer_done("down", layer, 0.0); + continue; } } else if is_moe && n_experts > 0 { let mut mats = Vec::new(); for expert in 0..n_experts { if let Some(key) = arch.expert_ffn_down_key(layer, expert) { let nk = normalize_key(&key, prefixes); - if let Some(t) = get_tensor_f32(&shard_mmaps, &tensor_index, &nk)? { + if let Some(t) = + get_tensor_f32(&shard_mmaps, &tensor_index, &nk, mlx_affine_group_size)? + { mats.push(t); } } @@ -306,9 +374,17 @@ pub fn build_vindex_streaming( mats } else { let down_key = normalize_key(&arch.ffn_down_key(layer), prefixes); - match get_tensor_f32(&shard_mmaps, &tensor_index, &down_key)? { + match get_tensor_f32( + &shard_mmaps, + &tensor_index, + &down_key, + mlx_affine_group_size, + )? { Some(t) => vec![t], - None => { callbacks.on_layer_done("down", layer, 0.0); continue; } + None => { + callbacks.on_layer_done("down", layer, 0.0); + continue; + } } }; @@ -324,10 +400,16 @@ pub fn build_vindex_streaming( for batch_start in (0..num_features).step_by(batch_size) { let batch_end = (batch_start + batch_size).min(num_features); - callbacks.on_feature_progress("down", layer, feature_offset + batch_start, - down_matrices.iter().map(|m| m.shape()[1]).sum()); + callbacks.on_feature_progress( + "down", + layer, + feature_offset + batch_start, + down_matrices.iter().map(|m| m.shape()[1]).sum(), + ); - let w_chunk = w_down.slice(ndarray::s![.., batch_start..batch_end]).to_owned(); + let w_chunk = w_down + .slice(ndarray::s![.., batch_start..batch_end]) + .to_owned(); let cpu = larql_compute::CpuBackend; use larql_compute::ComputeBackend; let chunk_logits = cpu.matmul(embed.view(), w_chunk.view()); @@ -342,29 +424,42 @@ pub fn build_vindex_streaming( scores.truncate(k); scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); - let top_k_entries: Vec = scores.into_iter() + let top_k_entries: Vec = scores + .into_iter() .filter_map(|(idx, logit)| { - tokenizer.decode(&[idx as u32], true).ok() + tokenizer + .decode(&[idx as u32], true) + .ok() .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) - .map(|token| larql_models::TopKEntry { token, token_id: idx as u32, logit }) + .map(|token| larql_models::TopKEntry { + token, + token_id: idx as u32, + logit, + }) }) .collect(); - let (top_token, top_token_id, c_score) = if let Some(first) = top_k_entries.first() { - (first.token.clone(), first.token_id, first.logit) - } else { - (String::new(), 0, 0.0) - }; + let (top_token, top_token_id, c_score) = + if let Some(first) = top_k_entries.first() { + (first.token.clone(), first.token_id, first.logit) + } else { + (String::new(), 0, 0.0) + }; let feat_idx = feature_offset + feat; if layer_down_meta.is_none() { *layer_down_meta = Some(Vec::new()); } if let Some(ref mut metas) = layer_down_meta { - while metas.len() <= feat_idx { metas.push(None); } + while metas.len() <= feat_idx { + metas.push(None); + } metas[feat_idx] = Some(crate::FeatureMeta { - top_token, top_token_id, c_score, top_k: top_k_entries, + top_token, + top_token_id, + c_score, + top_k: top_k_entries, }); } } @@ -380,7 +475,8 @@ pub fn build_vindex_streaming( // ── 4. Tokenizer ── callbacks.on_stage("tokenizer"); - let tokenizer_json = tokenizer.to_string(true) + let tokenizer_json = tokenizer + .to_string(true) .map_err(|e| VindexError::Parse(format!("tokenizer serialize: {e}")))?; std::fs::write(output_dir.join("tokenizer.json"), tokenizer_json)?; callbacks.on_stage_done("tokenizer", 0.0); @@ -391,7 +487,10 @@ pub fn build_vindex_streaming( version: 2, model: model_name.to_string(), family: family.clone(), - num_layers, hidden_size, intermediate_size, vocab_size, + num_layers, + hidden_size, + intermediate_size, + vocab_size, embed_scale, layers: layer_infos, down_top_k, @@ -421,7 +520,9 @@ pub fn build_vindex_streaming( shared_expert: arch.num_shared_experts() > 0, router_type: "top_k_softmax".to_string(), }) - } else { None }, + } else { + None + }, // Per-layer geometry (Gemma 4) global_head_dim: cfg.global_head_dim, num_global_kv_heads: cfg.num_global_kv_heads, @@ -437,8 +538,8 @@ pub fn build_vindex_streaming( }; // Write preliminary index.json (needed by write_model_weights which reads dtype from it) - let config_json = serde_json::to_string_pretty(&config) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let config_json = + serde_json::to_string_pretty(&config).map_err(|e| VindexError::Parse(e.to_string()))?; std::fs::write(output_dir.join("index.json"), config_json)?; // ── 6. Model weights (if extract level requires them) ── @@ -449,6 +550,7 @@ pub fn build_vindex_streaming( tensor_index: &tensor_index, arch: &*arch, num_layers, + mlx_affine_group_size, }; crate::format::weights::write_model_weights(&streaming_source, output_dir, callbacks)?; // write_model_weights updates index.json with has_model_weights=true @@ -456,11 +558,11 @@ pub fn build_vindex_streaming( // Final checksums let config_text = std::fs::read_to_string(output_dir.join("index.json"))?; - let mut config: VindexConfig = serde_json::from_str(&config_text) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let mut config: VindexConfig = + serde_json::from_str(&config_text).map_err(|e| VindexError::Parse(e.to_string()))?; config.checksums = crate::format::checksums::compute_checksums(output_dir).ok(); - let config_json = serde_json::to_string_pretty(&config) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let config_json = + serde_json::to_string_pretty(&config).map_err(|e| VindexError::Parse(e.to_string()))?; std::fs::write(output_dir.join("index.json"), config_json)?; Ok(()) @@ -471,6 +573,7 @@ fn get_tensor_f32( shards: &[MmapShard], index: &HashMap, key: &str, + mlx_affine_group_size: Option, ) -> Result>, VindexError> { let (shard_idx, tensor_name) = match index.get(key) { Some(v) => v, @@ -480,20 +583,36 @@ fn get_tensor_f32( let st = safetensors::SafeTensors::deserialize(&shards[*shard_idx].mmap) .map_err(|e| VindexError::Parse(e.to_string()))?; - let view = st.tensor(tensor_name) + let view = st + .tensor(tensor_name) .map_err(|e| VindexError::Parse(e.to_string()))?; let shape = view.shape(); - if shape.len() != 2 { return Ok(None); } + if shape.len() != 2 { + return Ok(None); + } let data = match view.dtype() { - safetensors::Dtype::F32 => { - view.data().chunks_exact(4) - .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) - .collect() - } + safetensors::Dtype::F32 => view + .data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(), safetensors::Dtype::F16 => crate::format::quant::half::decode_f16(view.data()), safetensors::Dtype::BF16 => crate::format::quant::half::decode_bf16(view.data()), + safetensors::Dtype::U32 => { + let Some(group_size) = mlx_affine_group_size else { + return Ok(None); + }; + let Some((data, cols)) = + dequantize_mlx_affine_tensor(shards, index, key, shape, view.data(), group_size)? + else { + return Ok(None); + }; + let arr = Array2::from_shape_vec((shape[0], cols), data) + .map_err(|e| VindexError::Parse(e.to_string()))?; + return Ok(Some(arr)); + } _ => return Ok(None), // skip non-float }; @@ -502,6 +621,71 @@ fn get_tensor_f32( Ok(Some(arr)) } +fn dequantize_mlx_affine_tensor( + shards: &[MmapShard], + index: &HashMap, + key: &str, + weight_shape: &[usize], + packed: &[u8], + group_size: usize, +) -> Result, usize)>, VindexError> { + let Some(stem) = key.strip_suffix(".weight") else { + return Ok(None); + }; + + let scales_key = format!("{stem}.scales"); + let biases_key = format!("{stem}.biases"); + + let (scales_shard, scales_name) = match index.get(&scales_key) { + Some(v) => v, + None => return Ok(None), + }; + let scales_st = safetensors::SafeTensors::deserialize(&shards[*scales_shard].mmap) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let scales_view = scales_st + .tensor(scales_name) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let scales = tensor_view_to_f32(&scales_view)?; + + let biases = if let Some((bias_shard, bias_name)) = index.get(&biases_key) { + let biases_st = safetensors::SafeTensors::deserialize(&shards[*bias_shard].mmap) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let biases_view = biases_st + .tensor(bias_name) + .map_err(|e| VindexError::Parse(e.to_string()))?; + Some(tensor_view_to_f32(&biases_view)?) + } else { + None + }; + + let (data, cols) = crate::format::quant::mlx_affine::dequantize_u32_matrix_bytes( + packed, + weight_shape[0], + weight_shape[1], + &scales, + biases.as_deref(), + group_size, + ) + .map_err(VindexError::Parse)?; + + Ok(Some((data, cols))) +} + +fn tensor_view_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result, VindexError> { + match view.dtype() { + safetensors::Dtype::F32 => Ok(view + .data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect()), + safetensors::Dtype::F16 => Ok(crate::format::quant::half::decode_f16(view.data())), + safetensors::Dtype::BF16 => Ok(crate::format::quant::half::decode_bf16(view.data())), + other => Err(VindexError::Parse(format!( + "unsupported tensor dtype: {other:?}" + ))), + } +} + fn normalize_key(key: &str, prefixes: &[&str]) -> String { for prefix in prefixes { if let Some(stripped) = key.strip_prefix(prefix) { diff --git a/crates/larql-vindex/src/format/quant/mod.rs b/crates/larql-vindex/src/format/quant/mod.rs index 6d82a79f..7e6b5a34 100644 --- a/crates/larql-vindex/src/format/quant/mod.rs +++ b/crates/larql-vindex/src/format/quant/mod.rs @@ -1,5 +1,6 @@ //! Quantization and dequantization — re-exports from larql-models. -pub use larql_models::quant::half; pub use larql_models::quant::ggml; +pub use larql_models::quant::half; +pub use larql_models::quant::mlx_affine; pub use larql_models::quant::mxfp4; diff --git a/crates/larql-vindex/src/format/weights.rs b/crates/larql-vindex/src/format/weights.rs index c35842aa..fc27ef46 100644 --- a/crates/larql-vindex/src/format/weights.rs +++ b/crates/larql-vindex/src/format/weights.rs @@ -18,11 +18,11 @@ use std::path::Path; use ndarray::Array2; use serde::{Deserialize, Serialize}; +use crate::config::{VindexConfig, VindexModelConfig}; use crate::error::VindexError; use crate::extract::callbacks::IndexBuildCallbacks; -use crate::config::{VindexConfig, VindexModelConfig}; -use crate::index::core::IndexLoadCallbacks; use crate::format::load::load_vindex_config; +use crate::index::core::IndexLoadCallbacks; use larql_models::ModelWeights; @@ -102,6 +102,7 @@ pub struct StreamingWeights<'a> { pub tensor_index: &'a HashMap, pub arch: &'a dyn larql_models::ModelArchitecture, pub num_layers: usize, + pub mlx_affine_group_size: Option, } impl<'a> StreamingWeights<'a> { @@ -109,32 +110,208 @@ impl<'a> StreamingWeights<'a> { let (shard_idx, tensor_name) = self.tensor_index.get(key)?; let st = safetensors::SafeTensors::deserialize(self.shard_mmaps[*shard_idx]).ok()?; let view = st.tensor(tensor_name).ok()?; - let shape = view.shape().to_vec(); + let mut shape = view.shape().to_vec(); let data = match view.dtype() { - safetensors::Dtype::F32 => { - view.data().chunks_exact(4) - .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) - .collect() - } + safetensors::Dtype::F32 => view + .data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(), safetensors::Dtype::F16 => crate::format::quant::half::decode_f16(view.data()), safetensors::Dtype::BF16 => crate::format::quant::half::decode_bf16(view.data()), + safetensors::Dtype::U32 if shape.len() == 2 => { + let group_size = self.mlx_affine_group_size?; + let stem = key.strip_suffix(".weight")?; + + let (scales_shard, scales_name) = + self.tensor_index.get(&format!("{stem}.scales"))?; + let scales_st = + safetensors::SafeTensors::deserialize(self.shard_mmaps[*scales_shard]).ok()?; + let scales_view = scales_st.tensor(scales_name).ok()?; + let scales = tensor_view_to_f32(&scales_view).ok()?; + + let biases = if let Some((bias_shard, bias_name)) = + self.tensor_index.get(&format!("{stem}.biases")) + { + let bias_st = + safetensors::SafeTensors::deserialize(self.shard_mmaps[*bias_shard]) + .ok()?; + let bias_view = bias_st.tensor(bias_name).ok()?; + Some(tensor_view_to_f32(&bias_view).ok()?) + } else { + None + }; + + let (data, cols) = crate::format::quant::mlx_affine::dequantize_u32_matrix_bytes( + view.data(), + shape[0], + shape[1], + &scales, + biases.as_deref(), + group_size, + ) + .ok()?; + shape[1] = cols; + data + } _ => return None, }; Some((data, shape)) } } +fn tensor_view_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result, VindexError> { + match view.dtype() { + safetensors::Dtype::F32 => Ok(view + .data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect()), + safetensors::Dtype::F16 => Ok(crate::format::quant::half::decode_f16(view.data())), + safetensors::Dtype::BF16 => Ok(crate::format::quant::half::decode_bf16(view.data())), + other => Err(VindexError::Parse(format!( + "unsupported tensor dtype: {other:?}" + ))), + } +} + +#[cfg(test)] +mod tests { + use super::StreamingWeights; + use crate::format::weights::WeightSource; + use std::collections::HashMap; + + fn pack_codes(codes: &[u32], bits: usize) -> Vec { + let mut out = Vec::new(); + let mut acc = 0u64; + let mut acc_bits = 0usize; + + for &code in codes { + acc |= (code as u64) << acc_bits; + acc_bits += bits; + + while acc_bits >= 32 { + out.push((acc & 0xFFFF_FFFF) as u32); + acc >>= 32; + acc_bits -= 32; + } + } + + if acc_bits > 0 { + out.push(acc as u32); + } + + out + } + + #[test] + fn streaming_weights_dequantize_mlx_affine_u32() { + let bits = 4usize; + let group_size = 64usize; + let codes: Vec = (0..group_size).map(|i| (i as u32) & 0xF).collect(); + let packed = pack_codes(&codes, bits); + let packed_bytes: Vec = packed.iter().flat_map(|w| w.to_le_bytes()).collect(); + let scales = vec![0.5f32]; + let biases = vec![-2.0f32]; + let scale_bytes: Vec = scales.iter().flat_map(|v| v.to_le_bytes()).collect(); + let bias_bytes: Vec = biases.iter().flat_map(|v| v.to_le_bytes()).collect(); + + let tensors = [ + ( + "embed_tokens.weight".to_string(), + safetensors::tensor::TensorView::new( + safetensors::Dtype::U32, + vec![1, packed.len()], + &packed_bytes, + ) + .unwrap(), + ), + ( + "embed_tokens.scales".to_string(), + safetensors::tensor::TensorView::new( + safetensors::Dtype::F32, + vec![1, 1], + &scale_bytes, + ) + .unwrap(), + ), + ( + "embed_tokens.biases".to_string(), + safetensors::tensor::TensorView::new( + safetensors::Dtype::F32, + vec![1, 1], + &bias_bytes, + ) + .unwrap(), + ), + ]; + let serialized = safetensors::tensor::serialize(tensors, &None).unwrap(); + + let mut tensor_index = HashMap::new(); + tensor_index.insert( + "embed_tokens.weight".to_string(), + (0usize, "embed_tokens.weight".to_string()), + ); + tensor_index.insert( + "embed_tokens.scales".to_string(), + (0usize, "embed_tokens.scales".to_string()), + ); + tensor_index.insert( + "embed_tokens.biases".to_string(), + (0usize, "embed_tokens.biases".to_string()), + ); + + let arch = larql_models::detect_from_json(&serde_json::json!({ + "model_type": "gemma4", + "text_config": { + "model_type": "gemma4_text", + "num_hidden_layers": 1, + "hidden_size": 64, + "intermediate_size": 64, + "head_dim": 8, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "vocab_size": 1 + } + })); + + let shards = [serialized.as_slice()]; + let source = StreamingWeights { + shard_mmaps: &shards, + tensor_index: &tensor_index, + arch: &*arch, + num_layers: 1, + mlx_affine_group_size: Some(group_size), + }; + + let (data, rows, cols) = source.get_tensor("embed_tokens.weight").unwrap(); + assert_eq!((rows, cols), (1, group_size)); + + for (idx, value) in data.iter().enumerate() { + let expected = -2.0 + 0.5 * codes[idx] as f32; + assert!( + (value - expected).abs() < 1e-6, + "mismatch at {idx}: {value} vs {expected}" + ); + } + } +} + impl<'a> WeightSource for StreamingWeights<'a> { fn get_tensor(&self, key: &str) -> Option<(Vec, usize, usize)> { let (data, shape) = self.read_tensor_raw(key)?; - if shape.len() != 2 { return None; } + if shape.len() != 2 { + return None; + } Some((data, shape[0], shape[1])) } fn get_vector(&self, key: &str) -> Option> { let (data, shape) = self.read_tensor_raw(key)?; - if shape.len() != 1 { return None; } + if shape.len() != 1 { + return None; + } Some(data) } @@ -207,9 +384,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(key) { let len = write_floats(&mut attn_file, &data, dtype)?; entries.push(WeightEntry { - key: key.clone(), kind: "tensor".into(), + key: key.clone(), + kind: "tensor".into(), shape: vec![rows, cols], - offset: attn_offset, length: len, + offset: attn_offset, + length: len, file: "attn_weights.bin".into(), }); attn_offset += len; @@ -217,14 +396,19 @@ pub fn write_model_weights( } // QK norms (1D vectors, stored alongside attention) - for key in [arch.attn_q_norm_key(layer), arch.attn_k_norm_key(layer)].iter().flatten() { + for key in [arch.attn_q_norm_key(layer), arch.attn_k_norm_key(layer)] + .iter() + .flatten() + { if let Some(data) = source.get_vector(key) { let bytes = crate::config::dtype::encode_floats(&data, dtype); attn_file.write_all(&bytes)?; entries.push(WeightEntry { - key: key.clone(), kind: "vector".into(), + key: key.clone(), + kind: "vector".into(), shape: vec![data.len()], - offset: attn_offset, length: bytes.len() as u64, + offset: attn_offset, + length: bytes.len() as u64, file: "attn_weights.bin".into(), }); attn_offset += bytes.len() as u64; @@ -253,9 +437,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&key) { let len = write_floats(&mut up_file, &data, dtype)?; entries.push(WeightEntry { - key, kind: "tensor".into(), + key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: up_offset, length: len, + offset: up_offset, + length: len, file: "up_weights.bin".into(), }); up_offset += len; @@ -265,9 +451,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&key) { let len = write_floats(&mut down_file, &data, dtype)?; entries.push(WeightEntry { - key, kind: "tensor".into(), + key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: down_offset, length: len, + offset: down_offset, + length: len, file: "down_weights.bin".into(), }); down_offset += len; @@ -278,9 +466,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&key) { let len = write_floats(&mut up_file, &data, dtype)?; entries.push(WeightEntry { - key, kind: "tensor".into(), + key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: up_offset, length: len, + offset: up_offset, + length: len, file: "up_weights.bin".into(), }); up_offset += len; @@ -291,9 +481,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&up_key) { let len = write_floats(&mut up_file, &data, dtype)?; entries.push(WeightEntry { - key: up_key, kind: "tensor".into(), + key: up_key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: up_offset, length: len, + offset: up_offset, + length: len, file: "up_weights.bin".into(), }); up_offset += len; @@ -303,9 +495,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&down_key) { let len = write_floats(&mut down_file, &data, dtype)?; entries.push(WeightEntry { - key: down_key, kind: "tensor".into(), + key: down_key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: down_offset, length: len, + offset: down_offset, + length: len, file: "down_weights.bin".into(), }); down_offset += len; @@ -329,16 +523,21 @@ pub fn write_model_weights( Some(arch.post_attention_layernorm_key(layer)), arch.pre_feedforward_layernorm_key(layer), arch.post_feedforward_layernorm_key(layer), - ].into_iter().flatten().collect(); + ] + .into_iter() + .flatten() + .collect(); for key in norm_keys { if let Some(data) = source.get_vector(&key) { let bytes = crate::config::dtype::encode_floats(&data, dtype); norms_file.write_all(&bytes)?; entries.push(WeightEntry { - key, kind: "vector".into(), + key, + kind: "vector".into(), shape: vec![data.len()], - offset: norms_offset, length: bytes.len() as u64, + offset: norms_offset, + length: bytes.len() as u64, file: "norms.bin".into(), }); norms_offset += bytes.len() as u64; @@ -351,9 +550,11 @@ pub fn write_model_weights( let bytes = crate::config::dtype::encode_floats(&data, dtype); norms_file.write_all(&bytes)?; entries.push(WeightEntry { - key: "norm.weight".into(), kind: "vector".into(), + key: "norm.weight".into(), + kind: "vector".into(), shape: vec![data.len()], - offset: norms_offset, length: bytes.len() as u64, + offset: norms_offset, + length: bytes.len() as u64, file: "norms.bin".into(), }); } @@ -364,23 +565,25 @@ pub fn write_model_weights( let lm_bytes = crate::config::dtype::encode_floats(&data, dtype); std::fs::write(dir.join("lm_head.bin"), &lm_bytes)?; entries.push(WeightEntry { - key: "lm_head.weight".into(), kind: "tensor".into(), + key: "lm_head.weight".into(), + kind: "tensor".into(), shape: vec![rows, cols], - offset: 0, length: lm_bytes.len() as u64, + offset: 0, + length: lm_bytes.len() as u64, file: "lm_head.bin".into(), }); } // ── Manifest ── - let manifest_json = serde_json::to_string_pretty(&entries) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let manifest_json = + serde_json::to_string_pretty(&entries).map_err(|e| VindexError::Parse(e.to_string()))?; std::fs::write(dir.join("weight_manifest.json"), manifest_json)?; // ── Update index.json ── let config_path = dir.join("index.json"); let config_text = std::fs::read_to_string(&config_path)?; - let mut config: VindexConfig = serde_json::from_str(&config_text) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let mut config: VindexConfig = + serde_json::from_str(&config_text).map_err(|e| VindexError::Parse(e.to_string()))?; config.has_model_weights = true; @@ -415,15 +618,19 @@ pub fn write_model_weights( query_pre_attn_scalar: cfg.query_pre_attn_scalar, }); - let config_json = serde_json::to_string_pretty(&config) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let config_json = + serde_json::to_string_pretty(&config).map_err(|e| VindexError::Parse(e.to_string()))?; std::fs::write(&config_path, config_json)?; callbacks.on_stage_done("model_weights", start.elapsed().as_secs_f64() * 1000.0); Ok(()) } -fn write_floats(w: &mut impl Write, data: &[f32], dtype: crate::config::dtype::StorageDtype) -> Result { +fn write_floats( + w: &mut impl Write, + data: &[f32], + dtype: crate::config::dtype::StorageDtype, +) -> Result { let bytes = crate::config::dtype::encode_floats(data, dtype); w.write_all(&bytes)?; Ok(bytes.len() as u64) @@ -444,9 +651,10 @@ pub fn load_model_weights( )); } - let model_cfg = config.model_config.as_ref().ok_or_else(|| { - VindexError::Parse("vindex missing model_config in index.json".into()) - })?; + let model_cfg = config + .model_config + .as_ref() + .ok_or_else(|| VindexError::Parse("vindex missing model_config in index.json".into()))?; // Reconstruct full architecture config — includes per-layer geometry for Gemma 4. let mut arch_obj = serde_json::json!({ @@ -463,19 +671,45 @@ pub fn load_model_weights( }); // Pass through Gemma 4 per-layer geometry fields (if present in vindex config). let obj = arch_obj.as_object_mut().unwrap(); - if let Some(v) = model_cfg.global_head_dim { obj.insert("global_head_dim".into(), v.into()); } - if let Some(v) = model_cfg.num_global_kv_heads { obj.insert("num_global_key_value_heads".into(), v.into()); } - if let Some(v) = model_cfg.partial_rotary_factor { obj.insert("partial_rotary_factor".into(), v.into()); } - if let Some(v) = model_cfg.sliding_window_pattern { obj.insert("sliding_window_pattern".into(), v.into()); } - if let Some(ref v) = model_cfg.layer_types { obj.insert("layer_types".into(), serde_json::to_value(v).unwrap_or_default()); } - if model_cfg.attention_k_eq_v { obj.insert("attention_k_eq_v".into(), true.into()); } - if let Some(v) = model_cfg.num_kv_shared_layers { obj.insert("num_kv_shared_layers".into(), v.into()); } - if let Some(v) = model_cfg.per_layer_embed_dim { obj.insert("hidden_size_per_layer_input".into(), v.into()); } - if let Some(v) = model_cfg.rope_local_base { obj.insert("rope_local_base_freq".into(), v.into()); } - if let Some(v) = model_cfg.query_pre_attn_scalar { obj.insert("query_pre_attn_scalar".into(), v.into()); } + if let Some(v) = model_cfg.global_head_dim { + obj.insert("global_head_dim".into(), v.into()); + } + if let Some(v) = model_cfg.num_global_kv_heads { + obj.insert("num_global_key_value_heads".into(), v.into()); + } + if let Some(v) = model_cfg.partial_rotary_factor { + obj.insert("partial_rotary_factor".into(), v.into()); + } + if let Some(v) = model_cfg.sliding_window_pattern { + obj.insert("sliding_window_pattern".into(), v.into()); + } + if let Some(ref v) = model_cfg.layer_types { + obj.insert( + "layer_types".into(), + serde_json::to_value(v).unwrap_or_default(), + ); + } + if model_cfg.attention_k_eq_v { + obj.insert("attention_k_eq_v".into(), true.into()); + } + if let Some(v) = model_cfg.num_kv_shared_layers { + obj.insert("num_kv_shared_layers".into(), v.into()); + } + if let Some(v) = model_cfg.per_layer_embed_dim { + obj.insert("hidden_size_per_layer_input".into(), v.into()); + } + if let Some(v) = model_cfg.rope_local_base { + obj.insert("rope_local_base_freq".into(), v.into()); + } + if let Some(v) = model_cfg.query_pre_attn_scalar { + obj.insert("query_pre_attn_scalar".into(), v.into()); + } let arch = larql_models::detect_from_json(&arch_obj); - callbacks.on_file_start("embeddings", &dir.join("embeddings.bin").display().to_string()); + callbacks.on_file_start( + "embeddings", + &dir.join("embeddings.bin").display().to_string(), + ); let embed_file = std::fs::File::open(dir.join("embeddings.bin"))?; let embed_mmap = unsafe { memmap2::Mmap::map(&embed_file)? }; // Detect actual dtype from file size (may differ from index.json global dtype) @@ -497,8 +731,8 @@ pub fn load_model_weights( callbacks.on_file_start("model_weights", "weight_manifest.json"); let manifest_text = std::fs::read_to_string(&manifest_path)?; - let entries: Vec = serde_json::from_str(&manifest_text) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let entries: Vec = + serde_json::from_str(&manifest_text).map_err(|e| VindexError::Parse(e.to_string()))?; let mut mmap_cache: HashMap = HashMap::new(); let mut tensors: HashMap = HashMap::new(); @@ -506,7 +740,11 @@ pub fn load_model_weights( let mut lm_head_loaded: Option = None; for entry in &entries { - let filename = if entry.file.is_empty() { "model_weights.bin".to_string() } else { entry.file.clone() }; + let filename = if entry.file.is_empty() { + "model_weights.bin".to_string() + } else { + entry.file.clone() + }; if !mmap_cache.contains_key(&filename) { let fpath = dir.join(&filename); @@ -522,11 +760,15 @@ pub fn load_model_weights( Some(m) => m.as_ref(), None => continue, }; - if data.is_empty() { continue; } + if data.is_empty() { + continue; + } let byte_offset = entry.offset as usize; let byte_count = entry.length as usize; - if byte_offset + byte_count > data.len() { continue; } + if byte_offset + byte_count > data.len() { + continue; + } let raw_bytes = &data[byte_offset..byte_offset + byte_count]; // Detect actual dtype from byte count vs expected shape. // Gate vector conversion may have changed index.json dtype to f32 @@ -568,9 +810,9 @@ pub fn load_model_weights( let float_count = info.num_features * config.hidden_size; if float_offset + float_count <= gate_floats.len() { let gate_data = &gate_floats[float_offset..float_offset + float_count]; - let gate_matrix = Array2::from_shape_vec( - (info.num_features, config.hidden_size), gate_data.to_vec(), - ).map_err(|e| VindexError::Parse(e.to_string()))?; + let gate_matrix = + Array2::from_shape_vec((info.num_features, config.hidden_size), gate_data.to_vec()) + .map_err(|e| VindexError::Parse(e.to_string()))?; tensors.insert(arch.ffn_gate_key(info.layer), gate_matrix.into_shared()); } } @@ -582,7 +824,10 @@ pub fn load_model_weights( let lm_head = lm_head_loaded.unwrap_or_else(|| embed.clone()); Ok(ModelWeights { - tensors, vectors, embed, lm_head, + tensors, + vectors, + embed, + lm_head, num_layers: cfg.num_layers, hidden_size: cfg.hidden_size, intermediate_size: cfg.intermediate_size, @@ -598,10 +843,14 @@ pub fn load_model_weights( /// Find the tokenizer path near a model or vindex directory. pub fn find_tokenizer_path(dir: &Path) -> Option { let p = dir.join("tokenizer.json"); - if p.exists() { return Some(p); } + if p.exists() { + return Some(p); + } if let Some(parent) = dir.parent() { let p = parent.join("tokenizer.json"); - if p.exists() { return Some(p); } + if p.exists() { + return Some(p); + } } None }