Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 158 additions & 30 deletions crates/larql-models/src/loading/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use std::path::{Path, PathBuf};

use ndarray::Array2;

use crate::weights::ModelWeights;
use crate::detect::ModelError;
use crate::weights::ModelWeights;

/// Load model weights from a directory or file.
///
Expand Down Expand Up @@ -43,16 +43,17 @@ pub fn load_model_dir(path: impl AsRef<Path>) -> Result<ModelWeights, ModelError

if !gguf_files.is_empty() {
// Use the first (or largest) GGUF file
let gguf_path = gguf_files.into_iter()
let gguf_path = gguf_files
.into_iter()
.max_by_key(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
.unwrap();
return super::gguf::load_gguf(&gguf_path);
}

// Safetensors loading (also handles MLX format — same files, sometimes in weights/ subdir)
let arch = crate::detect_architecture(path)
.map_err(|e| ModelError::Parse(e.to_string()))?;
let arch = crate::detect_architecture(path).map_err(|e| ModelError::Parse(e.to_string()))?;
let prefixes = arch.key_prefixes_to_strip();
let mlx_affine_group_size = mlx_affine_group_size_from_config(path);

let mut st_files: Vec<PathBuf> = std::fs::read_dir(path)?
.filter_map(|e| e.ok())
Expand Down Expand Up @@ -97,7 +98,9 @@ pub fn load_model_dir(path: impl AsRef<Path>) -> Result<ModelWeights, ModelError
for (name, view) in st.tensors() {
let key = normalize_key(&name, prefixes);
let shape = view.shape();
if name.ends_with("_blocks") || name.ends_with("_scales") { continue; }
if name.ends_with("_blocks") || name.ends_with("_scales") {
continue;
}
let data = match tensor_to_f32(&view) {
Ok(d) => d,
Err(_) => continue,
Expand All @@ -108,28 +111,62 @@ pub fn load_model_dir(path: impl AsRef<Path>) -> Result<ModelWeights, ModelError
.map_err(|e| ModelError::Parse(e.to_string()))?;
tensors.insert(key, arr.into_shared());
}
1 => { vectors.insert(key, data); }
1 => {
vectors.insert(key, data);
}
_ => {}
}
}
} else {
// Standard float path
for (name, view) in st.tensors() {
if name.ends_with(".scales") || name.ends_with(".biases") {
continue;
}
let key = normalize_key(&name, prefixes);
let shape = view.shape();
let data = match tensor_to_f32(&view) {
Ok(d) => d,
Err(_) => continue,
let data = match view.dtype() {
safetensors::Dtype::U32 if shape.len() == 2 => {
match dequantize_mlx_affine_tensor(&st, &name, &view, mlx_affine_group_size)
{
Ok((data, _, cols)) => {
// Replace packed width with dequantized width.
let mut dequant_shape = shape.to_vec();
dequant_shape[1] = cols;
match dequant_shape.len() {
2 => {
let arr = Array2::from_shape_vec(
(dequant_shape[0], dequant_shape[1]),
data,
)
.map_err(|e| ModelError::Parse(e.to_string()))?;
tensors.insert(key, arr.into_shared());
continue;
}
_ => continue,
}
}
Err(_) => continue,
}
}
_ => match tensor_to_f32(&view) {
Ok(d) => d,
Err(_) => continue,
},
};
match shape.len() {
2 => {
let arr = Array2::from_shape_vec((shape[0], shape[1]), data)
.map_err(|e| ModelError::Parse(e.to_string()))?;
tensors.insert(key, arr.into_shared());
}
1 => { vectors.insert(key, data); }
1 => {
vectors.insert(key, data);
}
// 0D scalar tensors (e.g., layer_scalar) → store as 1-element vector
0 => { vectors.insert(key, data); }
0 => {
vectors.insert(key, data);
}
_ => {}
}
}
Expand Down Expand Up @@ -167,6 +204,19 @@ pub fn load_model_dir(path: impl AsRef<Path>) -> Result<ModelWeights, ModelError
})
}

/// Read the global MLX affine quantization group size from config.json.
pub fn mlx_affine_group_size_from_config(model_dir: &Path) -> Option<usize> {
let config_path = model_dir.join("config.json");
let text = std::fs::read_to_string(config_path).ok()?;
let json: serde_json::Value = serde_json::from_str(&text).ok()?;

json.get("quantization")
.or_else(|| json.get("quantization_config"))
.and_then(|q| q.get("group_size"))
.and_then(|v| v.as_u64())
.map(|v| v as usize)
}

