Skip to content

feat(embed): route BGE through CLS pooling, keep E5/MiniLM on mean (P1-E3)#113

Closed
ohdearquant wants to merge 1 commit into
pr-embedperf-08-role-aware-promptsfrom
pr-embedperf-09-bge-cls-pooling
Closed

feat(embed): route BGE through CLS pooling, keep E5/MiniLM on mean (P1-E3)#113
ohdearquant wants to merge 1 commit into
pr-embedperf-08-role-aware-promptsfrom
pr-embedperf-09-bge-cls-pooling

Conversation

@ohdearquant
Copy link
Copy Markdown
Owner

Layer

L2 — embed pooling routing (PR9 of 11)

What

  • New BertPooling enum (Mean | CLS) in crates/inference/src/pool.rs, re-exported via inference/src/lib.rs.
  • BertModel gains pooling: BertPooling field (default Mean) with set_pooling() / pooling() accessors. Single pool() dispatch used by both encode() and encode_batch().
  • bert_pooling() method on EmbeddingModel (feature-gated #[cfg(feature = "native")]): BGE small/base/large → CLS; E5 / MiniLM → Mean; Qwen / remote → None (already routed via QwenModel::last_token).
  • NativeEmbeddingService::load_model_sync() calls bert.set_pooling(pooling) after loading each BERT-family model.

Why

BGE model card recipe specifies first-token (CLS) + L2 pooling. Pre-PR, lattice used mean pool for all BERT-family models including BGE, producing the wrong sentence vector.

Result

  • 4 pooling kernel tests (bert.rs): CLS extracts position-0 + L2, Mean averages masked tokens + L2, CLS≠Mean for same input, Mean respects padding mask
  • 5 bert_pooling() routing tests (model.rs): BGE→CLS, E5→Mean, MiniLM→Mean, Qwen/remote→None, BGE≠E5
  • L2 normalization stays post-pool for all paths

Stack

Base: #112 (PR8 role-aware prompts)
Umbrella: #104

🤖 Generated with Claude Code

…1-E3)

Add BertPooling enum (Mean | CLS) to pool.rs and re-export via inference
lib. Add bert_pooling() method on EmbeddingModel (feature-gated on
"native") that returns CLS for BGE v1.5 small/base/large, Mean for E5
multilingual and MiniLM family, and None for Qwen3/remote models.
Update load_model_sync in NativeEmbeddingService to call set_pooling()
on every BERT model after loading so BGE flows through CLS pooling.
L2 normalization stays post-pool for all paths.

Add deterministic pooling unit tests using fixed 2x4 hidden-state
tensors in bert.rs: CLS extracts position-0 + L2 produces unit vector;
mean averages masked tokens + L2 produces unit vector; CLS and mean
produce distinct embeddings for the same input (key correctness check).
Add bert_pooling() routing tests in model.rs confirming all model
families map to the correct strategy.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@ohdearquant
Copy link
Copy Markdown
Owner Author

Subsumed by #104 merge (umbrella PR brought all 11 PRs' content to main in one merge commit after stacked-PR base branches collapsed). Codex round-1 findings tracked in #116. Closing as superseded.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant