Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/khive-pack-memory/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ description = "Memory verb pack — remember/recall semantics with decay-aware r
[dependencies]
khive-types = { version = "0.2.2", path = "../khive-types", features = ["serde"] }
khive-runtime = { version = "0.2.2", path = "../khive-runtime" }
khive-retrieval = { version = "0.2.2", path = "../khive-retrieval" }
khive-pack-brain = { version = "0.2.2", path = "../khive-pack-brain" }
inventory = { workspace = true }
khive-storage = { version = "0.2.2", path = "../khive-storage" }
Expand Down
111 changes: 106 additions & 5 deletions crates/khive-pack-memory/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ use serde::Deserialize;
use serde_json::{json, Value};
use uuid::Uuid;

use khive_runtime::fusion::fuse_with_strategy;
use khive_runtime::{NamespaceToken, RuntimeError, SearchHit, SearchSource, VerbRegistry};
use khive_retrieval::{
fuse_search_results, FusionStrategy as RetrievalFusionStrategy, HybridConfig,
};
use khive_runtime::{
FusionStrategy as RuntimeFusionStrategy, NamespaceToken, RuntimeError, SearchHit, SearchSource,
VerbRegistry,
};
use khive_storage::types::{
TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
VectorSearchRequest,
Expand Down Expand Up @@ -138,22 +143,118 @@ fn search_source_label(source: SearchSource) -> &'static str {
}
}

#[derive(Default)]
struct CandidateMeta {
in_text: bool,
in_vector: bool,
title: Option<String>,
snippet: Option<String>,
}

fn to_retrieval_fusion_strategy(strategy: &RuntimeFusionStrategy) -> RetrievalFusionStrategy {
match strategy {
RuntimeFusionStrategy::Rrf { k } => RetrievalFusionStrategy::Rrf { k: *k },
RuntimeFusionStrategy::Weighted { .. } => RetrievalFusionStrategy::Weighted {
weights: Vec::new(),
},
RuntimeFusionStrategy::Union => RetrievalFusionStrategy::Union,
RuntimeFusionStrategy::VectorOnly => RetrievalFusionStrategy::VectorOnly,
}
}

fn retrieval_hybrid_config(strategy: &RuntimeFusionStrategy, limit: usize) -> HybridConfig {
let mut config = HybridConfig::new(limit)
.with_pool_size(limit)
.with_fusion_strategy(to_retrieval_fusion_strategy(strategy));

if let RuntimeFusionStrategy::Weighted { weights } = strategy {
// Runtime weighted fusion uses [text, vector]. HybridConfig uses keyword/vector.
// Preserve arbitrary positive scales — do not clamp via with_weights().
config.keyword_weight = weights.first().copied().unwrap_or(0.0).max(0.0);
config.vector_weight = weights.get(1).copied().unwrap_or(0.0).max(0.0);
}

config
}

fn source_from_meta(meta: &CandidateMeta) -> SearchSource {
match (meta.in_vector, meta.in_text) {
(true, true) => SearchSource::Both,
(true, false) => SearchSource::Vector,
(false, true) => SearchSource::Text,
(false, false) => SearchSource::Text,
}
}

fn fuse_candidates(
text_hits: Vec<TextSearchHit>,
vector_hits: Vec<VectorSearchHit>,
memory_ids: &HashSet<Uuid>,
cfg: &RecallConfig,
limit: usize,
) -> Vec<SearchHit> {
let text: Vec<TextSearchHit> = text_hits
let mut meta = HashMap::<Uuid, CandidateMeta>::new();

let text_source: Vec<_> = text_hits
.into_iter()
.filter(|h| memory_ids.contains(&h.subject_id))
.map(|h| {
let TextSearchHit {
subject_id,
score,
title,
snippet,
..
} = h;
let entry = meta.entry(subject_id).or_default();
entry.in_text = true;
if entry.title.is_none() {
entry.title = title;
}
if entry.snippet.is_none() {
entry.snippet = snippet;
}
(subject_id, score)
})
.collect();
let vec: Vec<VectorSearchHit> = vector_hits

let vector_source: Vec<_> = vector_hits
.into_iter()
.filter(|h| memory_ids.contains(&h.subject_id))
.map(|h| {
let entry = meta.entry(h.subject_id).or_default();
entry.in_vector = true;
(h.subject_id, h.score)
})
.collect();
fuse_with_strategy(text, vec, &cfg.fuse_strategy, limit)

let vector_only = matches!(&cfg.fuse_strategy, RuntimeFusionStrategy::VectorOnly);
let sources = if vector_only {
vec![vector_source]
} else {
// HybridConfig weighted convention: vector first, keyword second.
vec![vector_source, text_source]
};

let retrieval_cfg = retrieval_hybrid_config(&cfg.fuse_strategy, limit);
fuse_search_results(sources, &retrieval_cfg)
.into_iter()
.map(|(id, score)| {
let m = meta.remove(&id).unwrap_or_default();
let (source, title, snippet) = if vector_only {
(SearchSource::Vector, None, None)
} else {
(source_from_meta(&m), m.title, m.snippet)
};
SearchHit {
entity_id: id,
score,
source,
title,
snippet,
}
})
.collect()
}

