From 44150ce491b7d580c5972ffc80c1a00694d4fb5f Mon Sep 17 00:00:00 2001 From: OceanLi <122793010+ohdearquant@users.noreply.github.com> Date: Mon, 25 May 2026 14:43:02 -0400 Subject: [PATCH] feat(embed): role-aware prompts + cache key distinguishment (P0-E2) Fix document_instruction() to return "passage: " for E5 multilingual variants (was unconditionally None). Add EmbeddingRole enum (Query, Passage, Generic) and embed_query/embed_passage trait methods that apply model-specific prompt prefixes before forwarding. Extend CacheKey hash inputs with role.cache_tag() so query and passage embeddings of the same raw text are stored as separate cache entries. CachedEmbeddingService overrides both role-aware methods with prompt-application + role-keyed cache logic. Existing embed() uses Generic role for backwards compat. Co-Authored-By: Claude Opus 4.7 --- crates/embed/src/cache.rs | 185 +++++++++++++++++++++++++---- crates/embed/src/lib.rs | 4 +- crates/embed/src/model.rs | 15 ++- crates/embed/src/service/cached.rs | 63 ++++++++-- crates/embed/src/service/mod.rs | 83 +++++++++++++ crates/embed/src/service/tests.rs | 88 ++++++++++++++ 6 files changed, 401 insertions(+), 37 deletions(-) diff --git a/crates/embed/src/cache.rs b/crates/embed/src/cache.rs index 13ebaf24..ddec96e3 100644 --- a/crates/embed/src/cache.rs +++ b/crates/embed/src/cache.rs @@ -17,6 +17,7 @@ //! no-ops with no locking or hashing work. use crate::model::ModelConfig; +use crate::service::EmbeddingRole; use lru::LruCache; use parking_lot::RwLock; use std::num::NonZeroUsize; @@ -113,11 +114,16 @@ impl CacheShard { /// /// ```rust /// use lattice_embed::{EmbeddingCache, EmbeddingModel, ModelConfig}; +/// use lattice_embed::service::EmbeddingRole; /// /// let cache = EmbeddingCache::new(1000); /// /// // Cache miss - no embedding stored yet -/// let key = cache.compute_key("Hello, world!", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); +/// let key = cache.compute_key( +/// "Hello, world!", +/// ModelConfig::new(EmbeddingModel::BgeSmallEnV15), +/// EmbeddingRole::Generic, +/// ); /// assert!(cache.get(&key).is_none()); /// /// // Store embedding @@ -187,15 +193,25 @@ impl EmbeddingCache { /// Uses Blake3 hashing for fast, collision-resistant keys. The key includes the model /// name, revision, and active dimension from the `ModelConfig`, so different MRL truncations /// produce different cache keys. - pub fn compute_key(&self, text: &str, model_config: ModelConfig) -> CacheKey { + /// + /// The role is also included so that `embed_query("hello")` and `embed_passage("hello")` + /// produce different cache entries even when the raw text and model config are identical. + /// Use `EmbeddingRole::Generic` for the backwards-compatible `embed()` path. + pub fn compute_key( + &self, + text: &str, + model_config: ModelConfig, + role: EmbeddingRole, + ) -> CacheKey { let mut hasher = blake3::Hasher::new(); hasher.update(text.as_bytes()); // Unique identifier for the model config: "model_name:version:dims" let model_key = format!( - "{}:{}:{}", + "{}:{}:{}:{}", model_config.model, model_config.model.key_version(), model_config.dimensions(), + role.cache_tag(), ); hasher.update(model_key.as_bytes()); *hasher.finalize().as_bytes() @@ -400,7 +416,11 @@ mod tests { #[test] fn test_cache_basic_operations() { let cache = EmbeddingCache::new(100); - let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key = cache.compute_key( + "hello", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); // Miss assert!(cache.get(&key).is_none()); @@ -427,7 +447,11 @@ mod tests { let mut keys = Vec::new(); for i in 0..32u32 { let text = format!("text_{}", i); - let key = cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key = cache.compute_key( + &text, + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); keys.push(key); cache.put(key, vec![i as f32]); } @@ -448,14 +472,18 @@ mod tests { let target_shard; // Find the first key's shard and collect 3 keys for it. - let first_key = - cache.compute_key("probe_0", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let first_key = cache.compute_key( + "probe_0", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); target_shard = shard_index(&first_key); loop { let key = cache.compute_key( &format!("lru_test_{}", i), ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, ); if shard_index(&key) == target_shard { same_shard_keys.push((key, i)); @@ -492,8 +520,16 @@ mod tests { fn test_cache_different_models_different_keys() { let cache = EmbeddingCache::new(100); - let key_small = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); - let key_base = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeBaseEnV15)); + let key_small = cache.compute_key( + "text", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); + let key_base = cache.compute_key( + "text", + ModelConfig::new(EmbeddingModel::BgeBaseEnV15), + EmbeddingRole::Generic, + ); // Same text, different models = different keys assert_ne!(key_small, key_base); @@ -502,7 +538,11 @@ mod tests { #[test] fn test_cache_stats() { let cache = EmbeddingCache::new(100); - let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key = cache.compute_key( + "hello", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); cache.get(&key); // Miss cache.put(key, vec![0.1]); @@ -520,9 +560,21 @@ mod tests { // Use capacity large enough that no shard evicts. let cache = EmbeddingCache::new(100); - let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); - let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); - let key3 = cache.compute_key("three", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key1 = cache.compute_key( + "one", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); + let key2 = cache.compute_key( + "two", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); + let key3 = cache.compute_key( + "three", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); cache.put(key1, vec![1.0]); cache.put(key3, vec![3.0]); @@ -539,8 +591,16 @@ mod tests { // Use capacity large enough that no shard evicts (ceil(100/16) = 7 per shard). let cache = EmbeddingCache::new(100); - let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); - let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key1 = cache.compute_key( + "one", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); + let key2 = cache.compute_key( + "two", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); cache.put_many(vec![(key1, vec![1.0]), (key2, vec![2.0])]); @@ -553,7 +613,11 @@ mod tests { #[test] fn test_cache_clear() { let cache = EmbeddingCache::new(100); - let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key = cache.compute_key( + "hello", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); cache.put(key, vec![0.1]); assert!(cache.get(&key).is_some()); @@ -574,7 +638,11 @@ mod tests { let cache = EmbeddingCache::new(0); assert!(!cache.is_enabled()); - let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key = cache.compute_key( + "hello", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); cache.put(key, vec![0.1]); assert!(cache.get(&key).is_none()); @@ -599,8 +667,11 @@ mod tests { for i in 0..100 { // Each thread uses unique keys to avoid contention on same entry. let text = format!("thread_{}_item_{}", t, i); - let key = - cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key = cache.compute_key( + &text, + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); let embedding = vec![t as f32; 384]; cache.put(key, embedding.clone()); @@ -632,6 +703,7 @@ mod tests { let key = cache.compute_key( &format!("item_{}", i), ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, ); cache.put(key, vec![i as f32]); } @@ -670,8 +742,16 @@ mod tests { let cache = EmbeddingCache::new(100); // Insert a few entries and access them. - let key1 = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); - let key2 = cache.compute_key("world", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key1 = cache.compute_key( + "hello", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); + let key2 = cache.compute_key( + "world", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); cache.put(key1, vec![1.0]); cache.put(key2, vec![2.0]); @@ -697,8 +777,69 @@ mod tests { let cache = EmbeddingCache::new(3); assert!(cache.is_enabled()); - let key = cache.compute_key("x", ModelConfig::new(EmbeddingModel::BgeSmallEnV15)); + let key = cache.compute_key( + "x", + ModelConfig::new(EmbeddingModel::BgeSmallEnV15), + EmbeddingRole::Generic, + ); cache.put(key, vec![42.0]); assert!(cache.get(&key).is_some()); } + + // ------------------------------------------------------------------------- + // Role-aware cache key tests (P0-E2) + // ------------------------------------------------------------------------- + + /// Query and passage roles must produce different cache keys for the same raw text + /// so that embed_query("hello") and embed_passage("hello") are stored separately. + #[test] + fn test_role_query_vs_passage_different_keys() { + let cache = EmbeddingCache::new(100); + let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small); + let text = "hello world"; + + let key_query = cache.compute_key(text, model, EmbeddingRole::Query); + let key_passage = cache.compute_key(text, model, EmbeddingRole::Passage); + let key_generic = cache.compute_key(text, model, EmbeddingRole::Generic); + + assert_ne!(key_query, key_passage, "query vs passage must differ"); + assert_ne!(key_query, key_generic, "query vs generic must differ"); + assert_ne!(key_passage, key_generic, "passage vs generic must differ"); + } + + /// Role keys are consistent: same inputs always produce same key. + #[test] + fn test_role_key_deterministic() { + let cache = EmbeddingCache::new(100); + let model = ModelConfig::new(EmbeddingModel::BgeSmallEnV15); + + let k1 = cache.compute_key("test", model, EmbeddingRole::Query); + let k2 = cache.compute_key("test", model, EmbeddingRole::Query); + assert_eq!(k1, k2, "identical inputs must produce identical key"); + } + + /// Storing under one role does not pollute the other role's key. + #[test] + fn test_role_cache_isolation() { + let cache = EmbeddingCache::new(100); + let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small); + + let key_query = cache.compute_key("embed me", model, EmbeddingRole::Query); + let key_passage = cache.compute_key("embed me", model, EmbeddingRole::Passage); + + // Store under Query role. + cache.put(key_query, vec![1.0, 2.0]); + + // Passage role key must still be a miss. + assert!( + cache.get(&key_passage).is_none(), + "passage key must miss after storing under query key" + ); + + // Query role key must hit. + assert!( + cache.get(&key_query).is_some(), + "query key must hit after storing under query key" + ); + } } diff --git a/crates/embed/src/lib.rs b/crates/embed/src/lib.rs index 7e9591b5..fe62080a 100644 --- a/crates/embed/src/lib.rs +++ b/crates/embed/src/lib.rs @@ -80,14 +80,14 @@ mod cache; mod error; pub mod migration; mod model; -mod service; +pub mod service; pub mod simd; pub mod types; pub use cache::{CacheStats, DEFAULT_CACHE_CAPACITY, EmbeddingCache, ShardStats}; pub use error::{EmbedError, Result}; pub use model::{EmbeddingModel, MIN_MRL_OUTPUT_DIM, ModelConfig, ModelProvenance}; -pub use service::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS}; +pub use service::{DEFAULT_MAX_BATCH_SIZE, EmbeddingRole, EmbeddingService, MAX_TEXT_CHARS}; pub use simd::{SimdConfig, simd_config}; #[cfg(feature = "native")] diff --git a/crates/embed/src/model.rs b/crates/embed/src/model.rs index c71517f7..cfc2f4bb 100644 --- a/crates/embed/src/model.rs +++ b/crates/embed/src/model.rs @@ -266,9 +266,22 @@ impl EmbeddingModel { /// Some models use different prompts for documents vs queries. /// Returns `Some(prefix)` if the document text should be wrapped as /// `"{prefix}{text}"` before embedding at storage time. + /// + /// - **E5 models**: trained with `"passage: "` prefix on document/passage inputs. + /// Omitting the prefix on the document side degrades retrieval quality because + /// the model's embedding space was conditioned on this asymmetry during fine-tuning. + /// - **BGE / MiniLM**: no document prefix required (contrastive training on raw text). + /// - **Qwen3-Embedding**: raw passage text is used without an instruction prefix; + /// only the query side carries the task instruction. #[inline] pub const fn document_instruction(&self) -> Option<&'static str> { - None + match self { + EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => { + // E5 asymmetric retrieval: "passage: " prefix for documents/passages. + Some("passage: ") + } + _ => None, + } } /// **Stable**: get the model identifier (HuggingFace ID or provider/model). diff --git a/crates/embed/src/service/cached.rs b/crates/embed/src/service/cached.rs index d5037109..d6695c3b 100644 --- a/crates/embed/src/service/cached.rs +++ b/crates/embed/src/service/cached.rs @@ -1,6 +1,6 @@ //! Caching wrapper for embedding services. -use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS}; +use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingRole, EmbeddingService, MAX_TEXT_CHARS}; use crate::error::Result; use crate::model::EmbeddingModel; use async_trait::async_trait; @@ -80,6 +80,53 @@ impl CachedEmbeddingService { #[async_trait] impl EmbeddingService for CachedEmbeddingService { async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result>> { + // Generic role: cache key does NOT include a role tag, maintaining + // backwards compatibility with any on-disk cache entries written before + // role-aware keys were introduced. + self.embed_with_role(texts, model, EmbeddingRole::Generic) + .await + } + + /// Override: apply query prompt then cache with `Query` role key. + async fn embed_query(&self, texts: &[String], model: EmbeddingModel) -> Result>> { + let prefix = model.query_instruction(); + let prompted = super::apply_prefix(texts, prefix); + self.embed_with_role(&prompted, model, EmbeddingRole::Query) + .await + } + + /// Override: apply passage prompt then cache with `Passage` role key. + async fn embed_passage( + &self, + texts: &[String], + model: EmbeddingModel, + ) -> Result>> { + let prefix = model.document_instruction(); + let prompted = super::apply_prefix(texts, prefix); + self.embed_with_role(&prompted, model, EmbeddingRole::Passage) + .await + } + + fn supports_model(&self, model: EmbeddingModel) -> bool { + self.inner.supports_model(model) + } + + fn name(&self) -> &'static str { + "cached-embedding" + } +} + +impl CachedEmbeddingService { + /// Core cache-and-embed implementation shared by `embed`, `embed_query`, and + /// `embed_passage`. `texts` must already have the prompt prefix applied; `role` + /// is used only as part of the cache key so that different roles produce separate + /// cache entries for the same raw text. + async fn embed_with_role( + &self, + texts: &[String], + model: EmbeddingModel, + role: EmbeddingRole, + ) -> Result>> { use crate::error::EmbedError; // Validate inputs before any cache interaction so callers always get @@ -108,11 +155,11 @@ impl EmbeddingService for CachedEmbeddingService< return self.inner.embed(texts, model).await; } - // Compute cache keys — include the active dimension (for MRL models). + // Compute cache keys — include the active dimension (for MRL models) and role. let model_config = self.inner.model_config(model); let keys: Vec<_> = texts .iter() - .map(|t| self.cache.compute_key(t, model_config)) + .map(|t| self.cache.compute_key(t, model_config, role)) .collect(); // Check cache for all texts — returns Arc<[f32]> refs (O(1) per hit) @@ -144,7 +191,7 @@ impl EmbeddingService for CachedEmbeddingService< to_embed.len() ); - // Embed missing texts + // Embed missing texts (after prompt is already applied in texts) let texts_to_embed: Vec = to_embed.iter().map(|(_, t)| (*t).clone()).collect(); let new_embeddings = self.inner.embed(&texts_to_embed, model).await?; @@ -172,14 +219,6 @@ impl EmbeddingService for CachedEmbeddingService< // - Non-cached items were assigned via results[i] = Some(embedding) in the loop above Ok(results.into_iter().flatten().collect()) } - - fn supports_model(&self, model: EmbeddingModel) -> bool { - self.inner.supports_model(model) - } - - fn name(&self) -> &'static str { - "cached-embedding" - } } // Suppress dead code warnings for constants that are used by other modules diff --git a/crates/embed/src/service/mod.rs b/crates/embed/src/service/mod.rs index 76ea5a4f..fe3a5a5f 100644 --- a/crates/embed/src/service/mod.rs +++ b/crates/embed/src/service/mod.rs @@ -30,6 +30,42 @@ pub const DEFAULT_MAX_BATCH_SIZE: usize = 1000; /// 32KB is sufficient for most embedding use cases while preventing abuse. pub const MAX_TEXT_CHARS: usize = 32768; +/// **Stable**: role of text in asymmetric retrieval. +/// +/// Models trained with asymmetric objectives (E5, Qwen3-Embedding) use different +/// prompt prefixes for queries vs documents. Providing the wrong role causes the +/// embedding to land in the wrong region of the model's retrieval space, degrading +/// retrieval quality. +/// +/// Use [`EmbeddingService::embed_query`] / [`EmbeddingService::embed_passage`] to +/// apply the correct prefix automatically. The role is also included in the cache +/// key so that `embed_query("hello")` and `embed_passage("hello")` are stored as +/// separate entries even when the raw text is identical. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum EmbeddingRole { + /// Query / question text — may receive a query-side prompt prefix. + Query, + /// Document / passage text — may receive a passage-side prompt prefix. + Passage, + /// Generic text with no role-specific prefix (backwards-compatible default). + Generic, +} + +impl EmbeddingRole { + /// Short ASCII tag included in the cache key hash. + /// + /// Distinct strings ensure that role changes affect the Blake3 hash even + /// when the raw text and model config are identical. + #[inline] + pub(crate) const fn cache_tag(self) -> &'static str { + match self { + EmbeddingRole::Query => "role:query", + EmbeddingRole::Passage => "role:passage", + EmbeddingRole::Generic => "role:generic", + } + } +} + /// **Stable**: external consumers may depend on this; breaking changes require a SemVer bump. /// /// Trait for embedding generation services. @@ -56,6 +92,8 @@ pub trait EmbeddingService: Send + Sync { /// **Stable**: generate embeddings for multiple texts. /// /// Returns a vector of embeddings, one for each input text, in the same order. + /// Applies no role-specific prompt prefix (equivalent to `Generic` role). + /// Use [`embed_query`] / [`embed_passage`] for asymmetric retrieval models. async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result>>; /// **Stable**: generate an embedding for a single text. @@ -69,6 +107,38 @@ pub trait EmbeddingService: Send + Sync { .ok_or_else(|| EmbedError::Internal("no embedding generated".into())) } + /// **Stable**: embed query texts with model-specific query prompt prefix applied. + /// + /// For models that use asymmetric prompts (E5, Qwen3-Embedding), this prepends the + /// `query_instruction()` prefix before calling the model forward. For models with + /// no query prefix (BGE, MiniLM), this is equivalent to `embed()`. + /// + /// Cache keys produced by this method are distinct from those produced by + /// `embed_passage()` and `embed()` even when the raw text is identical. + async fn embed_query(&self, texts: &[String], model: EmbeddingModel) -> Result>> { + let prefix = model.query_instruction(); + let prompted = apply_prefix(texts, prefix); + self.embed(&prompted, model).await + } + + /// **Stable**: embed document/passage texts with model-specific document prompt prefix applied. + /// + /// For models that use asymmetric prompts (E5), this prepends the + /// `document_instruction()` prefix before calling the model forward. For models with + /// no document prefix (BGE, MiniLM, Qwen3), this is equivalent to `embed()`. + /// + /// Cache keys produced by this method are distinct from those produced by + /// `embed_query()` and `embed()` even when the raw text is identical. + async fn embed_passage( + &self, + texts: &[String], + model: EmbeddingModel, + ) -> Result>> { + let prefix = model.document_instruction(); + let prompted = apply_prefix(texts, prefix); + self.embed(&prompted, model).await + } + /// **Unstable**: returns the effective `ModelConfig` for a given model on this service. /// /// The default returns a config with no MRL truncation. `NativeEmbeddingService` @@ -84,3 +154,16 @@ pub trait EmbeddingService: Send + Sync { /// **Stable**: get the name/identifier of this service. fn name(&self) -> &'static str; } + +/// Apply an optional prompt prefix to each text. +/// +/// Returns a new `Vec` with the prefix prepended where the prefix is +/// `Some`, or a cloned vec of the original texts when the prefix is `None`. +/// This is a free function (not a method) so it can be called from default +/// trait method bodies without going through `self`. +pub(crate) fn apply_prefix(texts: &[String], prefix: Option<&str>) -> Vec { + match prefix { + None => texts.to_vec(), + Some(p) => texts.iter().map(|t| format!("{p}{t}")).collect(), + } +} diff --git a/crates/embed/src/service/tests.rs b/crates/embed/src/service/tests.rs index 2c567514..706d6f68 100644 --- a/crates/embed/src/service/tests.rs +++ b/crates/embed/src/service/tests.rs @@ -1,6 +1,7 @@ //! Tests for embedding services. use super::*; +use crate::model::EmbeddingModel; #[test] fn test_max_batch_size_constant() { @@ -199,3 +200,90 @@ mod native_tests { ); } } + +// --------------------------------------------------------------------------- +// Role-aware prompt tests (P0-E2) +// --------------------------------------------------------------------------- + +/// E5 query_instruction returns "query: ", document_instruction returns "passage: ". +#[test] +fn test_e5_query_instruction() { + assert_eq!( + EmbeddingModel::MultilingualE5Small.query_instruction(), + Some("query: "), + "E5 small must return 'query: ' prefix" + ); + assert_eq!( + EmbeddingModel::MultilingualE5Base.query_instruction(), + Some("query: "), + "E5 base must return 'query: ' prefix" + ); +} + +/// E5 document_instruction returns "passage: " (P0-E2 fix — was None before). +#[test] +fn test_e5_document_instruction() { + assert_eq!( + EmbeddingModel::MultilingualE5Small.document_instruction(), + Some("passage: "), + "E5 small must return 'passage: ' document prefix" + ); + assert_eq!( + EmbeddingModel::MultilingualE5Base.document_instruction(), + Some("passage: "), + "E5 base must return 'passage: ' document prefix" + ); +} + +/// BGE and MiniLM models must NOT have document_instruction (they use raw text). +#[test] +fn test_bge_minilm_no_document_instruction() { + assert_eq!( + EmbeddingModel::BgeSmallEnV15.document_instruction(), + None, + "BGE small must not have document prefix" + ); + assert_eq!(EmbeddingModel::BgeBaseEnV15.document_instruction(), None); + assert_eq!(EmbeddingModel::BgeLargeEnV15.document_instruction(), None); + assert_eq!(EmbeddingModel::AllMiniLmL6V2.document_instruction(), None); + assert_eq!( + EmbeddingModel::ParaphraseMultilingualMiniLmL12V2.document_instruction(), + None + ); +} + +/// Qwen document_instruction returns None (raw passage, instruction only for queries). +#[test] +fn test_qwen_no_document_instruction() { + assert_eq!( + EmbeddingModel::Qwen3Embedding0_6B.document_instruction(), + None, + "Qwen document side uses raw text" + ); +} + +/// apply_prefix prepends the prefix when Some, returns clone when None. +#[test] +fn test_apply_prefix_some() { + let texts = vec!["hello".to_string(), "world".to_string()]; + let result = apply_prefix(&texts, Some("query: ")); + assert_eq!(result, vec!["query: hello", "query: world"]); +} + +#[test] +fn test_apply_prefix_none() { + let texts = vec!["hello".to_string()]; + let result = apply_prefix(&texts, None); + assert_eq!(result, texts); +} + +/// EmbeddingRole cache tags are distinct strings. +#[test] +fn test_embedding_role_cache_tags_distinct() { + let q = EmbeddingRole::Query.cache_tag(); + let p = EmbeddingRole::Passage.cache_tag(); + let g = EmbeddingRole::Generic.cache_tag(); + assert_ne!(q, p); + assert_ne!(q, g); + assert_ne!(p, g); +}