From 6aeaefeb45954c0fed019d6dfe820c58829fe44b Mon Sep 17 00:00:00 2001 From: r266-tech Date: Thu, 26 Feb 2026 04:42:10 +0800 Subject: [PATCH] fix: search API now iterates all memory_types instead of only [0] Fixes #78 The search API accepted a list of memory_types but only used memory_types[0], silently ignoring all other types. This caused: - Silent data loss: only the first type was searched - Errors when unsupported types (e.g. profile) were first in the list Changes: - get_keyword_search_results: loop over all memory_types, search ES for each supported type, skip unsupported with info log, merge & dedup - get_vector_search_results: loop over all memory_types, embed query once, search Milvus for each supported type, merge & dedup - Add MILVUS_REPO_MAP for cleaner type-to-repo mapping - Add _memory_types_label() helper for metric labels - Update all metric references from memory_types[0].value to comma-joined label for proper observability --- src/agentic_layer/memory_manager.py | 270 ++++++++++++++-------------- 1 file changed, 136 insertions(+), 134 deletions(-) diff --git a/src/agentic_layer/memory_manager.py b/src/agentic_layer/memory_manager.py index df42b05a..e38ebfe5 100644 --- a/src/agentic_layer/memory_manager.py +++ b/src/agentic_layer/memory_manager.py @@ -95,6 +95,20 @@ } +MILVUS_REPO_MAP = { + MemoryType.FORESIGHT: ForesightMilvusRepository, + MemoryType.EVENT_LOG: EventLogMilvusRepository, + MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository, +} + + +def _memory_types_label(memory_types) -> str: + """Comma-joined memory type values for metric labels.""" + if not memory_types: + return 'unknown' + return ','.join(mt.value for mt in memory_types) + + @dataclass class EventLogCandidate: """Event Log candidate object (used for retrieval from atomic_fact)""" @@ -298,11 +312,7 @@ async def retrieve_mem_keyword( ) -> RetrieveMemResponse: """Keyword-based memory retrieval""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: hits = await self.get_keyword_search_results( @@ -339,11 +349,7 @@ async def get_keyword_search_results( ) -> List[Dict[str, Any]]: """Keyword search with stage-level metrics""" stage_start = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: # Get parameters from Request @@ -375,32 +381,41 @@ async def get_keyword_search_results( if end_time is not None: date_range["lte"] = end_time - mem_type = memory_types[0] - - repo_class = ES_REPO_MAP.get(mem_type) - if not repo_class: - logger.warning(f"Unsupported memory_type: {mem_type}") - return [] + # Iterate over ALL memory_types and merge results + all_results = [] + seen_ids = set() + for mem_type in memory_types: + repo_class = ES_REPO_MAP.get(mem_type) + if not repo_class: + logger.info(f"Skipping unsupported memory type for keyword search: {mem_type}") + continue - es_repo = get_bean_by_type(repo_class) - logger.debug(f"Using {repo_class.__name__} for {mem_type}") + es_repo = get_bean_by_type(repo_class) + logger.debug(f"Using {repo_class.__name__} for {mem_type}") - results = await es_repo.multi_search( - query=query_words, - user_id=user_id, - group_id=group_id, - size=top_k, - from_=0, - date_range=date_range, - ) + results = await es_repo.multi_search( + query=query_words, + user_id=user_id, + group_id=group_id, + size=top_k, + from_=0, + date_range=date_range, + ) - # Mark memory_type, search_source, and unified score - if results: - for r in results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.KEYWORD.value - r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' - r['score'] = r.get('_score', 0.0) # Unified score field + # Mark memory_type, search_source, and unified score + if results: + for r in results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.KEYWORD.value + r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' + r['score'] = r.get('_score', 0.0) # Unified score field + # Deduplicate by id + rid = r.get('id', '') + if rid and rid in seen_ids: + continue + if rid: + seen_ids.add(rid) + all_results.append(r) # Record stage metrics record_retrieve_stage( @@ -410,7 +425,7 @@ async def get_keyword_search_results( duration_seconds=time.perf_counter() - stage_start, ) - return results or [] + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, @@ -433,11 +448,7 @@ async def retrieve_mem_vector( ) -> RetrieveMemResponse: """Vector-based memory retrieval""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: hits = await self.get_vector_search_results( @@ -473,11 +484,7 @@ async def get_vector_search_results( retrieve_method: str = RetrieveMethod.VECTOR.value, ) -> List[Dict[str, Any]]: """Vector search with stage-level metrics (embedding + milvus_search)""" - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: # Get parameters from Request @@ -497,7 +504,6 @@ async def get_vector_search_results( top_k = retrieve_mem_request.top_k start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - mem_type = retrieve_mem_request.memory_types[0] logger.debug( f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}" @@ -506,7 +512,7 @@ async def get_vector_search_results( # Get vectorization service vectorize_service = get_vectorize_service() - # Convert query text to vector (embedding stage) + # Convert query text to vector (embedding stage) — only once logger.debug(f"Starting to vectorize query text: {query}") embedding_start = time.perf_counter() query_vector = await vectorize_service.get_embedding(query) @@ -521,74 +527,87 @@ async def get_vector_search_results( f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" ) - # Select Milvus repository based on memory type - match mem_type: - case MemoryType.FORESIGHT: - milvus_repo = get_bean_by_type(ForesightMilvusRepository) - case MemoryType.EVENT_LOG: - milvus_repo = get_bean_by_type(EventLogMilvusRepository) - case MemoryType.EPISODIC_MEMORY: - milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) - case _: - raise ValueError(f"Unsupported memory type: {mem_type}") + # Iterate over ALL memory_types and merge results + all_results = [] + seen_ids = set() + milvus_start = time.perf_counter() - # Handle time range filter conditions - start_time_dt = None - end_time_dt = None - current_time_dt = None + for mem_type in retrieve_mem_request.memory_types: + milvus_repo_class = MILVUS_REPO_MAP.get(mem_type) + if not milvus_repo_class: + logger.info(f"Skipping unsupported memory type for vector search: {mem_type}") + continue - if start_time is not None: - start_time_dt = ( - from_iso_format(start_time) - if isinstance(start_time, str) - else start_time - ) + milvus_repo = get_bean_by_type(milvus_repo_class) - if end_time is not None: - if isinstance(end_time, str): - end_time_dt = from_iso_format(end_time) - # If date only format, set to end of day - if len(end_time) == 10: - end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) + # Handle time range filter conditions + start_time_dt = None + end_time_dt = None + current_time_dt = None + + if start_time is not None: + start_time_dt = ( + from_iso_format(start_time) + if isinstance(start_time, str) + else start_time + ) + + if end_time is not None: + if isinstance(end_time, str): + end_time_dt = from_iso_format(end_time) + # If date only format, set to end of day + if len(end_time) == 10: + end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) + else: + end_time_dt = end_time + + # Handle foresight time range (only valid for foresight) + if mem_type == MemoryType.FORESIGHT: + if retrieve_mem_request.start_time: + start_time_dt = from_iso_format(retrieve_mem_request.start_time) + if retrieve_mem_request.end_time: + end_time_dt = from_iso_format(retrieve_mem_request.end_time) + if retrieve_mem_request.current_time: + current_time_dt = from_iso_format(retrieve_mem_request.current_time) + + # Call Milvus vector search (pass different parameters based on memory type) + if mem_type == MemoryType.FORESIGHT: + # Foresight: supports time range and validity filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + current_time=current_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) else: - end_time_dt = end_time - - # Handle foresight time range (only valid for foresight) - if mem_type == MemoryType.FORESIGHT: - if retrieve_mem_request.start_time: - start_time_dt = from_iso_format(retrieve_mem_request.start_time) - if retrieve_mem_request.end_time: - end_time_dt = from_iso_format(retrieve_mem_request.end_time) - if retrieve_mem_request.current_time: - current_time_dt = from_iso_format(retrieve_mem_request.current_time) - - # Call Milvus vector search (pass different parameters based on memory type) - milvus_start = time.perf_counter() - if mem_type == MemoryType.FORESIGHT: - # Foresight: supports time range and validity filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - current_time=current_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) - else: - # Episodic memory and event log: use timestamp filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) + # Episodic memory and event log: use timestamp filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + + for r in search_results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.VECTOR.value + # Deduplicate by id + rid = r.get('id', '') + if rid and rid in seen_ids: + continue + if rid: + seen_ids.add(rid) + all_results.append(r) + record_retrieve_stage( retrieve_method=retrieve_method, stage='milvus_search', @@ -596,12 +615,7 @@ async def get_vector_search_results( duration_seconds=time.perf_counter() - milvus_start, ) - for r in search_results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.VECTOR.value - # Milvus already uses 'score', no need to rename - - return search_results + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, @@ -624,11 +638,7 @@ async def retrieve_mem_hybrid( ) -> RetrieveMemResponse: """Hybrid memory retrieval: keyword + vector + rerank""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: hits = await self._search_hybrid( @@ -699,9 +709,7 @@ async def _search_hybrid( retrieve_method: str = RetrieveMethod.HYBRID.value, ) -> List[Dict]: """Core hybrid search: keyword + vector + rerank, returns flat list""" - memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' - ) + memory_type = _memory_types_label(request.memory_types) # Run keyword and vector search concurrently kw_results, vec_results = await asyncio.gather( self.get_keyword_search_results(request, retrieve_method=retrieve_method), @@ -722,9 +730,7 @@ async def _search_rrf( retrieve_method: str = RetrieveMethod.RRF.value, ) -> List[Dict]: """Core RRF search: keyword + vector + RRF fusion, returns flat list""" - memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' - ) + memory_type = _memory_types_label(request.memory_types) # Run keyword and vector search concurrently kw, vec = await asyncio.gather( @@ -766,7 +772,7 @@ async def _to_response( """Convert flat hits list to grouped RetrieveMemResponse""" user_id = req.user_id if req else "" source_type = req.retrieve_method.value - memory_type = req.memory_types[0].value + memory_type = _memory_types_label(req.memory_types) if not hits: return RetrieveMemResponse( @@ -808,11 +814,7 @@ async def retrieve_mem_rrf( ) -> RetrieveMemResponse: """RRF-based memory retrieval: keyword + vector + RRF fusion""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' - ) + memory_type = _memory_types_label(retrieve_mem_request.memory_types) try: hits = await self._search_rrf( @@ -855,7 +857,7 @@ async def retrieve_mem_agentic( req = retrieve_mem_request # alias top_k = req.top_k config = AgenticConfig() - memory_type = req.memory_types[0].value if req.memory_types else 'unknown' + memory_type = _memory_types_label(req.memory_types) try: llm_provider = LLMProvider(