Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions crates/embed/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,39 @@ impl EmbeddingModel {
)
}

/// **Stable**: the pooling strategy this model expects from BERT-family inference.
///
/// BGE v1.5 models use CLS-token pooling (first token) as documented on their
/// HuggingFace model cards (`model_output[0][:, 0]`). All other BERT-family
/// models (E5, MiniLM) use masked mean pooling.
///
/// Returns `None` for non-BERT models (Qwen3, OpenAI remote) which have their
/// own pooling paths.
///
/// Only available when the `native` feature is enabled (requires `lattice-inference`).
#[cfg(feature = "native")]
#[inline]
pub const fn bert_pooling(&self) -> Option<lattice_inference::BertPooling> {
match self {
// BGE v1.5 — CLS pooling per model card
EmbeddingModel::BgeSmallEnV15
| EmbeddingModel::BgeBaseEnV15
| EmbeddingModel::BgeLargeEnV15 => Some(lattice_inference::BertPooling::CLS),
// E5 multilingual — masked mean pooling per model card
EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
Some(lattice_inference::BertPooling::Mean)
}
// MiniLM family — masked mean pooling per sentence-transformers convention
EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
Some(lattice_inference::BertPooling::Mean)
}
// Qwen and remote models — not BERT-family, pooling handled separately
EmbeddingModel::Qwen3Embedding0_6B
| EmbeddingModel::Qwen3Embedding4B
| EmbeddingModel::TextEmbedding3Small => None,
}
}

/// **Stable**: embedding key revision string for this model family.
#[inline]
pub const fn key_version(&self) -> &'static str {
Expand Down Expand Up @@ -689,4 +722,99 @@ mod tests {
assert!(result.is_err());
assert!(result.unwrap_err().contains("unknown embedding model"));
}

// -------------------------------------------------------------------------
// bert_pooling() routing tests (P1-E3) — require `native` feature
// -------------------------------------------------------------------------

/// BGE small/base/large must use CLS pooling per their HF model cards.
#[cfg(feature = "native")]
#[test]
fn test_bge_models_use_cls_pooling() {
use lattice_inference::BertPooling;

assert_eq!(
EmbeddingModel::BgeSmallEnV15.bert_pooling(),
Some(BertPooling::CLS),
"BgeSmallEnV15 must use CLS pooling"
);
assert_eq!(
EmbeddingModel::BgeBaseEnV15.bert_pooling(),
Some(BertPooling::CLS),
"BgeBaseEnV15 must use CLS pooling"
);
assert_eq!(
EmbeddingModel::BgeLargeEnV15.bert_pooling(),
Some(BertPooling::CLS),
"BgeLargeEnV15 must use CLS pooling"
);
}

/// E5 models must use mean pooling per their HF model cards.
#[cfg(feature = "native")]
#[test]
fn test_e5_models_use_mean_pooling() {
use lattice_inference::BertPooling;

assert_eq!(
EmbeddingModel::MultilingualE5Small.bert_pooling(),
Some(BertPooling::Mean),
"MultilingualE5Small must use mean pooling"
);
assert_eq!(
EmbeddingModel::MultilingualE5Base.bert_pooling(),
Some(BertPooling::Mean),
"MultilingualE5Base must use mean pooling"
);
}

/// MiniLM models must use mean pooling per sentence-transformers convention.
#[cfg(feature = "native")]
#[test]
fn test_minilm_models_use_mean_pooling() {
use lattice_inference::BertPooling;

assert_eq!(
EmbeddingModel::AllMiniLmL6V2.bert_pooling(),
Some(BertPooling::Mean),
"AllMiniLmL6V2 must use mean pooling"
);
assert_eq!(
EmbeddingModel::ParaphraseMultilingualMiniLmL12V2.bert_pooling(),
Some(BertPooling::Mean),
"ParaphraseMultilingualMiniLmL12V2 must use mean pooling"
);
}

/// Qwen and remote models return None — they have separate pooling paths.
#[cfg(feature = "native")]
#[test]
fn test_non_bert_models_return_none_pooling() {
assert_eq!(
EmbeddingModel::Qwen3Embedding0_6B.bert_pooling(),
None,
"Qwen model must return None for bert_pooling()"
);
assert_eq!(
EmbeddingModel::Qwen3Embedding4B.bert_pooling(),
None,
"Qwen model must return None for bert_pooling()"
);
assert_eq!(
EmbeddingModel::TextEmbedding3Small.bert_pooling(),
None,
"Remote model must return None for bert_pooling()"
);
}

/// BGE and E5 use DIFFERENT pooling strategies — this is the key correctness distinction.
#[cfg(feature = "native")]
#[test]
fn test_bge_and_e5_use_different_pooling() {
assert_ne!(
EmbeddingModel::BgeSmallEnV15.bert_pooling(),
EmbeddingModel::MultilingualE5Small.bert_pooling(),
"BGE and E5 must use different pooling strategies"
);
}
}
7 changes: 6 additions & 1 deletion crates/embed/src/service/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ fn load_model_sync(model_config: ModelConfig) -> std::result::Result<LoadedModel
_ => unreachable!(),
};
info!(model = model_name, "loading native BERT embedding model");
let bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
let mut bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
// Route each model family through its correct pooling strategy.
// BGE uses CLS pooling; E5 and MiniLM use mean pooling.
if let Some(pooling) = model_config.model.bert_pooling() {
bert.set_pooling(pooling);
}
Ok(LoadedModel::Bert(Arc::new(bert)))
}
EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => {
Expand Down
1 change: 1 addition & 0 deletions crates/inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub use crate::error::InferenceError;
pub use crate::model::{
BertConfig, BertModel, CrossEncoderModel, LayerTimings, ProfileTimings, QwenConfig, QwenModel,
};
pub use crate::pool::BertPooling;
pub use crate::tokenizer::{
BpeTokenizer, SentencePieceTokenizer, TokenizedInput, Tokenizer, WordPieceTokenizer,
load_tokenizer,
Expand Down
Loading