/// Resolve a HuggingFace model ID or path to a local directory or GGUF file.
pub fn resolve_model_path(model: &str) -> Result<PathBuf, ModelError> {
let path = PathBuf::from(model);
Expand All @@ -191,11 +241,17 @@ pub fn resolve_model_path(model: &str) -> Result<PathBuf, ModelError> {
if let Ok(entries) = std::fs::read_dir(&hf_cache) {
for entry in entries.flatten() {
let p = entry.path();
if !p.is_dir() { continue; }
if !p.is_dir() {
continue;
}
// Prefer snapshot with safetensors files
let has_st = std::fs::read_dir(&p).ok().map(|rd| {
rd.flatten().any(|e| e.path().extension().is_some_and(|ext| ext == "safetensors"))
}).unwrap_or(false);
let has_st = std::fs::read_dir(&p)
.ok()
.map(|rd| {
rd.flatten()
.any(|e| e.path().extension().is_some_and(|ext| ext == "safetensors"))
})
.unwrap_or(false);
if has_st {
return Ok(p);
}
Expand Down Expand Up @@ -239,20 +295,26 @@ fn dequantize_mxfp4_experts(
) -> Result<(), ModelError> {
// Find all gate_up_proj_blocks tensors (one per layer)
for name in tensor_names {
if !name.ends_with(".gate_up_proj_blocks") { continue; }
if !name.ends_with(".gate_up_proj_blocks") {
continue;
}

let scales_name = name.replace("_blocks", "_scales");
let down_blocks_name = name.replace("gate_up_proj_blocks", "down_proj_blocks");
let down_scales_name = name.replace("gate_up_proj_blocks", "down_proj_scales");

// Get tensor views
let blocks_view = st.tensor(name)
let blocks_view = st
.tensor(name)
.map_err(|e| ModelError::Parse(format!("MXFP4 blocks: {e}")))?;
let scales_view = st.tensor(&scales_name)
let scales_view = st
.tensor(&scales_name)
.map_err(|e| ModelError::Parse(format!("MXFP4 scales: {e}")))?;

let shape = blocks_view.shape();
if shape.len() != 4 { continue; }
if shape.len() != 4 {
continue;
}

let num_experts = shape[0];
let out_features = shape[1]; // 2*hidden for gate_up, hidden for down
Expand All @@ -262,8 +324,11 @@ fn dequantize_mxfp4_experts(

// Dequantize gate_up (fused: first half = gate, second half = up)
let expert_data = crate::quant::mxfp4::dequantize_all_experts(
blocks_view.data(), scales_view.data(),
num_experts, out_features, groups,
blocks_view.data(),
scales_view.data(),
num_experts,
out_features,
groups,
);

// Extract layer number from key
Expand All @@ -280,12 +345,18 @@ fn dequantize_mxfp4_experts(
let gate_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w1.weight");
let up_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w3.weight");

tensors.insert(gate_key,
tensors.insert(
gate_key,
Array2::from_shape_vec((half, in_features), gate_data)
.map_err(|e| ModelError::Parse(e.to_string()))?.into_shared());
tensors.insert(up_key,
.map_err(|e| ModelError::Parse(e.to_string()))?
.into_shared(),
);
tensors.insert(
up_key,
Array2::from_shape_vec((half, in_features), up_data)
.map_err(|e| ModelError::Parse(e.to_string()))?.into_shared());
.map_err(|e| ModelError::Parse(e.to_string()))?
.into_shared(),
);
}

// Dequantize down projection
Expand All @@ -297,14 +368,21 @@ fn dequantize_mxfp4_experts(
let down_in = down_groups * 32;

let down_experts = crate::quant::mxfp4::dequantize_all_experts(
db.data(), ds.data(), num_experts, down_out, down_groups,
db.data(),
ds.data(),
num_experts,
down_out,
down_groups,
);

for (e, data) in down_experts.iter().enumerate() {
let down_key = format!("{layer_prefix}.block_sparse_moe.experts.{e}.w2.weight");
tensors.insert(down_key,
tensors.insert(
down_key,
Array2::from_shape_vec((down_out, down_in), data.clone())
.map_err(|e| ModelError::Parse(e.to_string()))?.into_shared());
.map_err(|e| ModelError::Parse(e.to_string()))?
.into_shared(),
);
}
}
}
Expand All @@ -316,9 +394,12 @@ fn dequantize_mxfp4_experts(
let s = router_view.shape();
if s.len() == 2 {
let router_key = format!("{layer_prefix}.block_sparse_moe.gate.weight");
tensors.insert(router_key,
tensors.insert(
router_key,
Array2::from_shape_vec((s[0], s[1]), data)
.map_err(|e| ModelError::Parse(e.to_string()))?.into_shared());
.map_err(|e| ModelError::Parse(e.to_string()))?
.into_shared(),
);
}
}
}
Expand All @@ -336,6 +417,53 @@ fn normalize_key(key: &str, prefixes: &[&str]) -> String {
key.to_string()
}

fn dequantize_mlx_affine_tensor(
st: &safetensors::SafeTensors,
name: &str,
view: &safetensors::tensor::TensorView<'_>,
group_size: Option<usize>,
) -> Result<(Vec<f32>, usize, usize), ModelError> {
let group_size = group_size.ok_or_else(|| {
ModelError::Parse(format!(
"missing MLX affine group_size for quantized tensor: {name}"
))
})?;

let shape = view.shape();
if shape.len() != 2 {
return Err(ModelError::UnsupportedDtype(format!("{:?}", view.dtype())));
}

let stem = name
.strip_suffix(".weight")
.ok_or_else(|| ModelError::UnsupportedDtype(format!("{:?}", view.dtype())))?;
let scales_name = format!("{stem}.scales");
let biases_name = format!("{stem}.biases");

let scales_view = st
.tensor(&scales_name)
.map_err(|e| ModelError::Parse(format!("MLX affine scales {scales_name}: {e}")))?;
let biases_view = st.tensor(&biases_name).ok();

let scales = tensor_to_f32(&scales_view)?;
let biases = match biases_view {
Some(biases_view) => Some(tensor_to_f32(&biases_view)?),
None => None,
};

let (data, cols) = crate::quant::mlx_affine::dequantize_u32_matrix_bytes(
view.data(),
shape[0],
shape[1],
&scales,
biases.as_deref(),
group_size,
)
.map_err(ModelError::Parse)?;

Ok((data, shape[0], cols))
}

fn tensor_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result<Vec<f32>, ModelError> {
use crate::quant::half;
match view.dtype() {
Expand Down
Loading