diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index ee481d028..9b41477e1 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -5,21 +5,13 @@ using dependency injection for better modularity and testability. """ -import json -import os - from datetime import datetime from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APIADDRequest, MemoryResponse -from memos.context.context import ContextThreadPoolExecutor -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - MEM_READ_LABEL, - PREF_ADD_LABEL, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.types import UserContext +from memos.multi_mem_cube.composite_cube import CompositeCubeView +from memos.multi_mem_cube.single_cube import SingleCubeView +from memos.multi_mem_cube.views import MemCubeView class AddHandler(BaseHandler): @@ -52,33 +44,69 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: Returns: MemoryResponse with added memory information """ - # Create UserContext object - user_context = UserContext( - user_id=add_req.user_id, - mem_cube_id=add_req.mem_cube_id, - session_id=add_req.session_id or "default_session", - ) + self.logger.info(f"[AddHandler] Add Req is: {add_req}") - self.logger.info(f"Add Req is: {add_req}") - if (not add_req.messages) and add_req.memory_content: + if (not add_req.messages) and getattr(add_req, "memory_content", None): add_req.messages = self._convert_content_messsage(add_req.memory_content) - self.logger.info(f"Converted Add Req content to messages: {add_req.messages}") - # Process text and preference memories in parallel - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._process_text_mem, add_req, user_context) - pref_future = executor.submit(self._process_pref_mem, add_req, user_context) + self.logger.info(f"[AddHandler] Converted content to messages: {add_req.messages}") - text_response_data = text_future.result() - pref_response_data = pref_future.result() + cube_view = self._build_cube_view(add_req) - self.logger.info(f"add_memories Text response data: {text_response_data}") - self.logger.info(f"add_memories Pref response data: {pref_response_data}") + results = cube_view.add_memories(add_req) + + self.logger.info(f"[AddHandler] Final add results count={len(results)}") return MemoryResponse( message="Memory added successfully", - data=text_response_data + pref_response_data, + data=results, ) + def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]: + """ + Normalize target cube ids from add_req. + Priority: + 1) writable_cube_ids + 2) mem_cube_id + 3) fallback to user_id + """ + if getattr(add_req, "writable_cube_ids", None): + return list(dict.fromkeys(add_req.writable_cube_ids)) + + if add_req.mem_cube_id: + return [add_req.mem_cube_id] + + return [add_req.user_id] + + def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView: + cube_ids = self._resolve_cube_ids(add_req) + + if len(cube_ids) == 1: + cube_id = cube_ids[0] + return SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, + searcher=None, + ) + else: + single_views = [ + SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, + searcher=None, + ) + for cube_id in cube_ids + ] + return CompositeCubeView( + cube_views=single_views, + logger=self.logger, + ) + def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]]: """ Convert content string to list of message dictionaries. @@ -98,197 +126,3 @@ def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]] ] # for only user-str input and convert message return messages_list - - def _process_text_mem( - self, - add_req: APIADDRequest, - user_context: UserContext, - ) -> list[dict[str, str]]: - """ - Process and add text memories. - - Extracts memories from messages and adds them to the text memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - - Returns: - List of formatted memory responses - """ - target_session_id = add_req.session_id or "default_session" - - # Determine sync mode - sync_mode = add_req.async_mode or self._get_sync_mode() - - self.logger.info(f"Processing text memory with mode: {sync_mode}") - - # Extract memories - memories_local = self.mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - mode="fast" if sync_mode == "async" else "fine", - ) - flattened_local = [mm for m in memories_local for mm in m] - self.logger.info(f"Memory extraction completed for user {add_req.user_id}") - - # Add memories to text_mem - mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( - flattened_local, - user_name=user_context.mem_cube_id, - ) - self.logger.info( - f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_ids_local}" - ) - - # Schedule async/sync tasks - self._schedule_memory_tasks( - add_req=add_req, - user_context=user_context, - mem_ids=mem_ids_local, - sync_mode=sync_mode, - ) - - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) - ] - - def _process_pref_mem( - self, - add_req: APIADDRequest, - user_context: UserContext, - ) -> list[dict[str, str]]: - """ - Process and add preference memories. - - Extracts preferences from messages and adds them to the preference memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - - Returns: - List of formatted preference responses - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - # Determine sync mode - sync_mode = add_req.async_mode or self._get_sync_mode() - target_session_id = add_req.session_id or "default_session" - - # Follow async behavior: enqueue when async - if sync_mode == "async": - try: - messages_list = [add_req.messages] - message_item_pref = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=self.naive_mem_cube, - label=PREF_ADD_LABEL, - content=json.dumps(messages_list), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_pref]) - self.logger.info("Submitted preference add to scheduler (async mode)") - except Exception as e: - self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) - return [] - else: - # Sync mode: process immediately - pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - "mem_cube_id": add_req.mem_cube_id, - }, - ) - pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) - self.logger.info( - f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " - f"in session {add_req.session_id}: {pref_ids_local}" - ) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] - - def _get_sync_mode(self) -> str: - """ - Get synchronization mode from memory cube. - - Returns: - Sync mode string ("sync" or "async") - """ - try: - return getattr(self.naive_mem_cube.text_mem, "mode", "sync") - except Exception: - return "sync" - - def _schedule_memory_tasks( - self, - add_req: APIADDRequest, - user_context: UserContext, - mem_ids: list[str], - sync_mode: str, - ) -> None: - """ - Schedule memory processing tasks based on sync mode. - - Args: - add_req: Add memory request - user_context: User context - mem_ids: List of memory IDs - sync_mode: Synchronization mode - """ - target_session_id = add_req.session_id or "default_session" - - if sync_mode == "async": - # Async mode: submit MEM_READ_LABEL task - try: - message_item_read = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=self.naive_mem_cube, - label=MEM_READ_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_read]) - self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}") - except Exception as e: - self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) - else: - # Sync mode: submit ADD_LABEL task - message_item_add = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=self.naive_mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_add]) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 7d7d52dc4..8a2c21aad 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,21 +5,12 @@ using dependency injection for better modularity and testability. """ -import os -import traceback - -from typing import Any - from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies -from memos.api.handlers.formatters_handler import ( - format_memory_item, - post_process_pref_mem, -) from memos.api.product_models import APISearchRequest, SearchResponse -from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import FINE_STRATEGY, FineStrategy, SearchMode -from memos.types import MOSSearchResult, UserContext +from memos.multi_mem_cube.composite_cube import CompositeCubeView +from memos.multi_mem_cube.single_cube import SingleCubeView +from memos.multi_mem_cube.views import MemCubeView logger = get_logger(__name__) @@ -55,274 +46,58 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse Returns: SearchResponse with formatted results """ - # Create UserContext object - user_context = UserContext( - user_id=search_req.user_id, - mem_cube_id=search_req.mem_cube_id, - session_id=search_req.session_id or "default_session", - ) - self.logger.info(f"Search Req is: {search_req}") - - memories_result: MOSSearchResult = { - "text_mem": [], - "act_mem": [], - "para_mem": [], - "pref_mem": [], - "pref_note": "", - } - - # Determine search mode - search_mode = self._get_search_mode(search_req.mode) + self.logger.info(f"[SearchHandler] Search Req is: {search_req}") - # Execute search in parallel for text and preference memories - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._search_text, search_req, user_context, search_mode) - pref_future = executor.submit(self._search_pref, search_req, user_context) + cube_view = self._build_cube_view(search_req) - text_formatted_memories = text_future.result() - pref_formatted_memories = pref_future.result() - - # Build result - memories_result["text_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": text_formatted_memories, - } - ) - - memories_result = post_process_pref_mem( - memories_result, - pref_formatted_memories, - search_req.mem_cube_id, - search_req.include_preference, - ) + results = cube_view.search_memories(search_req) - self.logger.info(f"Search memories result: {memories_result}") + self.logger.info(f"[AddHandler] Final add results count={len(results)}") return SearchResponse( - message="Search completed successfully", - data=memories_result, + message="Memory searched successfully", + data=results, ) - def _get_search_mode(self, mode: str) -> str: - return mode - - def _search_text( - self, - search_req: APISearchRequest, - user_context: UserContext, - search_mode: str, - ) -> list[dict[str, Any]]: - """ - Search text memories based on mode. - - Args: - search_req: Search request - user_context: User context - search_mode: Search mode (FAST, FINE, or MIXTURE) - - Returns: - List of formatted memory items - """ - try: - if search_mode == SearchMode.FAST: - text_memories = self._fast_search(search_req, user_context) - elif search_mode == SearchMode.FINE: - text_memories = self._fine_search(search_req, user_context) - elif search_mode == SearchMode.MIXTURE: - text_memories = self._mix_search(search_req, user_context) - else: - self.logger.error(f"Unsupported search mode: {search_mode}") - return [] - - return text_memories - - except Exception as e: - self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) - return [] - - def _search_pref( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list[dict[str, Any]]: + def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: """ - Search preference memories. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of formatted preference memory items + Normalize target cube ids from search_req. + Priority: + 1) readable_cube_ids + 2) mem_cube_id + 3) fallback to user_id """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - try: - results = self.naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, + if getattr(search_req, "readable_cube_ids", None): + return list(dict.fromkeys(search_req.readable_cube_ids)) + + if search_req.mem_cube_id: + return [search_req.mem_cube_id] + + return [search_req.user_id] + + def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: + cube_ids = self._resolve_cube_ids(search_req) + + if len(cube_ids) == 1: + cube_id = cube_ids[0] + return SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, + searcher=self.searcher, ) - return [format_memory_item(data) for data in results] - except Exception as e: - self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) - return [] - - def _fast_search( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list: - """ - Fast search using vector database. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of search results - """ - target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - search_results = self.naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - ) - - formatted_memories = [format_memory_item(data) for data in search_results] - - return formatted_memories - - def _deep_search( - self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int - ) -> list: - logger.error("waiting to be implemented") - return [] - - def _fine_search( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list[str]: - """ - Fine-grained search with query enhancement. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of enhanced search results - """ - if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: - return self._deep_search( - search_req=search_req, user_context=user_context, max_thinking_depth=3 - ) - - target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } - - # Fine retrieve - raw_retrieved_memories = self.searcher.retrieve( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FINE, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info=info, - ) - - # Post retrieve - raw_memories = self.searcher.post_retrieve( - retrieved_results=raw_retrieved_memories, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - - # Enhance with query - enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query( - query_history=[search_req.query], - memories=raw_memories, - ) - - if len(enhanced_memories) < len(raw_memories): - logger.info( - f"Enhanced memories ({len(enhanced_memories)}) are less than raw memories ({len(raw_memories)}). Recalling for more." - ) - missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories( - query=search_req.query, - memories=raw_memories, - ) - retrieval_size = len(raw_memories) - len(enhanced_memories) - logger.info(f"Retrieval size: {retrieval_size}") - if trigger: - logger.info(f"Triggering additional search with hint: {missing_info_hint}") - additional_memories = self.searcher.search( - query=missing_info_hint, - user_name=user_context.mem_cube_id, - top_k=retrieval_size, - mode=SearchMode.FAST, - memory_type="All", - search_filter=search_filter, - info=info, + else: + single_views = [ + SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, + searcher=self.searcher, ) - else: - logger.info("Not triggering additional search, using fast memories.") - additional_memories = raw_memories[:retrieval_size] - - enhanced_memories += additional_memories - logger.info( - f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}" - ) - formatted_memories = [format_memory_item(data) for data in enhanced_memories] - - logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") - - return formatted_memories - - def _mix_search( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list: - """ - Mix search combining fast and fine-grained approaches. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of formatted search results - """ - return self.mem_scheduler.mix_search_memories( - search_req=search_req, - user_context=user_context, - ) + for cube_id in cube_ids + ] + return CompositeCubeView(cube_views=single_views, logger=self.logger) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index f7f0304c7..cb72011a3 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -73,6 +73,12 @@ class ChatRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") + readable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can read for multi-cube chat" + ) + writable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can write for multi-cube chat" + ) history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(True, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") @@ -172,6 +178,9 @@ class APISearchRequest(BaseRequest): query: str = Field(..., description="Search query") user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") + readable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can read for multi-cube search" + ) mode: SearchMode = Field( os.getenv("SEARCH_MODE", SearchMode.FAST), description="search mode: fast, fine, or mixture" ) @@ -191,7 +200,10 @@ class APIADDRequest(BaseRequest): """Request model for creating memories.""" user_id: str = Field(None, description="User ID") - mem_cube_id: str = Field(..., description="Cube ID") + mem_cube_id: str | None = Field(None, description="Cube ID") + writable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can write for multi-cube add" + ) messages: list[MessageDict] | None = Field(None, description="List of messages to store.") memory_content: str | None = Field(None, description="Memory content to store") doc_path: str | None = Field(None, description="Path to document to store") @@ -212,6 +224,12 @@ class APIChatCompleteRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") + readable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can read for multi-cube chat" + ) + writable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can write for multi-cube chat" + ) history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(True, description="Whether to use MemOSCube") diff --git a/src/memos/multi_mem_cube/__init__.py b/src/memos/multi_mem_cube/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/multi_mem_cube/composite_cube.py b/src/memos/multi_mem_cube/composite_cube.py new file mode 100644 index 000000000..8f892d60d --- /dev/null +++ b/src/memos/multi_mem_cube/composite_cube.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from memos.multi_mem_cube.views import MemCubeView + + +if TYPE_CHECKING: + from memos.api.product_models import APIADDRequest, APISearchRequest + from memos.multi_mem_cube.single_cube import SingleCubeView + + +@dataclass +class CompositeCubeView(MemCubeView): + """ + A composite view over multiple logical cubes. + + For now (fast mode), it simply fan-out writes to all cubes; + later we can add smarter routing / slow mode here. + """ + + cube_views: list[SingleCubeView] + logger: Any + + def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: + all_results: list[dict[str, Any]] = [] + + # fast mode: for each cube view, add memories + # maybe add more strategies in add_req.async_mode + for view in self.cube_views: + self.logger.info(f"[CompositeCubeView] fan-out add to cube={view.cube_id}") + results = view.add_memories(add_req) + all_results.extend(results) + + return all_results + + def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: + # aggregated MOSSearchResult + merged_results: dict[str, Any] = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + "pref_mem": [], + "pref_note": "", + } + + for view in self.cube_views: + self.logger.info(f"[CompositeCubeView] fan-out search to cube={view.cube_id}") + cube_result = view.search_memories(search_req) + merged_results["text_mem"].extend(cube_result.get("text_mem", [])) + merged_results["act_mem"].extend(cube_result.get("act_mem", [])) + merged_results["para_mem"].extend(cube_result.get("para_mem", [])) + merged_results["pref_mem"].extend(cube_result.get("pref_mem", [])) + + note = cube_result.get("pref_note") + if note: + if merged_results["pref_note"]: + merged_results["pref_note"] += " | " + note + else: + merged_results["pref_note"] = note + + return merged_results diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py new file mode 100644 index 000000000..f34cad1ef --- /dev/null +++ b/src/memos/multi_mem_cube/single_cube.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +import json +import os +import traceback + +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from memos.api.handlers.formatters_handler import ( + format_memory_item, + post_process_pref_mem, +) +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + FINE_STRATEGY, + MEM_READ_LABEL, + PREF_ADD_LABEL, + FineStrategy, + SearchMode, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.multi_mem_cube.views import MemCubeView +from memos.types import MOSSearchResult, UserContext + + +logger = get_logger(__name__) + + +if TYPE_CHECKING: + from memos.api.product_models import APIADDRequest, APISearchRequest + + +@dataclass +class SingleCubeView(MemCubeView): + cube_id: str + naive_mem_cube: Any + mem_reader: Any + mem_scheduler: Any + logger: Any + searcher: Any + + def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: + """ + This is basically your current handle_add_memories logic, + but scoped to a single cube_id. + """ + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=self.cube_id, + session_id=add_req.session_id or "default_session", + ) + + target_session_id = add_req.session_id or "default_session" + sync_mode = add_req.async_mode or self._get_sync_mode() + + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} " + f"Processing add with mode={sync_mode}, session={target_session_id}" + ) + + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._process_text_mem, add_req, user_context, sync_mode) + pref_future = executor.submit(self._process_pref_mem, add_req, user_context, sync_mode) + + text_results = text_future.result() + pref_results = pref_future.result() + + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} text_results={len(text_results)}, " + f"pref_results={len(pref_results)}" + ) + + for item in text_results: + item["cube_id"] = self.cube_id + for item in pref_results: + item["cube_id"] = self.cube_id + + return text_results + pref_results + + def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: + # Create UserContext object + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=self.cube_id, + session_id=search_req.session_id or "default_session", + ) + self.logger.info(f"Search Req is: {search_req}") + + memories_result: MOSSearchResult = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + "pref_mem": [], + "pref_note": "", + } + + # Determine search mode + search_mode = self._get_search_mode(search_req.mode) + + # Execute search in parallel for text and preference memories + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._search_text, search_req, user_context, search_mode) + pref_future = executor.submit(self._search_pref, search_req, user_context) + + text_formatted_memories = text_future.result() + pref_formatted_memories = pref_future.result() + + # Build result + memories_result["text_mem"].append( + { + "cube_id": self.cube_id, + "memories": text_formatted_memories, + } + ) + + memories_result = post_process_pref_mem( + memories_result, + pref_formatted_memories, + self.cube_id, + search_req.include_preference, + ) + + self.logger.info(f"Search memories result: {memories_result}") + + return memories_result + + def _get_search_mode(self, mode: str) -> str: + """ + Get search mode with environment variable fallback. + + Args: + mode: Requested search mode + + Returns: + Search mode string + """ + return mode + + def _search_text( + self, + search_req: APISearchRequest, + user_context: UserContext, + search_mode: str, + ) -> list[dict[str, Any]]: + """ + Search text memories based on mode. + + Args: + search_req: Search request + user_context: User context + search_mode: Search mode (FAST, FINE, or MIXTURE) + + Returns: + List of formatted memory items + """ + try: + if search_mode == SearchMode.FAST: + text_memories = self._fast_search(search_req, user_context) + elif search_mode == SearchMode.FINE: + text_memories = self._fine_search(search_req, user_context) + elif search_mode == SearchMode.MIXTURE: + text_memories = self._mix_search(search_req, user_context) + else: + self.logger.error(f"Unsupported search mode: {search_mode}") + return [] + + return text_memories + + except Exception as e: + self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _search_pref( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list[dict[str, Any]]: + """ + Search preference memories. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted preference memory items + TODO: ADD CUBE ID IN PREFERENCE MEMORY + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + try: + results = self.naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.pref_top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [format_memory_item(data) for data in results] + except Exception as e: + self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _fast_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fast search using vector database. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + search_results = self.naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + + formatted_memories = [format_memory_item(data) for data in search_results] + + return formatted_memories + + def _deep_search( + self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int + ) -> list: + logger.error("waiting to be implemented") + return [] + + def _fine_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fine-grained search with query enhancement. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of enhanced search results + """ + if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: + return self._deep_search( + search_req=search_req, user_context=user_context, max_thinking_depth=3 + ) + + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + # Fast retrieve + fast_retrieved_memories = self.searcher.retrieve( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FINE, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info=info, + ) + + # Post retrieve + raw_memories = self.searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + + # Enhance with query + enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=raw_memories, + ) + + if len(enhanced_memories) < len(raw_memories): + logger.info( + f"Enhanced memories ({len(enhanced_memories)}) are less than raw memories ({len(raw_memories)}). Recalling for more." + ) + missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories( + query=search_req.query, + memories=raw_memories, + ) + retrieval_size = len(raw_memories) - len(enhanced_memories) + logger.info(f"Retrieval size: {retrieval_size}") + if trigger: + logger.info(f"Triggering additional search with hint: {missing_info_hint}") + additional_memories = self.searcher.search( + query=missing_info_hint, + user_name=user_context.mem_cube_id, + top_k=retrieval_size, + mode=SearchMode.FAST, + memory_type="All", + search_filter=search_filter, + info=info, + ) + else: + logger.info("Not triggering additional search, using fast memories.") + additional_memories = raw_memories[:retrieval_size] + + enhanced_memories += additional_memories + logger.info( + f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}" + ) + formatted_memories = [format_memory_item(data) for data in enhanced_memories] + + logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") + + return formatted_memories + + def _mix_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Mix search combining fast and fine-grained approaches. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted search results + """ + return self.mem_scheduler.mix_search_memories( + search_req=search_req, + user_context=user_context, + ) + + def _get_sync_mode(self) -> str: + """ + Get synchronization mode from memory cube. + + Returns: + Sync mode string ("sync" or "async") + """ + try: + return getattr(self.naive_mem_cube.text_mem, "mode", "sync") + except Exception: + return "sync" + + def _schedule_memory_tasks( + self, + add_req: APIADDRequest, + user_context: UserContext, + mem_ids: list[str], + sync_mode: str, + ) -> None: + """ + Schedule memory processing tasks based on sync mode. + + Args: + add_req: Add memory request + user_context: User context + mem_ids: List of memory IDs + sync_mode: Synchronization mode + """ + target_session_id = add_req.session_id or "default_session" + + if sync_mode == "async": + # Async mode: submit MEM_READ_LABEL task + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=self.cube_id, + ) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_read]) + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} Submitted async MEM_READ: {json.dumps(mem_ids)}" + ) + except Exception as e: + self.logger.error( + f"[SingleCubeView] cube={self.cube_id} Failed to submit async memory tasks: {e}", + exc_info=True, + ) + else: + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=self.cube_id, + ) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_add]) + + def _process_pref_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + sync_mode: str, + ) -> list[dict[str, Any]]: + """ + Process and add preference memories. + + Extracts preferences from messages and adds them to the preference memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted preference responses + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + target_session_id = add_req.session_id or "default_session" + + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=self.cube_id, + mem_cube=self.naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_pref]) + self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async") + except Exception as e: + self.logger.error( + f"[SingleCubeView] cube={self.cube_id} Failed to submit PREF_ADD: {e}", + exc_info=True, + ) + return [] + else: + pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + "mem_cube_id": self.cube_id, + }, + ) + pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) + self.logger.info( + f"[SingleCubeView] cube={self.cube_id} " + f"added {len(pref_ids_local)} preferences for user {add_req.user_id}: {pref_ids_local}" + ) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + def _process_text_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + sync_mode: str, + ) -> list[dict[str, Any]]: + """ + Process and add text memories. + + Extracts memories from messages and adds them to the text memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted memory responses + """ + target_session_id = add_req.session_id or "default_session" + + self.logger.info( + f"[SingleCubeView] cube={user_context.mem_cube_id} " + f"Processing text memory with mode: {sync_mode}" + ) + + # Extract memories + memories_local = self.mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode="fast" if sync_mode == "async" else "fine", + ) + flattened_local = [mm for m in memories_local for mm in m] + self.logger.info(f"Memory extraction completed for user {add_req.user_id}") + + # Add memories to text_mem + mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + self.logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + + # Schedule async/sync tasks + self._schedule_memory_tasks( + add_req=add_req, + user_context=user_context, + mem_ids=mem_ids_local, + sync_mode=sync_mode, + ) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] diff --git a/src/memos/multi_mem_cube/views.py b/src/memos/multi_mem_cube/views.py new file mode 100644 index 000000000..baf5e80e1 --- /dev/null +++ b/src/memos/multi_mem_cube/views.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + + +if TYPE_CHECKING: + from memos.api.product_models import APIADDRequest, APISearchRequest + + +class MemCubeView(Protocol): + """ + A high-level cube view used by AddHandler. + It may wrap a single logical cube or multiple cubes, + but exposes a unified add_memories interface. + """ + + def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: + """ + Process add_req, extract memories and write them into one or more cubes. + + Returns: + A list of memory dicts, each item should at least contain: + - memory + - memory_id + - memory_type + - cube_id + """ + ... + + def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: + """ + Process search_req, read memories from one or more cubes and search them. + + Returns: + A list of memory dicts, each item should at least contain: + - memory + - memory_id + - memory_type + - cube_id + """ + ...