Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/khive-pack-memory/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ description = "Memory verb pack — remember/recall semantics with decay-aware r
[dependencies]
khive-types = { version = "0.2.2", path = "../khive-types", features = ["serde"] }
khive-runtime = { version = "0.2.2", path = "../khive-runtime" }
khive-retrieval = { version = "0.2.2", path = "../khive-retrieval" }
khive-pack-brain = { version = "0.2.2", path = "../khive-pack-brain" }
inventory = { workspace = true }
khive-storage = { version = "0.2.2", path = "../khive-storage" }
Expand Down
303 changes: 296 additions & 7 deletions crates/khive-pack-memory/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ use serde::Deserialize;
use serde_json::{json, Value};
use uuid::Uuid;

use khive_runtime::fusion::fuse_with_strategy;
use khive_runtime::{NamespaceToken, RuntimeError, SearchHit, SearchSource, VerbRegistry};
use khive_retrieval::{
fuse_search_results, FusionStrategy as RetrievalFusionStrategy, HybridConfig,
};
use khive_runtime::{
FusionStrategy as RuntimeFusionStrategy, NamespaceToken, RuntimeError, SearchHit, SearchSource,
VerbRegistry,
};
use khive_storage::types::{
TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
VectorSearchRequest,
Expand All @@ -32,6 +37,19 @@ fn validate_memory_type(mt: &str) -> Result<(), RuntimeError> {
}
}

fn parse_fusion_strategy_str(s: &str) -> Result<RuntimeFusionStrategy, RuntimeError> {
match s {
"rrf" => Ok(RuntimeFusionStrategy::Rrf { k: 60 }),
"weighted" => Ok(RuntimeFusionStrategy::Weighted {
weights: vec![0.3, 0.7],
}),
"union" => Ok(RuntimeFusionStrategy::Union),
other => Err(RuntimeError::InvalidInput(format!(
"invalid fusion_strategy {other:?}: must be one of \"rrf\", \"weighted\", \"union\""
))),
}
}

#[derive(Deserialize)]
struct RememberParams {
content: String,
Expand All @@ -53,6 +71,9 @@ struct RecallParams {
min_score: Option<f64>,
min_salience: Option<f64>,
config: Option<RecallConfig>,
top_k: Option<usize>,
fusion_strategy: Option<String>,
score_floor: Option<f32>,
}

impl RecallParams {
Expand Down Expand Up @@ -138,22 +159,118 @@ fn search_source_label(source: SearchSource) -> &'static str {
}
}

#[derive(Default)]
struct CandidateMeta {
in_text: bool,
in_vector: bool,
title: Option<String>,
snippet: Option<String>,
}

fn to_retrieval_fusion_strategy(strategy: &RuntimeFusionStrategy) -> RetrievalFusionStrategy {
match strategy {
RuntimeFusionStrategy::Rrf { k } => RetrievalFusionStrategy::Rrf { k: *k },
RuntimeFusionStrategy::Weighted { .. } => RetrievalFusionStrategy::Weighted {
weights: Vec::new(),
},
RuntimeFusionStrategy::Union => RetrievalFusionStrategy::Union,
RuntimeFusionStrategy::VectorOnly => RetrievalFusionStrategy::VectorOnly,
}
}

fn retrieval_hybrid_config(strategy: &RuntimeFusionStrategy, limit: usize) -> HybridConfig {
let mut config = HybridConfig::new(limit)
.with_pool_size(limit)
.with_fusion_strategy(to_retrieval_fusion_strategy(strategy));

if let RuntimeFusionStrategy::Weighted { weights } = strategy {
// Runtime weighted fusion uses [text, vector]. HybridConfig uses keyword/vector.
// Preserve arbitrary positive scales — do not clamp via with_weights().
config.keyword_weight = weights.first().copied().unwrap_or(0.0).max(0.0);
config.vector_weight = weights.get(1).copied().unwrap_or(0.0).max(0.0);
}

config
}

fn source_from_meta(meta: &CandidateMeta) -> SearchSource {
match (meta.in_vector, meta.in_text) {
(true, true) => SearchSource::Both,
(true, false) => SearchSource::Vector,
(false, true) => SearchSource::Text,
(false, false) => SearchSource::Text,
}
}

