From 5f993bc80dcd62df3580966aa257609741a9a851 Mon Sep 17 00:00:00 2001 From: jiang Date: Sat, 28 Feb 2026 16:53:32 +0800 Subject: [PATCH 1/5] test: add rerank model --- src/memos/api/handlers/search_handler.py | 181 +++++++++++++++++++++-- 1 file changed, 172 insertions(+), 9 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 267d1bb28..965b05191 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -20,10 +20,15 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.reranker.http_bge import HTTPBGEReranker logger = get_logger(__name__) +# Temporary test endpoint override for HTTP reranker. +_TEST_RERANKER_URL = "https://0f31618ba0cf91b57d.gradio.live" +_TEST_RERANKER_MODEL = "zerank-2" + class SearchHandler(BaseHandler): """ @@ -43,6 +48,13 @@ def __init__(self, dependencies: HandlerDependencies): self._validate_dependencies( "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" ) + # Use a dedicated HTTP reranker for this handler to avoid dependency + # conflicts when global config uses a different reranker backend. + self.http_test_reranker = HTTPBGEReranker( + reranker_url=_TEST_RERANKER_URL, + model=_TEST_RERANKER_MODEL, + timeout=20, + ) def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -62,9 +74,14 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Use deepcopy to avoid modifying the original request object search_req_local = copy.deepcopy(search_req) - # Expand top_k for deduplication (5x to ensure enough candidates) + # Expand retrieval top_k for dedup candidate recall. + # Keep MMR output smaller than recall set to preserve filtering value. + recall_multiplier = 5 + mmr_output_multiplier = 2 + dedup_multiplier = 1 if search_req_local.dedup in ("sim", "mmr"): - search_req_local.top_k = search_req_local.top_k * 5 + dedup_multiplier = recall_multiplier + search_req_local.top_k = search_req_local.top_k * dedup_multiplier # Search and deduplicate cube_view = self._build_cube_view(search_req_local) @@ -75,21 +92,52 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse results = self._apply_relativity_threshold(results, search_req_local.relativity) if search_req_local.dedup == "sim": - results = self._dedup_text_memories(results, search_req.top_k) - self._strip_embeddings(results) + # Keep a larger candidate pool for downstream reranking. + results = self._dedup_text_memories(results, search_req_local.top_k) elif search_req_local.dedup == "mmr": - pref_top_k = getattr(search_req_local, "pref_top_k", 6) - results = self._mmr_dedup_text_memories(results, search_req.top_k, pref_top_k) - self._strip_embeddings(results) + pref_top_k_final = getattr( + search_req, "pref_top_k", getattr(search_req_local, "pref_top_k", 6) + ) + mmr_text_top_k = max(search_req.top_k * mmr_output_multiplier, search_req.top_k) + mmr_pref_top_k = max(pref_top_k_final * mmr_output_multiplier, pref_top_k_final) + # MMR keeps an intermediate candidate pool (smaller than recall set), + # then reranker reduces to final top_k. + results = self._mmr_dedup_text_memories( + results, + mmr_text_top_k, + mmr_pref_top_k, + ) + + text_rerank_pool_top_k = search_req_local.top_k + if search_req_local.dedup == "mmr": + text_rerank_pool_top_k = max( + search_req.top_k * mmr_output_multiplier, search_req.top_k + ) text_mem = results["text_mem"] results["text_mem"] = rerank_knowledge_mem( - self.reranker, + self.http_test_reranker, query=search_req.query, text_mem=text_mem, - top_k=search_req_local.top_k, + top_k=text_rerank_pool_top_k, file_mem_proportion=0.5, ) + pref_top_k_final = getattr( + search_req, "pref_top_k", getattr(search_req_local, "pref_top_k", 6) + ) + reranked_text_mem, reranked_pref_mem = self._rerank_text_and_pref_memories( + search_req.query, + results.get("text_mem", []), + results.get("pref_mem", []), + text_top_k=search_req.top_k, + pref_top_k=pref_top_k_final, + ) + results["text_mem"] = reranked_text_mem + results["pref_mem"] = reranked_pref_mem + + # Remove embeddings only after reranking so reranker still has full signals. + if search_req_local.dedup in ("sim", "mmr"): + self._strip_embeddings(results) self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" @@ -414,6 +462,121 @@ def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | N embeddings.append(embedding) return embeddings + def _rerank_text_and_pref_memories( + self, + query: str, + text_groups: list[dict[str, Any]], + pref_groups: list[dict[str, Any]], + text_top_k: int, + pref_top_k: int, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + if text_top_k <= 0 and pref_top_k <= 0: + return [], [] + + candidates: list[tuple[str, str, dict[str, Any]]] = [] + + def extend_candidates(mem_type: str, groups: list[dict[str, Any]]) -> None: + if not isinstance(groups, list): + return + for group in groups: + if not isinstance(group, dict): + continue + cube_id = group.get("cube_id") + memories = group.get("memories", []) + if cube_id is None or not isinstance(memories, list): + continue + for mem in memories: + if isinstance(mem, dict): + candidates.append((mem_type, str(cube_id), mem)) + + extend_candidates("text", text_groups) + extend_candidates("preference", pref_groups) + if not candidates: + return [], [] + + flat_memories = [mem for _, _, mem in candidates] + ranked_flat: list[dict[str, Any]] = [] + if self.http_test_reranker: + try: + total_top_k = max(0, text_top_k) + max(0, pref_top_k) + rerank_res = self.http_test_reranker.rerank( + query, + flat_memories, + top_k=min(total_top_k, len(flat_memories)), + ) + ranked_flat = [item for item, _score in rerank_res if isinstance(item, dict)] + except Exception as exc: + self.logger.warning( + f"[SearchHandler] joint reranker failed, fallback to score sort: {exc}" + ) + + if not ranked_flat: + def fallback_score(mem_type: str, mem: dict[str, Any]) -> float: + meta = mem.get("metadata") + if not isinstance(meta, dict): + return 0.0 + keys = ("relativity", "score") if mem_type == "text" else ("score", "relativity") + for key in keys: + value = meta.get(key) + if value is None: + continue + try: + return float(value) + except (TypeError, ValueError): + pass + return 0.0 + + ranked_candidates = sorted( + candidates, + key=lambda item: fallback_score(item[0], item[2]), + reverse=True, + ) + ranked_flat = [mem for _, _, mem in ranked_candidates] + + candidate_type_cube: dict[str, tuple[str, str]] = {} + for mem_type, cube_id, mem in candidates: + mem_id = mem.get("id") + if mem_id: + candidate_type_cube[mem_id] = (mem_type, cube_id) + + text_selected = 0 + pref_selected = 0 + seen_ids: set[str] = set() + text_by_cube: dict[str, list[dict[str, Any]]] = {} + pref_by_cube: dict[str, list[dict[str, Any]]] = {} + + for mem in ranked_flat: + mem_id = mem.get("id") + if not mem_id or mem_id in seen_ids: + continue + type_cube = candidate_type_cube.get(mem_id) + if not type_cube: + continue + mem_type, cube_id = type_cube + if mem_type == "text": + if text_selected >= text_top_k: + continue + text_by_cube.setdefault(cube_id, []).append(mem) + text_selected += 1 + else: + if pref_selected >= pref_top_k: + continue + pref_by_cube.setdefault(cube_id, []).append(mem) + pref_selected += 1 + seen_ids.add(mem_id) + if text_selected >= text_top_k and pref_selected >= pref_top_k: + break + + text_res = [ + {"cube_id": cube_id, "memories": memories, "total_nodes": len(memories)} + for cube_id, memories in text_by_cube.items() + ] + pref_res = [ + {"cube_id": cube_id, "memories": memories, "total_nodes": len(memories)} + for cube_id, memories in pref_by_cube.items() + ] + return text_res, pref_res + @staticmethod def _strip_embeddings(results: dict[str, Any]) -> None: for _mem_type, mem_results in results.items(): From 51385b4bcfbce37df3bfd4980ca149441a937372 Mon Sep 17 00:00:00 2001 From: jiang Date: Tue, 3 Mar 2026 19:41:43 +0800 Subject: [PATCH 2/5] test: test reranker model --- src/memos/api/handlers/search_handler.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 58121776e..ba1c50b07 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -120,10 +120,7 @@ def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> d if not isinstance(mem, dict): continue meta = mem.get("metadata", {}) - if key == "text_mem": - score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 - else: - score = meta.get("score", 1.0) if isinstance(meta, dict) else 1.0 + score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 try: score_val = float(score) if score is not None else 1.0 except (TypeError, ValueError): From 6dde166eff4f2c1437215d9aa93b972fcdb5625f Mon Sep 17 00:00:00 2001 From: jiang Date: Tue, 3 Mar 2026 19:53:36 +0800 Subject: [PATCH 3/5] test: test reranker model --- src/memos/memories/textual/tree_text_memory/retrieve/searcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index b4994671f..9e7cc4f11 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -936,6 +936,7 @@ def _retrieve_from_preference_memory( id_filter=id_filter, use_fast_graph=self.use_fast_graph, ) + logger.info(f"[test preference memory] prefence memory len: {len(items)}") return self.reranker.rerank( query=query, From 7a9629ba970b6f339a815df84fc567639659f91f Mon Sep 17 00:00:00 2001 From: jiang Date: Tue, 3 Mar 2026 20:50:12 +0800 Subject: [PATCH 4/5] test: delete useless log --- .../memories/textual/tree_text_memory/retrieve/searcher.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9e7cc4f11..278debe28 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -936,8 +936,6 @@ def _retrieve_from_preference_memory( id_filter=id_filter, use_fast_graph=self.use_fast_graph, ) - logger.info(f"[test preference memory] prefence memory len: {len(items)}") - return self.reranker.rerank( query=query, query_embedding=query_embedding[0], From c1b2528ba32b593f2a43aacdf7c5c5713d2a8a27 Mon Sep 17 00:00:00 2001 From: jiang Date: Tue, 3 Mar 2026 20:53:31 +0800 Subject: [PATCH 5/5] test: reformat --- src/memos/memories/textual/tree_text_memory/retrieve/searcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 278debe28..b4994671f 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -936,6 +936,7 @@ def _retrieve_from_preference_memory( id_filter=id_filter, use_fast_graph=self.use_fast_graph, ) + return self.reranker.rerank( query=query, query_embedding=query_embedding[0],