diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 267d1bb28..8e7785ad5 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -64,7 +64,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Expand top_k for deduplication (5x to ensure enough candidates) if search_req_local.dedup in ("sim", "mmr"): - search_req_local.top_k = search_req_local.top_k * 5 + search_req_local.top_k = search_req_local.top_k * 3 # Search and deduplicate cube_view = self._build_cube_view(search_req_local) @@ -152,9 +152,6 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di return results embeddings = self._extract_embeddings([mem for _, mem, _ in flat]) - if embeddings is None: - documents = [mem.get("memory", "") for _, mem, _ in flat] - embeddings = self.searcher.embedder.embed(documents) similarity_matrix = cosine_similarity_matrix(embeddings) @@ -235,12 +232,39 @@ def _mmr_dedup_text_memories( if len(flat) <= 1: return results + total_by_type: dict[str, int] = {"text": 0, "preference": 0} + existing_by_type: dict[str, int] = {"text": 0, "preference": 0} + missing_by_type: dict[str, int] = {"text": 0, "preference": 0} + missing_indices: list[int] = [] + for idx, (mem_type, _, mem, _) in enumerate(flat): + if mem_type not in total_by_type: + total_by_type[mem_type] = 0 + existing_by_type[mem_type] = 0 + missing_by_type[mem_type] = 0 + total_by_type[mem_type] += 1 + + embedding = mem.get("metadata", {}).get("embedding") + if embedding: + existing_by_type[mem_type] += 1 + else: + missing_by_type[mem_type] += 1 + missing_indices.append(idx) + + self.logger.info( + "[SearchHandler] MMR embedding metadata scan: total=%s total_by_type=%s existing_by_type=%s missing_by_type=%s", + len(flat), + total_by_type, + existing_by_type, + missing_by_type, + ) + if missing_indices: + self.logger.warning( + "[SearchHandler] MMR embedding metadata missing; will compute missing embeddings: missing_total=%s", + len(missing_indices), + ) + # Get or compute embeddings embeddings = self._extract_embeddings([mem for _, _, mem, _ in flat]) - if embeddings is None: - self.logger.warning("[SearchHandler] Embedding is missing; recomputing embeddings") - documents = [mem.get("memory", "") for _, _, mem, _ in flat] - embeddings = self.searcher.embedder.embed(documents) # Compute similarity matrix using NumPy-optimized method # Returns numpy array but compatible with list[i][j] indexing @@ -404,14 +428,32 @@ def _max_similarity( return 0.0 return max(similarity_matrix[index][j] for j in selected_indices) - @staticmethod - def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None: + def _extract_embeddings(self, memories: list[dict[str, Any]]) -> list[list[float]]: embeddings: list[list[float]] = [] - for mem in memories: - embedding = mem.get("metadata", {}).get("embedding") - if not embedding: - return None - embeddings.append(embedding) + missing_indices: list[int] = [] + missing_documents: list[str] = [] + + for idx, mem in enumerate(memories): + metadata = mem.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + mem["metadata"] = metadata + + embedding = metadata.get("embedding") + if embedding: + embeddings.append(embedding) + continue + + embeddings.append([]) + missing_indices.append(idx) + missing_documents.append(mem.get("memory", "")) + + if missing_indices: + computed = self.searcher.embedder.embed(missing_documents) + for idx, embedding in zip(missing_indices, computed, strict=False): + embeddings[idx] = embedding + memories[idx]["metadata"]["embedding"] = embedding + return embeddings @staticmethod diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 6352d5840..8483a5151 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -124,25 +124,45 @@ def retrieve( explicit_prefs.sort(key=lambda x: x.score, reverse=True) implicit_prefs.sort(key=lambda x: x.score, reverse=True) - explicit_prefs_mem = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), + explicit_prefs_mem = [] + for pref in explicit_prefs: + if not pref.payload.get("preference", None): + continue + if "embedding" in pref.payload: + payload = pref.payload + else: + pref_vector = getattr(pref, "vector", None) + if pref_vector is None: + payload = pref.payload + else: + payload = {**pref.payload, "embedding": pref_vector} + explicit_prefs_mem.append( + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**payload), + ) ) - for pref in explicit_prefs - if pref.payload.get("preference", None) - ] - implicit_prefs_mem = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), + implicit_prefs_mem = [] + for pref in implicit_prefs: + if not pref.payload.get("preference", None): + continue + if "embedding" in pref.payload: + payload = pref.payload + else: + pref_vector = getattr(pref, "vector", None) + if pref_vector is None: + payload = pref.payload + else: + payload = {**pref.payload, "embedding": pref_vector} + implicit_prefs_mem.append( + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**payload), + ) ) - for pref in implicit_prefs - if pref.payload.get("preference", None) - ] reranker_map = { "naive": self._naive_reranker, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9dcbe8c56..cc269e8c4 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -524,7 +524,7 @@ def _retrieve_from_keyword( user_name=user_name, tsquery_config="jiebaqry", ) - except Exception as e: + except Exception: logger.warning( f"[PATH-KEYWORD] search_by_fulltext failed, scope={scope}, user_name={user_name}" ) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 1678d9d15..d890c77bf 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -443,7 +443,10 @@ def _search_pref( }, search_filter=search_req.filter, ) - formatted_results = self._postformat_memories(results, user_context.mem_cube_id) + include_embedding = os.getenv("INCLUDE_EMBEDDING", "false") == "true" + formatted_results = self._postformat_memories( + results, user_context.mem_cube_id, include_embedding=include_embedding + ) # For each returned item, tackle with metadata.info project_id / # operation / manager_user_id