fn fuse_candidates(
text_hits: Vec<TextSearchHit>,
vector_hits: Vec<VectorSearchHit>,
memory_ids: &HashSet<Uuid>,
cfg: &RecallConfig,
limit: usize,
) -> Vec<SearchHit> {
let text: Vec<TextSearchHit> = text_hits
let mut meta = HashMap::<Uuid, CandidateMeta>::new();

let text_source: Vec<_> = text_hits
.into_iter()
.filter(|h| memory_ids.contains(&h.subject_id))
.map(|h| {
let TextSearchHit {
subject_id,
score,
title,
snippet,
..
} = h;
let entry = meta.entry(subject_id).or_default();
entry.in_text = true;
if entry.title.is_none() {
entry.title = title;
}
if entry.snippet.is_none() {
entry.snippet = snippet;
}
(subject_id, score)
})
.collect();
let vec: Vec<VectorSearchHit> = vector_hits

let vector_source: Vec<_> = vector_hits
.into_iter()
.filter(|h| memory_ids.contains(&h.subject_id))
.map(|h| {
let entry = meta.entry(h.subject_id).or_default();
entry.in_vector = true;
(h.subject_id, h.score)
})
.collect();
fuse_with_strategy(text, vec, &cfg.fuse_strategy, limit)

let vector_only = matches!(&cfg.fuse_strategy, RuntimeFusionStrategy::VectorOnly);
let sources = if vector_only {
vec![vector_source]
} else {
// HybridConfig weighted convention: vector first, keyword second.
vec![vector_source, text_source]
};

let retrieval_cfg = retrieval_hybrid_config(&cfg.fuse_strategy, limit);
fuse_search_results(sources, &retrieval_cfg)
.into_iter()
.map(|(id, score)| {
let m = meta.remove(&id).unwrap_or_default();
let (source, title, snippet) = if vector_only {
(SearchSource::Vector, None, None)
} else {
(source_from_meta(&m), m.title, m.snippet)
};
SearchHit {
entity_id: id,
score,
source,
title,
snippet,
}
})
.collect()
}

impl MemoryPack {
Expand Down Expand Up @@ -335,10 +452,35 @@ impl MemoryPack {
validate_memory_type(mt)?;
}

let cfg = p.effective_config(self.active_config());
if let Some(ref fs) = p.fusion_strategy {
parse_fusion_strategy_str(fs)?;
}

let mut cfg = p.effective_config(self.active_config());
if let Some(ref fs) = p.fusion_strategy {
let mut new_strategy = parse_fusion_strategy_str(fs)?;
// "weighted" in the request means "use weighted fusion" — the actual
// weight values come from pack config, not the request (ADR-033 §6.1).
if let (
RuntimeFusionStrategy::Weighted {
weights: ref mut new_w,
},
RuntimeFusionStrategy::Weighted {
weights: ref existing_w,
},
) = (&mut new_strategy, &cfg.fuse_strategy)
{
*new_w = existing_w.clone();
}
cfg.fuse_strategy = new_strategy;
}
cfg.validate()?;

let limit = p.limit.unwrap_or(10).min(100);
let limit = if let Some(k) = p.top_k {
u32::try_from(k.min(100)).unwrap_or(100)
} else {
p.limit.unwrap_or(10).min(100)
};
let candidate_limit = recall_candidate_count(&cfg, limit);
let candidates = self
.collect_recall_candidates(&p.query, token, candidate_limit)
Expand Down Expand Up @@ -392,6 +534,11 @@ impl MemoryPack {
if final_score < cfg.min_score {
continue;
}
if let Some(floor) = p.score_floor {
if final_score < floor as f64 {
continue;
}
}
ranked.push((id, final_score, breakdown, note));
}

Expand Down Expand Up @@ -661,6 +808,9 @@ mod tests {
min_score: None,
min_salience: None,
config: None,
top_k: None,
fusion_strategy: None,
score_floor: None,
};
let cfg = p.effective_config(RecallConfig::default());
assert!((cfg.relevance_weight - 0.70).abs() < 1e-12);
Expand All @@ -677,6 +827,9 @@ mod tests {
min_score: Some(0.5),
min_salience: Some(0.3),
config: None,
top_k: None,
fusion_strategy: None,
score_floor: None,
};
let cfg = p.effective_config(RecallConfig::default());
assert!((cfg.min_score - 0.5).abs() < 1e-12);
Expand All @@ -695,13 +848,149 @@ mod tests {
relevance_weight: 0.50,
..RecallConfig::default()
}),
top_k: None,
fusion_strategy: None,
score_floor: None,
};
let cfg = p.effective_config(RecallConfig::default());
assert!((cfg.relevance_weight - 0.50).abs() < 1e-12);
// legacy min_score overrides config's default
assert!((cfg.min_score - 0.1).abs() < 1e-12);
}

