Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions crates/embed/tests/embed_parity_vs_hf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,158 @@ async fn e5_small_parity_vs_hf() {
);
}

// ---------------------------------------------------------------------------
// all-MiniLM-L6-v2 parity test
// ---------------------------------------------------------------------------

#[tokio::test]
async fn all_minilm_l6_v2_parity_vs_hf() {
let Some(goldens) = load_fixture("all_minilm_l6_v2.json") else {
return;
};

let model_dir =
PathBuf::from(std::env::var("HOME").unwrap()).join(".lattice/models/all-minilm-l6-v2");
if !model_dir.join("model.safetensors").exists() {
eprintln!(
"SKIP all_minilm_l6_v2_parity_vs_hf: model weights not found at {}",
model_dir.display()
);
return;
}

let model = EmbeddingModel::AllMiniLmL6V2;
let service = NativeEmbeddingService::with_model(model);

let mut failures = 0;
let mut min_cos = 1.0_f64;
let mut max_diff = 0.0_f64;

for golden in &goldens {
// Golden was generated without any prompt prefix (MiniLM has none).
assert_eq!(
golden.prompt_prefix, "",
"all-MiniLM-L6-v2 golden must have empty prompt_prefix; got {:?}",
golden.prompt_prefix
);
let lattice_vec = embed_text(&service, &golden.input, model).await;

assert_eq!(
lattice_vec.len(),
golden.embedding_dim,
"all-MiniLM-L6-v2 dimension mismatch: got {}, want {}",
lattice_vec.len(),
golden.embedding_dim
);

let cos = cosine_sim(&lattice_vec, &golden.embedding);
let diff = max_abs_diff(&lattice_vec, &golden.embedding);
min_cos = min_cos.min(cos);
max_diff = max_diff.max(diff);

if cos < COS_SIM_MIN_F32 {
failures += 1;
eprintln!(
"PARITY FAIL [all-minilm-l6] input={:?}\n cosine={:.6} (need ≥ {COS_SIM_MIN_F32})\n max_abs_diff={diff:.2e}\n pooling={}, prompt_prefix={}",
golden.input, cos, golden.pooling, golden.prompt_prefix,
);
} else {
println!(
" [all-minilm-l6] '{:.40}' cosine={:.6} max_diff={:.2e}",
golden.input, cos, diff
);
}
}

println!(
"[all-minilm-l6] aggregate: min_cosine={min_cos:.6} max_abs_diff={max_diff:.2e} failures={failures}/{}",
goldens.len()
);

assert_eq!(
failures,
0,
"[all-minilm-l6] {failures}/{} parity checks failed — see stderr",
goldens.len()
);
}

// ---------------------------------------------------------------------------
// paraphrase-multilingual-MiniLM-L12-v2 parity test
// ---------------------------------------------------------------------------

#[tokio::test]
async fn paraphrase_multilingual_minilm_l12_v2_parity_vs_hf() {
let Some(goldens) = load_fixture("paraphrase_multilingual_minilm_l12_v2.json") else {
return;
};

let model_dir = PathBuf::from(std::env::var("HOME").unwrap())
.join(".lattice/models/paraphrase-multilingual-minilm-l12-v2");
if !model_dir.join("model.safetensors").exists() {
eprintln!(
"SKIP paraphrase_multilingual_minilm_l12_v2_parity_vs_hf: model weights not found at {}",
model_dir.display()
);
return;
}

let model = EmbeddingModel::ParaphraseMultilingualMiniLmL12V2;
let service = NativeEmbeddingService::with_model(model);

let mut failures = 0;
let mut min_cos = 1.0_f64;
let mut max_diff = 0.0_f64;

for golden in &goldens {
// Golden was generated without any prompt prefix (paraphrase-multilingual has none).
assert_eq!(
golden.prompt_prefix, "",
"paraphrase-multilingual golden must have empty prompt_prefix; got {:?}",
golden.prompt_prefix
);
let lattice_vec = embed_text(&service, &golden.input, model).await;

assert_eq!(
lattice_vec.len(),
golden.embedding_dim,
"paraphrase-multilingual dimension mismatch: got {}, want {}",
lattice_vec.len(),
golden.embedding_dim
);

let cos = cosine_sim(&lattice_vec, &golden.embedding);
let diff = max_abs_diff(&lattice_vec, &golden.embedding);
min_cos = min_cos.min(cos);
max_diff = max_diff.max(diff);

if cos < COS_SIM_MIN_F32 {
failures += 1;
eprintln!(
"PARITY FAIL [paraphrase-multilingual-minilm-l12] input={:?}\n cosine={:.6} (need ≥ {COS_SIM_MIN_F32})\n max_abs_diff={diff:.2e}\n pooling={}, prompt_prefix={}",
golden.input, cos, golden.pooling, golden.prompt_prefix,
);
} else {
println!(
" [paraphrase-multilingual-minilm-l12] '{:.40}' cosine={:.6} max_diff={:.2e}",
golden.input, cos, diff
);
}
}

println!(
"[paraphrase-multilingual-minilm-l12] aggregate: min_cosine={min_cos:.6} max_abs_diff={max_diff:.2e} failures={failures}/{}",
goldens.len()
);

assert_eq!(
failures,
0,
"[paraphrase-multilingual-minilm-l12] {failures}/{} parity checks failed — see stderr",
goldens.len()
);
}

// ---------------------------------------------------------------------------
// Qwen3-Embedding-0.6B parity test
// ---------------------------------------------------------------------------
Expand Down
Loading