From 40645879f613da8fcda865a2572c27f60a9bb6cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 17 Nov 2025 18:49:00 +0800 Subject: [PATCH 1/3] feat: abstract CubeView to Add Handler --- src/memos/api/handlers/add_handler.py | 276 +++++------------------ src/memos/api/product_models.py | 8 +- src/memos/multi_mem_cube/__init__.py | 0 src/memos/multi_mem_cube/views.py | 313 ++++++++++++++++++++++++++ 4 files changed, 373 insertions(+), 224 deletions(-) create mode 100644 src/memos/multi_mem_cube/__init__.py create mode 100644 src/memos/multi_mem_cube/views.py diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 48db7ae6e..23dd7b4d4 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -5,21 +5,11 @@ using dependency injection for better modularity and testability. """ -import json -import os - from datetime import datetime from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APIADDRequest, MemoryResponse -from memos.context.context import ContextThreadPoolExecutor -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - MEM_READ_LABEL, - PREF_ADD_LABEL, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.types import UserContext +from memos.multi_mem_cube.views import CompositeCubeView, MemCubeView, SingleCubeView class AddHandler(BaseHandler): @@ -52,33 +42,67 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: Returns: MemoryResponse with added memory information """ - # Create UserContext object - user_context = UserContext( - user_id=add_req.user_id, - mem_cube_id=add_req.mem_cube_id, - session_id=add_req.session_id or "default_session", - ) + self.logger.info(f"[AddHandler] Add Req is: {add_req}") - self.logger.info(f"Add Req is: {add_req}") - if (not add_req.messages) and add_req.memory_content: + if (not add_req.messages) and getattr(add_req, "memory_content", None): add_req.messages = self._convert_content_messsage(add_req.memory_content) - self.logger.info(f"Converted Add Req content to messages: {add_req.messages}") - # Process text and preference memories in parallel - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._process_text_mem, add_req, user_context) - pref_future = executor.submit(self._process_pref_mem, add_req, user_context) + self.logger.info(f"[AddHandler] Converted content to messages: {add_req.messages}") + + cube_view = self._build_cube_view(add_req) - text_response_data = text_future.result() - pref_response_data = pref_future.result() + results = cube_view.add_memories(add_req) - self.logger.info(f"add_memories Text response data: {text_response_data}") - self.logger.info(f"add_memories Pref response data: {pref_response_data}") + self.logger.info(f"[AddHandler] Final add results count={len(results)}") return MemoryResponse( message="Memory added successfully", - data=text_response_data + pref_response_data, + data=results, ) + def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]: + """ + Normalize target cube ids from add_req. + Priority: + 1) writable_cube_ids + 2) mem_cube_id + 3) fallback to user_id + """ + if getattr(add_req, "writable_cube_ids", None): + return list(dict.fromkeys(add_req.writable_cube_ids)) + + if add_req.mem_cube_id: + return [add_req.mem_cube_id] + + return [add_req.user_id] + + def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView: + cube_ids = self._resolve_cube_ids(add_req) + + if len(cube_ids) == 1: + cube_id = cube_ids[0] + return SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, + ) + else: + single_views = [ + SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, + ) + for cube_id in cube_ids + ] + return CompositeCubeView( + cube_views=single_views, + logger=self.logger, + ) + def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]]: """ Convert content string to list of message dictionaries. @@ -98,197 +122,3 @@ def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]] ] # for only user-str input and convert message return messages_list - - def _process_text_mem( - self, - add_req: APIADDRequest, - user_context: UserContext, - ) -> list[dict[str, str]]: - """ - Process and add text memories. - - Extracts memories from messages and adds them to the text memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - - Returns: - List of formatted memory responses - """ - target_session_id = add_req.session_id or "default_session" - - # Determine sync mode - sync_mode = add_req.async_mode or self._get_sync_mode() - - self.logger.info(f"Processing text memory with mode: {sync_mode}") - - # Extract memories - memories_local = self.mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - mode="fast" if sync_mode == "async" else "fine", - ) - flattened_local = [mm for m in memories_local for mm in m] - self.logger.info(f"Memory extraction completed for user {add_req.user_id}") - - # Add memories to text_mem - mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( - flattened_local, - user_name=user_context.mem_cube_id, - ) - self.logger.info( - f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_ids_local}" - ) - - # Schedule async/sync tasks - self._schedule_memory_tasks( - add_req=add_req, - user_context=user_context, - mem_ids=mem_ids_local, - sync_mode=sync_mode, - ) - - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) - ] - - def _process_pref_mem( - self, - add_req: APIADDRequest, - user_context: UserContext, - ) -> list[dict[str, str]]: - """ - Process and add preference memories. - - Extracts preferences from messages and adds them to the preference memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - - Returns: - List of formatted preference responses - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - # Determine sync mode - sync_mode = add_req.async_mode or self._get_sync_mode() - target_session_id = add_req.session_id or "default_session" - - # Follow async behavior: enqueue when async - if sync_mode == "async": - try: - messages_list = [add_req.messages] - message_item_pref = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=self.naive_mem_cube, - label=PREF_ADD_LABEL, - content=json.dumps(messages_list), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item_pref]) - self.logger.info("Submitted preference add to scheduler (async mode)") - except Exception as e: - self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) - return [] - else: - # Sync mode: process immediately - pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - "mem_cube_id": add_req.mem_cube_id, - }, - ) - pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) - self.logger.info( - f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " - f"in session {add_req.session_id}: {pref_ids_local}" - ) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] - - def _get_sync_mode(self) -> str: - """ - Get synchronization mode from memory cube. - - Returns: - Sync mode string ("sync" or "async") - """ - try: - return getattr(self.naive_mem_cube.text_mem, "mode", "sync") - except Exception: - return "sync" - - def _schedule_memory_tasks( - self, - add_req: APIADDRequest, - user_context: UserContext, - mem_ids: list[str], - sync_mode: str, - ) -> None: - """ - Schedule memory processing tasks based on sync mode. - - Args: - add_req: Add memory request - user_context: User context - mem_ids: List of memory IDs - sync_mode: Synchronization mode - """ - target_session_id = add_req.session_id or "default_session" - - if sync_mode == "async": - # Async mode: submit MEM_READ_LABEL task - try: - message_item_read = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=self.naive_mem_cube, - label=MEM_READ_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - self.mem_scheduler.submit_messages(messages=[message_item_read]) - self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}") - except Exception as e: - self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) - else: - # Sync mode: submit ADD_LABEL task - message_item_add = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=self.naive_mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - self.mem_scheduler.submit_messages(messages=[message_item_add]) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 892d2d436..a03134525 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -171,6 +171,9 @@ class APISearchRequest(BaseRequest): query: str = Field(..., description="Search query") user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") + readable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can read for multi-cube search" + ) mode: SearchMode = Field( SearchMode.NOT_INITIALIZED, description="search mode: fast, fine, or mixture" ) @@ -190,7 +193,10 @@ class APIADDRequest(BaseRequest): """Request model for creating memories.""" user_id: str = Field(None, description="User ID") - mem_cube_id: str = Field(..., description="Cube ID") + mem_cube_id: str | None = Field(None, description="Cube ID") + writable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can write for multi-cube add" + ) messages: list[MessageDict] | None = Field(None, description="List of messages to store.") memory_content: str | None = Field(None, description="Memory content to store") doc_path: str | None = Field(None, description="Path to document to store") diff --git a/src/memos/multi_mem_cube/__init__.py b/src/memos/multi_mem_cube/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/multi_mem_cube/views.py b/src/memos/multi_mem_cube/views.py new file mode 100644 index 000000000..bebce8082 --- /dev/null +++ b/src/memos/multi_mem_cube/views.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import json +import os + +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, Protocol + +from memos.context.context import ContextThreadPoolExecutor +from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + MEM_READ_LABEL, + PREF_ADD_LABEL, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.types import UserContext + + +if TYPE_CHECKING: + from memos.api.product_models import APIADDRequest + + +class MemCubeView(Protocol): + """ + A high-level cube view used by AddHandler. + It may wrap a single logical cube or multiple cubes, + but exposes a unified add_memories interface. + """ + + def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: + """ + Process add_req, extract memories and write them into one or more cubes. + + Returns: + A list of memory dicts, each item should at least contain: + - memory + - memory_id + - memory_type + - cube_id + """ + ... + + +@dataclass +class SingleCubeView(MemCubeView): + cube_id: str + naive_mem_cube: Any + mem_reader: Any + mem_scheduler: Any + logger: Any + + def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: + """ + This is basically your current handle_add_memories logic, + but scoped to a single cube_id. + """ + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=self.cube_id, + session_id=add_req.session_id or "default_session", + ) + + target_session_id = add_req.session_id or "default_session" + sync_mode = add_req.async_mode or self._get_sync_mode() + + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} " + f"Processing add with mode={sync_mode}, session={target_session_id}" + ) + + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._process_text_mem, add_req, user_context, sync_mode) + pref_future = executor.submit(self._process_pref_mem, add_req, user_context, sync_mode) + + text_results = text_future.result() + pref_results = pref_future.result() + + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} text_results={len(text_results)}, " + f"pref_results={len(pref_results)}" + ) + + for item in text_results: + item["cube_id"] = self.cube_id + for item in pref_results: + item["cube_id"] = self.cube_id + + return text_results + pref_results + + def _get_sync_mode(self) -> str: + """ + Get synchronization mode from memory cube. + + Returns: + Sync mode string ("sync" or "async") + """ + try: + return getattr(self.naive_mem_cube.text_mem, "mode", "sync") + except Exception: + return "sync" + + def _schedule_memory_tasks( + self, + add_req: APIADDRequest, + user_context: UserContext, + mem_ids: list[str], + sync_mode: str, + ) -> None: + """ + Schedule memory processing tasks based on sync mode. + + Args: + add_req: Add memory request + user_context: User context + mem_ids: List of memory IDs + sync_mode: Synchronization mode + """ + target_session_id = add_req.session_id or "default_session" + + if sync_mode == "async": + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=self.cube_id, + ) + self.mem_scheduler.submit_messages(messages=[message_item_read]) + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} Submitted async MEM_READ: {json.dumps(mem_ids)}" + ) + except Exception as e: + self.logger.error( + f"[SingleCubeView] cube={self.cube_id} Failed to submit async memory tasks: {e}", + exc_info=True, + ) + else: + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=self.cube_id, + ) + self.mem_scheduler.submit_messages(messages=[message_item_add]) + + def _process_pref_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + sync_mode: str, + ) -> list[dict[str, Any]]: + """ + Process and add preference memories. + + Extracts preferences from messages and adds them to the preference memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted preference responses + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + target_session_id = add_req.session_id or "default_session" + + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item_pref]) + self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async") + except Exception as e: + self.logger.error( + f"[SingleCubeView] cube={self.cube_id} Failed to submit PREF_ADD: {e}", + exc_info=True, + ) + return [] + else: + pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + "mem_cube_id": self.cube_id, + }, + ) + pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} " + f"added {len(pref_ids_local)} preferences for user {add_req.user_id}: {pref_ids_local}" + ) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + def _process_text_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + sync_mode: str, + ) -> list[dict[str, Any]]: + """ + Process and add text memories. + + Extracts memories from messages and adds them to the text memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted memory responses + """ + target_session_id = add_req.session_id or "default_session" + + self.logger.info( + f"[SingleCubeView] cube={user_context.mem_cube_id} " + f"Processing text memory with mode: {sync_mode}" + ) + + # Extract memories + memories_local = self.mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode="fast" if sync_mode == "async" else "fine", + ) + flattened_local = [mm for m in memories_local for mm in m] + self.logger.info(f"Memory extraction completed for user {add_req.user_id}") + + # Add memories to text_mem + mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + self.logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + + # Schedule async/sync tasks + self._schedule_memory_tasks( + add_req=add_req, + user_context=user_context, + mem_ids=mem_ids_local, + sync_mode=sync_mode, + ) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] + + +@dataclass +class CompositeCubeView(MemCubeView): + """ + A composite view over multiple logical cubes. + + For now (fast mode), it simply fan-out writes to all cubes; + later we can add smarter routing / slow mode here. + """ + + cube_views: list[SingleCubeView] + logger: Any + + def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: + all_results: list[dict[str, Any]] = [] + + # fast mode: for each cube view, add memories + # maybe add more strategies in add_req.async_mode + for view in self.cube_views: + self.logger.info(f"[CompositeCubeView] fan-out add to cube={view.cube_id}") + results = view.add_memories(add_req) + all_results.extend(results) + + return all_results From 949bd010c9cbae0360a7a45bbffd41c504a2b8f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Mon, 17 Nov 2025 20:52:55 +0800 Subject: [PATCH 2/3] feat: add readable and writable memcube-ids --- src/memos/api/handlers/add_handler.py | 4 +- src/memos/api/product_models.py | 12 + src/memos/multi_mem_cube/composite_cube.py | 36 +++ src/memos/multi_mem_cube/single_cube.py | 268 +++++++++++++++++++ src/memos/multi_mem_cube/views.py | 285 --------------------- 5 files changed, 319 insertions(+), 286 deletions(-) create mode 100644 src/memos/multi_mem_cube/composite_cube.py create mode 100644 src/memos/multi_mem_cube/single_cube.py diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 23dd7b4d4..67bd7b314 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -9,7 +9,9 @@ from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APIADDRequest, MemoryResponse -from memos.multi_mem_cube.views import CompositeCubeView, MemCubeView, SingleCubeView +from memos.multi_mem_cube.composite_cube import CompositeCubeView +from memos.multi_mem_cube.single_cube import SingleCubeView +from memos.multi_mem_cube.views import MemCubeView class AddHandler(BaseHandler): diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index a03134525..221e045cf 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -72,6 +72,12 @@ class ChatRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") + readable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can read for multi-cube chat" + ) + writable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can write for multi-cube chat" + ) history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(True, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") @@ -217,6 +223,12 @@ class APIChatCompleteRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") + readable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can read for multi-cube chat" + ) + writable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can write for multi-cube chat" + ) history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(True, description="Whether to use MemOSCube") diff --git a/src/memos/multi_mem_cube/composite_cube.py b/src/memos/multi_mem_cube/composite_cube.py new file mode 100644 index 000000000..f49945043 --- /dev/null +++ b/src/memos/multi_mem_cube/composite_cube.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from memos.multi_mem_cube.views import MemCubeView + + +if TYPE_CHECKING: + from memos.api.product_models import APIADDRequest + from memos.multi_mem_cube.single_cube import SingleCubeView + + +@dataclass +class CompositeCubeView(MemCubeView): + """ + A composite view over multiple logical cubes. + + For now (fast mode), it simply fan-out writes to all cubes; + later we can add smarter routing / slow mode here. + """ + + cube_views: list[SingleCubeView] + logger: Any + + def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: + all_results: list[dict[str, Any]] = [] + + # fast mode: for each cube view, add memories + # maybe add more strategies in add_req.async_mode + for view in self.cube_views: + self.logger.info(f"[CompositeCubeView] fan-out add to cube={view.cube_id}") + results = view.add_memories(add_req) + all_results.extend(results) + + return all_results diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py new file mode 100644 index 000000000..6cf0a4ef8 --- /dev/null +++ b/src/memos/multi_mem_cube/single_cube.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import json +import os + +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from memos.context.context import ContextThreadPoolExecutor +from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + MEM_READ_LABEL, + PREF_ADD_LABEL, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.multi_mem_cube.views import MemCubeView +from memos.types import UserContext + + +if TYPE_CHECKING: + from memos.api.product_models import APIADDRequest + + +@dataclass +class SingleCubeView(MemCubeView): + cube_id: str + naive_mem_cube: Any + mem_reader: Any + mem_scheduler: Any + logger: Any + + def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: + """ + This is basically your current handle_add_memories logic, + but scoped to a single cube_id. + """ + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=self.cube_id, + session_id=add_req.session_id or "default_session", + ) + + target_session_id = add_req.session_id or "default_session" + sync_mode = add_req.async_mode or self._get_sync_mode() + + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} " + f"Processing add with mode={sync_mode}, session={target_session_id}" + ) + + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._process_text_mem, add_req, user_context, sync_mode) + pref_future = executor.submit(self._process_pref_mem, add_req, user_context, sync_mode) + + text_results = text_future.result() + pref_results = pref_future.result() + + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} text_results={len(text_results)}, " + f"pref_results={len(pref_results)}" + ) + + for item in text_results: + item["cube_id"] = self.cube_id + for item in pref_results: + item["cube_id"] = self.cube_id + + return text_results + pref_results + + def _get_sync_mode(self) -> str: + """ + Get synchronization mode from memory cube. + + Returns: + Sync mode string ("sync" or "async") + """ + try: + return getattr(self.naive_mem_cube.text_mem, "mode", "sync") + except Exception: + return "sync" + + def _schedule_memory_tasks( + self, + add_req: APIADDRequest, + user_context: UserContext, + mem_ids: list[str], + sync_mode: str, + ) -> None: + """ + Schedule memory processing tasks based on sync mode. + + Args: + add_req: Add memory request + user_context: User context + mem_ids: List of memory IDs + sync_mode: Synchronization mode + """ + target_session_id = add_req.session_id or "default_session" + + if sync_mode == "async": + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=self.cube_id, + ) + self.mem_scheduler.submit_messages(messages=[message_item_read]) + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} Submitted async MEM_READ: {json.dumps(mem_ids)}" + ) + except Exception as e: + self.logger.error( + f"[SingleCubeView] cube={self.cube_id} Failed to submit async memory tasks: {e}", + exc_info=True, + ) + else: + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=self.cube_id, + ) + self.mem_scheduler.submit_messages(messages=[message_item_add]) + + def _process_pref_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + sync_mode: str, + ) -> list[dict[str, Any]]: + """ + Process and add preference memories. + + Extracts preferences from messages and adds them to the preference memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted preference responses + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + target_session_id = add_req.session_id or "default_session" + + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item_pref]) + self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async") + except Exception as e: + self.logger.error( + f"[SingleCubeView] cube={self.cube_id} Failed to submit PREF_ADD: {e}", + exc_info=True, + ) + return [] + else: + pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + "mem_cube_id": self.cube_id, + }, + ) + pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} " + f"added {len(pref_ids_local)} preferences for user {add_req.user_id}: {pref_ids_local}" + ) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + def _process_text_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + sync_mode: str, + ) -> list[dict[str, Any]]: + """ + Process and add text memories. + + Extracts memories from messages and adds them to the text memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted memory responses + """ + target_session_id = add_req.session_id or "default_session" + + self.logger.info( + f"[SingleCubeView] cube={user_context.mem_cube_id} " + f"Processing text memory with mode: {sync_mode}" + ) + + # Extract memories + memories_local = self.mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode="fast" if sync_mode == "async" else "fine", + ) + flattened_local = [mm for m in memories_local for mm in m] + self.logger.info(f"Memory extraction completed for user {add_req.user_id}") + + # Add memories to text_mem + mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + self.logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + + # Schedule async/sync tasks + self._schedule_memory_tasks( + add_req=add_req, + user_context=user_context, + mem_ids=mem_ids_local, + sync_mode=sync_mode, + ) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] diff --git a/src/memos/multi_mem_cube/views.py b/src/memos/multi_mem_cube/views.py index bebce8082..7c58aa469 100644 --- a/src/memos/multi_mem_cube/views.py +++ b/src/memos/multi_mem_cube/views.py @@ -1,21 +1,7 @@ from __future__ import annotations -import json -import os - -from dataclasses import dataclass -from datetime import datetime from typing import TYPE_CHECKING, Any, Protocol -from memos.context.context import ContextThreadPoolExecutor -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - MEM_READ_LABEL, - PREF_ADD_LABEL, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.types import UserContext - if TYPE_CHECKING: from memos.api.product_models import APIADDRequest @@ -40,274 +26,3 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: - cube_id """ ... - - -@dataclass -class SingleCubeView(MemCubeView): - cube_id: str - naive_mem_cube: Any - mem_reader: Any - mem_scheduler: Any - logger: Any - - def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: - """ - This is basically your current handle_add_memories logic, - but scoped to a single cube_id. - """ - user_context = UserContext( - user_id=add_req.user_id, - mem_cube_id=self.cube_id, - session_id=add_req.session_id or "default_session", - ) - - target_session_id = add_req.session_id or "default_session" - sync_mode = add_req.async_mode or self._get_sync_mode() - - self.logger.info( - f"[SingleCubeView] cube={self.cube_id} " - f"Processing add with mode={sync_mode}, session={target_session_id}" - ) - - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._process_text_mem, add_req, user_context, sync_mode) - pref_future = executor.submit(self._process_pref_mem, add_req, user_context, sync_mode) - - text_results = text_future.result() - pref_results = pref_future.result() - - self.logger.info( - f"[SingleCubeView] cube={self.cube_id} text_results={len(text_results)}, " - f"pref_results={len(pref_results)}" - ) - - for item in text_results: - item["cube_id"] = self.cube_id - for item in pref_results: - item["cube_id"] = self.cube_id - - return text_results + pref_results - - def _get_sync_mode(self) -> str: - """ - Get synchronization mode from memory cube. - - Returns: - Sync mode string ("sync" or "async") - """ - try: - return getattr(self.naive_mem_cube.text_mem, "mode", "sync") - except Exception: - return "sync" - - def _schedule_memory_tasks( - self, - add_req: APIADDRequest, - user_context: UserContext, - mem_ids: list[str], - sync_mode: str, - ) -> None: - """ - Schedule memory processing tasks based on sync mode. - - Args: - add_req: Add memory request - user_context: User context - mem_ids: List of memory IDs - sync_mode: Synchronization mode - """ - target_session_id = add_req.session_id or "default_session" - - if sync_mode == "async": - try: - message_item_read = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=self.cube_id, - mem_cube=self.naive_mem_cube, - label=MEM_READ_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - user_name=self.cube_id, - ) - self.mem_scheduler.submit_messages(messages=[message_item_read]) - self.logger.info( - f"[SingleCubeView] cube={self.cube_id} Submitted async MEM_READ: {json.dumps(mem_ids)}" - ) - except Exception as e: - self.logger.error( - f"[SingleCubeView] cube={self.cube_id} Failed to submit async memory tasks: {e}", - exc_info=True, - ) - else: - message_item_add = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=self.cube_id, - mem_cube=self.naive_mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - user_name=self.cube_id, - ) - self.mem_scheduler.submit_messages(messages=[message_item_add]) - - def _process_pref_mem( - self, - add_req: APIADDRequest, - user_context: UserContext, - sync_mode: str, - ) -> list[dict[str, Any]]: - """ - Process and add preference memories. - - Extracts preferences from messages and adds them to the preference memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - - Returns: - List of formatted preference responses - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - target_session_id = add_req.session_id or "default_session" - - if sync_mode == "async": - try: - messages_list = [add_req.messages] - message_item_pref = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=self.cube_id, - mem_cube=self.naive_mem_cube, - label=PREF_ADD_LABEL, - content=json.dumps(messages_list), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item_pref]) - self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async") - except Exception as e: - self.logger.error( - f"[SingleCubeView] cube={self.cube_id} Failed to submit PREF_ADD: {e}", - exc_info=True, - ) - return [] - else: - pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - "mem_cube_id": self.cube_id, - }, - ) - pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) - self.logger.info( - f"[SingleCubeView] cube={self.cube_id} " - f"added {len(pref_ids_local)} preferences for user {add_req.user_id}: {pref_ids_local}" - ) - - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] - - def _process_text_mem( - self, - add_req: APIADDRequest, - user_context: UserContext, - sync_mode: str, - ) -> list[dict[str, Any]]: - """ - Process and add text memories. - - Extracts memories from messages and adds them to the text memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - - Returns: - List of formatted memory responses - """ - target_session_id = add_req.session_id or "default_session" - - self.logger.info( - f"[SingleCubeView] cube={user_context.mem_cube_id} " - f"Processing text memory with mode: {sync_mode}" - ) - - # Extract memories - memories_local = self.mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - mode="fast" if sync_mode == "async" else "fine", - ) - flattened_local = [mm for m in memories_local for mm in m] - self.logger.info(f"Memory extraction completed for user {add_req.user_id}") - - # Add memories to text_mem - mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( - flattened_local, - user_name=user_context.mem_cube_id, - ) - self.logger.info( - f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_ids_local}" - ) - - # Schedule async/sync tasks - self._schedule_memory_tasks( - add_req=add_req, - user_context=user_context, - mem_ids=mem_ids_local, - sync_mode=sync_mode, - ) - - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) - ] - - -@dataclass -class CompositeCubeView(MemCubeView): - """ - A composite view over multiple logical cubes. - - For now (fast mode), it simply fan-out writes to all cubes; - later we can add smarter routing / slow mode here. - """ - - cube_views: list[SingleCubeView] - logger: Any - - def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: - all_results: list[dict[str, Any]] = [] - - # fast mode: for each cube view, add memories - # maybe add more strategies in add_req.async_mode - for view in self.cube_views: - self.logger.info(f"[CompositeCubeView] fan-out add to cube={view.cube_id}") - results = view.add_memories(add_req) - all_results.extend(results) - - return all_results From 1143d1e80e29bacdd0d6b894f2925e1b4a59d9b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Fri, 21 Nov 2025 16:08:23 +0800 Subject: [PATCH 3/3] feat: multi-cube search router --- src/memos/api/handlers/search_handler.py | 276 ++++----------------- src/memos/multi_mem_cube/composite_cube.py | 29 ++- src/memos/multi_mem_cube/single_cube.py | 245 +++++++++++++++++- src/memos/multi_mem_cube/views.py | 15 +- 4 files changed, 327 insertions(+), 238 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index e8e4e07d6..60317a155 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,20 +5,11 @@ using dependency injection for better modularity and testability. """ -import os -import traceback - -from typing import Any - from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies -from memos.api.handlers.formatters_handler import ( - format_memory_item, - post_process_pref_mem, -) from memos.api.product_models import APISearchRequest, SearchResponse -from memos.context.context import ContextThreadPoolExecutor -from memos.mem_scheduler.schemas.general_schemas import SearchMode -from memos.types import MOSSearchResult, UserContext +from memos.multi_mem_cube.composite_cube import CompositeCubeView +from memos.multi_mem_cube.single_cube import SingleCubeView +from memos.multi_mem_cube.views import MemCubeView class SearchHandler(BaseHandler): @@ -51,239 +42,56 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse Returns: SearchResponse with formatted results """ - # Create UserContext object - user_context = UserContext( - user_id=search_req.user_id, - mem_cube_id=search_req.mem_cube_id, - session_id=search_req.session_id or "default_session", - ) - self.logger.info(f"Search Req is: {search_req}") - - memories_result: MOSSearchResult = { - "text_mem": [], - "act_mem": [], - "para_mem": [], - "pref_mem": [], - "pref_note": "", - } + self.logger.info(f"[SearchHandler] Search Req is: {search_req}") - # Determine search mode - search_mode = self._get_search_mode(search_req.mode) + cube_view = self._build_cube_view(search_req) - # Execute search in parallel for text and preference memories - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._search_text, search_req, user_context, search_mode) - pref_future = executor.submit(self._search_pref, search_req, user_context) + results = cube_view.search_memories(search_req) - text_formatted_memories = text_future.result() - pref_formatted_memories = pref_future.result() - - # Build result - memories_result["text_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": text_formatted_memories, - } - ) - - memories_result = post_process_pref_mem( - memories_result, - pref_formatted_memories, - search_req.mem_cube_id, - search_req.include_preference, - ) - - self.logger.info(f"Search memories result: {memories_result}") + self.logger.info(f"[AddHandler] Final add results count={len(results)}") return SearchResponse( - message="Search completed successfully", - data=memories_result, + message="Memory searched successfully", + data=results, ) - def _get_search_mode(self, mode: str) -> str: - """ - Get search mode with environment variable fallback. - - Args: - mode: Requested search mode - - Returns: - Search mode string - """ - if mode == SearchMode.NOT_INITIALIZED: - return os.getenv("SEARCH_MODE", SearchMode.FAST) - return mode - - def _search_text( - self, - search_req: APISearchRequest, - user_context: UserContext, - search_mode: str, - ) -> list[dict[str, Any]]: + def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: """ - Search text memories based on mode. - - Args: - search_req: Search request - user_context: User context - search_mode: Search mode (FAST, FINE, or MIXTURE) - - Returns: - List of formatted memory items + Normalize target cube ids from search_req. + Priority: + 1) readable_cube_ids + 2) mem_cube_id + 3) fallback to user_id """ - try: - if search_mode == SearchMode.FAST: - memories = self._fast_search(search_req, user_context) - elif search_mode == SearchMode.FINE: - memories = self._fine_search(search_req, user_context) - elif search_mode == SearchMode.MIXTURE: - memories = self._mix_search(search_req, user_context) - else: - self.logger.error(f"Unsupported search mode: {search_mode}") - return [] - - return [format_memory_item(data) for data in memories] + if getattr(search_req, "readable_cube_ids", None): + return list(dict.fromkeys(search_req.readable_cube_ids)) - except Exception as e: - self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) - return [] + if search_req.mem_cube_id: + return [search_req.mem_cube_id] - def _search_pref( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list[dict[str, Any]]: - """ - Search preference memories. + return [search_req.user_id] - Args: - search_req: Search request - user_context: User context + def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: + cube_ids = self._resolve_cube_ids(search_req) - Returns: - List of formatted preference memory items - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - try: - results = self.naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, + if len(cube_ids) == 1: + cube_id = cube_ids[0] + return SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, ) - return [format_memory_item(data) for data in results] - except Exception as e: - self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) - return [] - - def _fast_search( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list: - """ - Fast search using vector database. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of search results - """ - target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - return self.naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - ) - - def _fine_search( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list: - """ - Fine-grained search with query enhancement. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of enhanced search results - """ - target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - searcher = self.mem_scheduler.searcher - - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } - - # Fast retrieve - fast_retrieved_memories = searcher.retrieve( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FINE, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info=info, - ) - - # Post retrieve - fast_memories = searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - - # Enhance with query - enhanced_results, _ = self.mem_scheduler.retriever.enhance_memories_with_query( - query_history=[search_req.query], - memories=fast_memories, - ) - - return enhanced_results - - def _mix_search( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list: - """ - Mix search combining fast and fine-grained approaches. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of formatted search results - """ - return self.mem_scheduler.mix_search_memories( - search_req=search_req, - user_context=user_context, - ) + else: + single_views = [ + SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, + ) + for cube_id in cube_ids + ] + return CompositeCubeView(cube_views=single_views, logger=self.logger) diff --git a/src/memos/multi_mem_cube/composite_cube.py b/src/memos/multi_mem_cube/composite_cube.py index f49945043..8f892d60d 100644 --- a/src/memos/multi_mem_cube/composite_cube.py +++ b/src/memos/multi_mem_cube/composite_cube.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: - from memos.api.product_models import APIADDRequest + from memos.api.product_models import APIADDRequest, APISearchRequest from memos.multi_mem_cube.single_cube import SingleCubeView @@ -34,3 +34,30 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: all_results.extend(results) return all_results + + def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: + # aggregated MOSSearchResult + merged_results: dict[str, Any] = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + "pref_mem": [], + "pref_note": "", + } + + for view in self.cube_views: + self.logger.info(f"[CompositeCubeView] fan-out search to cube={view.cube_id}") + cube_result = view.search_memories(search_req) + merged_results["text_mem"].extend(cube_result.get("text_mem", [])) + merged_results["act_mem"].extend(cube_result.get("act_mem", [])) + merged_results["para_mem"].extend(cube_result.get("para_mem", [])) + merged_results["pref_mem"].extend(cube_result.get("pref_mem", [])) + + note = cube_result.get("pref_note") + if note: + if merged_results["pref_note"]: + merged_results["pref_note"] += " | " + note + else: + merged_results["pref_note"] = note + + return merged_results diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 6cf0a4ef8..3f205048e 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -2,24 +2,30 @@ import json import os +import traceback from dataclasses import dataclass from datetime import datetime from typing import TYPE_CHECKING, Any +from memos.api.handlers.formatters_handler import ( + format_memory_item, + post_process_pref_mem, +) from memos.context.context import ContextThreadPoolExecutor from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, MEM_READ_LABEL, PREF_ADD_LABEL, + SearchMode, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.multi_mem_cube.views import MemCubeView -from memos.types import UserContext +from memos.types import MOSSearchResult, UserContext if TYPE_CHECKING: - from memos.api.product_models import APIADDRequest + from memos.api.product_models import APIADDRequest, APISearchRequest @dataclass @@ -68,6 +74,241 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: return text_results + pref_results + def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: + # Create UserContext object + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=search_req.mem_cube_id, + session_id=search_req.session_id or "default_session", + ) + self.logger.info(f"Search Req is: {search_req}") + + memories_result: MOSSearchResult = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + "pref_mem": [], + "pref_note": "", + } + + # Determine search mode + search_mode = self._get_search_mode(search_req.mode) + + # Execute search in parallel for text and preference memories + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._search_text, search_req, user_context, search_mode) + pref_future = executor.submit(self._search_pref, search_req, user_context) + + text_formatted_memories = text_future.result() + pref_formatted_memories = pref_future.result() + + # Build result + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": text_formatted_memories, + } + ) + + memories_result = post_process_pref_mem( + memories_result, + pref_formatted_memories, + search_req.mem_cube_id, + search_req.include_preference, + ) + + self.logger.info(f"Search memories result: {memories_result}") + + return memories_result + + def _get_search_mode(self, mode: str) -> str: + """ + Get search mode with environment variable fallback. + + Args: + mode: Requested search mode + + Returns: + Search mode string + """ + if mode == SearchMode.NOT_INITIALIZED: + return os.getenv("SEARCH_MODE", SearchMode.FAST) + return mode + + def _search_text( + self, + search_req: APISearchRequest, + user_context: UserContext, + search_mode: str, + ) -> list[dict[str, Any]]: + """ + Search text memories based on mode. + + Args: + search_req: Search request + user_context: User context + search_mode: Search mode (FAST, FINE, or MIXTURE) + + Returns: + List of formatted memory items + """ + try: + if search_mode == SearchMode.FAST: + memories = self._fast_search(search_req, user_context) + elif search_mode == SearchMode.FINE: + memories = self._fine_search(search_req, user_context) + elif search_mode == SearchMode.MIXTURE: + memories = self._mix_search(search_req, user_context) + else: + self.logger.error(f"Unsupported search mode: {search_mode}") + return [] + + return [format_memory_item(data) for data in memories] + + except Exception as e: + self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _search_pref( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list[dict[str, Any]]: + """ + Search preference memories. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted preference memory items + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + try: + results = self.naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.pref_top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [format_memory_item(data) for data in results] + except Exception as e: + self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _fast_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fast search using vector database. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + return self.naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + + def _fine_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fine-grained search with query enhancement. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of enhanced search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + searcher = self.mem_scheduler.searcher + + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + # Fast retrieve + fast_retrieved_memories = searcher.retrieve( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FINE, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info=info, + ) + + # Post retrieve + fast_memories = searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + + # Enhance with query + enhanced_results, _ = self.mem_scheduler.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=fast_memories, + ) + + return enhanced_results + + def _mix_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Mix search combining fast and fine-grained approaches. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted search results + """ + return self.mem_scheduler.mix_search_memories( + search_req=search_req, + user_context=user_context, + ) + def _get_sync_mode(self) -> str: """ Get synchronization mode from memory cube. diff --git a/src/memos/multi_mem_cube/views.py b/src/memos/multi_mem_cube/views.py index 7c58aa469..baf5e80e1 100644 --- a/src/memos/multi_mem_cube/views.py +++ b/src/memos/multi_mem_cube/views.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: - from memos.api.product_models import APIADDRequest + from memos.api.product_models import APIADDRequest, APISearchRequest class MemCubeView(Protocol): @@ -26,3 +26,16 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: - cube_id """ ... + + def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: + """ + Process search_req, read memories from one or more cubes and search them. + + Returns: + A list of memory dicts, each item should at least contain: + - memory + - memory_id + - memory_type + - cube_id + """ + ...