diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index e1bdd54e9..9b686a131 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -181,7 +181,7 @@ def search(self, query, user_id, top_k): "mem_cube_id": user_id, "conversation_id": "", "top_k": top_k, - "mode": "mixture", + "mode": "fast", "handle_pref_mem": False, }, ensure_ascii=False, diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 6de013313..92df1ecf8 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -23,7 +23,7 @@ def get_openai_config() -> dict[str, Any]: return { "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-4o-mini"), "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")), - "max_tokens": int(os.getenv("MOS_MAX_TOKENS", "1024")), + "max_tokens": int(os.getenv("MOS_MAX_TOKENS", "8000")), "top_p": float(os.getenv("MOS_TOP_P", "0.9")), "top_k": int(os.getenv("MOS_TOP_K", "50")), "remove_think_prefix": True, @@ -672,6 +672,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, + "mode": os.getenv("ASYNC_MODE", "sync"), }, }, "act_mem": {} diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index bb98f04ba..e9df292ad 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,9 +1,13 @@ +import json import os +import time import traceback +from datetime import datetime from typing import TYPE_CHECKING, Any from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse from memos.api.config import APIConfig from memos.api.product_models import ( @@ -32,8 +36,12 @@ from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + MEM_READ_LABEL, + PREF_ADD_LABEL, SearchMode, ) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.memories.textual.prefer_text_memory.config import ( AdderConfigFactory, ExtractorConfigFactory, @@ -233,6 +241,7 @@ def init_server(): chat_llm=llm, process_llm=mem_reader.llm, db_engine=BaseDBManager.create_default_sqlite_engine(), + mem_reader=mem_reader, ) mem_scheduler.current_mem_cube = naive_mem_cube mem_scheduler.start() @@ -477,6 +486,13 @@ def add_memories(add_req: APIADDRequest): if not target_session_id: target_session_id = "default_session" + # If text memory backend works in async mode, submit tasks to scheduler + try: + sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") + except Exception: + sync_mode = "sync" + logger.info(f"Add sync_mode mode is: {sync_mode}") + def _process_text_mem() -> list[dict[str, str]]: memories_local = mem_reader.get_memory( [add_req.messages], @@ -485,6 +501,7 @@ def _process_text_mem() -> list[dict[str, str]]: "user_id": add_req.user_id, "session_id": target_session_id, }, + mode="fast" if sync_mode == "async" else "fine", ) flattened_local = [mm for m in memories_local for mm in m] logger.info(f"Memory extraction completed for user {add_req.user_id}") @@ -496,6 +513,34 @@ def _process_text_mem() -> list[dict[str, str]]: f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " f"in session {add_req.session_id}: {mem_ids_local}" ) + if sync_mode == "async": + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids_local), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + mem_scheduler.submit_messages(messages=[message_item_read]) + logger.info(f"2105Submit messages!!!!!: {json.dumps(mem_ids_local)}") + except Exception as e: + logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) + else: + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids_local), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + mem_scheduler.submit_messages(messages=[message_item_add]) return [ { "memory": memory.memory, @@ -508,27 +553,46 @@ def _process_text_mem() -> list[dict[str, str]]: def _process_pref_mem() -> list[dict[str, str]]: if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] - pref_memories_local = naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - ) - pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) - logger.info( - f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " - f"in session {add_req.session_id}: {pref_ids_local}" - ) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] + # Follow async behavior similar to core.py: enqueue when async + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + mem_scheduler.submit_messages(messages=[message_item_pref]) + logger.info("Submitted preference add to scheduler (async mode)") + except Exception as e: + logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) + return [] + else: + pref_memories_local = naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) + logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_process_text_mem) @@ -542,6 +606,155 @@ def _process_pref_mem() -> list[dict[str, str]]: ) +@router.get("/scheduler/status", summary="Get scheduler running task count") +def scheduler_status(): + """ + Return current running tasks count from scheduler dispatcher. + Shape is consistent with /scheduler/wait. + """ + try: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + now_ts = time.time() + + return { + "message": "ok", + "data": { + "running_tasks": running_count, + "timestamp": now_ts, + }, + } + + except Exception as err: + logger.error("Failed to get scheduler status: %s", traceback.format_exc()) + + raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err + + +@router.post("/scheduler/wait", summary="Wait until scheduler is idle") +def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2): + """ + Block until scheduler has no running tasks, or timeout. + We return a consistent structured payload so callers can + tell whether this was a clean flush or a timeout. + + Args: + timeout_seconds: max seconds to wait + poll_interval: seconds between polls + """ + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + elapsed = time.time() - start + + # success -> scheduler is idle + if running_count == 0: + return { + "message": "idle", + "data": { + "running_tasks": 0, + "waited_seconds": round(elapsed, 3), + "timed_out": False, + }, + } + + # timeout check + if elapsed > timeout_seconds: + return { + "message": "timeout", + "data": { + "running_tasks": running_count, + "waited_seconds": round(elapsed, 3), + "timed_out": True, + }, + } + + time.sleep(poll_interval) + + except Exception as err: + logger.error( + "Failed while waiting for scheduler: %s", + traceback.format_exc(), + ) + raise HTTPException( + status_code=500, + detail="Failed while waiting for scheduler", + ) from err + + +@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)") +def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2): + """ + Stream scheduler progress via Server-Sent Events (SSE). + + Contract: + - We emit periodic heartbeat frames while tasks are still running. + - Each heartbeat frame is JSON, prefixed with "data: ". + - On final frame, we include status = "idle" or "timeout" and timed_out flag, + with the same semantics as /scheduler/wait. + + Example curl: + curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5" + """ + + def event_generator(): + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + elapsed = time.time() - start + + # heartbeat frame + heartbeat_payload = { + "running_tasks": running_count, + "elapsed_seconds": round(elapsed, 3), + "status": "running" if running_count > 0 else "idle", + } + yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n" + + # scheduler is idle -> final frame + break + if running_count == 0: + final_payload = { + "running_tasks": 0, + "elapsed_seconds": round(elapsed, 3), + "status": "idle", + "timed_out": False, + } + yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n" + break + + # timeout -> final frame + break + if elapsed > timeout_seconds: + final_payload = { + "running_tasks": running_count, + "elapsed_seconds": round(elapsed, 3), + "status": "timeout", + "timed_out": True, + } + yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n" + break + + time.sleep(poll_interval) + + except Exception as e: + err_payload = { + "status": "error", + "detail": "stream_failed", + "exception": str(e), + } + logger.error( + "Failed streaming scheduler wait: %s: %s", + e, + traceback.format_exc(), + ) + yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + @router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") def chat_complete(chat_req: APIChatCompleteRequest): """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 00bd04e6d..89b58f417 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -439,6 +439,7 @@ def remove_oldest_memory( Args: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name(str): optional user_name. """ try: user_name = user_name if user_name else self.config.user_name @@ -685,8 +686,7 @@ def get_node( Returns: dict: Node properties as key-value pairs, or None if not found. """ - user_name = user_name if user_name else self.config.user_name - filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"' + filter_clause = f'n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" MATCH (n@Memory) @@ -730,16 +730,13 @@ def get_nodes( """ if not ids: return [] - - user_name = user_name if user_name else self.config.user_name - where_user = f" AND n.user_name = '{user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) return_fields = self._build_return_fields(include_embedding) query = f""" MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.id IN [{id_list}] {where_user} + WHERE n.id IN [{id_list}] RETURN {return_fields} """ nodes = [] @@ -1497,10 +1494,10 @@ def _ensure_space_exists(cls, tmp_client, cfg): return try: - res = tmp_client.execute("SHOW GRAPHS;") + res = tmp_client.execute("SHOW GRAPHS") existing = {row.values()[0].as_string() for row in res} if db_name not in existing: - tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;") + tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type") logger.info(f"✅ Graph `{db_name}` created before session binding.") else: logger.debug(f"Graph `{db_name}` already exists.") diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index fd3a1ba22..f3a36a887 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -149,6 +149,7 @@ def remove_oldest_memory( Args: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name(str): optional user_name. """ user_name = user_name if user_name else self.config.user_name query = f""" diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 939b0c68d..97ff9879f 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -779,16 +779,16 @@ def process_textual_memory(): timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) - - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) + else: + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) def process_preference_memory(): if ( @@ -878,15 +878,16 @@ def process_preference_memory(): timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) + else: + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) # user doc input if ( diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 3958ee382..0360396af 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -134,6 +134,7 @@ def initialize_modules( chat_llm: BaseLLM, process_llm: BaseLLM | None = None, db_engine: Engine | None = None, + mem_reader=None, ): if process_llm is None: process_llm = chat_llm @@ -150,6 +151,9 @@ def initialize_modules( self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config) self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) + if mem_reader: + self.mem_reader = mem_reader + if self.enable_parallel_dispatch: self.dispatcher_monitor.initialize(dispatcher=self.dispatcher) self.dispatcher_monitor.start() diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 434cef3e9..6840adc2b 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -250,6 +250,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id mem_cube = message.mem_cube content = message.content + user_name = message.user_name # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -273,6 +274,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id=mem_cube_id, mem_cube=mem_cube, text_mem=text_mem, + user_name=user_name, ) logger.info( @@ -297,6 +299,7 @@ def _process_memories_with_reader( mem_cube_id: str, mem_cube: GeneralMemCube, text_mem: TreeTextMemory, + user_name: str, ) -> None: """ Process memories using mem_reader for enhanced memory processing. @@ -330,6 +333,18 @@ def _process_memories_with_reader( logger.warning("No valid memory items found for processing") return + # parse working_binding ids from the *original* memory_items (the raw items created in /add) + # these still carry metadata.background with "[working_binding:...]" so we can know + # which WorkingMemory clones should be cleaned up later. + from memos.memories.textual.tree_text_memory.organize.manager import ( + extract_working_binding_ids, + ) + + bindings_to_delete = extract_working_binding_ids(memory_items) + logger.info( + f"Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" + ) + # Use mem_reader to process the memories logger.info(f"Processing {len(memory_items)} memories with mem_reader") @@ -353,7 +368,7 @@ def _process_memories_with_reader( # Add the enhanced memories back to the memory system if flattened_memories: - enhanced_mem_ids = text_mem.add(flattened_memories) + enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name) logger.info( f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" ) @@ -362,9 +377,26 @@ def _process_memories_with_reader( else: logger.info("mem_reader returned no processed memories") - text_mem.delete(mem_ids) - logger.info("Delete raw mem_ids") - text_mem.memory_manager.remove_and_refresh_memory() + # build full delete list: + # - original raw mem_ids (temporary fast memories) + # - any bound working memories referenced by the enhanced memories + delete_ids = list(mem_ids) + if bindings_to_delete: + delete_ids.extend(list(bindings_to_delete)) + # deduplicate + delete_ids = list(dict.fromkeys(delete_ids)) + if delete_ids: + try: + text_mem.delete(delete_ids, user_name=user_name) + logger.info( + f"Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" + ) + except Exception as e: + logger.warning(f"Failed to delete some mem_ids {delete_ids}: {e}") + else: + logger.info("No mem_ids to delete (nothing to cleanup)") + + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) logger.info("Remove and Refresh Memories") logger.debug(f"Finished add {user_id} memory: {mem_ids}") @@ -382,6 +414,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id mem_cube = message.mem_cube content = message.content + user_name = message.user_name # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -405,6 +438,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id=mem_cube_id, mem_cube=mem_cube, text_mem=text_mem, + user_name=user_name, ) logger.info( @@ -429,6 +463,7 @@ def _process_memories_with_reorganize( mem_cube_id: str, mem_cube: GeneralMemCube, text_mem: TreeTextMemory, + user_name: str, ) -> None: """ Process memories using mem_reorganize for enhanced memory processing. @@ -455,7 +490,7 @@ def _process_memories_with_reorganize( memory_item = text_mem.get(mem_id) memory_items.append(memory_item) except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}") + logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}") continue if not memory_items: @@ -464,7 +499,7 @@ def _process_memories_with_reorganize( # Use mem_reader to process the memories logger.info(f"Processing {len(memory_items)} memories with mem_reader") - text_mem.memory_manager.remove_and_refresh_memory() + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) logger.info("Remove and Refresh Memories") logger.debug(f"Finished add {user_id} memory: {mem_ids}") diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index bd3155a96..9cdb6823d 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -42,6 +42,10 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): timestamp: datetime = Field( default_factory=get_utc_now, description="submit time for schedule_messages" ) + user_name: str | None = Field( + default=None, + description="user name / display name (optional)", + ) # Pydantic V2 model configuration model_config = ConfigDict( @@ -60,6 +64,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "mem_cube": "obj of GeneralMemCube", # Added mem_cube example "content": "sample content", # Example message content "timestamp": "2024-07-22T12:00:00Z", # Added timestamp example + "user_name": "Alice", # Added username example } }, ) @@ -81,6 +86,7 @@ def to_dict(self) -> dict: "cube": "Not Applicable", # Custom cube serialization "content": self.content, "timestamp": self.timestamp.isoformat(), + "user_name": self.user_name, } @classmethod @@ -94,6 +100,7 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), + user_name=data.get("user_name"), ) diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 52bf62c6d..8d07522cd 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -44,6 +44,8 @@ def __init__( """Initialize memory with the given configuration.""" time_start = time.time() self.config: TreeTextMemoryConfig = config + self.mode = self.config.mode + logger.info(f"Tree mode is {self.mode}") self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") @@ -79,20 +81,6 @@ def __init__( logger.info("No internet retriever configured") logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") - def add( - self, memories: list[TextualMemoryItem | dict[str, Any]], user_name: str | None = None - ) -> list[str]: - """Add memories. - Args: - memories: List of TextualMemoryItem objects or dictionaries to add. - Later: - memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] - metadata = extract_metadata(memory_items, self.extractor_llm) - plan = plan_memory_operations(memory_items, metadata, self.graph_store) - execute_plan(memory_items, metadata, plan, self.graph_store) - """ - return self.memory_manager.add(memories, user_name=user_name) - def replace_working_memory( self, memories: list[TextualMemoryItem], user_name: str | None = None ) -> None: @@ -271,17 +259,6 @@ def get(self, memory_id: str) -> TextualMemoryItem: def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: raise NotImplementedError - def get_all(self) -> dict: - """Get all memories. - Returns: - list[TextualMemoryItem]: List of all memories. - """ - all_items = self.graph_store.export_graph() - return all_items - - def delete(self, memory_ids: list[str]) -> None: - raise NotImplementedError - def delete_all(self) -> None: """Delete all memories and their relationships from the graph store.""" try: diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index fccd83fa6..472bed219 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -34,6 +34,8 @@ def __init__(self, config: TreeTextMemoryConfig): """Initialize memory with the given configuration.""" # Set mode from class default or override if needed self.mode = config.mode + logger.info(f"Tree mode is {self.mode}") + self.config: TreeTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm @@ -81,12 +83,18 @@ def __init__(self, config: TreeTextMemoryConfig): else: logger.info("No internet retriever configured") - def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: + def add( + self, + memories: list[TextualMemoryItem | dict[str, Any]], + user_name: str | None = None, + **kwargs, + ) -> list[str]: """Add memories. Args: memories: List of TextualMemoryItem objects or dictionaries to add. + user_name: optional user_name """ - return self.memory_manager.add(memories, mode=self.mode) + return self.memory_manager.add(memories, user_name=user_name, mode=self.mode) def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: self.memory_manager.replace_working_memory(memories) @@ -262,21 +270,21 @@ def get(self, memory_id: str) -> TextualMemoryItem: def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: raise NotImplementedError - def get_all(self) -> dict: + def get_all(self, user_name: str | None = None) -> dict: """Get all memories. Returns: list[TextualMemoryItem]: List of all memories. """ - all_items = self.graph_store.export_graph() + all_items = self.graph_store.export_graph(user_name=user_name) return all_items - def delete(self, memory_ids: list[str]) -> None: + def delete(self, memory_ids: list[str], user_name: str | None = None) -> None: """Hard delete: permanently remove nodes and their edges from the graph.""" if not memory_ids: return for mid in memory_ids: try: - self.graph_store.delete_node(mid) + self.graph_store.delete_node(mid, user_name=user_name) except Exception as e: logger.warning(f"TreeTextMemory.delete_hard: failed to delete {mid}: {e}") diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 54776134b..47cbf4ed1 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -1,3 +1,4 @@ +import re import traceback import uuid @@ -19,6 +20,37 @@ logger = get_logger(__name__) +def extract_working_binding_ids(mem_items: list[TextualMemoryItem]) -> set[str]: + """ + Scan enhanced memory items for background hints like + "[working_binding:]" and collect those working memory IDs. + + We store the working<->long binding inside metadata.background when + initially adding memories in async mode, so we can later clean up + the temporary WorkingMemory nodes after mem_reader produces the + final LongTermMemory/UserMemory. + + Args: + mem_items: list of TextualMemoryItem we just added (enhanced memories) + + Returns: + A set of working memory IDs (as strings) that should be deleted. + """ + bindings: set[str] = set() + pattern = re.compile(r"\[working_binding:([0-9a-fA-F-]{36})\]") + for item in mem_items: + try: + bg = getattr(item.metadata, "background", "") or "" + except Exception: + bg = "" + if not isinstance(bg, str): + continue + match = pattern.search(bg) + if match: + bindings.add(match.group(1)) + return bindings + + class MemoryManager: def __init__( self, @@ -129,15 +161,28 @@ def _refresh_memory_size(self, user_name: str | None = None) -> None: def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None): """ - Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). - This method runs asynchronously to process each memory item. + Process and add memory to different memory types. + + Behavior: + 1. Always create a WorkingMemory node from `memory` and get its node id. + 2. If `memory.metadata.memory_type` is "LongTermMemory" or "UserMemory", + also create a corresponding long/user node. + - In async mode, that long/user node's metadata will include + `working_binding` in `background` which records the WorkingMemory + node id created in step 1. + 3. Return ONLY the ids of the long/user nodes (NOT the working node id), + which preserves the previous external contract of `add()`. """ ids: list[str] = [] futures = [] + working_id = str(uuid.uuid4()) + with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: - f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name) - futures.append(f_working) + f_working = ex.submit( + self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id + ) + futures.append(("working", f_working)) if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"): f_graph = ex.submit( @@ -145,13 +190,14 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name, + working_binding=working_id, ) - futures.append(f_graph) + futures.append(("long", f_graph)) - for fut in as_completed(futures): + for kind, fut in futures: try: res = fut.result() - if isinstance(res, str) and res: + if kind != "working" and isinstance(res, str) and res: ids.append(res) except Exception: logger.warning("Parallel memory processing failed:\n%s", traceback.format_exc()) @@ -159,39 +205,51 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non return ids def _add_memory_to_db( - self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + self, + memory: TextualMemoryItem, + memory_type: str, + user_name: str | None = None, + forced_id: str | None = None, ) -> str: """ Add a single memory item to the graph store, with FIFO logic for WorkingMemory. + If forced_id is provided, use that as the node id. """ metadata = memory.metadata.model_copy(update={"memory_type": memory_type}).model_dump( exclude_none=True ) metadata["updated_at"] = datetime.now().isoformat() - working_memory = TextualMemoryItem(memory=memory.memory, metadata=metadata) - + node_id = forced_id or str(uuid.uuid4()) + working_memory = TextualMemoryItem(id=node_id, memory=memory.memory, metadata=metadata) # Insert node into graph self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) + return node_id def _add_to_graph_memory( - self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + self, + memory: TextualMemoryItem, + memory_type: str, + user_name: str | None = None, + working_binding: str | None = None, ): """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). - - Parameters: - - memory: memory item to insert - - memory_type: "LongTermMemory" | "UserMemory" - - similarity_threshold: deduplication threshold - - topic_summary_prefix: summary node id prefix if applicable - - enable_summary_link: whether to auto-link to a summary node """ node_id = str(uuid.uuid4()) # Step 2: Add new node to graph + metadata_dict = memory.metadata.model_dump(exclude_none=True) + tags = metadata_dict.get("tags") or [] + if working_binding and ("mode:fast" in tags): + prev_bg = metadata_dict.get("background", "") or "" + binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" + if prev_bg: + metadata_dict["background"] = prev_bg + " || " + binding_line + else: + metadata_dict["background"] = binding_line self.graph_store.add_node( node_id, memory.memory, - memory.metadata.model_dump(exclude_none=True), + metadata_dict, user_name=user_name, ) self.reorganizer.add_message( @@ -282,11 +340,11 @@ def _ensure_structure_path( # Step 3: Return this structure node ID as the parent_id return node_id - def remove_and_refresh_memory(self): - self._cleanup_memories_if_needed() - self._refresh_memory_size() + def remove_and_refresh_memory(self, user_name: str | None = None): + self._cleanup_memories_if_needed(user_name=user_name) + self._refresh_memory_size(user_name=user_name) - def _cleanup_memories_if_needed(self) -> None: + def _cleanup_memories_if_needed(self, user_name: str | None = None) -> None: """ Only clean up memories if we're close to or over the limit. This reduces unnecessary database operations. @@ -301,7 +359,7 @@ def _cleanup_memories_if_needed(self) -> None: if current_count >= threshold: try: self.graph_store.remove_oldest_memory( - memory_type=memory_type, keep_latest=limit + memory_type=memory_type, keep_latest=limit, user_name=user_name ) logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}") except Exception: diff --git a/tests/memories/textual/test_tree.py b/tests/memories/textual/test_tree.py index 772a79d78..a72709ec5 100644 --- a/tests/memories/textual/test_tree.py +++ b/tests/memories/textual/test_tree.py @@ -66,7 +66,9 @@ def test_add_calls_manager(mock_tree_text_memory): metadata=TreeNodeTextualMemoryMetadata(updated_at=None), ) mock_tree_text_memory.add([mock_item]) - mock_tree_text_memory.memory_manager.add.assert_called_once_with([mock_item], mode="sync") + mock_tree_text_memory.memory_manager.add.assert_called_once_with( + [mock_item], user_name=None, mode="sync" + ) def test_get_working_memory_sorted(mock_tree_text_memory): @@ -161,4 +163,6 @@ def test_add_returns_ids(mock_tree_text_memory): result = mock_tree_text_memory.add(mock_items) assert result == dummy_ids - mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items, mode="sync") + mock_tree_text_memory.memory_manager.add.assert_called_once_with( + mock_items, user_name=None, mode="sync" + )