diff --git a/crates/embed/src/model.rs b/crates/embed/src/model.rs index cfc2f4bb..4d38dc90 100644 --- a/crates/embed/src/model.rs +++ b/crates/embed/src/model.rs @@ -312,6 +312,39 @@ impl EmbeddingModel { ) } + /// **Stable**: the pooling strategy this model expects from BERT-family inference. + /// + /// BGE v1.5 models use CLS-token pooling (first token) as documented on their + /// HuggingFace model cards (`model_output[0][:, 0]`). All other BERT-family + /// models (E5, MiniLM) use masked mean pooling. + /// + /// Returns `None` for non-BERT models (Qwen3, OpenAI remote) which have their + /// own pooling paths. + /// + /// Only available when the `native` feature is enabled (requires `lattice-inference`). + #[cfg(feature = "native")] + #[inline] + pub const fn bert_pooling(&self) -> Option { + match self { + // BGE v1.5 — CLS pooling per model card + EmbeddingModel::BgeSmallEnV15 + | EmbeddingModel::BgeBaseEnV15 + | EmbeddingModel::BgeLargeEnV15 => Some(lattice_inference::BertPooling::CLS), + // E5 multilingual — masked mean pooling per model card + EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => { + Some(lattice_inference::BertPooling::Mean) + } + // MiniLM family — masked mean pooling per sentence-transformers convention + EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => { + Some(lattice_inference::BertPooling::Mean) + } + // Qwen and remote models — not BERT-family, pooling handled separately + EmbeddingModel::Qwen3Embedding0_6B + | EmbeddingModel::Qwen3Embedding4B + | EmbeddingModel::TextEmbedding3Small => None, + } + } + /// **Stable**: embedding key revision string for this model family. #[inline] pub const fn key_version(&self) -> &'static str { @@ -689,4 +722,99 @@ mod tests { assert!(result.is_err()); assert!(result.unwrap_err().contains("unknown embedding model")); } + + // ------------------------------------------------------------------------- + // bert_pooling() routing tests (P1-E3) — require `native` feature + // ------------------------------------------------------------------------- + + /// BGE small/base/large must use CLS pooling per their HF model cards. + #[cfg(feature = "native")] + #[test] + fn test_bge_models_use_cls_pooling() { + use lattice_inference::BertPooling; + + assert_eq!( + EmbeddingModel::BgeSmallEnV15.bert_pooling(), + Some(BertPooling::CLS), + "BgeSmallEnV15 must use CLS pooling" + ); + assert_eq!( + EmbeddingModel::BgeBaseEnV15.bert_pooling(), + Some(BertPooling::CLS), + "BgeBaseEnV15 must use CLS pooling" + ); + assert_eq!( + EmbeddingModel::BgeLargeEnV15.bert_pooling(), + Some(BertPooling::CLS), + "BgeLargeEnV15 must use CLS pooling" + ); + } + + /// E5 models must use mean pooling per their HF model cards. + #[cfg(feature = "native")] + #[test] + fn test_e5_models_use_mean_pooling() { + use lattice_inference::BertPooling; + + assert_eq!( + EmbeddingModel::MultilingualE5Small.bert_pooling(), + Some(BertPooling::Mean), + "MultilingualE5Small must use mean pooling" + ); + assert_eq!( + EmbeddingModel::MultilingualE5Base.bert_pooling(), + Some(BertPooling::Mean), + "MultilingualE5Base must use mean pooling" + ); + } + + /// MiniLM models must use mean pooling per sentence-transformers convention. + #[cfg(feature = "native")] + #[test] + fn test_minilm_models_use_mean_pooling() { + use lattice_inference::BertPooling; + + assert_eq!( + EmbeddingModel::AllMiniLmL6V2.bert_pooling(), + Some(BertPooling::Mean), + "AllMiniLmL6V2 must use mean pooling" + ); + assert_eq!( + EmbeddingModel::ParaphraseMultilingualMiniLmL12V2.bert_pooling(), + Some(BertPooling::Mean), + "ParaphraseMultilingualMiniLmL12V2 must use mean pooling" + ); + } + + /// Qwen and remote models return None — they have separate pooling paths. + #[cfg(feature = "native")] + #[test] + fn test_non_bert_models_return_none_pooling() { + assert_eq!( + EmbeddingModel::Qwen3Embedding0_6B.bert_pooling(), + None, + "Qwen model must return None for bert_pooling()" + ); + assert_eq!( + EmbeddingModel::Qwen3Embedding4B.bert_pooling(), + None, + "Qwen model must return None for bert_pooling()" + ); + assert_eq!( + EmbeddingModel::TextEmbedding3Small.bert_pooling(), + None, + "Remote model must return None for bert_pooling()" + ); + } + + /// BGE and E5 use DIFFERENT pooling strategies — this is the key correctness distinction. + #[cfg(feature = "native")] + #[test] + fn test_bge_and_e5_use_different_pooling() { + assert_ne!( + EmbeddingModel::BgeSmallEnV15.bert_pooling(), + EmbeddingModel::MultilingualE5Small.bert_pooling(), + "BGE and E5 must use different pooling strategies" + ); + } } diff --git a/crates/embed/src/service/native.rs b/crates/embed/src/service/native.rs index e96a4e6b..f8608c45 100644 --- a/crates/embed/src/service/native.rs +++ b/crates/embed/src/service/native.rs @@ -212,7 +212,12 @@ fn load_model_sync(model_config: ModelConfig) -> std::result::Result unreachable!(), }; info!(model = model_name, "loading native BERT embedding model"); - let bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?; + let mut bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?; + // Route each model family through its correct pooling strategy. + // BGE uses CLS pooling; E5 and MiniLM use mean pooling. + if let Some(pooling) = model_config.model.bert_pooling() { + bert.set_pooling(pooling); + } Ok(LoadedModel::Bert(Arc::new(bert))) } EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => { diff --git a/crates/inference/src/lib.rs b/crates/inference/src/lib.rs index 2ef3e283..c91f5e08 100644 --- a/crates/inference/src/lib.rs +++ b/crates/inference/src/lib.rs @@ -68,6 +68,7 @@ pub use crate::error::InferenceError; pub use crate::model::{ BertConfig, BertModel, CrossEncoderModel, LayerTimings, ProfileTimings, QwenConfig, QwenModel, }; +pub use crate::pool::BertPooling; pub use crate::tokenizer::{ BpeTokenizer, SentencePieceTokenizer, TokenizedInput, Tokenizer, WordPieceTokenizer, load_tokenizer, diff --git a/crates/inference/src/model/bert.rs b/crates/inference/src/model/bert.rs index 777e6079..ba109375 100644 --- a/crates/inference/src/model/bert.rs +++ b/crates/inference/src/model/bert.rs @@ -9,7 +9,7 @@ use crate::download::ensure_model_files; use crate::error::InferenceError; use crate::forward::cpu::{add_bias, gelu, layer_norm, matmul_bt}; use crate::lora_hook::{LoraHook, NoopLoraHook}; -use crate::pool::{l2_normalize, mean_pool}; +use crate::pool::{BertPooling, cls_pool, l2_normalize, mean_pool}; use crate::tokenizer::common::{Tokenizer, load_tokenizer}; use crate::weights::{BertWeights, SafetensorsFile}; use std::fs; @@ -113,6 +113,9 @@ pub struct BertModel { // See the struct-level doc comment for the full safety argument. weights: BertWeights<'static>, _safetensors: Box, + /// Pooling strategy used to reduce hidden states to a single embedding vector. + /// Defaults to `BertPooling::Mean` for backwards compatibility. + pooling: BertPooling, } impl BertModel { @@ -166,6 +169,7 @@ impl BertModel { tokenizer, weights, _safetensors: safetensors, + pooling: BertPooling::default(), }) } @@ -191,6 +195,19 @@ impl BertModel { self.config.hidden_size } + /// **Stable** (provisional): set the pooling strategy. + /// + /// Must be called before any encoding. The `NativeEmbeddingService` uses this to + /// route BGE models through CLS pooling and E5/MiniLM through mean pooling. + pub fn set_pooling(&mut self, pooling: BertPooling) { + self.pooling = pooling; + } + + /// **Unstable**: pooling strategy accessor for testing. + pub fn pooling(&self) -> BertPooling { + self.pooling + } + /// **Stable**: single-text encoding entry point; consumed by `lattice-embed`. pub fn encode(&self, text: &str) -> Result, InferenceError> { let input = self.tokenizer.tokenize(text); @@ -210,12 +227,7 @@ impl BertModel { &mut buffers, ); - let mut pooled = mean_pool( - &hidden_states, - &input.attention_mask[..seq_len], - seq_len, - self.config.hidden_size, - ); + let mut pooled = self.pool(&hidden_states, &input.attention_mask[..seq_len], seq_len); l2_normalize(&mut pooled); Ok(pooled) } @@ -250,12 +262,7 @@ impl BertModel { seq_len, &mut buffers, ); - let mut pooled = mean_pool( - &hidden_states, - &input.attention_mask[..seq_len], - seq_len, - self.config.hidden_size, - ); + let mut pooled = self.pool(&hidden_states, &input.attention_mask[..seq_len], seq_len); l2_normalize(&mut pooled); outputs.push(pooled); } @@ -263,6 +270,22 @@ impl BertModel { Ok(outputs) } + /// Apply the configured pooling strategy to `hidden_states`. + /// + /// Both `encode` and `encode_batch` delegate here so the pooling branch + /// is in one place. L2 normalization is applied by the caller. + fn pool(&self, hidden_states: &[f32], attention_mask: &[u32], seq_len: usize) -> Vec { + match self.pooling { + BertPooling::Mean => mean_pool( + hidden_states, + attention_mask, + seq_len, + self.config.hidden_size, + ), + BertPooling::CLS => cls_pool(hidden_states, seq_len, self.config.hidden_size), + } + } + /// Forward pass for a pre-tokenized input; used by `CrossEncoderModel`. pub(crate) fn forward_tokenized( &self, @@ -567,6 +590,7 @@ fn infer_num_attention_heads(hidden_size: usize) -> Result()).sqrt(); assert_relative_eq!(norm, 1.0, epsilon = 1e-4); } + + // ------------------------------------------------------------------------- + // Deterministic pooling tests (P1-E3) + // + // These tests use fixed hidden-state tensors — no model weights needed. + // They validate the pooling routing at the kernel level: CLS extracts + // position 0, mean computes an attention-mask-weighted average, and L2 + // normalisation produces a unit vector in both cases. + // ------------------------------------------------------------------------- + + /// Fixed 2-token, 4-dim hidden-state tensor. + /// + /// Token 0 (CLS): [1.0, 0.0, 0.0, 0.0] + /// Token 1 (word): [0.0, 1.0, 0.0, 0.0] + /// Both tokens are real (attention_mask = [1, 1]). + fn hidden_2x4() -> (Vec, Vec) { + let hidden = vec![ + 1.0_f32, 0.0, 0.0, 0.0, // token 0 (CLS) + 0.0_f32, 1.0, 0.0, 0.0, // token 1 (word) + ]; + let mask = vec![1_u32, 1]; + (hidden, mask) + } + + /// CLS pooling returns the first-token hidden state ([1,0,0,0]), then L2 normalises. + /// + /// The CLS token is already unit-length here, so after L2 it stays [1,0,0,0]. + /// This matches the BGE model-card recipe: `model_output[0][:, 0]` + L2. + #[test] + fn test_cls_pool_extracts_first_token_and_l2_unit_norm() { + let (hidden, _mask) = hidden_2x4(); + let seq_len = 2; + let hidden_size = 4; + + let mut pooled = cls_pool(&hidden, seq_len, hidden_size); + + // Before L2: should be the CLS row [1,0,0,0]. + assert_eq!( + pooled, + vec![1.0, 0.0, 0.0, 0.0], + "CLS row mismatch before L2" + ); + + l2_normalize(&mut pooled); + + // CLS row is already unit-length → unchanged. + let norm: f32 = pooled.iter().map(|x| x * x).sum::().sqrt(); + assert_relative_eq!(norm, 1.0, epsilon = 1e-6); + assert_relative_eq!(pooled[0], 1.0, epsilon = 1e-6); + assert_relative_eq!(pooled[1], 0.0, epsilon = 1e-6); + } + + /// Mean pooling with uniform mask averages all tokens, then L2 normalises. + /// + /// With hidden = [[1,0,0,0],[0,1,0,0]] and mask [1,1], + /// mean = [0.5, 0.5, 0, 0]. After L2: [1/√2, 1/√2, 0, 0] ≈ [0.7071, 0.7071, 0, 0]. + /// + /// This matches the E5/MiniLM model-card recipe: masked mean pooling + L2. + #[test] + fn test_mean_pool_averages_masked_tokens_and_l2_unit_norm() { + let (hidden, mask) = hidden_2x4(); + let seq_len = 2; + let hidden_size = 4; + + let mut pooled = mean_pool(&hidden, &mask, seq_len, hidden_size); + + // Before L2: mean of [1,0,0,0] and [0,1,0,0] = [0.5, 0.5, 0, 0]. + assert_relative_eq!(pooled[0], 0.5, epsilon = 1e-6); + assert_relative_eq!(pooled[1], 0.5, epsilon = 1e-6); + assert_relative_eq!(pooled[2], 0.0, epsilon = 1e-6); + assert_relative_eq!(pooled[3], 0.0, epsilon = 1e-6); + + l2_normalize(&mut pooled); + + let norm: f32 = pooled.iter().map(|x| x * x).sum::().sqrt(); + assert_relative_eq!(norm, 1.0, epsilon = 1e-6); + + // L2 of [0.5, 0.5, 0, 0]: magnitude = √0.5, so normalised = [1/√2, 1/√2, 0, 0]. + let inv_sqrt2 = std::f32::consts::FRAC_1_SQRT_2; + assert_relative_eq!(pooled[0], inv_sqrt2, epsilon = 1e-5); + assert_relative_eq!(pooled[1], inv_sqrt2, epsilon = 1e-5); + } + + /// CLS and mean pooling of the same hidden states produce DIFFERENT vectors. + /// + /// This is the key correctness guarantee for P1-E3: using the wrong pooling + /// strategy for a model produces a meaningfully different embedding. + #[test] + fn test_cls_and_mean_produce_different_embeddings() { + let (hidden, mask) = hidden_2x4(); + let seq_len = 2; + let hidden_size = 4; + + let mut cls = cls_pool(&hidden, seq_len, hidden_size); + let mut mean = mean_pool(&hidden, &mask, seq_len, hidden_size); + + l2_normalize(&mut cls); + l2_normalize(&mut mean); + + // CLS = [1, 0, 0, 0], mean = [1/√2, 1/√2, 0, 0] — these differ. + assert_ne!( + cls, mean, + "CLS and mean pooling must produce different unit vectors" + ); + } + + /// Mean pooling with a padding mask ignores masked positions. + /// + /// With hidden = [[1,0,0,0],[0,1,0,0]] and mask [1, 0], + /// only token 0 contributes: mean = [1, 0, 0, 0]. + #[test] + fn test_mean_pool_respects_padding_mask() { + let hidden = vec![ + 1.0_f32, 0.0, 0.0, 0.0, // token 0 (real) + 0.0_f32, 1.0, 0.0, 0.0, // token 1 (pad, mask=0) + ]; + let mask = vec![1_u32, 0]; // second token is padding + let seq_len = 2; + let hidden_size = 4; + + let pooled = mean_pool(&hidden, &mask, seq_len, hidden_size); + + // Only token 0 is unmasked → mean = [1,0,0,0]. + assert_relative_eq!(pooled[0], 1.0, epsilon = 1e-6); + assert_relative_eq!(pooled[1], 0.0, epsilon = 1e-6); + } } /// Compile-time guard for the struct field drop-order invariant. diff --git a/crates/inference/src/pool.rs b/crates/inference/src/pool.rs index 756436dd..d585bab9 100644 --- a/crates/inference/src/pool.rs +++ b/crates/inference/src/pool.rs @@ -1,5 +1,34 @@ use crate::forward::cpu::simd_config; +/// **Stable** (provisional): pooling strategy for BERT-family encoder models. +/// +/// Selects how the per-token hidden states produced by the final transformer layer +/// are reduced to a single fixed-size embedding vector. +/// +/// | Family | Strategy | Reference | +/// |--------|----------|-----------| +/// | BGE small/base/large-en-v1.5 | `CLS` | HF card: `model_output[0][:, 0]` + L2 | +/// | E5 multilingual small/base | `Mean` | HF card: masked average pooling + L2 | +/// | all-MiniLM-L6-v2 | `Mean` | sentence-transformers: mean pooling | +/// | paraphrase-multilingual-MiniLM-L12-v2 | `Mean` | sentence-transformers: mean pooling | +/// +/// Qwen3-Embedding models use `last_token_pool` (decoder-style) and do not use +/// this enum — they have a separate inference path in `QwenModel`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum BertPooling { + /// Masked mean pooling (attention-mask-weighted average over all real tokens). + /// + /// Correct for E5 and MiniLM families. This is the legacy default; existing + /// BGE users who loaded models without explicit pooling selection used this path. + #[default] + Mean, + /// CLS token pooling (hidden state at position 0). + /// + /// Correct for BGE v1.5 families. The BGE model cards document their inference + /// recipe as `model_output[0][:, 0]` (CLS hidden state) followed by L2 normalization. + CLS, +} + /// **Unstable**: internal pooling kernel; SIMD dispatch details may change. /// /// Dispatches to SIMD (NEON/AVX2) when available, falls back to scalar.