Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 163 additions & 22 deletions crates/embed/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//! no-ops with no locking or hashing work.

use crate::model::ModelConfig;
use crate::service::EmbeddingRole;
use lru::LruCache;
use parking_lot::RwLock;
use std::num::NonZeroUsize;
Expand Down Expand Up @@ -113,11 +114,16 @@ impl CacheShard {
///
/// ```rust
/// use lattice_embed::{EmbeddingCache, EmbeddingModel, ModelConfig};
/// use lattice_embed::service::EmbeddingRole;
///
/// let cache = EmbeddingCache::new(1000);
///
/// // Cache miss - no embedding stored yet
/// let key = cache.compute_key("Hello, world!", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
/// let key = cache.compute_key(
/// "Hello, world!",
/// ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
/// EmbeddingRole::Generic,
/// );
/// assert!(cache.get(&key).is_none());
///
/// // Store embedding
Expand Down Expand Up @@ -187,15 +193,25 @@ impl EmbeddingCache {
/// Uses Blake3 hashing for fast, collision-resistant keys. The key includes the model
/// name, revision, and active dimension from the `ModelConfig`, so different MRL truncations
/// produce different cache keys.
pub fn compute_key(&self, text: &str, model_config: ModelConfig) -> CacheKey {
///
/// The role is also included so that `embed_query("hello")` and `embed_passage("hello")`
/// produce different cache entries even when the raw text and model config are identical.
/// Use `EmbeddingRole::Generic` for the backwards-compatible `embed()` path.
pub fn compute_key(
&self,
text: &str,
model_config: ModelConfig,
role: EmbeddingRole,
) -> CacheKey {
let mut hasher = blake3::Hasher::new();
hasher.update(text.as_bytes());
// Unique identifier for the model config: "model_name:version:dims"
let model_key = format!(
"{}:{}:{}",
"{}:{}:{}:{}",
model_config.model,
model_config.model.key_version(),
model_config.dimensions(),
role.cache_tag(),
);
hasher.update(model_key.as_bytes());
*hasher.finalize().as_bytes()
Expand Down Expand Up @@ -400,7 +416,11 @@ mod tests {
#[test]
fn test_cache_basic_operations() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);

// Miss
assert!(cache.get(&key).is_none());
Expand All @@ -427,7 +447,11 @@ mod tests {
let mut keys = Vec::new();
for i in 0..32u32 {
let text = format!("text_{}", i);
let key = cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key = cache.compute_key(
&text,
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
keys.push(key);
cache.put(key, vec![i as f32]);
}
Expand All @@ -448,14 +472,18 @@ mod tests {
let target_shard;

// Find the first key's shard and collect 3 keys for it.
let first_key =
cache.compute_key("probe_0", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let first_key = cache.compute_key(
"probe_0",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
target_shard = shard_index(&first_key);

loop {
let key = cache.compute_key(
&format!("lru_test_{}", i),
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
if shard_index(&key) == target_shard {
same_shard_keys.push((key, i));
Expand Down Expand Up @@ -492,8 +520,16 @@ mod tests {
fn test_cache_different_models_different_keys() {
let cache = EmbeddingCache::new(100);

let key_small = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key_base = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeBaseEnV15));
let key_small = cache.compute_key(
"text",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key_base = cache.compute_key(
"text",
ModelConfig::new(EmbeddingModel::BgeBaseEnV15),
EmbeddingRole::Generic,
);

// Same text, different models = different keys
assert_ne!(key_small, key_base);
Expand All @@ -502,7 +538,11 @@ mod tests {
#[test]
fn test_cache_stats() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);

cache.get(&key); // Miss
cache.put(key, vec![0.1]);
Expand All @@ -520,9 +560,21 @@ mod tests {
// Use capacity large enough that no shard evicts.
let cache = EmbeddingCache::new(100);

let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key3 = cache.compute_key("three", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key1 = cache.compute_key(
"one",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key2 = cache.compute_key(
"two",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key3 = cache.compute_key(
"three",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);

cache.put(key1, vec![1.0]);
cache.put(key3, vec![3.0]);
Expand All @@ -539,8 +591,16 @@ mod tests {
// Use capacity large enough that no shard evicts (ceil(100/16) = 7 per shard).
let cache = EmbeddingCache::new(100);

let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key1 = cache.compute_key(
"one",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key2 = cache.compute_key(
"two",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);

cache.put_many(vec![(key1, vec![1.0]), (key2, vec![2.0])]);

Expand All @@ -553,7 +613,11 @@ mod tests {
#[test]
fn test_cache_clear() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);

cache.put(key, vec![0.1]);
assert!(cache.get(&key).is_some());
Expand All @@ -574,7 +638,11 @@ mod tests {
let cache = EmbeddingCache::new(0);
assert!(!cache.is_enabled());

let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key, vec![0.1]);
assert!(cache.get(&key).is_none());

Expand All @@ -599,8 +667,11 @@ mod tests {
for i in 0..100 {
// Each thread uses unique keys to avoid contention on same entry.
let text = format!("thread_{}_item_{}", t, i);
let key =
cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key = cache.compute_key(
&text,
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let embedding = vec![t as f32; 384];
cache.put(key, embedding.clone());

Expand Down Expand Up @@ -632,6 +703,7 @@ mod tests {
let key = cache.compute_key(
&format!("item_{}", i),
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key, vec![i as f32]);
}
Expand Down Expand Up @@ -670,8 +742,16 @@ mod tests {
let cache = EmbeddingCache::new(100);

// Insert a few entries and access them.
let key1 = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key2 = cache.compute_key("world", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key1 = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key2 = cache.compute_key(
"world",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);

cache.put(key1, vec![1.0]);
cache.put(key2, vec![2.0]);
Expand All @@ -697,8 +777,69 @@ mod tests {
let cache = EmbeddingCache::new(3);
assert!(cache.is_enabled());

let key = cache.compute_key("x", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key = cache.compute_key(
"x",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key, vec![42.0]);
assert!(cache.get(&key).is_some());
}

// -------------------------------------------------------------------------
// Role-aware cache key tests (P0-E2)
// -------------------------------------------------------------------------

/// Query and passage roles must produce different cache keys for the same raw text
/// so that embed_query("hello") and embed_passage("hello") are stored separately.
#[test]
fn test_role_query_vs_passage_different_keys() {
let cache = EmbeddingCache::new(100);
let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small);
let text = "hello world";

let key_query = cache.compute_key(text, model, EmbeddingRole::Query);
let key_passage = cache.compute_key(text, model, EmbeddingRole::Passage);
let key_generic = cache.compute_key(text, model, EmbeddingRole::Generic);

assert_ne!(key_query, key_passage, "query vs passage must differ");
assert_ne!(key_query, key_generic, "query vs generic must differ");
assert_ne!(key_passage, key_generic, "passage vs generic must differ");
}

/// Role keys are consistent: same inputs always produce same key.
#[test]
fn test_role_key_deterministic() {
let cache = EmbeddingCache::new(100);
let model = ModelConfig::new(EmbeddingModel::BgeSmallEnV15);

let k1 = cache.compute_key("test", model, EmbeddingRole::Query);
let k2 = cache.compute_key("test", model, EmbeddingRole::Query);
assert_eq!(k1, k2, "identical inputs must produce identical key");
}

/// Storing under one role does not pollute the other role's key.
#[test]
fn test_role_cache_isolation() {
let cache = EmbeddingCache::new(100);
let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small);

let key_query = cache.compute_key("embed me", model, EmbeddingRole::Query);
let key_passage = cache.compute_key("embed me", model, EmbeddingRole::Passage);

// Store under Query role.
cache.put(key_query, vec![1.0, 2.0]);

// Passage role key must still be a miss.
assert!(
cache.get(&key_passage).is_none(),
"passage key must miss after storing under query key"
);

// Query role key must hit.
assert!(
cache.get(&key_query).is_some(),
"query key must hit after storing under query key"
);
}
}
4 changes: 2 additions & 2 deletions crates/embed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ mod cache;
mod error;
pub mod migration;
mod model;
mod service;
pub mod service;
pub mod simd;
pub mod types;

pub use cache::{CacheStats, DEFAULT_CACHE_CAPACITY, EmbeddingCache, ShardStats};
pub use error::{EmbedError, Result};
pub use model::{EmbeddingModel, MIN_MRL_OUTPUT_DIM, ModelConfig, ModelProvenance};
pub use service::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
pub use service::{DEFAULT_MAX_BATCH_SIZE, EmbeddingRole, EmbeddingService, MAX_TEXT_CHARS};
pub use simd::{SimdConfig, simd_config};

#[cfg(feature = "native")]
Expand Down
15 changes: 14 additions & 1 deletion crates/embed/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,22 @@ impl EmbeddingModel {
/// Some models use different prompts for documents vs queries.
/// Returns `Some(prefix)` if the document text should be wrapped as
/// `"{prefix}{text}"` before embedding at storage time.
///
/// - **E5 models**: trained with `"passage: "` prefix on document/passage inputs.
/// Omitting the prefix on the document side degrades retrieval quality because
/// the model's embedding space was conditioned on this asymmetry during fine-tuning.
/// - **BGE / MiniLM**: no document prefix required (contrastive training on raw text).
/// - **Qwen3-Embedding**: raw passage text is used without an instruction prefix;
/// only the query side carries the task instruction.
#[inline]
pub const fn document_instruction(&self) -> Option<&'static str> {
None
match self {
EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
// E5 asymmetric retrieval: "passage: " prefix for documents/passages.
Some("passage: ")
}
_ => None,
}
}

/// **Stable**: get the model identifier (HuggingFace ID or provider/model).
Expand Down
Loading