impl MemoryPack {
Expand Down
117 changes: 117 additions & 0 deletions crates/khive-pack-memory/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,123 @@ async fn test_recall_fuse_source_field_is_plain_string() {
);
}

/// Verifies that recall.fuse routes through khive_retrieval::fuse_search_results
/// by injecting a non-default fusion config (Rrf k=1) and asserting the fused
/// score matches the RRF k=1 formula: 1/(k + rank) = 1/(1 + 1) = 0.5.
///
/// Under default k=60 the score would be 1/61 ≈ 0.0164. The large gap (0.5 vs
/// 0.0164) is the discriminator: if the adapter did not pass k=1 through to
/// khive_retrieval::HybridConfig, the score would not be 0.5.
#[tokio::test]
async fn test_recall_fuse_rrf_k1_uses_retrieval_adapter() {
let rt = make_runtime();
let registry = make_registry(rt);

registry
.dispatch(
"remember",
json!({ "content": "retrieval adapter rrf k1 probe memory" }),
)
.await
.expect("remember");

let result = registry
.dispatch(
"recall.fuse",
json!({
"query": "retrieval adapter rrf k1 probe",
"config": {
"fuse_strategy": { "rrf": { "k": 1 } }
}
}),
)
.await
.expect("recall.fuse with Rrf k=1");

let fused = result["fused_candidates"].as_array().expect("fused array");
assert!(
!fused.is_empty(),
"recall.fuse must return at least one candidate"
);

let score = fused[0]["fused_score"]
.as_f64()
.expect("fused_score is f64");
// Rank 1 in a single text source with k=1: RRF = 1/(1+1) = 0.5.
// If k=60 were used instead, score ≈ 0.0164 — the gap proves the adapter works.
let expected = 0.5_f64;
assert!(
(score - expected).abs() < 1e-6,
"RRF k=1, rank 1 → fused_score must be 0.5; got {score:.6} \
(≈0.0164 means the adapter passed k=60 instead of k=1)"
);
}

/// Regression: after wiring khive-retrieval into fuse_candidates, the recall.fuse
/// response shape must be unchanged — top-level strategy + candidate_limit, and
/// per-candidate note_id + fused_score + source must all be present. Full recall
/// fields (content, salience) must remain absent.
#[tokio::test]
async fn test_recall_fuse_shape_preserved_after_retrieval_wiring() {
let rt = make_runtime();
let registry = make_registry(rt);

registry
.dispatch(
"remember",
json!({ "content": "shape regression check after retrieval wiring" }),
)
.await
.expect("remember");

let result = registry
.dispatch(
"recall.fuse",
json!({ "query": "shape regression retrieval wiring" }),
)
.await
.expect("recall.fuse");

// Top-level shape
assert!(
result.get("strategy").is_some(),
"strategy field must be present in recall.fuse response"
);
assert!(
result["candidate_limit"].as_u64().is_some(),
"candidate_limit must be a non-negative integer"
);

let fused = result["fused_candidates"]
.as_array()
.expect("fused_candidates array");
assert!(!fused.is_empty(), "fused_candidates must be non-empty");

let c = &fused[0];
assert!(
c["note_id"].as_str().is_some(),
"note_id must be a string UUID"
);
assert!(
c["fused_score"].as_f64().is_some(),
"fused_score must be a float"
);
let source = c["source"].as_str().expect("source must be a plain string");
assert!(
matches!(source, "text" | "vector" | "both"),
"source must be a plain label, got {source:?}"
);
// Full recall fields must not leak into fuse output
assert!(
c.get("content").is_none(),
"content must be absent from recall.fuse output"
);
assert!(
c.get("salience").is_none(),
"salience must be absent from recall.fuse output"
);
}

