From aa969142eed4f1904c7308028a80de85547b827c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=86=A0=E8=BE=B0?= Date: Tue, 26 May 2026 20:49:47 +0800 Subject: [PATCH] feat(recall): add optional remote rerank --- README.md | 2 + README_CN.md | 2 + openclaw.plugin.json | 14 +- src/config.ts | 26 +++ src/core/hooks/auto-recall.rerank.test.ts | 101 +++++++++++ src/core/hooks/auto-recall.ts | 67 ++++++- src/core/recall/reranker.test.ts | 132 ++++++++++++++ src/core/recall/reranker.ts | 189 ++++++++++++++++++++ src/core/tdai-core.ts | 1 + src/core/tools/memory-search.rerank.test.ts | 67 +++++++ src/core/tools/memory-search.ts | 23 ++- 11 files changed, 611 insertions(+), 13 deletions(-) create mode 100644 src/core/hooks/auto-recall.rerank.test.ts create mode 100644 src/core/recall/reranker.test.ts create mode 100644 src/core/recall/reranker.ts create mode 100644 src/core/tools/memory-search.rerank.test.ts diff --git a/README.md b/README.md index 58aa74f..d1ae219 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,7 @@ docker exec -it hermes-memory hermes | `pipeline.l1IdleTimeoutSeconds` | `600` | Trigger L1 after the user has been idle for this many seconds | | `pipeline.l2MinIntervalSeconds` | `900` | Minimum interval between two L2 passes within the same session | | `recall.timeoutMs` | `5000` | Recall timeout; on timeout, skip injection without blocking the conversation | +| `recall.rerank.enabled` | `false` | Optional remote rerank for L1 recall candidates; falls back to the original order on timeout or API failure | | `extraction.enableDedup` | `true` | L1 vector dedup / conflict detection | | `capture.excludeAgents` | `[]` | Glob patterns to exclude specific agents (e.g. `bench-judge-*`) | | `capture.l0l1RetentionDays` | `0` | Local retention days for L0 / L1 files; `0` = never clean up | @@ -291,6 +292,7 @@ docker exec -it hermes-memory hermes For all fields, types, and constraints see [`openclaw.plugin.json`](./openclaw.plugin.json)。 - `embedding.*` — remote embedding service (OpenAI-compatible API) +- `recall.rerank.*` — remote rerank service compatible with `/rerank` APIs - `llm.*` — standalone LLM mode (bypass OpenClaw's built-in model and run L1/L2/L3 with a designated API) - `offload.backendUrl / backendApiKey` — offload the L1/L1.5/L2/L4 flow to a backend service - `report.*` — metrics reporting diff --git a/README_CN.md b/README_CN.md index 16ef697..7249959 100644 --- a/README_CN.md +++ b/README_CN.md @@ -279,6 +279,7 @@ docker exec -it hermes-memory hermes | `pipeline.l1IdleTimeoutSeconds` | `600` | 用户停止对话多久后触发 L1 | | `pipeline.l2MinIntervalSeconds` | `900` | 同 session 两次 L2 之间的最小间隔 | | `recall.timeoutMs` | `5000` | 召回超时阈值,超时跳过注入不阻塞对话 | +| `recall.rerank.enabled` | `false` | 可选远程 rerank,用于重排 L1 召回候选;超时或 API 失败时回退原排序 | | `extraction.enableDedup` | `true` | L1 向量去重 / 冲突检测 | | `capture.excludeAgents` | `[]` | Glob 模式排除特定 Agent(如 `bench-judge-*`) | | `capture.l0l1RetentionDays` | `0` | L0/L1 本地文件保留天数,`0` = 永不清理 | @@ -295,6 +296,7 @@ docker exec -it hermes-memory hermes 完整字段、类型、约束见 [`openclaw.plugin.json`](./openclaw.plugin.json) 。 - `embedding.*` — 远程 embedding 服务(OpenAI 兼容 API) +- `recall.rerank.*` — 兼容 `/rerank` API 的远程 rerank 服务 - `llm.*` — 独立 LLM 模式(绕过 OpenClaw 内置模型,用指定 API 跑 L1/L2/L3) - `offload.backendUrl / backendApiKey` — 将 L1/L1.5/L2/L4 offload 流程卸载到后端服务 - `report.*` — 指标上报 diff --git a/openclaw.plugin.json b/openclaw.plugin.json index f6ea5fd..e8388e2 100644 --- a/openclaw.plugin.json +++ b/openclaw.plugin.json @@ -71,7 +71,19 @@ "maxResults": { "type": "number", "default": 5, "description": "召回最大结果数" }, "scoreThreshold": { "type": "number", "default": 0.3, "description": "最低分数阈值" }, "strategy": { "type": "string", "enum": ["embedding", "keyword", "hybrid"], "default": "hybrid", "description": "搜索策略:keyword(关键词)、embedding(向量)、hybrid(混合RRF融合,推荐)" }, - "timeoutMs": { "type": "number", "default": 5000, "description": "召回整体超时(毫秒),超时后跳过记忆注入并打印警告日志" } + "timeoutMs": { "type": "number", "default": 5000, "description": "召回整体超时(毫秒),超时后跳过记忆注入并打印警告日志" }, + "rerank": { + "type": "object", + "description": "远程 rerank 设置。默认关闭;开启后会先多召回候选,再调用兼容 /rerank 的云端 API 重排,失败时自动回退原排序", + "properties": { + "enabled": { "type": "boolean", "default": false, "description": "是否启用远程 rerank" }, + "baseUrl": { "type": "string", "description": "Rerank API Base URL;如果未以 /rerank 结尾,会自动追加 /rerank" }, + "apiKey": { "type": "string", "description": "Rerank API Key" }, + "model": { "type": "string", "description": "Rerank 模型名称" }, + "timeoutMs": { "type": "number", "default": 3000, "description": "Rerank 请求超时(毫秒),超时后回退原排序" }, + "candidateMultiplier": { "type": "number", "default": 3, "description": "Rerank 前候选召回倍数,相对于 recall.maxResults;范围建议 1-10" } + } + } } }, "embedding": { diff --git a/src/config.ts b/src/config.ts index e09cff5..2344f4c 100644 --- a/src/config.ts +++ b/src/config.ts @@ -75,6 +75,21 @@ export interface PipelineTriggerConfig { } /** Recall settings — controls memory retrieval for context injection. */ +export interface RecallRerankConfig { + /** Enable remote rerank for L1 recall candidates (default: false). */ + enabled: boolean; + /** OpenAI-compatible rerank API base URL. "/rerank" is appended when omitted. */ + baseUrl?: string; + /** API key for the rerank provider. */ + apiKey?: string; + /** Rerank model name. */ + model?: string; + /** Request timeout in milliseconds (default: 3000). */ + timeoutMs: number; + /** Candidate multiplier before rerank, relative to recall.maxResults (default: 3). */ + candidateMultiplier: number; +} + export interface RecallConfig { /** Enable auto-recall (default: true) */ enabled: boolean; @@ -86,6 +101,8 @@ export interface RecallConfig { strategy: "embedding" | "keyword" | "hybrid"; /** Overall recall timeout in milliseconds (default: 5000). When exceeded, recall is skipped with a warning. */ timeoutMs: number; + /** Optional remote rerank configuration. Disabled by default. */ + rerank: RecallRerankConfig; } /** Embedding service configuration for vector search. */ @@ -322,6 +339,7 @@ export function parseConfig(raw: Record | undefined): MemoryTda // --- Recall --- const recallGroup = obj(c, "recall"); + const rerankGroup = obj(recallGroup, "rerank"); // --- Embedding --- const embeddingGroup = obj(c, "embedding"); @@ -489,6 +507,14 @@ export function parseConfig(raw: Record | undefined): MemoryTda scoreThreshold: num(recallGroup, "scoreThreshold") ?? 0.3, strategy: validateStrategy(str(recallGroup, "strategy")) ?? "hybrid", timeoutMs: num(recallGroup, "timeoutMs") ?? 5000, + rerank: { + enabled: bool(rerankGroup, "enabled") ?? false, + baseUrl: str(rerankGroup, "baseUrl"), + apiKey: str(rerankGroup, "apiKey"), + model: str(rerankGroup, "model"), + timeoutMs: num(rerankGroup, "timeoutMs") ?? 3000, + candidateMultiplier: num(rerankGroup, "candidateMultiplier") ?? 3, + }, }, embedding: { enabled: embeddingEnabled, diff --git a/src/core/hooks/auto-recall.rerank.test.ts b/src/core/hooks/auto-recall.rerank.test.ts new file mode 100644 index 0000000..f6b6ab0 --- /dev/null +++ b/src/core/hooks/auto-recall.rerank.test.ts @@ -0,0 +1,101 @@ +import { mkdtempSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import path from "node:path"; +import { afterEach, describe, expect, it, vi } from "vitest"; + +import { parseConfig } from "../../config.js"; +import type { IMemoryStore, L1FtsResult } from "../store/types.js"; +import { performAutoRecall } from "./auto-recall.js"; + +describe("performAutoRecall rerank", () => { + let dataDir: string | undefined; + + afterEach(() => { + vi.unstubAllGlobals(); + if (dataDir) { + rmSync(dataDir, { recursive: true, force: true }); + dataDir = undefined; + } + }); + + it("reranks over-retrieved L1 candidates before injecting top results", async () => { + dataDir = mkdtempSync(path.join(tmpdir(), "memory-tdai-rerank-")); + vi.stubGlobal("fetch", vi.fn(async () => new Response( + JSON.stringify({ + results: [ + { index: 2, relevance_score: 0.95 }, + { index: 0, relevance_score: 0.52 }, + ], + }), + { status: 200 }, + ))); + + const store = { + isFtsAvailable: () => true, + searchL1Fts: vi.fn(async (_query: string, limit?: number): Promise => { + expect(limit).toBeGreaterThanOrEqual(6); + return [ + makeFtsResult("a", "无关的天气记录", 0.9), + makeFtsResult("b", "用户喜欢 Python", 0.89), + makeFtsResult("c", "用户明确偏好 TypeScript", 0.88), + ]; + }), + } as unknown as IMemoryStore; + + const cfg = parseConfig({ + recall: { + strategy: "keyword", + maxResults: 2, + scoreThreshold: 0, + rerank: { + enabled: true, + baseUrl: "https://api.example.com/v1", + apiKey: "test-key", + model: "bge-reranker-v2-m3", + candidateMultiplier: 3, + }, + }, + }); + + const result = await performAutoRecall({ + userText: "TypeScript 偏好", + actorId: "user", + sessionKey: "session", + cfg, + pluginDataDir: dataDir, + vectorStore: store, + }); + + const injected = extractRelevantMemoryLines(result?.prependContext); + expect(injected).toContain("用户明确偏好 TypeScript"); + expect(injected).toContain("无关的天气记录"); + expect(injected).not.toContain("用户喜欢 Python"); + expect(injected.indexOf("用户明确偏好 TypeScript")).toBeLessThan( + injected.indexOf("无关的天气记录"), + ); + }); +}); + +function makeFtsResult(id: string, content: string, score: number): L1FtsResult { + return { + record_id: id, + content, + type: "episodic", + priority: 80, + scene_name: "test", + score, + timestamp_str: "", + timestamp_start: "", + timestamp_end: "", + session_key: "session", + session_id: "session-1", + metadata_json: "{}", + }; +} + +function extractRelevantMemoryLines(prependContext: string | undefined): string { + const match = prependContext?.match( + /[\s\S]*?\n\n([\s\S]*?)\n<\/relevant-memories>/, + ); + return match?.[1] ?? ""; +} diff --git a/src/core/hooks/auto-recall.ts b/src/core/hooks/auto-recall.ts index cccb864..1094ba7 100644 --- a/src/core/hooks/auto-recall.ts +++ b/src/core/hooks/auto-recall.ts @@ -20,6 +20,7 @@ import type { IMemoryStore, L1SearchResult, L1FtsResult } from "../store/types.j import { buildFtsQuery } from "../store/sqlite.js"; import type { EmbeddingService, EmbeddingCallOptions } from "../store/embedding.js"; import { sanitizeText } from "../../utils/sanitize.js"; +import { getRerankCandidateLimit, rerankTextCandidates } from "../recall/reranker.js"; const TAG = "[memory-tdai] [recall]"; @@ -331,7 +332,8 @@ async function searchMemories( ); } - const maxResults = cfg.recall.maxResults ?? 5; + const maxResults = normalizePositiveInt(cfg.recall.maxResults, 5); + const candidateLimit = getRerankCandidateLimit(maxResults, cfg.recall.rerank); const threshold = cfg.recall.scoreThreshold ?? 0.3; const embeddingAvailable = !!vectorStore && !!embeddingService; @@ -340,7 +342,7 @@ async function searchMemories( `${TAG} [searchMemories] strategy=${strategy}, embeddingAvailable=${embeddingAvailable}, ` + `vectorStore=${vectorStore ? "available" : "UNAVAILABLE"}, ` + `embeddingService=${embeddingService ? "available" : "UNAVAILABLE"}, ` + - `maxResults=${maxResults}, threshold=${threshold}`, + `maxResults=${maxResults}, candidateLimit=${candidateLimit}, threshold=${threshold}`, ); // Determine effective strategy (fall back to keyword if embedding not available) @@ -362,14 +364,26 @@ async function searchMemories( try { if (effectiveStrategy === "keyword") { const tFts = performance.now(); - const lines = await searchByKeyword(cleanText, pluginDataDir, maxResults, threshold, logger, vectorStore); - return { lines, timing: { ftsMs: performance.now() - tFts, embeddingMs: 0, ftsHits: lines.length, embeddingHits: 0 } }; + const lines = await searchByKeyword(cleanText, pluginDataDir, candidateLimit, threshold, logger, vectorStore); + return await finalizeSearchResult( + { lines, timing: { ftsMs: performance.now() - tFts, embeddingMs: 0, ftsHits: lines.length, embeddingHits: 0 } }, + cleanText, + cfg, + logger, + maxResults, + ); } if (effectiveStrategy === "embedding") { const tEmb = performance.now(); - const lines = await searchByEmbedding(cleanText, maxResults, threshold, vectorStore!, embeddingService!, logger, embeddingCallOpts); - return { lines, timing: { ftsMs: 0, embeddingMs: performance.now() - tEmb, ftsHits: 0, embeddingHits: lines.length } }; + const lines = await searchByEmbedding(cleanText, candidateLimit, threshold, vectorStore!, embeddingService!, logger, embeddingCallOpts); + return await finalizeSearchResult( + { lines, timing: { ftsMs: 0, embeddingMs: performance.now() - tEmb, ftsHits: 0, embeddingHits: lines.length } }, + cleanText, + cfg, + logger, + maxResults, + ); } // Hybrid: if the store natively supports hybrid search (e.g. TCVDB does @@ -377,21 +391,51 @@ async function searchMemories( // to avoid a redundant second HTTP request and a wasted local embed(). if (vectorStore?.getCapabilities().nativeHybridSearch) { const tNative = performance.now(); - const results = await vectorStore.searchL1Hybrid({ query: cleanText, topK: maxResults }); + const results = await vectorStore.searchL1Hybrid({ query: cleanText, topK: candidateLimit }); const nativeMs = performance.now() - tNative; logger?.debug?.(`${TAG} [hybrid-native] Single-call hybrid: ${results.length} results in ${nativeMs.toFixed(0)}ms`); const lines = results.map((r) => formatMemoryLine(vectorResultToFormatable(r))); - return { lines, timing: { ftsMs: 0, embeddingMs: nativeMs, ftsHits: 0, embeddingHits: results.length } }; + return await finalizeSearchResult( + { lines, timing: { ftsMs: 0, embeddingMs: nativeMs, ftsHits: 0, embeddingHits: results.length } }, + cleanText, + cfg, + logger, + maxResults, + ); } // Fallback: run keyword + embedding in parallel, merge with client-side RRF (SQLite path) - return await searchHybrid(cleanText, pluginDataDir, maxResults, threshold, vectorStore!, embeddingService!, logger, embeddingCallOpts); + return await finalizeSearchResult( + await searchHybrid(cleanText, pluginDataDir, candidateLimit, threshold, vectorStore!, embeddingService!, logger, embeddingCallOpts), + cleanText, + cfg, + logger, + maxResults, + ); } catch (err) { logger?.warn?.(`${TAG} Memory search failed (strategy=${effectiveStrategy}): ${err instanceof Error ? err.message : String(err)}`); return emptyResult; } } +async function finalizeSearchResult( + result: SearchResult, + query: string, + cfg: MemoryTdaiConfig, + logger: Logger | undefined, + maxResults: number, +): Promise { + if (result.lines.length === 0) return result; + const lines = await rerankTextCandidates({ + query, + documents: result.lines, + topN: maxResults, + config: cfg.recall.rerank, + logger, + }); + return { ...result, lines }; +} + // ============================ // Strategy: Keyword (FTS5 BM25, no in-memory fallback) // ============================ @@ -725,6 +769,11 @@ function formatTimestamp(ts: string | undefined): string | undefined { return `${datePart} ${timePart}`; } +function normalizePositiveInt(value: number | undefined, fallback: number): number { + if (value == null || !Number.isFinite(value) || value <= 0) return fallback; + return Math.floor(value); +} + /** * Build a FormatableMemory from a full MemoryRecord (keyword search path). * Handles empty metadata, empty timestamps array gracefully. diff --git a/src/core/recall/reranker.test.ts b/src/core/recall/reranker.test.ts new file mode 100644 index 0000000..baa7d82 --- /dev/null +++ b/src/core/recall/reranker.test.ts @@ -0,0 +1,132 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +import { + getRerankCandidateLimit, + isRerankConfigured, + rerankCandidates, +} from "./reranker.js"; + +describe("recall reranker", () => { + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("keeps original top results when rerank is disabled", async () => { + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + const candidates = [ + { id: "a", text: "用户喜欢 TypeScript" }, + { id: "b", text: "用户喜欢 Python" }, + { id: "c", text: "用户喜欢 Rust" }, + ]; + + const result = await rerankCandidates({ + query: "TypeScript 偏好", + candidates, + topN: 2, + config: { enabled: false }, + getDocumentText: (item) => item.text, + }); + + expect(result.map((item) => item.id)).toEqual(["a", "b"]); + expect(fetchMock).not.toHaveBeenCalled(); + expect(isRerankConfigured({ enabled: false })).toBe(false); + expect(getRerankCandidateLimit(5, { enabled: false })).toBe(5); + }); + + it("reorders candidates with remote relevance scores", async () => { + const fetchMock = vi.fn(async (_url: string, init: RequestInit) => { + expect(JSON.parse(init.body as string)).toEqual({ + model: "bge-reranker-v2-m3", + query: "TypeScript 偏好", + documents: [ + "无关的天气记录", + "用户喜欢 Python", + "用户明确偏好 TypeScript", + ], + top_n: 2, + }); + + return new Response( + JSON.stringify({ + results: [ + { index: 2, relevance_score: 0.92 }, + { index: 1, relevance_score: 0.41 }, + ], + }), + { status: 200 }, + ); + }); + vi.stubGlobal("fetch", fetchMock); + + const candidates = [ + { id: "a", text: "无关的天气记录" }, + { id: "b", text: "用户喜欢 Python" }, + { id: "c", text: "用户明确偏好 TypeScript" }, + ]; + + const result = await rerankCandidates({ + query: "TypeScript 偏好", + candidates, + topN: 2, + config: { + enabled: true, + baseUrl: "https://api.example.com/v1", + apiKey: "test-key", + model: "bge-reranker-v2-m3", + timeoutMs: 1000, + candidateMultiplier: 4, + }, + getDocumentText: (item) => item.text, + }); + + expect(fetchMock).toHaveBeenCalledWith( + "https://api.example.com/v1/rerank", + expect.objectContaining({ + method: "POST", + headers: expect.objectContaining({ + Authorization: "Bearer test-key", + "Content-Type": "application/json", + }), + }), + ); + expect(result.map((item) => item.id)).toEqual(["c", "b"]); + expect(getRerankCandidateLimit(5, { + enabled: true, + baseUrl: "https://api.example.com/v1", + apiKey: "test-key", + model: "bge-reranker-v2-m3", + candidateMultiplier: 4, + })).toBe(20); + }); + + it("falls back to original top results when remote rerank fails", async () => { + vi.stubGlobal("fetch", vi.fn(async () => new Response("bad gateway", { status: 502 }))); + const warn = vi.fn(); + + const candidates = [ + { id: "a", text: "第一条" }, + { id: "b", text: "第二条" }, + { id: "c", text: "第三条" }, + ]; + + const result = await rerankCandidates({ + query: "查询", + candidates, + topN: 2, + config: { + enabled: true, + baseUrl: "https://api.example.com/v1/rerank", + apiKey: "test-key", + model: "reranker", + timeoutMs: 1000, + }, + getDocumentText: (item) => item.text, + logger: { warn, info: vi.fn(), error: vi.fn() }, + }); + + expect(result.map((item) => item.id)).toEqual(["a", "b"]); + expect(warn).toHaveBeenCalledWith(expect.stringContaining("Remote rerank failed")); + }); +}); diff --git a/src/core/recall/reranker.ts b/src/core/recall/reranker.ts new file mode 100644 index 0000000..9ca37db --- /dev/null +++ b/src/core/recall/reranker.ts @@ -0,0 +1,189 @@ +import type { RecallRerankConfig } from "../../config.js"; + +const TAG = "[memory-tdai] [recall-rerank]"; +const DEFAULT_MAX_RESULTS = 5; +const DEFAULT_CANDIDATE_MULTIPLIER = 3; +const MAX_CANDIDATE_MULTIPLIER = 10; + +interface Logger { + debug?: (message: string) => void; + info: (message: string) => void; + warn: (message: string) => void; + error: (message: string) => void; +} + +export interface RerankCandidatesOptions { + query: string; + candidates: T[]; + topN: number; + config?: RecallRerankConfig; + getDocumentText: (candidate: T) => string; + logger?: Logger; +} + +interface RemoteRerankResult { + index: number; + score: number; +} + +export function isRerankConfigured(config: RecallRerankConfig | undefined): boolean { + return !!( + config?.enabled && + config.baseUrl?.trim() && + config.apiKey?.trim() && + config.model?.trim() + ); +} + +export function getRerankCandidateLimit( + maxResults: number | undefined, + config: RecallRerankConfig | undefined, +): number { + const topN = normalizePositiveInt(maxResults, DEFAULT_MAX_RESULTS); + if (!isRerankConfigured(config)) return topN; + return topN * normalizeCandidateMultiplier(config?.candidateMultiplier); +} + +export async function rerankTextCandidates(options: { + query: string; + documents: string[]; + topN: number; + config?: RecallRerankConfig; + logger?: Logger; +}): Promise { + return rerankCandidates({ + query: options.query, + candidates: options.documents, + topN: options.topN, + config: options.config, + getDocumentText: (document) => document, + logger: options.logger, + }); +} + +export async function rerankCandidates(options: RerankCandidatesOptions): Promise { + const topN = normalizePositiveInt(options.topN, DEFAULT_MAX_RESULTS); + const fallback = options.candidates.slice(0, topN); + + if (options.candidates.length <= 1) return fallback; + if (!isRerankConfigured(options.config)) return fallback; + + const config = options.config; + const documents = options.candidates.map(options.getDocumentText); + const requestTopN = Math.min(topN, documents.length); + + try { + const results = await callRemoteRerank({ + query: options.query, + documents, + topN: requestTopN, + config, + }); + + if (results.length === 0) return fallback; + + const selected: T[] = []; + const seen = new Set(); + + for (const result of results) { + if (result.index < 0 || result.index >= options.candidates.length) continue; + if (seen.has(result.index)) continue; + selected.push(options.candidates[result.index]); + seen.add(result.index); + if (selected.length >= topN) break; + } + + for (let index = 0; index < options.candidates.length && selected.length < topN; index++) { + if (seen.has(index)) continue; + selected.push(options.candidates[index]); + } + + options.logger?.debug?.( + `${TAG} Reranked ${options.candidates.length} candidates to top ${selected.length}`, + ); + return selected; + } catch (err) { + options.logger?.warn?.( + `${TAG} Remote rerank failed; using original recall order: ${err instanceof Error ? err.message : String(err)}`, + ); + return fallback; + } +} + +async function callRemoteRerank(params: { + query: string; + documents: string[]; + topN: number; + config: RecallRerankConfig; +}): Promise { + const timeoutMs = normalizePositiveInt(params.config.timeoutMs, 3000); + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), timeoutMs); + + try { + const response = await fetch(buildRerankUrl(params.config.baseUrl!), { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${params.config.apiKey!}`, + }, + body: JSON.stringify({ + model: params.config.model, + query: params.query, + documents: params.documents, + top_n: params.topN, + }), + signal: controller.signal, + }); + + if (!response.ok) { + const body = await response.text().catch(() => ""); + throw new Error(`HTTP ${response.status} ${response.statusText}${body ? `: ${body.slice(0, 200)}` : ""}`); + } + + const payload = await response.json() as unknown; + return parseRemoteRerankResults(payload); + } finally { + clearTimeout(timeoutId); + } +} + +function parseRemoteRerankResults(payload: unknown): RemoteRerankResult[] { + const obj = payload && typeof payload === "object" ? payload as Record : {}; + const rawResults = Array.isArray(obj.results) + ? obj.results + : Array.isArray(obj.data) + ? obj.data + : []; + + return rawResults + .map((item): RemoteRerankResult | undefined => { + if (!item || typeof item !== "object") return undefined; + const record = item as Record; + const index = typeof record.index === "number" ? record.index : Number.NaN; + const scoreValue = typeof record.relevance_score === "number" + ? record.relevance_score + : typeof record.score === "number" + ? record.score + : Number.NaN; + if (!Number.isInteger(index) || !Number.isFinite(scoreValue)) return undefined; + return { index, score: scoreValue }; + }) + .filter((item): item is RemoteRerankResult => !!item) + .sort((a, b) => b.score - a.score); +} + +function buildRerankUrl(baseUrl: string): string { + const trimmed = baseUrl.trim().replace(/\/+$/, ""); + return trimmed.endsWith("/rerank") ? trimmed : `${trimmed}/rerank`; +} + +function normalizeCandidateMultiplier(value: number | undefined): number { + const normalized = normalizePositiveInt(value, DEFAULT_CANDIDATE_MULTIPLIER); + return Math.min(Math.max(normalized, 1), MAX_CANDIDATE_MULTIPLIER); +} + +function normalizePositiveInt(value: number | undefined, fallback: number): number { + if (value == null || !Number.isFinite(value) || value <= 0) return fallback; + return Math.floor(value); +} diff --git a/src/core/tdai-core.ts b/src/core/tdai-core.ts index 977d4a2..06fc105 100644 --- a/src/core/tdai-core.ts +++ b/src/core/tdai-core.ts @@ -295,6 +295,7 @@ export class TdaiCore { scene: params.scene, vectorStore: this.vectorStore, embeddingService: this.embeddingService, + rerank: this.cfg.recall.rerank, logger: this.logger, }); diff --git a/src/core/tools/memory-search.rerank.test.ts b/src/core/tools/memory-search.rerank.test.ts new file mode 100644 index 0000000..05e0c08 --- /dev/null +++ b/src/core/tools/memory-search.rerank.test.ts @@ -0,0 +1,67 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +import type { IMemoryStore, L1FtsResult } from "../store/types.js"; +import { executeMemorySearch } from "./memory-search.js"; + +describe("executeMemorySearch rerank", () => { + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("reranks tool search candidates before trimming to the requested limit", async () => { + vi.stubGlobal("fetch", vi.fn(async () => new Response( + JSON.stringify({ + results: [ + { index: 2, score: 0.93 }, + { index: 0, score: 0.42 }, + ], + }), + { status: 200 }, + ))); + + const store = { + isFtsAvailable: () => true, + searchL1Fts: vi.fn(async (_query: string, limit?: number): Promise => { + expect(limit).toBeGreaterThanOrEqual(6); + return [ + makeFtsResult("a", "无关的天气记录", 0.9), + makeFtsResult("b", "用户喜欢 Python", 0.89), + makeFtsResult("c", "用户明确偏好 TypeScript", 0.88), + ]; + }), + } as unknown as IMemoryStore; + + const result = await executeMemorySearch({ + query: "TypeScript 偏好", + limit: 2, + vectorStore: store, + rerank: { + enabled: true, + baseUrl: "https://api.example.com/v1", + apiKey: "test-key", + model: "bge-reranker-v2-m3", + timeoutMs: 1000, + candidateMultiplier: 3, + }, + }); + + expect(result.results.map((item) => item.id)).toEqual(["c", "a"]); + }); +}); + +function makeFtsResult(id: string, content: string, score: number): L1FtsResult { + return { + record_id: id, + content, + type: "episodic", + priority: 80, + scene_name: "test", + score, + timestamp_str: "", + timestamp_start: "", + timestamp_end: "", + session_key: "session", + session_id: "session-1", + metadata_json: "{}", + }; +} diff --git a/src/core/tools/memory-search.ts b/src/core/tools/memory-search.ts index dc9d2c2..cba7f2c 100644 --- a/src/core/tools/memory-search.ts +++ b/src/core/tools/memory-search.ts @@ -13,6 +13,8 @@ import type { IMemoryStore, L1SearchResult } from "../store/types.js"; import { buildFtsQuery } from "../store/sqlite.js"; import type { EmbeddingService } from "../store/embedding.js"; +import type { RecallRerankConfig } from "../../config.js"; +import { getRerankCandidateLimit, rerankCandidates } from "../recall/reranker.js"; // ============================ // Types @@ -92,6 +94,7 @@ export async function executeMemorySearch(params: { scene?: string; vectorStore?: IMemoryStore; embeddingService?: EmbeddingService; + rerank?: RecallRerankConfig; logger?: Logger; }): Promise { const { @@ -101,6 +104,7 @@ export async function executeMemorySearch(params: { scene: sceneFilter, vectorStore, embeddingService, + rerank, logger, } = params; @@ -139,7 +143,8 @@ export async function executeMemorySearch(params: { } // ── Over-retrieve for later filtering and RRF merging ── - const candidateK = limit * 3; + const rerankCandidateLimit = getRerankCandidateLimit(limit, rerank); + const candidateK = Math.max(limit * 3, rerankCandidateLimit); // ── Run available search strategies in parallel ── const [ftsItems, vecItems] = await Promise.all([ @@ -245,8 +250,15 @@ export async function executeMemorySearch(params: { logger?.debug?.(`${TAG} After scene filter "${sceneFilter}": ${results.length}/${preFilterCount}`); } - // ── Trim to requested limit ── - const trimmed = results.slice(0, limit); + // ── Optional remote rerank, then trim to requested limit ── + const trimmed = await rerankCandidates({ + query, + candidates: results.slice(0, rerankCandidateLimit), + topN: limit, + config: rerank, + getDocumentText: formatRerankDocument, + logger, + }); logger?.debug?.( `${TAG} RESULT (strategy=${strategy}): returning ${trimmed.length} memories ` + @@ -260,6 +272,11 @@ export async function executeMemorySearch(params: { }; } +function formatRerankDocument(item: MemorySearchResultItem): string { + const scene = item.scene_name ? ` scene=${item.scene_name}` : ""; + return `[${item.type}${scene}] ${item.content}`; +} + // ============================ // Tool response formatter // ============================