#[test]
fn test_weighted_strategy_preserves_pack_weights() {
use khive_runtime::FusionStrategy as RuntimeFusionStrategy;

// Pack config has custom weighted weights [0.8, 0.2]
let base = RecallConfig {
fuse_strategy: RuntimeFusionStrategy::Weighted {
weights: vec![0.8, 0.2],
},
..RecallConfig::default()
};

// Request overrides to "weighted" — must preserve [0.8, 0.2], not replace with [0.3, 0.7]
let p = RecallParams {
query: "test".to_string(),
limit: None,
memory_type: None,
min_score: None,
min_salience: None,
config: None,
top_k: None,
fusion_strategy: Some("weighted".to_string()),
score_floor: None,
};

let mut cfg = p.effective_config(base);
if let Some(ref fs) = p.fusion_strategy {
let mut new_strategy = parse_fusion_strategy_str(fs).unwrap();
if let (
RuntimeFusionStrategy::Weighted {
weights: ref mut new_w,
},
RuntimeFusionStrategy::Weighted {
weights: ref existing_w,
},
) = (&mut new_strategy, &cfg.fuse_strategy)
{
*new_w = existing_w.clone();
}
cfg.fuse_strategy = new_strategy;
}

match cfg.fuse_strategy {
RuntimeFusionStrategy::Weighted { weights } => {
assert_eq!(
weights,
vec![0.8, 0.2],
"fusion_strategy=weighted must preserve pack weights [0.8, 0.2], not override with [0.3, 0.7]"
);
}
other => panic!("expected Weighted strategy, got {other:?}"),
}
}

#[test]
fn fusion_strategy_change_produces_observable_ordering_difference() {
// Codex Medium 2 (PR #406): prove the fusion_strategy knob actually
// affects fusion output, not just validation. Uses a deterministic fixture
// where rank-based (RRF) and score-based (Weighted) fusion must rank
// differently.
use khive_runtime::FusionStrategy as RuntimeFusionStrategy;
use khive_storage::types::{TextSearchHit, VectorSearchHit};
use std::collections::HashSet;
use uuid::Uuid;

let id_a = Uuid::from_u128(0xAAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAA);
let id_b = Uuid::from_u128(0xBBBB_BBBB_BBBB_BBBB_BBBB_BBBB_BBBB_BBBB);
let id_c = Uuid::from_u128(0xCCCC_CCCC_CCCC_CCCC_CCCC_CCCC_CCCC_CCCC);

let text_hits = vec![
TextSearchHit {
subject_id: id_a,
score: 0.9_f64.into(),
rank: 1,
title: None,
snippet: None,
},
TextSearchHit {
subject_id: id_b,
score: 0.5_f64.into(),
rank: 2,
title: None,
snippet: None,
},
];
let vector_hits = vec![
VectorSearchHit {
subject_id: id_c,
score: 0.95_f64.into(),
rank: 1,
},
VectorSearchHit {
subject_id: id_a,
score: 0.3_f64.into(),
rank: 2,
},
];
let memory_ids: HashSet<Uuid> = [id_a, id_b, id_c].into_iter().collect();

let cfg_rrf = RecallConfig {
fuse_strategy: RuntimeFusionStrategy::Rrf { k: 60 },
..RecallConfig::default()
};
let rrf_results = fuse_candidates(
text_hits.clone(),
vector_hits.clone(),
&memory_ids,
&cfg_rrf,
10,
);
let rrf_order: Vec<Uuid> = rrf_results.iter().map(|h| h.entity_id).collect();

let cfg_weighted = RecallConfig {
fuse_strategy: RuntimeFusionStrategy::Weighted {
weights: vec![0.1, 0.9],
},
..RecallConfig::default()
};
let weighted_results =
fuse_candidates(text_hits, vector_hits, &memory_ids, &cfg_weighted, 10);
let weighted_order: Vec<Uuid> = weighted_results.iter().map(|h| h.entity_id).collect();

// RRF on this fixture: id_a in both sources gets highest combined rank score;
// id_c (vector rank 1) and id_b (text rank 2) tied around 0.0161-0.0164.
// Weighted [0.1, 0.9]: id_c dominates (0.95 * 0.9 = 0.855); id_a drops
// (0.9 * 0.1 + 0.3 * 0.9 = 0.36); id_b last (0.5 * 0.1 = 0.05).
// The orderings MUST differ — this is the discriminating assertion.
assert_ne!(
rrf_order, weighted_order,
"fusion_strategy change must affect ordering; RRF and Weighted produced identical: {rrf_order:?}"
);
}

#[test]
fn compute_score_default_config_reproduces_legacy() {
let cfg = RecallConfig::default();
Expand Down
Loading
Loading