/// When include_breakdown is true, breakdown.total() must equal the hit's composite score.
#[tokio::test]
async fn test_recall_breakdown_total_matches_composite_score() {
Expand Down
2 changes: 1 addition & 1 deletion crates/khive-retrieval/src/graph/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Unit tests for graph traversal module.

use super::compat::{test_context, EntityRef, MockLinkStore};
use super::compat::{test_context, EntityRef, LinkStore, MockLinkStore};

use crate::graph::types::{
Direction, PathNode, TraversalOptions, MAX_TRAVERSAL_DEPTH, MAX_TRAVERSAL_RESULTS,
Expand Down
2 changes: 1 addition & 1 deletion crates/khive-retrieval/src/persist/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
//!
//! ```rust,no_run
//! use khive_retrieval::persist::RetrievalPersistence;
//! use khive_retrieval::hnsw::HnswIndex;
//! use khive_retrieval::HnswIndex;
//! use rusqlite::Connection;
//! use std::sync::Arc;
//! use tokio::sync::Mutex;
Expand Down
1 change: 1 addition & 0 deletions crates/khive-retrieval/src/persist/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::NodeId;
use khive_bm25::Bm25Index;
use khive_hnsw::HnswIndex;
use rusqlite::Connection;
Expand Down
21 changes: 18 additions & 3 deletions crates/khive-retrieval/src/replay/engine_replay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,11 +844,26 @@ pub mod metrics {
#[cfg(test)]
mod tests {
use super::*;
use khive_db::SqliteStore;

fn make_conn() -> Arc<Mutex<Connection>> {
let store = SqliteStore::memory().expect("in-memory store");
store.conn()
let conn = Connection::open_in_memory().expect("open in-memory db");
conn.execute_batch(
r#"
CREATE TABLE weight_events (
namespace TEXT NOT NULL,
atom_id TEXT NOT NULL,
delta REAL NOT NULL,
weight_after REAL NOT NULL,
channel TEXT NOT NULL,
eta REAL NOT NULL,
event_id TEXT,
context_id TEXT,
ts INTEGER NOT NULL
);
"#,
)
.expect("init replay test schema");
Arc::new(Mutex::new(conn))
}

fn insert_weight_event(
Expand Down
31 changes: 26 additions & 5 deletions crates/khive-retrieval/src/weights/engine_weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,35 @@ pub async fn batch_load_weights(
#[cfg(test)]
mod tests {
use super::*;
use khive_db::SqliteStore;
use std::sync::Arc;

fn make_conn() -> Arc<Mutex<Connection>> {
// Open an in-memory SQLite DB and run migrations so atom_weights and
// weight_events tables exist.
let store = SqliteStore::memory().expect("in-memory store");
store.conn()
let conn = Connection::open_in_memory().expect("open in-memory db");
conn.execute_batch(
r#"
CREATE TABLE atom_weights (
namespace TEXT NOT NULL,
atom_id TEXT NOT NULL,
weight REAL NOT NULL,
updated_at INTEGER NOT NULL,
version INTEGER NOT NULL DEFAULT 1,
PRIMARY KEY(namespace, atom_id)
);
CREATE TABLE weight_events (
namespace TEXT NOT NULL,
atom_id TEXT NOT NULL,
delta REAL NOT NULL,
weight_after REAL NOT NULL,
channel TEXT NOT NULL,
eta REAL NOT NULL,
event_id TEXT,
context_id TEXT,
ts INTEGER NOT NULL
);
"#,
)
.expect("init weight test schema");
Arc::new(Mutex::new(conn))
}

// -------------------------------------------------------------------------
Expand Down
Loading
Loading