diff --git a/crates/larql-models/src/loading/safetensors.rs b/crates/larql-models/src/loading/safetensors.rs index 0212cfe6..edce25c8 100644 --- a/crates/larql-models/src/loading/safetensors.rs +++ b/crates/larql-models/src/loading/safetensors.rs @@ -8,8 +8,8 @@ use std::path::{Path, PathBuf}; use ndarray::Array2; -use crate::weights::ModelWeights; use crate::detect::ModelError; +use crate::weights::ModelWeights; /// Load model weights from a directory or file. /// @@ -43,16 +43,17 @@ pub fn load_model_dir(path: impl AsRef) -> Result = std::fs::read_dir(path)? .filter_map(|e| e.ok()) @@ -97,7 +98,9 @@ pub fn load_model_dir(path: impl AsRef) -> Result d, Err(_) => continue, @@ -108,18 +111,48 @@ pub fn load_model_dir(path: impl AsRef) -> Result { vectors.insert(key, data); } + 1 => { + vectors.insert(key, data); + } _ => {} } } } else { // Standard float path for (name, view) in st.tensors() { + if name.ends_with(".scales") || name.ends_with(".biases") { + continue; + } let key = normalize_key(&name, prefixes); let shape = view.shape(); - let data = match tensor_to_f32(&view) { - Ok(d) => d, - Err(_) => continue, + let data = match view.dtype() { + safetensors::Dtype::U32 if shape.len() == 2 => { + match dequantize_mlx_affine_tensor(&st, &name, &view, mlx_affine_group_size) + { + Ok((data, _, cols)) => { + // Replace packed width with dequantized width. + let mut dequant_shape = shape.to_vec(); + dequant_shape[1] = cols; + match dequant_shape.len() { + 2 => { + let arr = Array2::from_shape_vec( + (dequant_shape[0], dequant_shape[1]), + data, + ) + .map_err(|e| ModelError::Parse(e.to_string()))?; + tensors.insert(key, arr.into_shared()); + continue; + } + _ => continue, + } + } + Err(_) => continue, + } + } + _ => match tensor_to_f32(&view) { + Ok(d) => d, + Err(_) => continue, + }, }; match shape.len() { 2 => { @@ -127,9 +160,13 @@ pub fn load_model_dir(path: impl AsRef) -> Result { vectors.insert(key, data); } + 1 => { + vectors.insert(key, data); + } // 0D scalar tensors (e.g., layer_scalar) → store as 1-element vector - 0 => { vectors.insert(key, data); } + 0 => { + vectors.insert(key, data); + } _ => {} } } @@ -167,6 +204,19 @@ pub fn load_model_dir(path: impl AsRef) -> Result Option { + let config_path = model_dir.join("config.json"); + let text = std::fs::read_to_string(config_path).ok()?; + let json: serde_json::Value = serde_json::from_str(&text).ok()?; + + json.get("quantization") + .or_else(|| json.get("quantization_config")) + .and_then(|q| q.get("group_size")) + .and_then(|v| v.as_u64()) + .map(|v| v as usize) +} + /// Resolve a HuggingFace model ID or path to a local directory or GGUF file. pub fn resolve_model_path(model: &str) -> Result { let path = PathBuf::from(model); @@ -191,11 +241,17 @@ pub fn resolve_model_path(model: &str) -> Result { if let Ok(entries) = std::fs::read_dir(&hf_cache) { for entry in entries.flatten() { let p = entry.path(); - if !p.is_dir() { continue; } + if !p.is_dir() { + continue; + } // Prefer snapshot with safetensors files - let has_st = std::fs::read_dir(&p).ok().map(|rd| { - rd.flatten().any(|e| e.path().extension().is_some_and(|ext| ext == "safetensors")) - }).unwrap_or(false); + let has_st = std::fs::read_dir(&p) + .ok() + .map(|rd| { + rd.flatten() + .any(|e| e.path().extension().is_some_and(|ext| ext == "safetensors")) + }) + .unwrap_or(false); if has_st { return Ok(p); } @@ -239,20 +295,26 @@ fn dequantize_mxfp4_experts( ) -> Result<(), ModelError> { // Find all gate_up_proj_blocks tensors (one per layer) for name in tensor_names { - if !name.ends_with(".gate_up_proj_blocks") { continue; } + if !name.ends_with(".gate_up_proj_blocks") { + continue; + } let scales_name = name.replace("_blocks", "_scales"); let down_blocks_name = name.replace("gate_up_proj_blocks", "down_proj_blocks"); let down_scales_name = name.replace("gate_up_proj_blocks", "down_proj_scales"); // Get tensor views - let blocks_view = st.tensor(name) + let blocks_view = st + .tensor(name) .map_err(|e| ModelError::Parse(format!("MXFP4 blocks: {e}")))?; - let scales_view = st.tensor(&scales_name) + let scales_view = st + .tensor(&scales_name) .map_err(|e| ModelError::Parse(format!("MXFP4 scales: {e}")))?; let shape = blocks_view.shape(); - if shape.len() != 4 { continue; } + if shape.len() != 4 { + continue; + } let num_experts = shape[0]; let out_features = shape[1]; // 2*hidden for gate_up, hidden for down @@ -262,8 +324,11 @@ fn dequantize_mxfp4_experts( // Dequantize gate_up (fused: first half = gate, second half = up) let expert_data = crate::quant::mxfp4::dequantize_all_experts( - blocks_view.data(), scales_view.data(), - num_experts, out_features, groups, + blocks_view.data(), + scales_view.data(), + num_experts, + out_features, + groups, ); // Extract layer number from key @@ -280,12 +345,18 @@ fn dequantize_mxfp4_experts( let gate_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w1.weight"); let up_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w3.weight"); - tensors.insert(gate_key, + tensors.insert( + gate_key, Array2::from_shape_vec((half, in_features), gate_data) - .map_err(|e| ModelError::Parse(e.to_string()))?.into_shared()); - tensors.insert(up_key, + .map_err(|e| ModelError::Parse(e.to_string()))? + .into_shared(), + ); + tensors.insert( + up_key, Array2::from_shape_vec((half, in_features), up_data) - .map_err(|e| ModelError::Parse(e.to_string()))?.into_shared()); + .map_err(|e| ModelError::Parse(e.to_string()))? + .into_shared(), + ); } // Dequantize down projection @@ -297,14 +368,21 @@ fn dequantize_mxfp4_experts( let down_in = down_groups * 32; let down_experts = crate::quant::mxfp4::dequantize_all_experts( - db.data(), ds.data(), num_experts, down_out, down_groups, + db.data(), + ds.data(), + num_experts, + down_out, + down_groups, ); for (e, data) in down_experts.iter().enumerate() { let down_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w2.weight"); - tensors.insert(down_key, + tensors.insert( + down_key, Array2::from_shape_vec((down_out, down_in), data.clone()) - .map_err(|e| ModelError::Parse(e.to_string()))?.into_shared()); + .map_err(|e| ModelError::Parse(e.to_string()))? + .into_shared(), + ); } } } @@ -316,9 +394,12 @@ fn dequantize_mxfp4_experts( let s = router_view.shape(); if s.len() == 2 { let router_key = format!("{layer_prefix}.block_sparse_moe.gate.weight"); - tensors.insert(router_key, + tensors.insert( + router_key, Array2::from_shape_vec((s[0], s[1]), data) - .map_err(|e| ModelError::Parse(e.to_string()))?.into_shared()); + .map_err(|e| ModelError::Parse(e.to_string()))? + .into_shared(), + ); } } } @@ -336,6 +417,53 @@ fn normalize_key(key: &str, prefixes: &[&str]) -> String { key.to_string() } +fn dequantize_mlx_affine_tensor( + st: &safetensors::SafeTensors, + name: &str, + view: &safetensors::tensor::TensorView<'_>, + group_size: Option, +) -> Result<(Vec, usize, usize), ModelError> { + let group_size = group_size.ok_or_else(|| { + ModelError::Parse(format!( + "missing MLX affine group_size for quantized tensor: {name}" + )) + })?; + + let shape = view.shape(); + if shape.len() != 2 { + return Err(ModelError::UnsupportedDtype(format!("{:?}", view.dtype()))); + } + + let stem = name + .strip_suffix(".weight") + .ok_or_else(|| ModelError::UnsupportedDtype(format!("{:?}", view.dtype())))?; + let scales_name = format!("{stem}.scales"); + let biases_name = format!("{stem}.biases"); + + let scales_view = st + .tensor(&scales_name) + .map_err(|e| ModelError::Parse(format!("MLX affine scales {scales_name}: {e}")))?; + let biases_view = st.tensor(&biases_name).ok(); + + let scales = tensor_to_f32(&scales_view)?; + let biases = match biases_view { + Some(biases_view) => Some(tensor_to_f32(&biases_view)?), + None => None, + }; + + let (data, cols) = crate::quant::mlx_affine::dequantize_u32_matrix_bytes( + view.data(), + shape[0], + shape[1], + &scales, + biases.as_deref(), + group_size, + ) + .map_err(ModelError::Parse)?; + + Ok((data, shape[0], cols)) +} + fn tensor_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result, ModelError> { use crate::quant::half; match view.dtype() { diff --git a/crates/larql-models/src/quant/mlx_affine.rs b/crates/larql-models/src/quant/mlx_affine.rs new file mode 100644 index 00000000..4eedcc18 --- /dev/null +++ b/crates/larql-models/src/quant/mlx_affine.rs @@ -0,0 +1,204 @@ +//! MLX affine quantization — packed U32 weights + per-group scales/biases. + +/// Infer bits-per-weight from packed width and grouping. +pub fn infer_bits(packed_cols: usize, groups: usize, group_size: usize) -> Result { + let packed_bits = packed_cols + .checked_mul(32) + .ok_or_else(|| "packed column count overflow".to_string())?; + let cols = groups + .checked_mul(group_size) + .ok_or_else(|| "group shape overflow".to_string())?; + + if cols == 0 || packed_bits % cols != 0 { + return Err(format!( + "cannot infer MLX affine bits: packed_cols={packed_cols}, groups={groups}, group_size={group_size}" + )); + } + + let bits = packed_bits / cols; + if bits == 0 || bits > 32 { + return Err(format!("invalid MLX affine bit width: {bits}")); + } + Ok(bits) +} + +/// Dequantize an MLX affine quantized 2D weight matrix stored as packed U32. +/// +/// Layout matches `mlx.core.quantize(..., mode="affine")`: +/// - packed codes are concatenated little-endian within each row +/// - `scales` and `biases` are per-row, per-group +/// - dequantization is `bias + scale * code` +pub fn dequantize_u32_matrix_bytes( + packed_bytes: &[u8], + rows: usize, + packed_cols: usize, + scales: &[f32], + biases: Option<&[f32]>, + group_size: usize, +) -> Result<(Vec, usize), String> { + if packed_bytes.len() % 4 != 0 { + return Err("MLX affine packed weights must be U32-aligned".to_string()); + } + + let packed: Vec = packed_bytes + .chunks_exact(4) + .map(|b| u32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(); + + if packed.len() != rows.saturating_mul(packed_cols) { + return Err(format!( + "MLX affine packed size mismatch: got {} words, expected {}", + packed.len(), + rows.saturating_mul(packed_cols) + )); + } + + if rows == 0 { + return Ok((Vec::new(), 0)); + } + + if scales.len() % rows != 0 { + return Err(format!( + "MLX affine scales shape mismatch: {} values for {rows} rows", + scales.len() + )); + } + let groups = scales.len() / rows; + let cols = groups + .checked_mul(group_size) + .ok_or_else(|| "MLX affine output shape overflow".to_string())?; + let bits = infer_bits(packed_cols, groups, group_size)?; + + if let Some(biases) = biases { + if biases.len() != scales.len() { + return Err(format!( + "MLX affine biases shape mismatch: {} values, expected {}", + biases.len(), + scales.len() + )); + } + } + + let mask = if bits == 32 { + u64::MAX + } else { + (1u64 << bits) - 1 + }; + + let mut out = vec![0.0; rows * cols]; + + for row in 0..rows { + let packed_row = &packed[row * packed_cols..(row + 1) * packed_cols]; + let row_scales = &scales[row * groups..(row + 1) * groups]; + let row_biases = biases.map(|b| &b[row * groups..(row + 1) * groups]); + + let mut out_col = 0usize; + let mut acc = 0u64; + let mut acc_bits = 0usize; + + for &word in packed_row { + acc |= (word as u64) << acc_bits; + acc_bits += 32; + + while acc_bits >= bits && out_col < cols { + let group = out_col / group_size; + let code = (acc & mask) as f32; + let scale = row_scales[group]; + let bias = row_biases.map(|b| b[group]).unwrap_or(0.0); + out[row * cols + out_col] = bias + scale * code; + acc >>= bits; + acc_bits -= bits; + out_col += 1; + } + } + + if out_col != cols { + return Err(format!( + "MLX affine unpack ended early: row {row}, decoded {out_col}/{cols} values" + )); + } + } + + Ok((out, cols)) +} + +#[cfg(test)] +mod tests { + use super::dequantize_u32_matrix_bytes; + + fn pack_codes(codes: &[u32], bits: usize) -> Vec { + let mut out = Vec::new(); + let mut acc = 0u64; + let mut acc_bits = 0usize; + + for &code in codes { + acc |= (code as u64) << acc_bits; + acc_bits += bits; + + while acc_bits >= 32 { + out.push((acc & 0xFFFF_FFFF) as u32); + acc >>= 32; + acc_bits -= 32; + } + } + + if acc_bits > 0 { + out.push(acc as u32); + } + + out + } + + fn approx_eq(a: &[f32], b: &[f32]) { + assert_eq!(a.len(), b.len()); + for (idx, (&lhs, &rhs)) in a.iter().zip(b.iter()).enumerate() { + assert!( + (lhs - rhs).abs() < 1e-6, + "mismatch at {idx}: {lhs} vs {rhs}" + ); + } + } + + #[test] + fn dequantizes_affine_u32_for_multiple_bit_widths() { + for bits in [4usize, 5, 6, 8] { + let rows = 2usize; + let groups = 2usize; + let group_size = 64usize; + let cols = groups * group_size; + let max_code = (1u32 << bits) - 1; + + let scales = vec![0.25, -0.5, 1.5, -0.125]; + let biases = vec![1.0, -3.0, 2.5, 7.0]; + + let mut packed = Vec::new(); + let mut expected = Vec::new(); + + for row in 0..rows { + let codes: Vec = (0..cols) + .map(|col| ((row * cols + col) as u32) & max_code) + .collect(); + packed.extend(pack_codes(&codes, bits)); + + for (col, code) in codes.into_iter().enumerate() { + let group = row * groups + (col / group_size); + expected.push(biases[group] + scales[group] * code as f32); + } + } + + let packed_bytes: Vec = packed.iter().flat_map(|w| w.to_le_bytes()).collect(); + let (actual, actual_cols) = dequantize_u32_matrix_bytes( + &packed_bytes, + rows, + cols * bits / 32, + &scales, + Some(&biases), + group_size, + ) + .unwrap(); + + assert_eq!(actual_cols, cols); + approx_eq(&actual, &expected); + } + } +} diff --git a/crates/larql-models/src/quant/mod.rs b/crates/larql-models/src/quant/mod.rs index dacb8bb1..fe1ca700 100644 --- a/crates/larql-models/src/quant/mod.rs +++ b/crates/larql-models/src/quant/mod.rs @@ -8,6 +8,7 @@ //! This module handles data format encoding/decoding only. //! Compute operations (matvec, vecmat, GPU shaders) are in `larql-compute`. -pub mod half; pub mod ggml; +pub mod half; +pub mod mlx_affine; pub mod mxfp4; diff --git a/crates/larql-vindex/src/extract/streaming.rs b/crates/larql-vindex/src/extract/streaming.rs index 7378859a..1dac6b20 100644 --- a/crates/larql-vindex/src/extract/streaming.rs +++ b/crates/larql-vindex/src/extract/streaming.rs @@ -44,6 +44,8 @@ pub fn build_vindex_streaming( .map_err(|e| VindexError::Parse(e.to_string()))?; let prefixes = arch.key_prefixes_to_strip(); let cfg = arch.config(); + let mlx_affine_group_size = + larql_models::loading::safetensors::mlx_affine_group_size_from_config(model_dir); let num_layers = cfg.num_layers; let hidden_size = cfg.hidden_size; @@ -75,18 +77,24 @@ pub fn build_vindex_streaming( } callbacks.on_stage("loading"); - eprintln!(" Streaming mode: {} safetensors shards (mmap'd, not loaded)", st_files.len()); + eprintln!( + " Streaming mode: {} safetensors shards (mmap'd, not loaded)", + st_files.len() + ); // (shards vec was for an earlier design — tensor_index + shard_mmaps is the actual approach) // SAFETY: We need to hold both the mmap and the SafeTensors that borrows from it. // We use a two-phase approach: first mmap all files, then deserialize. // The mmaps are kept alive in `shard_mmaps` for the lifetime of the function. - let shard_mmaps: Vec = st_files.iter().map(|path| { - let file = std::fs::File::open(path).unwrap(); - let mmap = unsafe { memmap2::Mmap::map(&file).unwrap() }; - MmapShard { _file: file, mmap } - }).collect(); + let shard_mmaps: Vec = st_files + .iter() + .map(|path| { + let file = std::fs::File::open(path).unwrap(); + let mmap = unsafe { memmap2::Mmap::map(&file).unwrap() }; + MmapShard { _file: file, mmap } + }) + .collect(); // Build a tensor index: key → (shard_idx, tensor_name) // We need to find which shard contains each tensor. @@ -123,18 +131,21 @@ pub fn build_vindex_streaming( let blocks_key = arch.packed_gate_up_blocks_key(layer).unwrap_or_default(); let scales_key = arch.packed_gate_up_scales_key(layer).unwrap_or_default(); - if let (Some(blocks_info), Some(scales_info)) = ( - tensor_index.get(&blocks_key), - tensor_index.get(&scales_key), - ) { - let blocks_st = safetensors::SafeTensors::deserialize(&shard_mmaps[blocks_info.0].mmap) + if let (Some(blocks_info), Some(scales_info)) = + (tensor_index.get(&blocks_key), tensor_index.get(&scales_key)) + { + let blocks_st = + safetensors::SafeTensors::deserialize(&shard_mmaps[blocks_info.0].mmap) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let scales_st = + safetensors::SafeTensors::deserialize(&shard_mmaps[scales_info.0].mmap) + .map_err(|e| VindexError::Parse(e.to_string()))?; + + let blocks_view = blocks_st + .tensor(&blocks_info.1) .map_err(|e| VindexError::Parse(e.to_string()))?; - let scales_st = safetensors::SafeTensors::deserialize(&shard_mmaps[scales_info.0].mmap) - .map_err(|e| VindexError::Parse(e.to_string()))?; - - let blocks_view = blocks_st.tensor(&blocks_info.1) - .map_err(|e| VindexError::Parse(e.to_string()))?; - let scales_view = scales_st.tensor(&scales_info.1) + let scales_view = scales_st + .tensor(&scales_info.1) .map_err(|e| VindexError::Parse(e.to_string()))?; let shape = blocks_view.shape(); @@ -145,7 +156,11 @@ pub fn build_vindex_streaming( let half = out_features / 2; // gate portion let experts = crate::format::quant::mxfp4::dequantize_all_experts( - blocks_view.data(), scales_view.data(), n_exp, out_features, groups, + blocks_view.data(), + scales_view.data(), + n_exp, + out_features, + groups, ); let mut total_features = 0usize; @@ -160,7 +175,10 @@ pub fn build_vindex_streaming( if total_features > 0 { layer_infos.push(VindexLayerInfo { - layer, num_features: total_features, offset, length: layer_bytes, + layer, + num_features: total_features, + offset, + length: layer_bytes, num_experts: Some(n_exp), num_features_per_expert: Some(half), }); @@ -179,7 +197,12 @@ pub fn build_vindex_streaming( None => continue, }; - if let Some(tensor) = get_tensor_f32(&shard_mmaps, &tensor_index, &gate_key)? { + if let Some(tensor) = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &gate_key, + mlx_affine_group_size, + )? { features_per_expert = tensor.shape()[0]; total_features += features_per_expert; let data = tensor.as_slice().unwrap(); @@ -189,7 +212,10 @@ pub fn build_vindex_streaming( if total_features > 0 { layer_infos.push(VindexLayerInfo { - layer, num_features: total_features, offset, length: layer_bytes, + layer, + num_features: total_features, + offset, + length: layer_bytes, num_experts: Some(n_experts), num_features_per_expert: Some(features_per_expert), }); @@ -198,13 +224,22 @@ pub fn build_vindex_streaming( } else { // Dense: single gate matrix per layer let gate_key = normalize_key(&arch.ffn_gate_key(layer), prefixes); - if let Some(tensor) = get_tensor_f32(&shard_mmaps, &tensor_index, &gate_key)? { + if let Some(tensor) = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &gate_key, + mlx_affine_group_size, + )? { let num_features = tensor.shape()[0]; let data = tensor.as_slice().unwrap(); let length = write_floats(&mut gate_file, data, dtype)?; layer_infos.push(VindexLayerInfo { - layer, num_features, offset, length, - num_experts: None, num_features_per_expert: None, + layer, + num_features, + offset, + length, + num_experts: None, + num_features_per_expert: None, }); offset += length; } @@ -222,11 +257,17 @@ pub fn build_vindex_streaming( let mut router_file = BufWriter::new(std::fs::File::create(&router_path)?); for layer in 0..num_layers { - let router_key = arch.moe_router_key(layer) + let router_key = arch + .moe_router_key(layer) .map(|k| normalize_key(&k, prefixes)) .unwrap_or_default(); - if let Some(tensor) = get_tensor_f32(&shard_mmaps, &tensor_index, &router_key)? { + if let Some(tensor) = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &router_key, + mlx_affine_group_size, + )? { let data = tensor.as_slice().unwrap(); let bytes = crate::config::dtype::encode_floats(data, dtype); router_file.write_all(&bytes)?; @@ -234,7 +275,12 @@ pub fn build_vindex_streaming( // Also try router bias let bias_key = router_key.replace(".weight", ".bias"); - if let Some(tensor) = get_tensor_f32(&shard_mmaps, &tensor_index, &bias_key)? { + if let Some(tensor) = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &bias_key, + mlx_affine_group_size, + )? { let data = tensor.as_slice().unwrap(); let bytes = crate::config::dtype::encode_floats(data, dtype); // Write bias after weight for each layer @@ -248,8 +294,13 @@ pub fn build_vindex_streaming( // ── 2. Embeddings ── callbacks.on_stage("embeddings"); let embed_key = normalize_key(arch.embed_key(), prefixes); - let embed = get_tensor_f32(&shard_mmaps, &tensor_index, &embed_key)? - .ok_or_else(|| VindexError::MissingTensor(embed_key.clone()))?; + let embed = get_tensor_f32( + &shard_mmaps, + &tensor_index, + &embed_key, + mlx_affine_group_size, + )? + .ok_or_else(|| VindexError::MissingTensor(embed_key.clone()))?; let vocab_size = embed.shape()[0]; let embed_data = embed.as_slice().unwrap(); let embed_bytes = crate::config::dtype::encode_floats(embed_data, dtype); @@ -261,44 +312,61 @@ pub fn build_vindex_streaming( let mut all_down_meta: Vec>>> = vec![None; num_layers]; // Build whole-word vocab once - let (_ww_ids, _ww_embed) = super::build::build_whole_word_vocab(tokenizer, &embed, vocab_size, hidden_size); + let (_ww_ids, _ww_embed) = + super::build::build_whole_word_vocab(tokenizer, &embed, vocab_size, hidden_size); for (layer, layer_down_meta) in all_down_meta.iter_mut().enumerate().take(num_layers) { callbacks.on_layer_start("down", layer, num_layers); let start = std::time::Instant::now(); // Get down matrices for this layer - let down_matrices: Vec> = if expert_format == larql_models::ExpertFormat::PackedMxfp4 { + let down_matrices: Vec> = if expert_format + == larql_models::ExpertFormat::PackedMxfp4 + { // MXFP4: dequantize down_proj_blocks let blocks_key = arch.packed_down_blocks_key(layer).unwrap_or_default(); let scales_key = arch.packed_down_scales_key(layer).unwrap_or_default(); - if let (Some(bi), Some(si)) = (tensor_index.get(&blocks_key), tensor_index.get(&scales_key)) { + if let (Some(bi), Some(si)) = + (tensor_index.get(&blocks_key), tensor_index.get(&scales_key)) + { let bst = safetensors::SafeTensors::deserialize(&shard_mmaps[bi.0].mmap) .map_err(|e| VindexError::Parse(e.to_string()))?; let sst = safetensors::SafeTensors::deserialize(&shard_mmaps[si.0].mmap) .map_err(|e| VindexError::Parse(e.to_string()))?; - let bv = bst.tensor(&bi.1).map_err(|e| VindexError::Parse(e.to_string()))?; - let sv = sst.tensor(&si.1).map_err(|e| VindexError::Parse(e.to_string()))?; + let bv = bst + .tensor(&bi.1) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let sv = sst + .tensor(&si.1) + .map_err(|e| VindexError::Parse(e.to_string()))?; let shape = bv.shape(); let n_exp = shape[0]; let out_features = shape[1]; let groups = shape[2]; let in_features = groups * 32; let experts = crate::format::quant::mxfp4::dequantize_all_experts( - bv.data(), sv.data(), n_exp, out_features, groups, + bv.data(), + sv.data(), + n_exp, + out_features, + groups, ); - experts.into_iter().map(|data| { - Array2::from_shape_vec((out_features, in_features), data).unwrap() - }).collect() + experts + .into_iter() + .map(|data| Array2::from_shape_vec((out_features, in_features), data).unwrap()) + .collect() } else { - callbacks.on_layer_done("down", layer, 0.0); continue; + callbacks.on_layer_done("down", layer, 0.0); + continue; } } else if is_moe && n_experts > 0 { let mut mats = Vec::new(); for expert in 0..n_experts { if let Some(key) = arch.expert_ffn_down_key(layer, expert) { let nk = normalize_key(&key, prefixes); - if let Some(t) = get_tensor_f32(&shard_mmaps, &tensor_index, &nk)? { + if let Some(t) = + get_tensor_f32(&shard_mmaps, &tensor_index, &nk, mlx_affine_group_size)? + { mats.push(t); } } @@ -306,9 +374,17 @@ pub fn build_vindex_streaming( mats } else { let down_key = normalize_key(&arch.ffn_down_key(layer), prefixes); - match get_tensor_f32(&shard_mmaps, &tensor_index, &down_key)? { + match get_tensor_f32( + &shard_mmaps, + &tensor_index, + &down_key, + mlx_affine_group_size, + )? { Some(t) => vec![t], - None => { callbacks.on_layer_done("down", layer, 0.0); continue; } + None => { + callbacks.on_layer_done("down", layer, 0.0); + continue; + } } }; @@ -324,10 +400,16 @@ pub fn build_vindex_streaming( for batch_start in (0..num_features).step_by(batch_size) { let batch_end = (batch_start + batch_size).min(num_features); - callbacks.on_feature_progress("down", layer, feature_offset + batch_start, - down_matrices.iter().map(|m| m.shape()[1]).sum()); + callbacks.on_feature_progress( + "down", + layer, + feature_offset + batch_start, + down_matrices.iter().map(|m| m.shape()[1]).sum(), + ); - let w_chunk = w_down.slice(ndarray::s![.., batch_start..batch_end]).to_owned(); + let w_chunk = w_down + .slice(ndarray::s![.., batch_start..batch_end]) + .to_owned(); let cpu = larql_compute::CpuBackend; use larql_compute::ComputeBackend; let chunk_logits = cpu.matmul(embed.view(), w_chunk.view()); @@ -342,29 +424,42 @@ pub fn build_vindex_streaming( scores.truncate(k); scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); - let top_k_entries: Vec = scores.into_iter() + let top_k_entries: Vec = scores + .into_iter() .filter_map(|(idx, logit)| { - tokenizer.decode(&[idx as u32], true).ok() + tokenizer + .decode(&[idx as u32], true) + .ok() .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) - .map(|token| larql_models::TopKEntry { token, token_id: idx as u32, logit }) + .map(|token| larql_models::TopKEntry { + token, + token_id: idx as u32, + logit, + }) }) .collect(); - let (top_token, top_token_id, c_score) = if let Some(first) = top_k_entries.first() { - (first.token.clone(), first.token_id, first.logit) - } else { - (String::new(), 0, 0.0) - }; + let (top_token, top_token_id, c_score) = + if let Some(first) = top_k_entries.first() { + (first.token.clone(), first.token_id, first.logit) + } else { + (String::new(), 0, 0.0) + }; let feat_idx = feature_offset + feat; if layer_down_meta.is_none() { *layer_down_meta = Some(Vec::new()); } if let Some(ref mut metas) = layer_down_meta { - while metas.len() <= feat_idx { metas.push(None); } + while metas.len() <= feat_idx { + metas.push(None); + } metas[feat_idx] = Some(crate::FeatureMeta { - top_token, top_token_id, c_score, top_k: top_k_entries, + top_token, + top_token_id, + c_score, + top_k: top_k_entries, }); } } @@ -380,7 +475,8 @@ pub fn build_vindex_streaming( // ── 4. Tokenizer ── callbacks.on_stage("tokenizer"); - let tokenizer_json = tokenizer.to_string(true) + let tokenizer_json = tokenizer + .to_string(true) .map_err(|e| VindexError::Parse(format!("tokenizer serialize: {e}")))?; std::fs::write(output_dir.join("tokenizer.json"), tokenizer_json)?; callbacks.on_stage_done("tokenizer", 0.0); @@ -391,7 +487,10 @@ pub fn build_vindex_streaming( version: 2, model: model_name.to_string(), family: family.clone(), - num_layers, hidden_size, intermediate_size, vocab_size, + num_layers, + hidden_size, + intermediate_size, + vocab_size, embed_scale, layers: layer_infos, down_top_k, @@ -421,7 +520,9 @@ pub fn build_vindex_streaming( shared_expert: arch.num_shared_experts() > 0, router_type: "top_k_softmax".to_string(), }) - } else { None }, + } else { + None + }, // Per-layer geometry (Gemma 4) global_head_dim: cfg.global_head_dim, num_global_kv_heads: cfg.num_global_kv_heads, @@ -437,8 +538,8 @@ pub fn build_vindex_streaming( }; // Write preliminary index.json (needed by write_model_weights which reads dtype from it) - let config_json = serde_json::to_string_pretty(&config) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let config_json = + serde_json::to_string_pretty(&config).map_err(|e| VindexError::Parse(e.to_string()))?; std::fs::write(output_dir.join("index.json"), config_json)?; // ── 6. Model weights (if extract level requires them) ── @@ -449,6 +550,7 @@ pub fn build_vindex_streaming( tensor_index: &tensor_index, arch: &*arch, num_layers, + mlx_affine_group_size, }; crate::format::weights::write_model_weights(&streaming_source, output_dir, callbacks)?; // write_model_weights updates index.json with has_model_weights=true @@ -456,11 +558,11 @@ pub fn build_vindex_streaming( // Final checksums let config_text = std::fs::read_to_string(output_dir.join("index.json"))?; - let mut config: VindexConfig = serde_json::from_str(&config_text) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let mut config: VindexConfig = + serde_json::from_str(&config_text).map_err(|e| VindexError::Parse(e.to_string()))?; config.checksums = crate::format::checksums::compute_checksums(output_dir).ok(); - let config_json = serde_json::to_string_pretty(&config) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let config_json = + serde_json::to_string_pretty(&config).map_err(|e| VindexError::Parse(e.to_string()))?; std::fs::write(output_dir.join("index.json"), config_json)?; Ok(()) @@ -471,6 +573,7 @@ fn get_tensor_f32( shards: &[MmapShard], index: &HashMap, key: &str, + mlx_affine_group_size: Option, ) -> Result>, VindexError> { let (shard_idx, tensor_name) = match index.get(key) { Some(v) => v, @@ -480,20 +583,36 @@ fn get_tensor_f32( let st = safetensors::SafeTensors::deserialize(&shards[*shard_idx].mmap) .map_err(|e| VindexError::Parse(e.to_string()))?; - let view = st.tensor(tensor_name) + let view = st + .tensor(tensor_name) .map_err(|e| VindexError::Parse(e.to_string()))?; let shape = view.shape(); - if shape.len() != 2 { return Ok(None); } + if shape.len() != 2 { + return Ok(None); + } let data = match view.dtype() { - safetensors::Dtype::F32 => { - view.data().chunks_exact(4) - .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) - .collect() - } + safetensors::Dtype::F32 => view + .data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(), safetensors::Dtype::F16 => crate::format::quant::half::decode_f16(view.data()), safetensors::Dtype::BF16 => crate::format::quant::half::decode_bf16(view.data()), + safetensors::Dtype::U32 => { + let Some(group_size) = mlx_affine_group_size else { + return Ok(None); + }; + let Some((data, cols)) = + dequantize_mlx_affine_tensor(shards, index, key, shape, view.data(), group_size)? + else { + return Ok(None); + }; + let arr = Array2::from_shape_vec((shape[0], cols), data) + .map_err(|e| VindexError::Parse(e.to_string()))?; + return Ok(Some(arr)); + } _ => return Ok(None), // skip non-float }; @@ -502,6 +621,71 @@ fn get_tensor_f32( Ok(Some(arr)) } +fn dequantize_mlx_affine_tensor( + shards: &[MmapShard], + index: &HashMap, + key: &str, + weight_shape: &[usize], + packed: &[u8], + group_size: usize, +) -> Result, usize)>, VindexError> { + let Some(stem) = key.strip_suffix(".weight") else { + return Ok(None); + }; + + let scales_key = format!("{stem}.scales"); + let biases_key = format!("{stem}.biases"); + + let (scales_shard, scales_name) = match index.get(&scales_key) { + Some(v) => v, + None => return Ok(None), + }; + let scales_st = safetensors::SafeTensors::deserialize(&shards[*scales_shard].mmap) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let scales_view = scales_st + .tensor(scales_name) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let scales = tensor_view_to_f32(&scales_view)?; + + let biases = if let Some((bias_shard, bias_name)) = index.get(&biases_key) { + let biases_st = safetensors::SafeTensors::deserialize(&shards[*bias_shard].mmap) + .map_err(|e| VindexError::Parse(e.to_string()))?; + let biases_view = biases_st + .tensor(bias_name) + .map_err(|e| VindexError::Parse(e.to_string()))?; + Some(tensor_view_to_f32(&biases_view)?) + } else { + None + }; + + let (data, cols) = crate::format::quant::mlx_affine::dequantize_u32_matrix_bytes( + packed, + weight_shape[0], + weight_shape[1], + &scales, + biases.as_deref(), + group_size, + ) + .map_err(VindexError::Parse)?; + + Ok(Some((data, cols))) +} + +fn tensor_view_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result, VindexError> { + match view.dtype() { + safetensors::Dtype::F32 => Ok(view + .data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect()), + safetensors::Dtype::F16 => Ok(crate::format::quant::half::decode_f16(view.data())), + safetensors::Dtype::BF16 => Ok(crate::format::quant::half::decode_bf16(view.data())), + other => Err(VindexError::Parse(format!( + "unsupported tensor dtype: {other:?}" + ))), + } +} + fn normalize_key(key: &str, prefixes: &[&str]) -> String { for prefix in prefixes { if let Some(stripped) = key.strip_prefix(prefix) { diff --git a/crates/larql-vindex/src/format/quant/mod.rs b/crates/larql-vindex/src/format/quant/mod.rs index 6d82a79f..7e6b5a34 100644 --- a/crates/larql-vindex/src/format/quant/mod.rs +++ b/crates/larql-vindex/src/format/quant/mod.rs @@ -1,5 +1,6 @@ //! Quantization and dequantization — re-exports from larql-models. -pub use larql_models::quant::half; pub use larql_models::quant::ggml; +pub use larql_models::quant::half; +pub use larql_models::quant::mlx_affine; pub use larql_models::quant::mxfp4; diff --git a/crates/larql-vindex/src/format/weights.rs b/crates/larql-vindex/src/format/weights.rs index c35842aa..fc27ef46 100644 --- a/crates/larql-vindex/src/format/weights.rs +++ b/crates/larql-vindex/src/format/weights.rs @@ -18,11 +18,11 @@ use std::path::Path; use ndarray::Array2; use serde::{Deserialize, Serialize}; +use crate::config::{VindexConfig, VindexModelConfig}; use crate::error::VindexError; use crate::extract::callbacks::IndexBuildCallbacks; -use crate::config::{VindexConfig, VindexModelConfig}; -use crate::index::core::IndexLoadCallbacks; use crate::format::load::load_vindex_config; +use crate::index::core::IndexLoadCallbacks; use larql_models::ModelWeights; @@ -102,6 +102,7 @@ pub struct StreamingWeights<'a> { pub tensor_index: &'a HashMap, pub arch: &'a dyn larql_models::ModelArchitecture, pub num_layers: usize, + pub mlx_affine_group_size: Option, } impl<'a> StreamingWeights<'a> { @@ -109,32 +110,208 @@ impl<'a> StreamingWeights<'a> { let (shard_idx, tensor_name) = self.tensor_index.get(key)?; let st = safetensors::SafeTensors::deserialize(self.shard_mmaps[*shard_idx]).ok()?; let view = st.tensor(tensor_name).ok()?; - let shape = view.shape().to_vec(); + let mut shape = view.shape().to_vec(); let data = match view.dtype() { - safetensors::Dtype::F32 => { - view.data().chunks_exact(4) - .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) - .collect() - } + safetensors::Dtype::F32 => view + .data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(), safetensors::Dtype::F16 => crate::format::quant::half::decode_f16(view.data()), safetensors::Dtype::BF16 => crate::format::quant::half::decode_bf16(view.data()), + safetensors::Dtype::U32 if shape.len() == 2 => { + let group_size = self.mlx_affine_group_size?; + let stem = key.strip_suffix(".weight")?; + + let (scales_shard, scales_name) = + self.tensor_index.get(&format!("{stem}.scales"))?; + let scales_st = + safetensors::SafeTensors::deserialize(self.shard_mmaps[*scales_shard]).ok()?; + let scales_view = scales_st.tensor(scales_name).ok()?; + let scales = tensor_view_to_f32(&scales_view).ok()?; + + let biases = if let Some((bias_shard, bias_name)) = + self.tensor_index.get(&format!("{stem}.biases")) + { + let bias_st = + safetensors::SafeTensors::deserialize(self.shard_mmaps[*bias_shard]) + .ok()?; + let bias_view = bias_st.tensor(bias_name).ok()?; + Some(tensor_view_to_f32(&bias_view).ok()?) + } else { + None + }; + + let (data, cols) = crate::format::quant::mlx_affine::dequantize_u32_matrix_bytes( + view.data(), + shape[0], + shape[1], + &scales, + biases.as_deref(), + group_size, + ) + .ok()?; + shape[1] = cols; + data + } _ => return None, }; Some((data, shape)) } } +fn tensor_view_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result, VindexError> { + match view.dtype() { + safetensors::Dtype::F32 => Ok(view + .data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect()), + safetensors::Dtype::F16 => Ok(crate::format::quant::half::decode_f16(view.data())), + safetensors::Dtype::BF16 => Ok(crate::format::quant::half::decode_bf16(view.data())), + other => Err(VindexError::Parse(format!( + "unsupported tensor dtype: {other:?}" + ))), + } +} + +#[cfg(test)] +mod tests { + use super::StreamingWeights; + use crate::format::weights::WeightSource; + use std::collections::HashMap; + + fn pack_codes(codes: &[u32], bits: usize) -> Vec { + let mut out = Vec::new(); + let mut acc = 0u64; + let mut acc_bits = 0usize; + + for &code in codes { + acc |= (code as u64) << acc_bits; + acc_bits += bits; + + while acc_bits >= 32 { + out.push((acc & 0xFFFF_FFFF) as u32); + acc >>= 32; + acc_bits -= 32; + } + } + + if acc_bits > 0 { + out.push(acc as u32); + } + + out + } + + #[test] + fn streaming_weights_dequantize_mlx_affine_u32() { + let bits = 4usize; + let group_size = 64usize; + let codes: Vec = (0..group_size).map(|i| (i as u32) & 0xF).collect(); + let packed = pack_codes(&codes, bits); + let packed_bytes: Vec = packed.iter().flat_map(|w| w.to_le_bytes()).collect(); + let scales = vec![0.5f32]; + let biases = vec![-2.0f32]; + let scale_bytes: Vec = scales.iter().flat_map(|v| v.to_le_bytes()).collect(); + let bias_bytes: Vec = biases.iter().flat_map(|v| v.to_le_bytes()).collect(); + + let tensors = [ + ( + "embed_tokens.weight".to_string(), + safetensors::tensor::TensorView::new( + safetensors::Dtype::U32, + vec![1, packed.len()], + &packed_bytes, + ) + .unwrap(), + ), + ( + "embed_tokens.scales".to_string(), + safetensors::tensor::TensorView::new( + safetensors::Dtype::F32, + vec![1, 1], + &scale_bytes, + ) + .unwrap(), + ), + ( + "embed_tokens.biases".to_string(), + safetensors::tensor::TensorView::new( + safetensors::Dtype::F32, + vec![1, 1], + &bias_bytes, + ) + .unwrap(), + ), + ]; + let serialized = safetensors::tensor::serialize(tensors, &None).unwrap(); + + let mut tensor_index = HashMap::new(); + tensor_index.insert( + "embed_tokens.weight".to_string(), + (0usize, "embed_tokens.weight".to_string()), + ); + tensor_index.insert( + "embed_tokens.scales".to_string(), + (0usize, "embed_tokens.scales".to_string()), + ); + tensor_index.insert( + "embed_tokens.biases".to_string(), + (0usize, "embed_tokens.biases".to_string()), + ); + + let arch = larql_models::detect_from_json(&serde_json::json!({ + "model_type": "gemma4", + "text_config": { + "model_type": "gemma4_text", + "num_hidden_layers": 1, + "hidden_size": 64, + "intermediate_size": 64, + "head_dim": 8, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "vocab_size": 1 + } + })); + + let shards = [serialized.as_slice()]; + let source = StreamingWeights { + shard_mmaps: &shards, + tensor_index: &tensor_index, + arch: &*arch, + num_layers: 1, + mlx_affine_group_size: Some(group_size), + }; + + let (data, rows, cols) = source.get_tensor("embed_tokens.weight").unwrap(); + assert_eq!((rows, cols), (1, group_size)); + + for (idx, value) in data.iter().enumerate() { + let expected = -2.0 + 0.5 * codes[idx] as f32; + assert!( + (value - expected).abs() < 1e-6, + "mismatch at {idx}: {value} vs {expected}" + ); + } + } +} + impl<'a> WeightSource for StreamingWeights<'a> { fn get_tensor(&self, key: &str) -> Option<(Vec, usize, usize)> { let (data, shape) = self.read_tensor_raw(key)?; - if shape.len() != 2 { return None; } + if shape.len() != 2 { + return None; + } Some((data, shape[0], shape[1])) } fn get_vector(&self, key: &str) -> Option> { let (data, shape) = self.read_tensor_raw(key)?; - if shape.len() != 1 { return None; } + if shape.len() != 1 { + return None; + } Some(data) } @@ -207,9 +384,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(key) { let len = write_floats(&mut attn_file, &data, dtype)?; entries.push(WeightEntry { - key: key.clone(), kind: "tensor".into(), + key: key.clone(), + kind: "tensor".into(), shape: vec![rows, cols], - offset: attn_offset, length: len, + offset: attn_offset, + length: len, file: "attn_weights.bin".into(), }); attn_offset += len; @@ -217,14 +396,19 @@ pub fn write_model_weights( } // QK norms (1D vectors, stored alongside attention) - for key in [arch.attn_q_norm_key(layer), arch.attn_k_norm_key(layer)].iter().flatten() { + for key in [arch.attn_q_norm_key(layer), arch.attn_k_norm_key(layer)] + .iter() + .flatten() + { if let Some(data) = source.get_vector(key) { let bytes = crate::config::dtype::encode_floats(&data, dtype); attn_file.write_all(&bytes)?; entries.push(WeightEntry { - key: key.clone(), kind: "vector".into(), + key: key.clone(), + kind: "vector".into(), shape: vec![data.len()], - offset: attn_offset, length: bytes.len() as u64, + offset: attn_offset, + length: bytes.len() as u64, file: "attn_weights.bin".into(), }); attn_offset += bytes.len() as u64; @@ -253,9 +437,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&key) { let len = write_floats(&mut up_file, &data, dtype)?; entries.push(WeightEntry { - key, kind: "tensor".into(), + key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: up_offset, length: len, + offset: up_offset, + length: len, file: "up_weights.bin".into(), }); up_offset += len; @@ -265,9 +451,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&key) { let len = write_floats(&mut down_file, &data, dtype)?; entries.push(WeightEntry { - key, kind: "tensor".into(), + key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: down_offset, length: len, + offset: down_offset, + length: len, file: "down_weights.bin".into(), }); down_offset += len; @@ -278,9 +466,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&key) { let len = write_floats(&mut up_file, &data, dtype)?; entries.push(WeightEntry { - key, kind: "tensor".into(), + key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: up_offset, length: len, + offset: up_offset, + length: len, file: "up_weights.bin".into(), }); up_offset += len; @@ -291,9 +481,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&up_key) { let len = write_floats(&mut up_file, &data, dtype)?; entries.push(WeightEntry { - key: up_key, kind: "tensor".into(), + key: up_key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: up_offset, length: len, + offset: up_offset, + length: len, file: "up_weights.bin".into(), }); up_offset += len; @@ -303,9 +495,11 @@ pub fn write_model_weights( if let Some((data, rows, cols)) = source.get_tensor(&down_key) { let len = write_floats(&mut down_file, &data, dtype)?; entries.push(WeightEntry { - key: down_key, kind: "tensor".into(), + key: down_key, + kind: "tensor".into(), shape: vec![rows, cols], - offset: down_offset, length: len, + offset: down_offset, + length: len, file: "down_weights.bin".into(), }); down_offset += len; @@ -329,16 +523,21 @@ pub fn write_model_weights( Some(arch.post_attention_layernorm_key(layer)), arch.pre_feedforward_layernorm_key(layer), arch.post_feedforward_layernorm_key(layer), - ].into_iter().flatten().collect(); + ] + .into_iter() + .flatten() + .collect(); for key in norm_keys { if let Some(data) = source.get_vector(&key) { let bytes = crate::config::dtype::encode_floats(&data, dtype); norms_file.write_all(&bytes)?; entries.push(WeightEntry { - key, kind: "vector".into(), + key, + kind: "vector".into(), shape: vec![data.len()], - offset: norms_offset, length: bytes.len() as u64, + offset: norms_offset, + length: bytes.len() as u64, file: "norms.bin".into(), }); norms_offset += bytes.len() as u64; @@ -351,9 +550,11 @@ pub fn write_model_weights( let bytes = crate::config::dtype::encode_floats(&data, dtype); norms_file.write_all(&bytes)?; entries.push(WeightEntry { - key: "norm.weight".into(), kind: "vector".into(), + key: "norm.weight".into(), + kind: "vector".into(), shape: vec![data.len()], - offset: norms_offset, length: bytes.len() as u64, + offset: norms_offset, + length: bytes.len() as u64, file: "norms.bin".into(), }); } @@ -364,23 +565,25 @@ pub fn write_model_weights( let lm_bytes = crate::config::dtype::encode_floats(&data, dtype); std::fs::write(dir.join("lm_head.bin"), &lm_bytes)?; entries.push(WeightEntry { - key: "lm_head.weight".into(), kind: "tensor".into(), + key: "lm_head.weight".into(), + kind: "tensor".into(), shape: vec![rows, cols], - offset: 0, length: lm_bytes.len() as u64, + offset: 0, + length: lm_bytes.len() as u64, file: "lm_head.bin".into(), }); } // ── Manifest ── - let manifest_json = serde_json::to_string_pretty(&entries) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let manifest_json = + serde_json::to_string_pretty(&entries).map_err(|e| VindexError::Parse(e.to_string()))?; std::fs::write(dir.join("weight_manifest.json"), manifest_json)?; // ── Update index.json ── let config_path = dir.join("index.json"); let config_text = std::fs::read_to_string(&config_path)?; - let mut config: VindexConfig = serde_json::from_str(&config_text) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let mut config: VindexConfig = + serde_json::from_str(&config_text).map_err(|e| VindexError::Parse(e.to_string()))?; config.has_model_weights = true; @@ -415,15 +618,19 @@ pub fn write_model_weights( query_pre_attn_scalar: cfg.query_pre_attn_scalar, }); - let config_json = serde_json::to_string_pretty(&config) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let config_json = + serde_json::to_string_pretty(&config).map_err(|e| VindexError::Parse(e.to_string()))?; std::fs::write(&config_path, config_json)?; callbacks.on_stage_done("model_weights", start.elapsed().as_secs_f64() * 1000.0); Ok(()) } -fn write_floats(w: &mut impl Write, data: &[f32], dtype: crate::config::dtype::StorageDtype) -> Result { +fn write_floats( + w: &mut impl Write, + data: &[f32], + dtype: crate::config::dtype::StorageDtype, +) -> Result { let bytes = crate::config::dtype::encode_floats(data, dtype); w.write_all(&bytes)?; Ok(bytes.len() as u64) @@ -444,9 +651,10 @@ pub fn load_model_weights( )); } - let model_cfg = config.model_config.as_ref().ok_or_else(|| { - VindexError::Parse("vindex missing model_config in index.json".into()) - })?; + let model_cfg = config + .model_config + .as_ref() + .ok_or_else(|| VindexError::Parse("vindex missing model_config in index.json".into()))?; // Reconstruct full architecture config — includes per-layer geometry for Gemma 4. let mut arch_obj = serde_json::json!({ @@ -463,19 +671,45 @@ pub fn load_model_weights( }); // Pass through Gemma 4 per-layer geometry fields (if present in vindex config). let obj = arch_obj.as_object_mut().unwrap(); - if let Some(v) = model_cfg.global_head_dim { obj.insert("global_head_dim".into(), v.into()); } - if let Some(v) = model_cfg.num_global_kv_heads { obj.insert("num_global_key_value_heads".into(), v.into()); } - if let Some(v) = model_cfg.partial_rotary_factor { obj.insert("partial_rotary_factor".into(), v.into()); } - if let Some(v) = model_cfg.sliding_window_pattern { obj.insert("sliding_window_pattern".into(), v.into()); } - if let Some(ref v) = model_cfg.layer_types { obj.insert("layer_types".into(), serde_json::to_value(v).unwrap_or_default()); } - if model_cfg.attention_k_eq_v { obj.insert("attention_k_eq_v".into(), true.into()); } - if let Some(v) = model_cfg.num_kv_shared_layers { obj.insert("num_kv_shared_layers".into(), v.into()); } - if let Some(v) = model_cfg.per_layer_embed_dim { obj.insert("hidden_size_per_layer_input".into(), v.into()); } - if let Some(v) = model_cfg.rope_local_base { obj.insert("rope_local_base_freq".into(), v.into()); } - if let Some(v) = model_cfg.query_pre_attn_scalar { obj.insert("query_pre_attn_scalar".into(), v.into()); } + if let Some(v) = model_cfg.global_head_dim { + obj.insert("global_head_dim".into(), v.into()); + } + if let Some(v) = model_cfg.num_global_kv_heads { + obj.insert("num_global_key_value_heads".into(), v.into()); + } + if let Some(v) = model_cfg.partial_rotary_factor { + obj.insert("partial_rotary_factor".into(), v.into()); + } + if let Some(v) = model_cfg.sliding_window_pattern { + obj.insert("sliding_window_pattern".into(), v.into()); + } + if let Some(ref v) = model_cfg.layer_types { + obj.insert( + "layer_types".into(), + serde_json::to_value(v).unwrap_or_default(), + ); + } + if model_cfg.attention_k_eq_v { + obj.insert("attention_k_eq_v".into(), true.into()); + } + if let Some(v) = model_cfg.num_kv_shared_layers { + obj.insert("num_kv_shared_layers".into(), v.into()); + } + if let Some(v) = model_cfg.per_layer_embed_dim { + obj.insert("hidden_size_per_layer_input".into(), v.into()); + } + if let Some(v) = model_cfg.rope_local_base { + obj.insert("rope_local_base_freq".into(), v.into()); + } + if let Some(v) = model_cfg.query_pre_attn_scalar { + obj.insert("query_pre_attn_scalar".into(), v.into()); + } let arch = larql_models::detect_from_json(&arch_obj); - callbacks.on_file_start("embeddings", &dir.join("embeddings.bin").display().to_string()); + callbacks.on_file_start( + "embeddings", + &dir.join("embeddings.bin").display().to_string(), + ); let embed_file = std::fs::File::open(dir.join("embeddings.bin"))?; let embed_mmap = unsafe { memmap2::Mmap::map(&embed_file)? }; // Detect actual dtype from file size (may differ from index.json global dtype) @@ -497,8 +731,8 @@ pub fn load_model_weights( callbacks.on_file_start("model_weights", "weight_manifest.json"); let manifest_text = std::fs::read_to_string(&manifest_path)?; - let entries: Vec = serde_json::from_str(&manifest_text) - .map_err(|e| VindexError::Parse(e.to_string()))?; + let entries: Vec = + serde_json::from_str(&manifest_text).map_err(|e| VindexError::Parse(e.to_string()))?; let mut mmap_cache: HashMap = HashMap::new(); let mut tensors: HashMap = HashMap::new(); @@ -506,7 +740,11 @@ pub fn load_model_weights( let mut lm_head_loaded: Option = None; for entry in &entries { - let filename = if entry.file.is_empty() { "model_weights.bin".to_string() } else { entry.file.clone() }; + let filename = if entry.file.is_empty() { + "model_weights.bin".to_string() + } else { + entry.file.clone() + }; if !mmap_cache.contains_key(&filename) { let fpath = dir.join(&filename); @@ -522,11 +760,15 @@ pub fn load_model_weights( Some(m) => m.as_ref(), None => continue, }; - if data.is_empty() { continue; } + if data.is_empty() { + continue; + } let byte_offset = entry.offset as usize; let byte_count = entry.length as usize; - if byte_offset + byte_count > data.len() { continue; } + if byte_offset + byte_count > data.len() { + continue; + } let raw_bytes = &data[byte_offset..byte_offset + byte_count]; // Detect actual dtype from byte count vs expected shape. // Gate vector conversion may have changed index.json dtype to f32 @@ -568,9 +810,9 @@ pub fn load_model_weights( let float_count = info.num_features * config.hidden_size; if float_offset + float_count <= gate_floats.len() { let gate_data = &gate_floats[float_offset..float_offset + float_count]; - let gate_matrix = Array2::from_shape_vec( - (info.num_features, config.hidden_size), gate_data.to_vec(), - ).map_err(|e| VindexError::Parse(e.to_string()))?; + let gate_matrix = + Array2::from_shape_vec((info.num_features, config.hidden_size), gate_data.to_vec()) + .map_err(|e| VindexError::Parse(e.to_string()))?; tensors.insert(arch.ffn_gate_key(info.layer), gate_matrix.into_shared()); } } @@ -582,7 +824,10 @@ pub fn load_model_weights( let lm_head = lm_head_loaded.unwrap_or_else(|| embed.clone()); Ok(ModelWeights { - tensors, vectors, embed, lm_head, + tensors, + vectors, + embed, + lm_head, num_layers: cfg.num_layers, hidden_size: cfg.hidden_size, intermediate_size: cfg.intermediate_size, @@ -598,10 +843,14 @@ pub fn load_model_weights( /// Find the tokenizer path near a model or vindex directory. pub fn find_tokenizer_path(dir: &Path) -> Option { let p = dir.join("tokenizer.json"); - if p.exists() { return Some(p); } + if p.exists() { + return Some(p); + } if let Some(parent) = dir.parent() { let p = parent.join("tokenizer.json"); - if p.exists() { return Some(p); } + if p.exists() { + return Some(p); + } } None }