From d9f863e9aa13bbfb6517a24028a86cc1643ed80c Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:05:53 +0800 Subject: [PATCH 01/64] fix: sqlite list users error (#384) fix: sqlite users error --- src/memos/mem_user/persistent_user_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/mem_user/persistent_user_manager.py b/src/memos/mem_user/persistent_user_manager.py index e3c476262..d6f7b3155 100644 --- a/src/memos/mem_user/persistent_user_manager.py +++ b/src/memos/mem_user/persistent_user_manager.py @@ -177,7 +177,7 @@ def delete_user_config(self, user_id: str) -> bool: finally: session.close() - def list_user_configs(self) -> dict[str, MOSConfig]: + def list_user_configs(self, limit: int = 1) -> dict[str, MOSConfig]: """List all user configurations. Returns: @@ -185,7 +185,7 @@ def list_user_configs(self) -> dict[str, MOSConfig]: """ session = self._get_session() try: - user_configs = session.query(UserConfig).all() + user_configs = session.query(UserConfig).limit(limit).all() result = {} for user_config in user_configs: From b5ea7e61aa94594d038d7cb19b031d6b040d01cc Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Thu, 23 Oct 2025 16:13:08 +0800 Subject: [PATCH 02/64] feat: introduce async memory add for TreeTextMemory using MemScheduler (#373) * feat: define mem-read schedular message&consumer; add async mem-reader mode in core; * feat: add fast/fine mode in mem-reader; * feat: add mem-reader in scheduler * feat: change async remove * feat: modify async-add in core.py * feat: add 'remove and refresh memory in schedular' * feat: add naive fast mode in mem-reader * feat: finish fast mode in mem-reader * feat: add token-based window splitting and concurrency improvements * feat: add split chunker into mode in simple struct mem reader * feat: update async-mode add * chore: update gitignore * feat: improve database note write performance * feat: fix mem-read scheduler * fix: nebula group-by bug * fix: bug in adding mem scheduler * fix: nebula index; mem-reader chat-time; * format: searcher * fix: some bug in shceduler and mem-reader * feat: add mem-organize in scheduler * feat: add tree.mode to config; modify scheduler config * fix: test bug --- .gitignore | 1 + examples/mem_reader/reader.py | 400 +++++++++++++++++- src/memos/chunkers/sentence_chunker.py | 2 +- src/memos/configs/mem_scheduler.py | 2 - src/memos/configs/memory.py | 5 + src/memos/graph_dbs/nebular.py | 39 +- src/memos/graph_dbs/neo4j.py | 2 +- src/memos/mem_os/core.py | 46 +- src/memos/mem_reader/base.py | 9 +- src/memos/mem_reader/simple_struct.py | 332 ++++++++++++--- src/memos/mem_scheduler/base_scheduler.py | 156 ++++++- .../mem_scheduler/general_modules/misc.py | 23 +- src/memos/mem_scheduler/general_scheduler.py | 250 ++++++++++- .../mem_scheduler/schemas/general_schemas.py | 2 + src/memos/memories/textual/base.py | 3 + src/memos/memories/textual/general.py | 2 + src/memos/memories/textual/naive.py | 2 + src/memos/memories/textual/tree.py | 40 +- .../tree_text_memory/organize/manager.py | 89 ++-- .../tree_text_memory/retrieve/recall.py | 2 + .../tree_text_memory/retrieve/searcher.py | 6 +- tests/memories/textual/test_tree.py | 4 +- 22 files changed, 1241 insertions(+), 176 deletions(-) diff --git a/.gitignore b/.gitignore index ae7bdc4d6..8319a4d2f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ evaluation/.env !evaluation/configs-example/*.json evaluation/configs/* **tree_textual_memory_locomo** +**script.py** .env evaluation/scripts/personamem diff --git a/examples/mem_reader/reader.py b/examples/mem_reader/reader.py index e26d00a67..3da5d5e76 100644 --- a/examples/mem_reader/reader.py +++ b/examples/mem_reader/reader.py @@ -2,6 +2,11 @@ from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) def main(): @@ -11,7 +16,7 @@ def main(): ) reader = SimpleStructMemReader(reader_config) - # 3. Define scene data + # 2. Define scene data scene_data = [ [ {"role": "user", "chat_time": "3 May 2025", "content": "I'm feeling a bit down today."}, @@ -187,32 +192,389 @@ def main(): ], ] - # 4. Acquiring memories + print("=== Mem-Reader Fast vs Fine Mode Comparison ===\n") + + # 3. Test Fine Mode (default) + print("🔄 Testing FINE mode (default, with LLM processing)...") + start_time = time.time() + fine_memory = reader.get_memory( + scene_data, type="chat", info={"user_id": "user1", "session_id": "session1"}, mode="fine" + ) + fine_time = time.time() - start_time + print(f"✅ Fine mode completed in {fine_time:.2f} seconds") + print(f"📊 Fine mode generated {sum(len(mem_list) for mem_list in fine_memory)} memory items") + + # 4. Test Fast Mode + print("\n⚡ Testing FAST mode (quick processing, no LLM calls)...") start_time = time.time() - chat_memory = reader.get_memory( - scene_data, type="chat", info={"user_id": "user1", "session_id": "session1"} + fast_memory = reader.get_memory( + scene_data, type="chat", info={"user_id": "user1", "session_id": "session1"}, mode="fast" ) - print("\nChat Memory:\n", chat_memory) + fast_time = time.time() - start_time + print(f"✅ Fast mode completed in {fast_time:.2f} seconds") + print(f"📊 Fast mode generated {sum(len(mem_list) for mem_list in fast_memory)} memory items") + + # 5. Performance Comparison + print("\n📈 Performance Comparison:") + print(f" Fine mode: {fine_time:.2f}s") + print(f" Fast mode: {fast_time:.2f}s") + print(f" Speed improvement: {fine_time / fast_time:.1f}x faster") + + # 6. Show sample results from both modes + print("\n🔍 Sample Results Comparison:") + print("\n--- FINE Mode Results (first 3 items) ---") + for i, mem_list in enumerate(fine_memory[:3]): + for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list + print(f" [{i}][{j}] {mem_item.memory[:100]}...") - # 5. Example of processing documents - print("\n=== Processing Documents ===") + print("\n--- FAST Mode Results (first 3 items) ---") + for i, mem_list in enumerate(fast_memory[:3]): + for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list + print(f" [{i}][{j}] {mem_item.memory[:100]}...") + + # 7. Example of transfer fast mode result into fine result + fast_mode_memories = [ + TextualMemoryItem( + id="4553141b-3a33-4548-b779-e677ec797a9f", + memory="user: Nate:Oh cool! I might check that one out some time soon! I do love watching classics.\nassistant: Joanna:Yep, that movie is awesome. I first watched it around 3 years ago. I even went out and got a physical copy!\nuser: Nate:Sounds cool! Have you seen it a lot? sounds like you know the movie well!\nassistant: Joanna:A few times. It's one of my favorites! I really like the idea and the acting.\nuser: Nate:Cool! I'll definitely check it out. Thanks for the recommendation!\nassistant: Joanna:No problem, Nate! Let me know if you like it!\n", + metadata=TreeNodeTextualMemoryMetadata( + user_id="nate_test", + session_id="root_session", + status="activated", + type="fact", + key="user: Nate:Oh cool", + confidence=0.9900000095367432, + source=None, + tags=["mode:fast", "lang:en", "role:assistant", "role:user"], + visibility=None, + updated_at="2025-10-16T17:16:30.094877+08:00", + memory_type="LongTermMemory", + sources=[ + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=0, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=1, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=2, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=3, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=4, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=5, + ), + ], + embedding=None, + created_at="2025-10-16T17:16:30.094919+08:00", + usage=[], + background="", + ), + ), + TextualMemoryItem( + id="752e42fa-92b6-491a-a430-6864a7730fba", + memory="user: Nate:It was! How about you? Do you have any hobbies you love?\nassistant: Joanna:Yeah! Besides writing, I also enjoy reading, watching movies, and exploring nature. Anything else you enjoy doing, Nate?\nuser: Nate:Playing video games and watching movies are my main hobbies.\nassistant: Joanna:Cool, Nate! So we both have similar interests. What type of movies do you like best?\nuser: Nate:I love action and sci-fi movies, the effects are so cool! What about you, what's your favorite genre?\nassistant: Joanna:I'm all about dramas and romcoms. I love getting immersed in the feelings and plots.\nuser: Nate:Wow, movies can be so powerful! Do you have any recommendations for me?\nassistant: Joanna:Yeah, totally! Have you seen this romantic drama that's all about memory and relationships? It's such a good one.\nuser: Nate:Oh cool! I might check that one out some time soon! I do love watching classics.\nassistant: Joanna:Yep, that movie is awesome. I first watched it around 3 years ago. I even went out and got a physical copy!\n", + metadata=TreeNodeTextualMemoryMetadata( + user_id="nate_test", + session_id="root_session", + status="activated", + type="fact", + key="user: Nate:It was", + confidence=0.9900000095367432, + source=None, + tags=["mode:fast", "lang:en", "role:assistant", "role:user"], + visibility=None, + updated_at="2025-10-16T17:16:30.095726+08:00", + memory_type="LongTermMemory", + sources=[ + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=0, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=1, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=2, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=3, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=4, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=5, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=6, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=7, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=8, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=9, + ), + ], + embedding=None, + created_at="2025-10-16T17:16:30.095767+08:00", + usage=[], + background="", + ), + ), + TextualMemoryItem( + id="c9cf448c-deee-43a8-bafd-eb15fde535b2", + memory="user: Nate:Hey Joanna! Long time no see! What's up? Anything fun going on?\nassistant: Joanna:Hey Nate! Long time no see! I've been working on a project lately - it's been pretty cool. What about you - any fun projects or hobbies?\nuser: Nate:Hey Joanna! That's cool! I won my first video game tournament last week - so exciting!\nassistant: Joanna:Wow Nate! Congrats on winning! Tell me more - what game was it?\nuser: Nate:Thanks! it's a team shooter game.\nassistant: Joanna:Wow, great job! What was is called?\nuser: Nate:The game was called Counter-Strike: Global Offensive, and me and my team had a blast to the very end!\nassistant: Joanna:Cool, Nate! Sounds like a fun experience, even if I'm not into games.\nuser: Nate:It was! How about you? Do you have any hobbies you love?\nassistant: Joanna:Yeah! Besides writing, I also enjoy reading, watching movies, and exploring nature. Anything else you enjoy doing, Nate?\n", + metadata=TreeNodeTextualMemoryMetadata( + user_id="nate_test", + session_id="root_session", + status="activated", + type="fact", + key="user: Nate:Hey Joanna", + confidence=0.9900000095367432, + source=None, + tags=["mode:fast", "lang:en", "role:assistant", "role:user"], + visibility=None, + updated_at="2025-10-16T17:16:30.098208+08:00", + memory_type="LongTermMemory", + sources=[ + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=0, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=1, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=2, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=3, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=4, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=5, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=6, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=7, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=8, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=9, + ), + ], + embedding=None, + created_at="2025-10-16T17:16:30.098246+08:00", + usage=[], + background="", + ), + ), + ] + fine_memories = reader.fine_transfer_simple_mem(fast_mode_memories, type="chat") + print("\n--- Transfer Mode Results (first 3 items) ---") + for i, mem_list in enumerate(fine_memories[:3]): + for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list + print(f" [{i}][{j}] {mem_item.memory[:100]}...") + + # 7. Example of processing documents (only in fine mode) + print("\n=== Processing Documents (Fine Mode Only) ===") # Example document paths (you should replace these with actual document paths) doc_paths = [ "examples/mem_reader/text1.txt", "examples/mem_reader/text2.txt", ] - # 6. Acquiring memories from documents - doc_memory = reader.get_memory( - doc_paths, - "doc", - info={ - "user_id": "1111", - "session_id": "2222", - }, - ) - print("\nDocument Memory:\n", doc_memory) - end_time = time.time() - print(f"The runtime is {end_time - start_time} seconds.") + + try: + # 6. Acquiring memories from documents + doc_memory = reader.get_memory( + doc_paths, + "doc", + info={ + "user_id": "1111", + "session_id": "2222", + }, + mode="fine", + ) + print( + f"\n📄 Document Memory generated {sum(len(mem_list) for mem_list in doc_memory)} items" + ) + except Exception as e: + print(f"⚠️ Document processing failed: {e}") + print(" (This is expected if document files don't exist)") + + print("\n🎯 Summary:") + print(f" • Fast mode: {fast_time:.2f}s - Quick processing, no LLM calls") + print(f" • Fine mode: {fine_time:.2f}s - Full LLM processing for better understanding") + print(" • Use fast mode for: Real-time applications, high-throughput scenarios") + print(" • Use fine mode for: Quality analysis, detailed memory extraction") if __name__ == "__main__": diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index 4de0cf32b..080962482 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -28,7 +28,7 @@ def __init__(self, config: SentenceChunkerConfig): ) logger.info(f"Initialized SentenceChunker with config: {config}") - def chunk(self, text: str) -> list[Chunk]: + def chunk(self, text: str) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" chonkie_chunks = self.chunker.chunk(text) diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 39586081c..2d6155ec2 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -28,13 +28,11 @@ class BaseSchedulerConfig(BaseConfig): thread_pool_max_workers: int = Field( default=DEFAULT_THREAD_POOL_MAX_WORKERS, gt=1, - lt=20, description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD_POOL_MAX_WORKERS})", ) consume_interval_seconds: float = Field( default=DEFAULT_CONSUME_INTERVAL_SECONDS, gt=0, - le=60, description=f"Interval for consuming messages from queue in seconds (default: {DEFAULT_CONSUME_INTERVAL_SECONDS})", ) auth_config_path: str | None = Field( diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 237450e15..2c3a715f7 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -179,6 +179,11 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) + mode: str | None = Field( + default="sync", + description=("whether use asynchronous mode in memory add"), + ) + class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): """Simple tree text memory configuration class.""" diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 9a74373d7..12b493e58 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -440,20 +440,22 @@ def remove_oldest_memory( memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. """ - optional_condition = "" - - user_name = user_name if user_name else self.config.user_name - - optional_condition = f"AND n.user_name = '{user_name}'" - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.memory_type = '{memory_type}' - {optional_condition} - ORDER BY n.updated_at DESC - OFFSET {int(keep_latest)} - DETACH DELETE n - """ - self.execute_query(query) + try: + user_name = user_name if user_name else self.config.user_name + optional_condition = f"AND n.user_name = '{user_name}'" + count = self.count_nodes(memory_type, user_name) + if count > keep_latest: + delete_query = f""" + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) + WHERE n.memory_type = '{memory_type}' + {optional_condition} + ORDER BY n.updated_at DESC + OFFSET {int(keep_latest)} + DETACH DELETE n + """ + self.execute_query(delete_query) + except Exception as e: + logger.warning(f"Delete old mem error: {e}") @timed def add_node( @@ -1175,7 +1177,6 @@ def get_grouped_counts( MATCH (n /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {", ".join(return_fields)}, COUNT(n) AS count - GROUP BY {", ".join(group_by_fields)} """ result = self.execute_query(gql) # Pure GQL string execution @@ -1620,7 +1621,13 @@ def _create_basic_property_indexes(self) -> None: Create standard B-tree indexes on user_name when use Shared Database Multi-Tenant Mode. """ - fields = ["status", "memory_type", "created_at", "updated_at", "user_name"] + fields = [ + "status", + "memory_type", + "created_at", + "updated_at", + "user_name", + ] for field in fields: index_name = f"idx_memory_{field}" diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 55db60ed2..f51b3465d 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -669,7 +669,7 @@ def search_by_embedding( vector (list[float]): The embedding vector representing query semantics. top_k (int): Number of top similar nodes to retrieve. scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory'). - status (str, optional): Node status filter (e.g., 'active', 'archived'). + status (str, optional): Node status filter (e.g., 'activated', 'archived'). If provided, restricts results to nodes with matching status. threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 958cc140c..0010897c0 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -17,6 +17,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, + MEM_READ_LABEL, QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -70,6 +71,7 @@ def __init__(self, config: MOSConfig, user_manager: UserManager | None = None): if self.enable_mem_scheduler: self._mem_scheduler = self._initialize_mem_scheduler() self._mem_scheduler.mem_cubes = self.mem_cubes + self._mem_scheduler.mem_reader = self.mem_reader else: self._mem_scheduler: GeneralScheduler = None @@ -681,6 +683,12 @@ def add( logger.info( f"time add: get mem_cube_id check in mem_cubes time user_id: {target_user_id} time is: {time.time() - time_start_0}" ) + sync_mode = self.mem_cubes[mem_cube_id].text_mem.mode + if sync_mode == "async": + assert self.mem_scheduler is not None, ( + "Mem-Scheduler must be working when use asynchronous memory adding." + ) + logger.debug(f"Mem-reader mode is: {sync_mode}") time_start_1 = time.time() if ( (messages is not None) @@ -690,6 +698,7 @@ def add( logger.info( f"time add: messages is not None and enable_textual_memory and text_mem is not None time user_id: {target_user_id} time is: {time.time() - time_start_1}" ) + if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": add_memory = [] metadata = TextualMemoryMetadata( @@ -707,21 +716,30 @@ def add( messages_list, type="chat", info={"user_id": target_user_id, "session_id": target_session_id}, + mode="fast" if sync_mode == "async" else "fine", ) logger.info( f"time add: get mem_reader time user_id: {target_user_id} time is: {time.time() - time_start_2}" ) - mem_ids = [] - for mem in memories: - mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem) - mem_ids.extend(mem_id_list) - logger.info( - f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_id_list}" - ) - + memories_flatten = [m for m_list in memories for m in m_list] + mem_ids: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(memories_flatten) + logger.info( + f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_ids}" + ) # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: mem_cube = self.mem_cubes[mem_cube_id] + if sync_mode == "async": + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, @@ -749,10 +767,12 @@ def add( messages_list = [ [{"role": "user", "content": memory_content}] ] # for only user-str input and convert message + memories = self.mem_reader.get_memory( messages_list, type="chat", info={"user_id": target_user_id, "session_id": target_session_id}, + mode="fast" if sync_mode == "async" else "fine", ) mem_ids = [] @@ -766,6 +786,16 @@ def add( # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: mem_cube = self.mem_cubes[mem_cube_id] + if sync_mode == "async": + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, diff --git a/src/memos/mem_reader/base.py b/src/memos/mem_reader/base.py index f092c3870..3095a0bc6 100644 --- a/src/memos/mem_reader/base.py +++ b/src/memos/mem_reader/base.py @@ -18,10 +18,17 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: @abstractmethod def get_memory( - self, scene_data: list, type: str, info: dict[str, Any] + self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fast" ) -> list[list[TextualMemoryItem]]: """Various types of memories extracted from scene_data""" @abstractmethod def transform_memreader(self, data: dict) -> list[TextualMemoryItem]: """Transform the memory data into a list of TextualMemoryItem objects.""" + + @abstractmethod + def fine_transfer_simple_mem( + self, input_memories: list[list[TextualMemoryItem]], type: str + ) -> list[list[TextualMemoryItem]]: + """Fine Transform TextualMemoryItem List into another list of + TextualMemoryItem objects via calling llm to better understand users.""" diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index b439cb2b2..9f5eb9832 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -3,6 +3,7 @@ import json import os import re +import traceback from abc import ABC from typing import Any @@ -41,6 +42,26 @@ "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, } +try: + import tiktoken + + try: + _ENC = tiktoken.encoding_for_model("gpt-4o-mini") + except Exception: + _ENC = tiktoken.get_encoding("cl100k_base") + + def _count_tokens_text(s: str) -> int: + return len(_ENC.encode(s or "")) +except Exception: + # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars + def _count_tokens_text(s: str) -> int: + if not s: + return 0 + zh_chars = re.findall(r"[\u4e00-\u9fff]", s) + zh = len(zh_chars) + rest = len(s) - zh + return zh + max(1, rest // 4) + def detect_lang(text): try: @@ -112,6 +133,14 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder return None +def _derive_key(text: str, max_len: int = 80) -> str: + """default key when without LLM: first max_len words""" + if not text: + return "" + sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] + return (sent[:max_len]).strip() + + class SimpleStructMemReader(BaseMemReader, ABC): """Naive implementation of MemReader.""" @@ -126,27 +155,50 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.llm = LLMFactory.from_config(config.llm) self.embedder = EmbedderFactory.from_config(config.embedder) self.chunker = ChunkerFactory.from_config(config.chunker) + self.memory_max_length = 8000 + # Use token-based windowing; default to ~5000 tokens if not configured + self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024) + self._count_tokens = _count_tokens_text + + def _make_memory_item( + self, + value: str, + info: dict, + memory_type: str, + tags: list[str] | None = None, + key: str | None = None, + sources: list | None = None, + background: str = "", + type_: str = "fact", + confidence: float = 0.99, + ) -> TextualMemoryItem: + """construct memory item""" + return TextualMemoryItem( + memory=value, + metadata=TreeNodeTextualMemoryMetadata( + user_id=info.get("user_id", ""), + session_id=info.get("session_id", ""), + memory_type=memory_type, + status="activated", + tags=tags or [], + key=key if key is not None else _derive_key(value), + embedding=self.embedder.embed([value])[0], + usage=[], + sources=sources or [], + background=background, + confidence=confidence, + type=type_, + ), + ) - @timed - def _process_chat_data(self, scene_data_info, info): - mem_list = [] - for item in scene_data_info: - if "chat_time" in item: - mem = item["role"] + ": " + f"[{item['chat_time']}]: " + item["content"] - mem_list.append(mem) - else: - mem = item["role"] + ":" + item["content"] - mem_list.append(mem) - lang = detect_lang("\n".join(mem_list)) + def _get_llm_response(self, mem_str: str) -> dict: + lang = detect_lang(mem_str) template = PROMPT_DICT["chat"][lang] examples = PROMPT_DICT["chat"][f"{lang}_example"] - - prompt = template.replace("${conversation}", "\n".join(mem_list)) + prompt = template.replace("${conversation}", mem_str) if self.config.remove_prompt_example: prompt = prompt.replace(examples, "") - messages = [{"role": "user", "content": prompt}] - try: response_text = self.llm.generate(messages) response_json = self.parse_json_result(response_text) @@ -155,15 +207,111 @@ def _process_chat_data(self, scene_data_info, info): response_json = { "memory list": [ { - "key": "\n".join(mem_list)[:10], + "key": mem_str[:10], "memory_type": "UserMemory", - "value": "\n".join(mem_list), + "value": mem_str, "tags": [], } ], - "summary": "\n".join(mem_list), + "summary": mem_str, } + return response_json + def _iter_chat_windows(self, scene_data_info, max_tokens=None, overlap=200): + """ + use token counter to get a slide window generator + """ + max_tokens = max_tokens or self.chat_window_max_tokens + buf, sources, start_idx = [], [], 0 + cur_text = "" + + for idx, item in enumerate(scene_data_info): + role = item.get("role", "") + content = item.get("content", "") + chat_time = item.get("chat_time", None) + parts = [] + if role and str(role).lower() != "mix": + parts.append(f"{role}: ") + if chat_time: + parts.append(f"[{chat_time}]: ") + prefix = "".join(parts) + line = f"{prefix}{content}\n" + + if self._count_tokens(cur_text + line) > max_tokens and cur_text: + text = "".join(buf) + yield {"text": text, "sources": sources.copy(), "start_idx": start_idx} + while buf and self._count_tokens("".join(buf)) > overlap: + buf.pop(0) + sources.pop(0) + start_idx = idx + cur_text = "".join(buf) + + buf.append(line) + sources.append({"type": "chat", "index": idx, "role": role, "chat_time": chat_time}) + cur_text = "".join(buf) + + if buf: + yield {"text": "".join(buf), "sources": sources.copy(), "start_idx": start_idx} + + @timed + def _process_chat_data(self, scene_data_info, info, **kwargs): + mode = kwargs.get("mode", "fine") + windows = list(self._iter_chat_windows(scene_data_info)) + + if mode == "fast": + logger.debug("Using unified Fast Mode") + + def _build_fast_node(w): + text = w["text"] + roles = {s.get("role", "") for s in w["sources"] if s.get("role")} + mem_type = "UserMemory" if roles == {"user"} else "LongTermMemory" + tags = ["mode:fast"] + return self._make_memory_item( + value=text, info=info, memory_type=mem_type, tags=tags, sources=w["sources"] + ) + + with ContextThreadPoolExecutor(max_workers=8) as ex: + futures = {ex.submit(_build_fast_node, w): i for i, w in enumerate(windows)} + results = [None] * len(futures) + for fut in concurrent.futures.as_completed(futures): + i = futures[fut] + try: + node = fut.result() + if node: + results[i] = node + except Exception as e: + logger.error(f"[ChatFast] error: {e}") + chat_nodes = [r for r in results if r] + return chat_nodes + else: + logger.debug("Using unified Fine Mode") + chat_read_nodes = [] + for w in windows: + resp = self._get_llm_response(w["text"]) + for m in resp.get("memory list", []): + try: + memory_type = ( + m.get("memory_type", "LongTermMemory") + .replace("长期记忆", "LongTermMemory") + .replace("用户记忆", "UserMemory") + ) + node = self._make_memory_item( + value=m.get("value", ""), + info=info, + memory_type=memory_type, + tags=m.get("tags", []), + key=m.get("key", ""), + sources=w["sources"], + background=resp.get("summary", ""), + ) + chat_read_nodes.append(node) + except Exception as e: + logger.error(f"[ChatFine] parse error: {e}") + return chat_read_nodes + + def _process_transfer_chat_data(self, raw_node: TextualMemoryItem): + raw_memory = raw_node.memory + response_json = self._get_llm_response(raw_memory) chat_read_nodes = [] for memory_i_raw in response_json.get("memory list", []): try: @@ -172,28 +320,23 @@ def _process_chat_data(self, scene_data_info, info): .replace("长期记忆", "LongTermMemory") .replace("用户记忆", "UserMemory") ) - if memory_type not in ["LongTermMemory", "UserMemory"]: memory_type = "LongTermMemory" - - node_i = TextualMemoryItem( - memory=memory_i_raw.get("value", ""), - metadata=TreeNodeTextualMemoryMetadata( - user_id=info.get("user_id"), - session_id=info.get("session_id"), - memory_type=memory_type, - status="activated", - tags=memory_i_raw.get("tags", []) - if type(memory_i_raw.get("tags", [])) is list - else [], - key=memory_i_raw.get("key", ""), - embedding=self.embedder.embed([memory_i_raw.get("value", "")])[0], - usage=[], - sources=scene_data_info, - background=response_json.get("summary", ""), - confidence=0.99, - type="fact", - ), + node_i = self._make_memory_item( + value=memory_i_raw.get("value", ""), + info={ + "user_id": raw_node.metadata.user_id, + "session_id": raw_node.metadata.session_id, + }, + memory_type=memory_type, + tags=memory_i_raw.get("tags", []) + if isinstance(memory_i_raw.get("tags", []), list) + else [], + key=memory_i_raw.get("key", ""), + sources=raw_node.metadata.sources, + background=response_json.get("summary", ""), + type_="fact", + confidence=0.99, ) chat_read_nodes.append(node_i) except Exception as e: @@ -202,7 +345,7 @@ def _process_chat_data(self, scene_data_info, info): return chat_read_nodes def get_memory( - self, scene_data: list, type: str, info: dict[str, Any] + self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fine" ) -> list[list[TextualMemoryItem]]: """ Extract and classify memory content from scene_data. @@ -219,6 +362,8 @@ def get_memory( - topic_chunk_overlap: Overlap for large topic chunks (default: 100) - chunk_size: Size for small chunks (default: 256) - chunk_overlap: Overlap for small chunks (default: 50) + mode: mem-reader mode, fast for quick process while fine for + better understanding via calling llm Returns: list[list[TextualMemoryItem]] containing memory content with summaries as keys and original text as values Raises: @@ -253,13 +398,48 @@ def get_memory( # Process Q&A pairs concurrently with context propagation with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(processing_func, scene_data_info, info) + executor.submit(processing_func, scene_data_info, info, mode=mode) for scene_data_info in list_scene_data_info ] for future in concurrent.futures.as_completed(futures): - res_memory = future.result() - memory_list.append(res_memory) + try: + res_memory = future.result() + if res_memory is not None: + memory_list.append(res_memory) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + logger.error(traceback.format_exc()) + return memory_list + + def fine_transfer_simple_mem( + self, input_memories: list[TextualMemoryItem], type: str + ) -> list[list[TextualMemoryItem]]: + if not input_memories: + return [] + + memory_list = [] + if type == "chat": + processing_func = self._process_transfer_chat_data + elif type == "doc": + processing_func = self._process_transfer_doc_data + else: + processing_func = self._process_transfer_doc_data + + # Process Q&A pairs concurrently with context propagation + with ContextThreadPoolExecutor() as executor: + futures = [ + executor.submit(processing_func, scene_data_info) + for scene_data_info in input_memories + ] + for future in concurrent.futures.as_completed(futures): + try: + res_memory = future.result() + if res_memory is not None: + memory_list.append(res_memory) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + logger.error(traceback.format_exc()) return memory_list def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: @@ -275,13 +455,6 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: List of strings containing the processed scene data """ results = [] - parser_config = ParserConfigFactory.model_validate( - { - "backend": "markitdown", - "config": {}, - } - ) - parser = ParserFactory.from_config(parser_config) if type == "chat": for items in scene_data: @@ -299,6 +472,13 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: if result: results.append(result) elif type == "doc": + parser_config = ParserConfigFactory.model_validate( + { + "backend": "markitdown", + "config": {}, + } + ) + parser = ParserFactory.from_config(parser_config) for item in scene_data: try: if os.path.exists(item): @@ -317,6 +497,9 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: return results def _process_doc_data(self, scene_data_info, info, **kwargs): + mode = kwargs.get("mode", "fine") + if mode == "fast": + raise NotImplementedError chunks = self.chunker.chunk(scene_data_info["text"]) messages = [] for chunk in chunks: @@ -357,19 +540,48 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): logger.error(f"[DocReader] Future task failed: {e}") return doc_nodes - def parse_json_result(self, response_text): + def _process_transfer_doc_data(self, raw_node: TextualMemoryItem): + raise NotImplementedError + + def parse_json_result(self, response_text: str) -> dict: + s = (response_text or "").strip() + + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() + try: - json_start = response_text.find("{") - response_text = response_text[json_start:] - response_text = response_text.replace("```", "").strip() - if not response_text.endswith("}"): - response_text += "}" - return json.loads(response_text) + return json.loads(s) + except json.JSONDecodeError: + pass + + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass + + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t + + t = _cheap_close(s) + try: + return json.loads(t) except json.JSONDecodeError as e: - logger.error(f"[JSONParse] Failed to decode JSON: {e}\nRaw:\n{response_text}") - return {} - except Exception as e: - logger.error(f"[JSONParse] Unexpected error: {e}") + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + return json.loads(s) + logger.error( + f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ + json: {s}" + ) return {} def transform_memreader(self, data: dict) -> list[TextualMemoryItem]: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 4f8b0719b..1e8b042b1 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -76,6 +76,7 @@ def __init__(self, config: BaseSchedulerConfig): self.db_engine: Engine | None = None self.monitor: SchedulerGeneralMonitor | None = None self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None + self.mem_reader = None # Will be set by MOSCore self.dispatcher = SchedulerDispatcher( config=self.config, max_workers=self.thread_pool_max_workers, @@ -87,7 +88,7 @@ def __init__(self, config: BaseSchedulerConfig): # internal message queue self.max_internal_message_queue_size = self.config.get( - "max_internal_message_queue_size", 100 + "max_internal_message_queue_size", 10000 ) self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( maxsize=self.max_internal_message_queue_size @@ -138,12 +139,17 @@ def initialize_modules( self.dispatcher_monitor.start() # initialize with auth_config - if self.auth_config_path is not None and Path(self.auth_config_path).exists(): - self.auth_config = AuthConfig.from_local_config(config_path=self.auth_config_path) - elif AuthConfig.default_config_exists(): - self.auth_config = AuthConfig.from_local_config() - else: - self.auth_config = AuthConfig.from_local_env() + try: + if self.auth_config_path is not None and Path(self.auth_config_path).exists(): + self.auth_config = AuthConfig.from_local_config( + config_path=self.auth_config_path + ) + elif AuthConfig.default_config_exists(): + self.auth_config = AuthConfig.from_local_config() + else: + self.auth_config = AuthConfig.from_local_env() + except Exception: + pass if self.auth_config is not None: self.rabbitmq_config = self.auth_config.rabbitmq @@ -730,3 +736,139 @@ def _cleanup_queues(self) -> None: self._web_log_message_queue.get_nowait() except queue.Empty: pass + + def mem_scheduler_wait( + self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01 + ) -> bool: + """ + Uses EWMA throughput, detects leaked `unfinished_tasks`, and waits for dispatcher. + """ + deadline = time.monotonic() + timeout + + # --- helpers (local, no external deps) --- + def _unfinished() -> int: + """Prefer `unfinished_tasks`; fallback to `qsize()`.""" + try: + u = getattr(self.memos_message_queue, "unfinished_tasks", None) + if u is not None: + return int(u) + except Exception: + pass + try: + return int(self.memos_message_queue.qsize()) + except Exception: + return 0 + + def _fmt_eta(seconds: float | None) -> str: + """Format seconds to human-readable string.""" + if seconds is None or seconds != seconds or seconds == float("inf"): + return "unknown" + s = max(0, int(seconds)) + h, s = divmod(s, 3600) + m, s = divmod(s, 60) + if h > 0: + return f"{h:d}h{m:02d}m{s:02d}s" + if m > 0: + return f"{m:d}m{s:02d}s" + return f"{s:d}s" + + # --- EWMA throughput state (tasks/s) --- + alpha = 0.3 + rate = 0.0 + last_t = None # type: float | None + last_done = 0 + + # --- dynamic totals & stuck detection --- + init_unfinished = _unfinished() + done_total = 0 + last_unfinished = None + stuck_ticks = 0 + next_log = 0.0 + + while True: + # 1) read counters + curr_unfinished = _unfinished() + try: + qsz = int(self.memos_message_queue.qsize()) + except Exception: + qsz = -1 + + pend = run = 0 + stats_fn = getattr(self.dispatcher, "stats", None) + if self.enable_parallel_dispatch and self.dispatcher is not None and callable(stats_fn): + try: + st = ( + stats_fn() + ) # expected: {'pending':int,'running':int,'done':int?,'rate':float?} + pend = int(st.get("pending", 0)) + run = int(st.get("running", 0)) + except Exception: + pass + + # 2) dynamic total (allows new tasks queued while waiting) + total_now = max(init_unfinished, done_total + curr_unfinished) + done_total = max(0, total_now - curr_unfinished) + + # 3) update EWMA throughput + now = time.monotonic() + if last_t is None: + last_t = now + else: + dt = max(1e-6, now - last_t) + dc = max(0, done_total - last_done) + inst = dc / dt + rate = inst if rate == 0.0 else alpha * inst + (1 - alpha) * rate + last_t = now + last_done = done_total + + eta = None if rate <= 1e-9 else (curr_unfinished / rate) + + # 4) progress log (throttled) + if now >= next_log: + print( + f"[mem_scheduler_wait] remaining≈{curr_unfinished} | throughput≈{rate:.2f} msg/s | ETA≈{_fmt_eta(eta)} " + f"| qsize={qsz} pending={pend} running={run}" + ) + next_log = now + max(0.2, log_every) + + # 5) exit / stuck detection + idle_dispatcher = ( + (pend == 0 and run == 0) + if (self.enable_parallel_dispatch and self.dispatcher is not None) + else True + ) + if curr_unfinished == 0: + break + if curr_unfinished > 0 and qsz == 0 and idle_dispatcher: + if last_unfinished == curr_unfinished: + stuck_ticks += 1 + else: + stuck_ticks = 0 + else: + stuck_ticks = 0 + last_unfinished = curr_unfinished + + if stuck_ticks >= 3: + logger.warning( + "mem_scheduler_wait: detected leaked 'unfinished_tasks' -> treating queue as drained" + ) + break + + if now >= deadline: + logger.warning("mem_scheduler_wait: queue did not drain before timeout") + return False + + time.sleep(poll) + + # 6) wait dispatcher (second stage) + remaining = max(0.0, deadline - time.monotonic()) + if self.enable_parallel_dispatch and self.dispatcher is not None: + try: + ok = self.dispatcher.join(timeout=remaining if remaining > 0 else 0) + except TypeError: + ok = self.dispatcher.join() + if not ok: + logger.warning("mem_scheduler_wait: dispatcher did not complete before timeout") + return False + + return True diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py index 7dda25a29..6f05bf72f 100644 --- a/src/memos/mem_scheduler/general_modules/misc.py +++ b/src/memos/mem_scheduler/general_modules/misc.py @@ -205,7 +205,9 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non """Put an item into the queue. If the queue is full, the oldest item will be automatically removed to make space. - This operation is thread-safe. + IMPORTANT: When we drop an item we also call `task_done()` to keep + the internal `unfinished_tasks` counter consistent (the dropped task + will never be processed). Args: item: The item to be put into the queue @@ -216,19 +218,34 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non # First try non-blocking put super().put(item, block=block, timeout=timeout) except Full: + # Remove oldest item and mark it done to avoid leaking unfinished_tasks with suppress(Empty): - self.get_nowait() # Remove oldest item + _ = self.get_nowait() + # If the removed item had previously incremented unfinished_tasks, + # we must decrement here since it will never be processed. + with suppress(ValueError): + self.task_done() # Retry putting the new item super().put(item, block=block, timeout=timeout) def get_queue_content_without_pop(self) -> list[T]: """Return a copy of the queue's contents without modifying it.""" - return list(self.queue) + # Ensure a consistent snapshot by holding the mutex + with self.mutex: + return list(self.queue) def clear(self) -> None: """Remove all items from the queue. This operation is thread-safe. + IMPORTANT: We also decrement `unfinished_tasks` by the number of + items cleared, since those tasks will never be processed. """ with self.mutex: + dropped = len(self.queue) self.queue.clear() + # Call task_done() outside of the mutex to avoid deadlocks because + # Queue.task_done() acquires the same condition bound to `self.mutex`. + for _ in range(dropped): + with suppress(ValueError): + self.task_done() diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 25c7b78fd..f47cc0cc5 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -1,4 +1,6 @@ +import concurrent.futures import json +import traceback from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger @@ -8,6 +10,8 @@ ADD_LABEL, ANSWER_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, + MEM_ORGANIZE_LABEL, + MEM_READ_LABEL, QUERY_LABEL, WORKING_MEMORY_TYPE, MemCubeID, @@ -34,6 +38,8 @@ def __init__(self, config: GeneralSchedulerConfig): QUERY_LABEL: self._query_message_consumer, ANSWER_LABEL: self._answer_message_consumer, ADD_LABEL: self._add_message_consumer, + MEM_READ_LABEL: self._mem_read_message_consumer, + MEM_ORGANIZE_LABEL: self._mem_reorganize_message_consumer, } self.dispatcher.register_handlers(handlers) @@ -180,7 +186,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) + grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) self.validate_schedule_messages(messages=messages, label=ADD_LABEL) try: @@ -203,7 +209,15 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: mem_cube = msg.mem_cube for memory_id in userinput_memory_ids: - mem_item: TextualMemoryItem = mem_cube.text_mem.get(memory_id=memory_id) + try: + mem_item: TextualMemoryItem = mem_cube.text_mem.get( + memory_id=memory_id + ) + except Exception: + logger.warning( + f"This MemoryItem {memory_id} has already been deleted." + ) + continue mem_type = mem_item.metadata.memory_type mem_content = mem_item.memory @@ -222,6 +236,238 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: except Exception as e: logger.error(f"Error: {e}", exc_info=True) + def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_READ_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = message.mem_cube + content = message.content + + # Parse the memory IDs from content + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" + ) + + # Get the text memory from the mem_cube + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") + return + + # Use mem_reader to process the memories + self._process_memories_with_reader( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + text_mem=text_mem, + ) + + logger.info( + f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}" + ) + + except Exception as e: + logger.error(f"Error processing mem_read message: {e}", exc_info=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Thread task failed: {e}", exc_info=True) + + def _process_memories_with_reader( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + text_mem: TreeTextMemory, + ) -> None: + """ + Process memories using mem_reader for enhanced memory processing. + + Args: + mem_ids: List of memory IDs to process + user_id: User ID + mem_cube_id: Memory cube ID + mem_cube: Memory cube instance + text_mem: Text memory instance + """ + try: + # Get the mem_reader from the parent MOSCore + if not hasattr(self, "mem_reader") or self.mem_reader is None: + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) + return + + # Get the original memory items + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id) + memory_items.append(memory_item) + except Exception as e: + logger.warning(f"Failed to get memory {mem_id}: {e}") + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + # Use mem_reader to process the memories + logger.info(f"Processing {len(memory_items)} memories with mem_reader") + + # Extract memories using mem_reader + try: + processed_memories = self.mem_reader.fine_transfer_simple_mem( + memory_items, + type="chat", + ) + except Exception as e: + logger.warning(f"{e}: Fail to transfer mem: {memory_items}") + processed_memories = [] + + if processed_memories and len(processed_memories) > 0: + # Flatten the results (mem_reader returns list of lists) + flattened_memories = [] + for memory_list in processed_memories: + flattened_memories.extend(memory_list) + + logger.info(f"mem_reader processed {len(flattened_memories)} enhanced memories") + + # Add the enhanced memories back to the memory system + if flattened_memories: + enhanced_mem_ids = text_mem.add(flattened_memories) + logger.info( + f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" + ) + else: + logger.info("No enhanced memories generated by mem_reader") + else: + logger.info("mem_reader returned no processed memories") + + text_mem.delete(mem_ids) + logger.info("Delete raw mem_ids") + text_mem.memory_manager.remove_and_refresh_memory() + logger.info("Remove and Refresh Memories") + logger.debug(f"Finished add {user_id} memory: {mem_ids}") + + except Exception: + logger.error( + f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True + ) + + def _mem_reorganize_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_READ_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = message.mem_cube + content = message.content + + # Parse the memory IDs from content + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" + ) + + # Get the text memory from the mem_cube + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") + return + + # Use mem_reader to process the memories + self._process_memories_with_reorganize( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + text_mem=text_mem, + ) + + logger.info( + f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}" + ) + + except Exception as e: + logger.error(f"Error processing mem_read message: {e}", exc_info=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Thread task failed: {e}", exc_info=True) + + def _process_memories_with_reorganize( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + text_mem: TreeTextMemory, + ) -> None: + """ + Process memories using mem_reorganize for enhanced memory processing. + + Args: + mem_ids: List of memory IDs to process + user_id: User ID + mem_cube_id: Memory cube ID + mem_cube: Memory cube instance + text_mem: Text memory instance + """ + try: + # Get the mem_reader from the parent MOSCore + if not hasattr(self, "mem_reader") or self.mem_reader is None: + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) + return + + # Get the original memory items + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id) + memory_items.append(memory_item) + except Exception as e: + logger.warning(f"Failed to get memory {mem_id}: {e}") + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + # Use mem_reader to process the memories + logger.info(f"Processing {len(memory_items)} memories with mem_reader") + text_mem.memory_manager.remove_and_refresh_memory() + logger.info("Remove and Refresh Memories") + logger.debug(f"Finished add {user_id} memory: {mem_ids}") + + except Exception: + logger.error( + f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True + ) + def process_session_turn( self, queries: str | list[str], diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index d0d83091b..248c42e80 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -8,6 +8,8 @@ QUERY_LABEL = "query" ANSWER_LABEL = "answer" ADD_LABEL = "add" +MEM_READ_LABEL = "mem_read" +MEM_ORGANIZE_LABEL = "mem_organize" TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 82dad4486..8a6113345 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -10,6 +10,9 @@ class BaseTextMemory(BaseMemory): """Base class for all textual memory implementations.""" + # Default mode configuration - can be overridden by subclasses + mode: str = "sync" # Default mode: 'async' or 'sync' + @abstractmethod def __init__(self, config: BaseTextMemoryConfig): """Initialize memory with the given configuration.""" diff --git a/src/memos/memories/textual/general.py b/src/memos/memories/textual/general.py index 9793224b5..d71a86d2e 100644 --- a/src/memos/memories/textual/general.py +++ b/src/memos/memories/textual/general.py @@ -26,6 +26,8 @@ class GeneralTextMemory(BaseTextMemory): def __init__(self, config: GeneralTextMemoryConfig): """Initialize memory with the given configuration.""" + # Set mode from class default or override if needed + self.mode = getattr(self.__class__, "mode", "sync") self.config: GeneralTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm diff --git a/src/memos/memories/textual/naive.py b/src/memos/memories/textual/naive.py index f8684729a..7bc49e767 100644 --- a/src/memos/memories/textual/naive.py +++ b/src/memos/memories/textual/naive.py @@ -61,6 +61,8 @@ class NaiveTextMemory(BaseTextMemory): def __init__(self, config: NaiveTextMemoryConfig): """Initialize memory with the given configuration.""" + # Set mode from class default or override if needed + self.mode = getattr(self.__class__, "mode", "sync") self.config = config self.extractor_llm = LLMFactory.from_config(config.extractor_llm) self.memories = [] diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 0048f4a59..fccd83fa6 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -2,7 +2,6 @@ import os import shutil import tempfile -import time from datetime import datetime from pathlib import Path @@ -33,28 +32,17 @@ class TreeTextMemory(BaseTextMemory): def __init__(self, config: TreeTextMemoryConfig): """Initialize memory with the given configuration.""" - time_start = time.time() + # Set mode from class default or override if needed + self.mode = config.mode self.config: TreeTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm ) - logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") - - time_start_ex = time.time() self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.dispatcher_llm ) - logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}") - - time_start_em = time.time() self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) - logger.info(f"time init: embedder time is: {time.time() - time_start_em}") - - time_start_gs = time.time() self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db) - logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") - - time_start_rr = time.time() if config.reranker is None: default_cfg = RerankerConfigFactory.model_validate( { @@ -68,10 +56,7 @@ def __init__(self, config: TreeTextMemoryConfig): self.reranker = RerankerFactory.from_config(default_cfg) else: self.reranker = RerankerFactory.from_config(config.reranker) - logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") self.is_reorganize = config.reorganize - - time_start_mm = time.time() self.memory_manager: MemoryManager = MemoryManager( self.graph_store, self.embedder, @@ -84,8 +69,6 @@ def __init__(self, config: TreeTextMemoryConfig): }, is_reorganize=self.is_reorganize, ) - logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}") - time_start_ir = time.time() # Create internet retriever if configured self.internet_retriever = None if config.internet_retriever is not None: @@ -97,19 +80,13 @@ def __init__(self, config: TreeTextMemoryConfig): ) else: logger.info("No internet retriever configured") - logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") - def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: """Add memories. Args: memories: List of TextualMemoryItem objects or dictionaries to add. - Later: - memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] - metadata = extract_metadata(memory_items, self.extractor_llm) - plan = plan_memory_operations(memory_items, metadata, self.graph_store) - execute_plan(memory_items, metadata, plan, self.graph_store) """ - return self.memory_manager.add(memories) + return self.memory_manager.add(memories, mode=self.mode) def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: self.memory_manager.replace_working_memory(memories) @@ -294,7 +271,14 @@ def get_all(self) -> dict: return all_items def delete(self, memory_ids: list[str]) -> None: - raise NotImplementedError + """Hard delete: permanently remove nodes and their edges from the graph.""" + if not memory_ids: + return + for mid in memory_ids: + try: + self.graph_store.delete_node(mid) + except Exception as e: + logger.warning(f"TreeTextMemory.delete_hard: failed to delete {mid}: {e}") def delete_all(self) -> None: """Delete all memories and their relationships from the graph store.""" diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 3e1609cb7..54776134b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -52,13 +52,15 @@ def __init__( ) self._merged_threshold = merged_threshold - def add(self, memories: list[TextualMemoryItem], user_name: str | None = None) -> list[str]: + def add( + self, memories: list[TextualMemoryItem], user_name: str | None = None, mode: str = "sync" + ) -> list[str]: """ - Add new memories in parallel to different memory types (WorkingMemory, LongTermMemory, UserMemory). + Add new memories in parallel to different memory types. """ added_ids: list[str] = [] - with ContextThreadPoolExecutor(max_workers=8) as executor: + with ContextThreadPoolExecutor(max_workers=20) as executor: futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} for future in as_completed(futures, timeout=60): try: @@ -67,17 +69,18 @@ def add(self, memories: list[TextualMemoryItem], user_name: str | None = None) - except Exception as e: logger.exception("Memory processing error: ", exc_info=e) - for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: - try: - self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", - keep_latest=self.memory_size[mem_type], - user_name=user_name, - ) - except Exception: - logger.warning(f"Remove {mem_type} error: {traceback.format_exc()}") - - self._refresh_memory_size(user_name=user_name) + if mode == "sync": + for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + try: + self.graph_store.remove_oldest_memory( + memory_type="WorkingMemory", + keep_latest=self.memory_size[mem_type], + user_name=user_name, + ) + except Exception: + logger.warning(f"Remove {mem_type} error: {traceback.format_exc()}") + + self._refresh_memory_size(user_name=user_name) return added_ids def replace_working_memory( @@ -129,17 +132,29 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). This method runs asynchronously to process each memory item. """ - ids = [] - - # Add to WorkingMemory do not return working_id - self._add_memory_to_db(memory, "WorkingMemory", user_name) + ids: list[str] = [] + futures = [] + + with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: + f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name) + futures.append(f_working) + + if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"): + f_graph = ex.submit( + self._add_to_graph_memory, + memory=memory, + memory_type=memory.metadata.memory_type, + user_name=user_name, + ) + futures.append(f_graph) - # Add to LongTermMemory and UserMemory - if memory.metadata.memory_type in ["LongTermMemory", "UserMemory"]: - added_id = self._add_to_graph_memory( - memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name - ) - ids.append(added_id) + for fut in as_completed(futures): + try: + res = fut.result() + if isinstance(res, str) and res: + ids.append(res) + except Exception: + logger.warning("Parallel memory processing failed:\n%s", traceback.format_exc()) return ids @@ -157,7 +172,6 @@ def _add_memory_to_db( # Insert node into graph self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) - return working_memory.id def _add_to_graph_memory( self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None @@ -268,6 +282,31 @@ def _ensure_structure_path( # Step 3: Return this structure node ID as the parent_id return node_id + def remove_and_refresh_memory(self): + self._cleanup_memories_if_needed() + self._refresh_memory_size() + + def _cleanup_memories_if_needed(self) -> None: + """ + Only clean up memories if we're close to or over the limit. + This reduces unnecessary database operations. + """ + cleanup_threshold = 0.8 # Clean up when 80% full + + for memory_type, limit in self.memory_size.items(): + current_count = self.current_memory_size.get(memory_type, 0) + threshold = int(limit * cleanup_threshold) + + # Only clean up if we're at or above the threshold + if current_count >= threshold: + try: + self.graph_store.remove_oldest_memory( + memory_type=memory_type, keep_latest=limit + ) + logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}") + except Exception: + logger.warning(f"Remove {memory_type} error: {traceback.format_exc()}") + def wait_reorganizer(self): """ Wait for the reorganizer to finish processing all messages. diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index d4cfcf501..c1ade3021 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -197,6 +197,7 @@ def _vector_recall( memory_scope: str, top_k: int = 20, max_num: int = 3, + status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, user_name: str | None = None, @@ -213,6 +214,7 @@ def search_single(vec, filt=None): self.graph_store.search_by_embedding( vector=vec, top_k=top_k, + status=status, scope=memory_scope, cube_name=cube_name, search_filter=filt, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 05db56f53..96c6c97f1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -134,7 +134,11 @@ def _parse_task( related_nodes = [ self.graph_store.get_node(n["id"]) for n in self.graph_store.search_by_embedding( - query_embedding, top_k=top_k, search_filter=search_filter, user_name=user_name + query_embedding, + top_k=top_k, + status="activated", + search_filter=search_filter, + user_name=user_name, ) ] memories = [] diff --git a/tests/memories/textual/test_tree.py b/tests/memories/textual/test_tree.py index f3e662992..772a79d78 100644 --- a/tests/memories/textual/test_tree.py +++ b/tests/memories/textual/test_tree.py @@ -66,7 +66,7 @@ def test_add_calls_manager(mock_tree_text_memory): metadata=TreeNodeTextualMemoryMetadata(updated_at=None), ) mock_tree_text_memory.add([mock_item]) - mock_tree_text_memory.memory_manager.add.assert_called_once() + mock_tree_text_memory.memory_manager.add.assert_called_once_with([mock_item], mode="sync") def test_get_working_memory_sorted(mock_tree_text_memory): @@ -161,4 +161,4 @@ def test_add_returns_ids(mock_tree_text_memory): result = mock_tree_text_memory.add(mock_items) assert result == dummy_ids - mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items) + mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items, mode="sync") From 6efe419c7eecbc6fc224688d94ab6da5d4fbd2fb Mon Sep 17 00:00:00 2001 From: Hao <42795704+Nyakult@users.noreply.github.com> Date: Fri, 24 Oct 2025 14:35:15 +0800 Subject: [PATCH 03/64] add pm and pref eval scripts (#385) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: check nodes existence * feat: use different template for different language input * feat: use different template for different language input * fix: eval script * feat: memos-api eval scripts * feat: mem reader * feat: 实现äºprefeval memos-api evaluation scripts * refactor:format code * feat: add PersonaMem eval scripts * docs(evaluation): update PersonaMem eval readme * feat:memos-api ingest batch message * feat: refactor search * feat: refactor search * update: add api for memory * feat: add memory api return memory and memory type * refactor(server):重构服务器路由模块以优化内存管理 * format: ruff format code * feat(server): 增加LLM最大令牌数 * test * fix: user query embedding for search * count memory_size by user * fix(server):修复记忆读取逻辑中的列表展开问题 * feat(nebular):优化图数据库查询性能 * refactor(memory): - 移除了对 `_refresh_memory_size` 方法的调用- 保留原有逻辑以便后续恢复或重构 * feat: remove user idx_memory_user_name * feat(graph):优化Nebula图数据库查询性能 * feat: rollback remove_oldest_memory * feat:nebula gql add index * feat: align code * feat: update memos_api * feat: update memos_api * feat: 更新默认选项 * feat:memory client * feat:refactor lme * feat: memu & supermemory client * feat: locomo memu * feat: locomo supermemory * New 'add' and 'process' modes. * feat: lme supermemory & memu * feat: default args * api and local * api and local * memobase fix * memos fix * default args * fix memos-api search data * prefeval pipeline * fix lme memos-api * personamem pipeline * personamem pipeline * lme scrips * align dev * format code * refactor: remove old files * format code * pm and prefeval pipeline * format code * format code * pm and prefeval pipeline * pm and prefeval pipeline * pm and prefeval pipeline * format code * format code * pref pipeline * add search response mode * add search response mode * update readme and example * update mem0 api * pm mem0 * fix MEMOBASE api * update pm and prefeval pipepline for frames * update pm and prefeval readme * format code * fix memobase api * fix memobase api * format code * format code * fix format * fix format * fix format --------- Co-authored-by: 2Rant Co-authored-by: fridayL --- evaluation/.env-example | 21 +- evaluation/README.md | 21 +- evaluation/scripts/PrefEval/pref_eval.py | 46 ++- evaluation/scripts/PrefEval/pref_mem0.py | 295 ++++++++++++++++ evaluation/scripts/PrefEval/pref_memobase.py | 306 ++++++++++++++++ evaluation/scripts/PrefEval/pref_memos.py | 135 +++++-- evaluation/scripts/PrefEval/pref_memu.py | 301 ++++++++++++++++ .../scripts/PrefEval/pref_supermemory.py | 334 ++++++++++++++++++ evaluation/scripts/PrefEval/pref_zep.py | 307 ++++++++++++++++ evaluation/scripts/locomo/locomo_ingestion.py | 10 +- evaluation/scripts/locomo/locomo_search.py | 8 +- .../scripts/longmemeval/lme_ingestion.py | 6 +- evaluation/scripts/longmemeval/lme_search.py | 4 - evaluation/scripts/personamem/pm_ingestion.py | 166 +++++---- evaluation/scripts/personamem/pm_metric.py | 203 ++++++----- evaluation/scripts/personamem/pm_responses.py | 56 ++- evaluation/scripts/personamem/pm_search.py | 214 +++++------ evaluation/scripts/run_pm_eval.sh | 82 +++-- evaluation/scripts/run_prefeval_eval.sh | 102 ++++-- evaluation/scripts/utils/client.py | 154 ++++++-- evaluation/scripts/utils/mirix_utils.py | 81 +++++ 21 files changed, 2397 insertions(+), 455 deletions(-) create mode 100644 evaluation/scripts/PrefEval/pref_mem0.py create mode 100644 evaluation/scripts/PrefEval/pref_memobase.py create mode 100644 evaluation/scripts/PrefEval/pref_memu.py create mode 100644 evaluation/scripts/PrefEval/pref_supermemory.py create mode 100644 evaluation/scripts/PrefEval/pref_zep.py create mode 100644 evaluation/scripts/utils/mirix_utils.py diff --git a/evaluation/.env-example b/evaluation/.env-example index fc57344da..4b2b9311f 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -3,21 +3,28 @@ MODEL="gpt-4o-mini" OPENAI_API_KEY="sk-***REDACTED***" OPENAI_BASE_URL="http://***.***.***.***:3000/v1" -MEM0_API_KEY="m0-***REDACTED***" - -ZEP_API_KEY="z_***REDACTED***" # response model CHAT_MODEL="gpt-4o-mini" CHAT_MODEL_BASE_URL="http://***.***.***.***:3000/v1" CHAT_MODEL_API_KEY="sk-***REDACTED***" +# memos MEMOS_KEY="Token mpg-xxxxx" -MEMOS_URL="https://apigw-pre.memtensor.cn/api/openmem/v1" -PRE_SPLIT_CHUNK=false # pre split chunk in client end +MEMOS_URL="http://127.0.0.1:8001" +MEMOS_ONLINE_URL="https://memos.memtensor.cn/api/openmem/v1" + +# other memory agents +MEM0_API_KEY="m0-xxx" +ZEP_API_KEY="z_xxx" +MEMU_API_KEY="mu_xxx" +SUPERMEMORY_API_KEY="sm_xxx" +MEMOBASE_API_KEY="xxx" +MEMOBASE_PROJECT_URL="http://***.***.***.***:8019" + +# eval settings +PRE_SPLIT_CHUNK=false -MEMOBASE_API_KEY="xxxxx" -MEMOBASE_PROJECT_URL="http://xxx.xxx.xxx.xxx:8019" # Configuration Only For Scheduler # RabbitMQ Configuration diff --git a/evaluation/README.md b/evaluation/README.md index 16752c075..f0bd166e1 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -21,7 +21,14 @@ This repository provides tools and scripts for evaluating the LoCoMo dataset usi 2. Copy the `configs-example/` directory to a new directory named `configs/`, and modify the configuration files inside it as needed. This directory contains model and API-specific settings. +## Setup MemOS +```bash +#start server +uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8001 --workers 8 +# modify .env file +MEMOS_URL="http://127.0.0.1:8001" +``` ## Evaluation Scripts ### LoCoMo Evaluation @@ -45,10 +52,20 @@ First prepare the dataset `longmemeval_s` from https://huggingface.co/datasets/x ./scripts/run_lme_eval.sh ``` -### prefEval Evaluation +### PrefEval Evaluation +To evaluate the **Prefeval** dataset using one of the supported memory frameworks — `memos`, `mem0`, or `zep` — run the following [script](./scripts/run_prefeval_eval.sh): -### personaMem Evaluation +```bash +# Edit the configuration in ./scripts/run_prefeval_eval.sh +# Specify the model and memory backend you want to use (e.g., mem0, zep, etc.) +./scripts/run_prefeval_eval.sh +``` + +### PersonaMem Evaluation get `questions_32k.csv` and `shared_contexts_32k.jsonl` from https://huggingface.co/datasets/bowen-upenn/PersonaMem and save them at `data/personamem/` ```bash +# Edit the configuration in ./scripts/run_pm_eval.sh +# Specify the model and memory backend you want to use (e.g., mem0, zep, etc.) +# If you want to use MIRIX, edit the the configuration in ./scripts/personamem/config.yaml ./scripts/run_pm_eval.sh ``` diff --git a/evaluation/scripts/PrefEval/pref_eval.py b/evaluation/scripts/PrefEval/pref_eval.py index cd9c5dde2..10cf41bf3 100644 --- a/evaluation/scripts/PrefEval/pref_eval.py +++ b/evaluation/scripts/PrefEval/pref_eval.py @@ -15,10 +15,6 @@ API_KEY = os.getenv("OPENAI_API_KEY") API_URL = os.getenv("OPENAI_BASE_URL") -INPUT_FILE = "./results/prefeval/pref_memos_process.jsonl" -OUTPUT_FILE = "./results/prefeval/eval_pref_memos.jsonl" -OUTPUT_EXCEL_FILE = "./results/prefeval/eval_pref_memos_summary.xlsx" - async def call_gpt4o_mini_async(client: OpenAI, prompt: str) -> str: messages = [{"role": "user", "content": prompt}] @@ -255,9 +251,10 @@ def generate_excel_summary( avg_search_time: float, avg_context_tokens: float, avg_add_time: float, + output_excel_file: str, model_name: str = "gpt-4o-mini", ): - print(f"Generating Excel summary at {OUTPUT_EXCEL_FILE}...") + print(f"Generating Excel summary at {output_excel_file}...") def get_pct(key): return summary_results.get(key, {}).get("percentage", 0) @@ -282,7 +279,7 @@ def get_pct(key): df = pd.DataFrame(data) - with pd.ExcelWriter(OUTPUT_EXCEL_FILE, engine="xlsxwriter") as writer: + with pd.ExcelWriter(output_excel_file, engine="xlsxwriter") as writer: df.to_excel(writer, index=False, sheet_name="Summary") workbook = writer.book @@ -300,10 +297,10 @@ def get_pct(key): bold_pct_format = workbook.add_format({"num_format": "0.0%", "bold": True}) worksheet.set_column("F:F", 18, bold_pct_format) - print(f"Successfully saved summary to {OUTPUT_EXCEL_FILE}") + print(f"Successfully saved summary to {output_excel_file}") -async def main(concurrency_limit: int): +async def main(concurrency_limit: int, input_file: str, output_file: str, output_excel_file: str): semaphore = asyncio.Semaphore(concurrency_limit) error_counter = Counter() @@ -313,17 +310,17 @@ async def main(concurrency_limit: int): total_add_time = 0 print(f"Starting evaluation with a concurrency limit of {concurrency_limit}...") - print(f"Input file: {INPUT_FILE}") - print(f"Output JSONL: {OUTPUT_FILE}") - print(f"Output Excel: {OUTPUT_EXCEL_FILE}") + print(f"Input file: {input_file}") + print(f"Output JSONL: {output_file}") + print(f"Output Excel: {output_excel_file}") client = OpenAI(api_key=API_KEY, base_url=API_URL) try: - with open(INPUT_FILE, "r", encoding="utf-8") as f: + with open(input_file, "r", encoding="utf-8") as f: lines = f.readlines() except FileNotFoundError: - print(f"Error: Input file not found at '{INPUT_FILE}'") + print(f"Error: Input file not found at '{input_file}'") return if not lines: @@ -332,7 +329,7 @@ async def main(concurrency_limit: int): tasks = [process_line(line, client, semaphore) for line in lines] - with open(OUTPUT_FILE, "w", encoding="utf-8") as outfile: + with open(output_file, "w", encoding="utf-8") as outfile: pbar = tqdm( asyncio.as_completed(tasks), total=len(tasks), @@ -382,6 +379,7 @@ async def main(concurrency_limit: int): avg_search_time, avg_context_tokens, avg_add_time, + output_excel_file, ) except Exception as e: print(f"\nFailed to generate Excel file: {e}") @@ -389,6 +387,11 @@ async def main(concurrency_limit: int): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate assistant responses from a JSONL file.") + + parser.add_argument( + "--input", type=str, required=True, help="Path to the input JSONL file from pref_memos.py." + ) + parser.add_argument( "--concurrency-limit", type=int, @@ -397,4 +400,17 @@ async def main(concurrency_limit: int): ) args = parser.parse_args() - asyncio.run(main(concurrency_limit=args.concurrency_limit)) + input_path = args.input + output_dir = os.path.dirname(input_path) + + output_jsonl_path = os.path.join(output_dir, "eval_pref_memos.jsonl") + output_excel_path = os.path.join(output_dir, "eval_pref_memos_summary.xlsx") + + asyncio.run( + main( + concurrency_limit=args.concurrency_limit, + input_file=input_path, + output_file=output_jsonl_path, + output_excel_file=output_excel_path, + ) + ) diff --git a/evaluation/scripts/PrefEval/pref_mem0.py b/evaluation/scripts/PrefEval/pref_mem0.py new file mode 100644 index 000000000..416d8045f --- /dev/null +++ b/evaluation/scripts/PrefEval/pref_mem0.py @@ -0,0 +1,295 @@ +import argparse +import concurrent.futures +import json +import os +import sys +import time +import tiktoken +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + +from irrelevant_conv import irre_10, irre_300 + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) +load_dotenv() +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +BASE_URL = os.getenv("OPENAI_BASE_URL") +MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") +tokenizer = tiktoken.get_encoding("cl100k_base") +os.environ["MEM0_API_KEY"] = os.getenv("MEM0_API_KEY") + + +def add_memory_for_line( + line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str +) -> dict: + """ + Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. + """ + i, line = line_data + user_id = f"{lib}_user_pref_eval_{i}_{version}" + + try: + original_data = json.loads(line) + conversation = original_data.get("conversation", []) + + if num_irrelevant_turns == 10: + conversation = conversation + irre_10 + elif num_irrelevant_turns == 300: + conversation = conversation + irre_300 + + turns_add = 5 + start_time_add = time.monotonic() + if conversation: + for chunk_start in range(0, len(conversation), turns_add * 2): + chunk = conversation[chunk_start : chunk_start + turns_add * 2] + timestamp_add = int(time.time() * 100) + mem_client.add(messages=chunk, user_id=user_id, timestamp=timestamp_add) + end_time_add = time.monotonic() + add_duration = end_time_add - start_time_add + + original_data["user_id"] = user_id + original_data["metrics"] = {"add_memories_duration_seconds": add_duration} + return original_data + + except Exception as e: + print(f"Error adding memory for line {i + 1} (user_id: {user_id}): {e}") + return None + + +def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict: + """ + Processes a single line of data, searching memory based on the question. + """ + i, line = line_data + try: + original_data = json.loads(line) + + user_id = original_data.get("user_id") + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + + if not user_id: + original_data["error"] = ( + "Error: user_id not found in this line. Please run 'add' mode first." + ) + return original_data + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + start_time_search = time.monotonic() + relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) + search_memories_duration = time.monotonic() - start_time_search + memory_list = relevant_memories.get("results", []) + memories_str = "\n".join(f"- {entry['memory']}" for entry in memory_list) + + memory_tokens_used = len(tokenizer.encode(memories_str)) + + metrics_dict.update( + { + "search_memories_duration_seconds": search_memories_duration, + "memory_tokens_used": memory_tokens_used, + "retrieved_memories_text": memories_str, + } + ) + original_data["metrics"] = metrics_dict + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error searching memory for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: + """ + Generates a response for a single line of data using pre-fetched memories. + """ + i, line = line_data + try: + original_data = json.loads(line) + + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + memories_str = metrics_dict.get("retrieved_memories_text") + + # If an error occurred in 'add' or 'search' mode, just pass the line through + if original_data.get("error"): + return original_data + + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + # Check for None, as an empty string (no memories found) is a valid result + if memories_str is None: + original_data["error"] = ( + "Error: retrieved_memories_text not found in metrics. " + "Please run 'search' mode first." + ) + return original_data + + system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + + response = openai_client.chat.completions.create(model=MODEL_NAME, messages=messages) + assistant_response = response.choices[0].message.content + original_data["response"] = assistant_response + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error generating response for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Process conversations with MemOS. Run 'add', then 'search', then 'response'." + ) + parser.add_argument( + "mode", + choices=["add", "search", "response"], + help="The mode to run the script in ('add', 'search', or 'response').", + ) + parser.add_argument("--input", required=True, help="Path to the input JSONL file.") + parser.add_argument("--output", required=True, help="Path to the output JSONL file.") + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Number of memories to retrieve (used in 'search' mode).", + ) + parser.add_argument( + "--add-turn", + type=int, + choices=[0, 10, 300], + default=0, + help="Number of irrelevant turns to add (used in 'add' mode).", + ) + parser.add_argument( + "--lib", + type=str, + choices=["mem0", "mem0_graph"], + default="mem0", + help="Which Mem0 library to use (used in 'add' mode).", + ) + parser.add_argument( + "--version", + type=str, + default="0929-1", + help="Version identifier for user_id generation (used in 'add' mode).", + ) + parser.add_argument( + "--max-workers", type=int, default=20, help="Maximum number of concurrent workers." + ) + + args = parser.parse_args() + + try: + with open(args.input, "r", encoding="utf-8") as infile: + lines = infile.readlines() + except FileNotFoundError: + print(f"Error: Input file '{args.input}' not found") + return + + from utils.client import Mem0Client + + mem_client = Mem0Client(enable_graph="graph" in args.lib) + + if args.mode == "add": + print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") + print(f"Adding {args.add_turn} irrelevant turns.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit( + add_memory_for_line, + (i, line), + mem_client, + args.add_turn, + args.lib, + args.version, + ) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Adding memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'add' mode complete! Data with user_id written to '{args.output}'.") + + elif args.mode == "search": + print(f"Running in 'search' mode. Searching memories based on '{args.input}'...") + print(f"Retrieving top {args.top_k} memories for each query.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Searching memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print( + f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'." + ) + + elif args.mode == "response": + print(f"Running in 'response' mode. Generating responses based on '{args.input}'...") + print(f"Using {args.max_workers} workers.") + openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(generate_response_for_line, (i, line), openai_client) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Generating responses...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'response' mode complete! Final results written to '{args.output}'.") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/PrefEval/pref_memobase.py b/evaluation/scripts/PrefEval/pref_memobase.py new file mode 100644 index 000000000..34d3ea86f --- /dev/null +++ b/evaluation/scripts/PrefEval/pref_memobase.py @@ -0,0 +1,306 @@ +import argparse +import concurrent.futures +import json +import os +import sys +import time +import tiktoken +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm +import time +from irrelevant_conv import irre_10, irre_300 + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) +load_dotenv() +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +BASE_URL = os.getenv("OPENAI_BASE_URL") +MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") +tokenizer = tiktoken.get_encoding("cl100k_base") + + +def add_memory_for_line( + line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str +) -> dict: + """ + Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. + """ + i, line = line_data + user_id = f"{lib}_user_pref_eval_{i}_{version}" + mem_client.delete_user(user_id) + user_id = mem_client.client.add_user({"user_id": user_id}) + print("user_id:", user_id) + try: + original_data = json.loads(line) + conversation = original_data.get("conversation", []) + + if num_irrelevant_turns == 10: + conversation = conversation + irre_10 + elif num_irrelevant_turns == 300: + conversation = conversation + irre_300 + + start_time_add = time.monotonic() + if conversation: + messages = [] + + for chunk_start in range(0, len(conversation)): + chunk = conversation[chunk_start : chunk_start + 1] + timestamp_add = str(int(time.time() * 100)) + time.sleep(0.001) # Ensure unique timestamp + + messages.append( + { + "role": chunk[0]["role"], + "content": chunk[0]["content"][:8000], + "created_at": timestamp_add, + } + ) + mem_client.add(messages=messages, user_id=user_id) + + end_time_add = time.monotonic() + add_duration = end_time_add - start_time_add + + original_data["user_id"] = user_id + original_data["metrics"] = {"add_memories_duration_seconds": add_duration} + return original_data + + except Exception as e: + print(f"Error adding memory for line {i + 1} (user_id: {user_id}): {e}") + return None + + +def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict: + """ + Processes a single line of data, searching memory based on the question. + """ + i, line = line_data + try: + original_data = json.loads(line) + + user_id = original_data.get("user_id") + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + + if not user_id: + original_data["error"] = ( + "Error: user_id not found in this line. Please run 'add' mode first." + ) + return original_data + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + start_time_search = time.monotonic() + relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) + search_memories_duration = time.monotonic() - start_time_search + memories_str = relevant_memories + + memory_tokens_used = len(tokenizer.encode(memories_str)) + + metrics_dict.update( + { + "search_memories_duration_seconds": search_memories_duration, + "memory_tokens_used": memory_tokens_used, + "retrieved_memories_text": memories_str, + } + ) + original_data["metrics"] = metrics_dict + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error searching memory for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: + """ + Generates a response for a single line of data using pre-fetched memories. + """ + i, line = line_data + try: + original_data = json.loads(line) + + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + memories_str = metrics_dict.get("retrieved_memories_text") + + # If an error occurred in 'add' or 'search' mode, just pass the line through + if original_data.get("error"): + return original_data + + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + # Check for None, as an empty string (no memories found) is a valid result + if memories_str is None: + original_data["error"] = ( + "Error: retrieved_memories_text not found in metrics. " + "Please run 'search' mode first." + ) + return original_data + + system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + + response = openai_client.chat.completions.create(model=MODEL_NAME, messages=messages) + assistant_response = response.choices[0].message.content + original_data["response"] = assistant_response + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error generating response for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Process conversations with MemOS. Run 'add', then 'search', then 'response'." + ) + parser.add_argument( + "mode", + choices=["add", "search", "response"], + help="The mode to run the script in ('add', 'search', or 'response').", + ) + parser.add_argument("--input", required=True, help="Path to the input JSONL file.") + parser.add_argument("--output", required=True, help="Path to the output JSONL file.") + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Number of memories to retrieve (used in 'search' mode).", + ) + parser.add_argument( + "--add-turn", + type=int, + choices=[0, 10, 300], + default=0, + help="Number of irrelevant turns to add (used in 'add' mode).", + ) + parser.add_argument( + "--lib", + type=str, + choices=["memobase"], + default="memobase", + help="Which Memobase library to use (used in 'add' mode).", + ) + parser.add_argument( + "--version", + type=str, + default="0929-1", + help="Version identifier for user_id generation (used in 'add' mode).", + ) + parser.add_argument( + "--max-workers", type=int, default=20, help="Maximum number of concurrent workers." + ) + + args = parser.parse_args() + + try: + with open(args.input, "r", encoding="utf-8") as infile: + lines = infile.readlines() + except FileNotFoundError: + print(f"Error: Input file '{args.input}' not found") + return + + from utils.client import MemobaseClient + + mem_client = MemobaseClient() + + if args.mode == "add": + print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") + print(f"Adding {args.add_turn} irrelevant turns.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit( + add_memory_for_line, + (i, line), + mem_client, + args.add_turn, + args.lib, + args.version, + ) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Adding memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'add' mode complete! Data with user_id written to '{args.output}'.") + + elif args.mode == "search": + print(f"Running in 'search' mode. Searching memories based on '{args.input}'...") + print(f"Retrieving top {args.top_k} memories for each query.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Searching memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print( + f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'." + ) + + elif args.mode == "response": + print(f"Running in 'response' mode. Generating responses based on '{args.input}'...") + print(f"Using {args.max_workers} workers.") + openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(generate_response_for_line, (i, line), openai_client) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Generating responses...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'response' mode complete! Final results written to '{args.output}'.") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index d1a901dd2..5ee064b1f 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -64,11 +64,9 @@ def add_memory_for_line( return None -def process_line_with_id( - line_data: tuple, mem_client, openai_client: OpenAI, top_k_value: int, lib: str, version: str -) -> dict: +def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict: """ - Processes a single line of data using a pre-existing user_id, searching memory and generating a response. + Processes a single line of data, searching memory based on the question. """ i, line = line_data try: @@ -79,12 +77,12 @@ def process_line_with_id( metrics_dict = original_data.get("metrics", {}) if not user_id: - original_data["response"] = ( + original_data["error"] = ( "Error: user_id not found in this line. Please run 'add' mode first." ) return original_data if not question: - original_data["response"] = "Question not found in this line." + original_data["error"] = "Question not found in this line." return original_data start_time_search = time.monotonic() @@ -96,6 +94,51 @@ def process_line_with_id( memory_tokens_used = len(tokenizer.encode(memories_str)) + metrics_dict.update( + { + "search_memories_duration_seconds": search_memories_duration, + "memory_tokens_used": memory_tokens_used, + "retrieved_memories_text": memories_str, + } + ) + original_data["metrics"] = metrics_dict + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error searching memory for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: + """ + Generates a response for a single line of data using pre-fetched memories. + """ + i, line = line_data + try: + original_data = json.loads(line) + + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + memories_str = metrics_dict.get("retrieved_memories_text") + + # If an error occurred in 'add' or 'search' mode, just pass the line through + if original_data.get("error"): + return original_data + + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + # Check for None, as an empty string (no memories found) is a valid result + if memories_str is None: + original_data["error"] = ( + "Error: retrieved_memories_text not found in metrics. " + "Please run 'search' mode first." + ) + return original_data + system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" messages = [ {"role": "system", "content": system_prompt}, @@ -106,51 +149,50 @@ def process_line_with_id( assistant_response = response.choices[0].message.content original_data["response"] = assistant_response - metrics_dict.update( - { - "search_memories_duration_seconds": search_memories_duration, - "memory_tokens_used": memory_tokens_used, - "retrieved_memories_text": memories_str, - } - ) - original_data["metrics"] = metrics_dict - return original_data except Exception as e: user_id_from_data = json.loads(line).get("user_id", "N/A") - print(f"Error processing line {i + 1} (user_id: {user_id_from_data}): {e}") + print(f"Error generating response for line {i + 1} (user_id: {user_id_from_data}): {e}") return None def main(): parser = argparse.ArgumentParser( - description="Process conversations with MemOS. Run 'add' mode first, then 'process' mode." + description="Process conversations with MemOS. Run 'add', then 'search', then 'response'." ) parser.add_argument( "mode", - choices=["add", "process"], - help="The mode to run the script in ('add' or 'process').", + choices=["add", "search", "response"], + help="The mode to run the script in ('add', 'search', or 'response').", ) parser.add_argument("--input", required=True, help="Path to the input JSONL file.") parser.add_argument("--output", required=True, help="Path to the output JSONL file.") - parser.add_argument("--top-k", type=int, default=10, help="Number of memories to retrieve.") + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Number of memories to retrieve (used in 'search' mode).", + ) parser.add_argument( "--add-turn", type=int, choices=[0, 10, 300], default=0, - help="Number of irrelevant turns to add (0, 10, or 300).", + help="Number of irrelevant turns to add (used in 'add' mode).", ) parser.add_argument( "--lib", type=str, choices=["memos-api", "memos-local"], default="memos-api", - help="Which MemOS library to use.", + help="Which MemOS library to use (used in 'add' mode).", ) parser.add_argument( - "--version", type=str, default="0929-1", help="Version identifier for user_id generation." + "--version", + type=str, + default="0929-1", + help="Version identifier for user_id generation (used in 'add' mode).", ) parser.add_argument( "--max-workers", type=int, default=20, help="Maximum number of concurrent workers." @@ -165,9 +207,9 @@ def main(): print(f"Error: Input file '{args.input}' not found") return - from utils.client import memosApiClient + from utils.client import MemosApiClient - mem_client = memosApiClient() + mem_client = MemosApiClient() if args.mode == "add": print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") @@ -200,38 +242,55 @@ def main(): outfile.write(json.dumps(result, ensure_ascii=False) + "\n") print(f"\n'add' mode complete! Data with user_id written to '{args.output}'.") - elif args.mode == "process": - print(f"Running in 'process' mode. Processing questions from '{args.input}'...") + elif args.mode == "search": + print(f"Running in 'search' mode. Searching memories based on '{args.input}'...") print(f"Retrieving top {args.top_k} memories for each query.") print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Searching memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print( + f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'." + ) + + elif args.mode == "response": + print(f"Running in 'response' mode. Generating responses based on '{args.input}'...") + print(f"Using {args.max_workers} workers.") openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) with ( open(args.output, "w", encoding="utf-8") as outfile, concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, ): futures = [ - executor.submit( - process_line_with_id, - (i, line), - mem_client, - openai_client, - args.top_k, - args.lib, - args.version, - ) + executor.submit(generate_response_for_line, (i, line), openai_client) for i, line in enumerate(lines) ] pbar = tqdm( concurrent.futures.as_completed(futures), total=len(lines), - desc="Processing questions...", + desc="Generating responses...", ) for future in pbar: result = future.result() if result: outfile.write(json.dumps(result, ensure_ascii=False) + "\n") - print(f"\n'process' mode complete! Final results written to '{args.output}'.") + print(f"\n'response' mode complete! Final results written to '{args.output}'.") if __name__ == "__main__": diff --git a/evaluation/scripts/PrefEval/pref_memu.py b/evaluation/scripts/PrefEval/pref_memu.py new file mode 100644 index 000000000..719f2b488 --- /dev/null +++ b/evaluation/scripts/PrefEval/pref_memu.py @@ -0,0 +1,301 @@ +import argparse +import concurrent.futures +import json +import os +import sys +import time +import tiktoken +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm +from datetime import datetime +from irrelevant_conv import irre_10, irre_300 + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) +load_dotenv() +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +BASE_URL = os.getenv("OPENAI_BASE_URL") +MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") +tokenizer = tiktoken.get_encoding("cl100k_base") + + +def add_memory_for_line( + line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str +) -> dict: + """ + Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. + """ + i, line = line_data + user_id = f"{lib}_user_pref_eval_{i}_{version}" + + try: + original_data = json.loads(line) + conversation = original_data.get("conversation", []) + + if num_irrelevant_turns == 10: + conversation = conversation + irre_10 + elif num_irrelevant_turns == 300: + conversation = conversation + irre_300 + + turns_add = 5 + start_time_add = time.monotonic() + if conversation: + if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true": + for chunk_start in range(0, len(conversation), turns_add * 2): + chunk = conversation[chunk_start : chunk_start + turns_add * 2] + mem_client.add( + messages=chunk, user_id=user_id, iso_date=datetime.now().isoformat() + ) + else: + mem_client.add( + messages=conversation, user_id=user_id, iso_date=datetime.now().isoformat() + ) + end_time_add = time.monotonic() + add_duration = end_time_add - start_time_add + + original_data["user_id"] = user_id + original_data["metrics"] = {"add_memories_duration_seconds": add_duration} + return original_data + + except Exception as e: + print(f"Error adding memory for line {i + 1} (user_id: {user_id}): {e}") + return None + + +def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict: + """ + Processes a single line of data, searching memory based on the question. + """ + i, line = line_data + try: + original_data = json.loads(line) + + user_id = original_data.get("user_id") + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + + if not user_id: + original_data["error"] = ( + "Error: user_id not found in this line. Please run 'add' mode first." + ) + return original_data + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + start_time_search = time.monotonic() + relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) + search_memories_duration = time.monotonic() - start_time_search + memories_str = "\n".join( + f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] + ) + + memory_tokens_used = len(tokenizer.encode(memories_str)) + + metrics_dict.update( + { + "search_memories_duration_seconds": search_memories_duration, + "memory_tokens_used": memory_tokens_used, + "retrieved_memories_text": memories_str, + } + ) + original_data["metrics"] = metrics_dict + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error searching memory for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: + """ + Generates a response for a single line of data using pre-fetched memories. + """ + i, line = line_data + try: + original_data = json.loads(line) + + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + memories_str = metrics_dict.get("retrieved_memories_text") + + # If an error occurred in 'add' or 'search' mode, just pass the line through + if original_data.get("error"): + return original_data + + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + # Check for None, as an empty string (no memories found) is a valid result + if memories_str is None: + original_data["error"] = ( + "Error: retrieved_memories_text not found in metrics. " + "Please run 'search' mode first." + ) + return original_data + + system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + + response = openai_client.chat.completions.create(model=MODEL_NAME, messages=messages) + assistant_response = response.choices[0].message.content + original_data["response"] = assistant_response + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error generating response for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Process conversations with MemOS. Run 'add', then 'search', then 'response'." + ) + parser.add_argument( + "mode", + choices=["add", "search", "response"], + help="The mode to run the script in ('add', 'search', or 'response').", + ) + parser.add_argument("--input", required=True, help="Path to the input JSONL file.") + parser.add_argument("--output", required=True, help="Path to the output JSONL file.") + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Number of memories to retrieve (used in 'search' mode).", + ) + parser.add_argument( + "--add-turn", + type=int, + choices=[0, 10, 300], + default=0, + help="Number of irrelevant turns to add (used in 'add' mode).", + ) + parser.add_argument( + "--lib", + type=str, + choices=["memu"], + default="memu", + help="Which Memu library to use (used in 'add' mode).", + ) + parser.add_argument( + "--version", + type=str, + default="0929-1", + help="Version identifier for user_id generation (used in 'add' mode).", + ) + parser.add_argument( + "--max-workers", type=int, default=20, help="Maximum number of concurrent workers." + ) + + args = parser.parse_args() + + try: + with open(args.input, "r", encoding="utf-8") as infile: + lines = infile.readlines() + except FileNotFoundError: + print(f"Error: Input file '{args.input}' not found") + return + + from utils.client import MemuClient + + mem_client = MemuClient() + + if args.mode == "add": + print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") + print(f"Adding {args.add_turn} irrelevant turns.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit( + add_memory_for_line, + (i, line), + mem_client, + args.add_turn, + args.lib, + args.version, + ) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Adding memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'add' mode complete! Data with user_id written to '{args.output}'.") + + elif args.mode == "search": + print(f"Running in 'search' mode. Searching memories based on '{args.input}'...") + print(f"Retrieving top {args.top_k} memories for each query.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Searching memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print( + f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'." + ) + + elif args.mode == "response": + print(f"Running in 'response' mode. Generating responses based on '{args.input}'...") + print(f"Using {args.max_workers} workers.") + openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(generate_response_for_line, (i, line), openai_client) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Generating responses...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'response' mode complete! Final results written to '{args.output}'.") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/PrefEval/pref_supermemory.py b/evaluation/scripts/PrefEval/pref_supermemory.py new file mode 100644 index 000000000..85e84b6c9 --- /dev/null +++ b/evaluation/scripts/PrefEval/pref_supermemory.py @@ -0,0 +1,334 @@ +import argparse +import concurrent.futures +import json +import os +import sys +import time +import tiktoken +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm +from datetime import datetime +from irrelevant_conv import irre_10, irre_300 + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) +load_dotenv() +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +BASE_URL = os.getenv("OPENAI_BASE_URL") +MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") +tokenizer = tiktoken.get_encoding("cl100k_base") + + +def add_memory_for_line( + line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str +) -> dict: + """ + Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. + """ + i, line = line_data + user_id = f"{lib}_user_pref_eval_{i}_{version}" + + try: + original_data = json.loads(line) + conversation = original_data.get("conversation", []) + + if num_irrelevant_turns == 10: + conversation = conversation + irre_10 + elif num_irrelevant_turns == 300: + conversation = conversation + irre_300 + + turns_add = 5 + start_time_add = time.monotonic() + if conversation: + if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true": + for chunk_start in range(0, len(conversation), turns_add * 2): + chunk = conversation[chunk_start : chunk_start + turns_add * 2] + mem_client.add(messages=chunk, user_id=user_id) + else: + mem_client.add(messages=conversation, user_id=user_id) + end_time_add = time.monotonic() + add_duration = end_time_add - start_time_add + + original_data["user_id"] = user_id + original_data["metrics"] = {"add_memories_duration_seconds": add_duration} + return original_data + + except Exception as e: + print(f"Error adding memory for line {i + 1} (user_id: {user_id}): {e}") + return None + + +def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict: + """ + Processes a single line of data, searching memory based on the question. + """ + i, line = line_data + try: + original_data = json.loads(line) + + user_id = original_data.get("user_id") + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + + if not user_id: + original_data["error"] = ( + "Error: user_id not found in this line. Please run 'add' mode first." + ) + return original_data + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + start_time_search = time.monotonic() + relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) + search_memories_duration = time.monotonic() - start_time_search + memories_str = "\n".join( + f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] + ) + + memory_tokens_used = len(tokenizer.encode(memories_str)) + + metrics_dict.update( + { + "search_memories_duration_seconds": search_memories_duration, + "memory_tokens_used": memory_tokens_used, + "retrieved_memories_text": memories_str, + } + ) + original_data["metrics"] = metrics_dict + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error searching memory for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: + """ + Generates a response for a single line of data using pre-fetched memories. + """ + i, line = line_data + try: + original_data = json.loads(line) + + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + memories_str = metrics_dict.get("retrieved_memories_text") + + # If an error occurred in 'add' or 'search' mode, just pass the line through + if original_data.get("error"): + return original_data + + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + # Check for None, as an empty string (no memories found) is a valid result + if memories_str is None: + original_data["error"] = ( + "Error: retrieved_memories_text not found in metrics. " + "Please run 'search' mode first." + ) + return original_data + + system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + + response = openai_client.chat.completions.create(model=MODEL_NAME, messages=messages) + assistant_response = response.choices[0].message.content + original_data["response"] = assistant_response + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error generating response for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Process conversations with MemOS. Run 'add', then 'search', then 'response'." + ) + parser.add_argument( + "mode", + choices=["add", "search", "response"], + help="The mode to run the script in ('add', 'search', or 'response').", + ) + parser.add_argument("--input", required=True, help="Path to the input JSONL file.") + parser.add_argument("--output", required=True, help="Path to the output JSONL file.") + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Number of memories to retrieve (used in 'search' mode).", + ) + parser.add_argument( + "--add-turn", + type=int, + choices=[0, 10, 300], + default=0, + help="Number of irrelevant turns to add (used in 'add' mode).", + ) + parser.add_argument( + "--lib", + type=str, + choices=["supermemory"], + default="supermemory", + help="Which Supermemory library to use (used in 'add' mode).", + ) + parser.add_argument( + "--version", + type=str, + default="0929-1", + help="Version identifier for user_id generation (used in 'add' mode).", + ) + parser.add_argument( + "--max-workers", type=int, default=20, help="Maximum number of concurrent workers." + ) + + args = parser.parse_args() + + try: + with open(args.input, "r", encoding="utf-8") as infile: + lines = infile.readlines() + except FileNotFoundError: + print(f"Error: Input file '{args.input}' not found") + return + + class SupermemoryClient: + def __init__(self): + from supermemory import Supermemory + + self.client = Supermemory(api_key=os.getenv("SUPERMEMORY_API_KEY")) + + def add(self, messages, user_id): + content = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + max_retries = 5 + for attempt in range(max_retries): + try: + self.client.memories.add(content=content, container_tag=user_id) + break + except Exception as e: + if attempt < max_retries - 1: + time.sleep(2**attempt) + else: + raise e + + def search(self, query, user_id, top_k): + max_retries = 10 + for attempt in range(max_retries): + try: + results = self.client.search.memories( + q=query, + container_tag=user_id, + threshold=0, + rerank=True, + rewrite_query=True, + limit=top_k, + ) + context = "\n\n".join([r.memory for r in results.results]) + return context + except Exception as e: + if attempt < max_retries - 1: + time.sleep(2**attempt) + else: + raise e + + mem_client = SupermemoryClient() + + if args.mode == "add": + print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") + print(f"Adding {args.add_turn} irrelevant turns.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit( + add_memory_for_line, + (i, line), + mem_client, + args.add_turn, + args.lib, + args.version, + ) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Adding memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'add' mode complete! Data with user_id written to '{args.output}'.") + + elif args.mode == "search": + print(f"Running in 'search' mode. Searching memories based on '{args.input}'...") + print(f"Retrieving top {args.top_k} memories for each query.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Searching memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print( + f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'." + ) + + elif args.mode == "response": + print(f"Running in 'response' mode. Generating responses based on '{args.input}'...") + print(f"Using {args.max_workers} workers.") + openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(generate_response_for_line, (i, line), openai_client) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Generating responses...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'response' mode complete! Final results written to '{args.output}'.") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/PrefEval/pref_zep.py b/evaluation/scripts/PrefEval/pref_zep.py new file mode 100644 index 000000000..699660787 --- /dev/null +++ b/evaluation/scripts/PrefEval/pref_zep.py @@ -0,0 +1,307 @@ +import argparse +import concurrent.futures +import json +import os +import sys +import time +import tiktoken +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm +from datetime import datetime +from irrelevant_conv import irre_10, irre_300 + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) +load_dotenv() +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +BASE_URL = os.getenv("OPENAI_BASE_URL") +MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") +tokenizer = tiktoken.get_encoding("cl100k_base") + + +def add_memory_for_line( + line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str +) -> dict: + """ + Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. + """ + i, line = line_data + user_id = f"{lib}_user_pref_eval_{i}_{version}" + + try: + original_data = json.loads(line) + conversation = original_data.get("conversation", []) + + if num_irrelevant_turns == 10: + conversation = conversation + irre_10 + elif num_irrelevant_turns == 300: + conversation = conversation + irre_300 + + turns_add = 5 + start_time_add = time.monotonic() + if conversation: + if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true": + for chunk_start in range(0, len(conversation), turns_add * 2): + chunk = conversation[chunk_start : chunk_start + turns_add * 2] + mem_client.add( + messages=chunk, + user_id=user_id, + conv_id=None, + timestamp=datetime.now().isoformat(), + ) + else: + mem_client.add( + messages=conversation, + user_id=user_id, + conv_id=None, + timestamp=datetime.now().isoformat(), + ) + end_time_add = time.monotonic() + add_duration = end_time_add - start_time_add + + original_data["user_id"] = user_id + original_data["metrics"] = {"add_memories_duration_seconds": add_duration} + return original_data + + except Exception as e: + print(f"Error adding memory for line {i + 1} (user_id: {user_id}): {e}") + return None + + +def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict: + """ + Processes a single line of data, searching memory based on the question. + """ + i, line = line_data + try: + original_data = json.loads(line) + + user_id = original_data.get("user_id") + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + + if not user_id: + original_data["error"] = ( + "Error: user_id not found in this line. Please run 'add' mode first." + ) + return original_data + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + start_time_search = time.monotonic() + relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) + search_memories_duration = time.monotonic() - start_time_search + memories_str = "\n".join( + f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] + ) + + memory_tokens_used = len(tokenizer.encode(memories_str)) + + metrics_dict.update( + { + "search_memories_duration_seconds": search_memories_duration, + "memory_tokens_used": memory_tokens_used, + "retrieved_memories_text": memories_str, + } + ) + original_data["metrics"] = metrics_dict + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error searching memory for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: + """ + Generates a response for a single line of data using pre-fetched memories. + """ + i, line = line_data + try: + original_data = json.loads(line) + + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + memories_str = metrics_dict.get("retrieved_memories_text") + + # If an error occurred in 'add' or 'search' mode, just pass the line through + if original_data.get("error"): + return original_data + + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + # Check for None, as an empty string (no memories found) is a valid result + if memories_str is None: + original_data["error"] = ( + "Error: retrieved_memories_text not found in metrics. " + "Please run 'search' mode first." + ) + return original_data + + system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + + response = openai_client.chat.completions.create(model=MODEL_NAME, messages=messages) + assistant_response = response.choices[0].message.content + original_data["response"] = assistant_response + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error generating response for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Process conversations with MemOS. Run 'add', then 'search', then 'response'." + ) + parser.add_argument( + "mode", + choices=["add", "search", "response"], + help="The mode to run the script in ('add', 'search', or 'response').", + ) + parser.add_argument("--input", required=True, help="Path to the input JSONL file.") + parser.add_argument("--output", required=True, help="Path to the output JSONL file.") + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Number of memories to retrieve (used in 'search' mode).", + ) + parser.add_argument( + "--add-turn", + type=int, + choices=[0, 10, 300], + default=0, + help="Number of irrelevant turns to add (used in 'add' mode).", + ) + parser.add_argument( + "--lib", + type=str, + choices=["zep"], + default="zep", + help="Which Zep library to use (used in 'add' mode).", + ) + parser.add_argument( + "--version", + type=str, + default="0929-1", + help="Version identifier for user_id generation (used in 'add' mode).", + ) + parser.add_argument( + "--max-workers", type=int, default=20, help="Maximum number of concurrent workers." + ) + + args = parser.parse_args() + + try: + with open(args.input, "r", encoding="utf-8") as infile: + lines = infile.readlines() + except FileNotFoundError: + print(f"Error: Input file '{args.input}' not found") + return + + from utils.client import ZepClient + + mem_client = ZepClient() + + if args.mode == "add": + print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") + print(f"Adding {args.add_turn} irrelevant turns.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit( + add_memory_for_line, + (i, line), + mem_client, + args.add_turn, + args.lib, + args.version, + ) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Adding memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'add' mode complete! Data with user_id written to '{args.output}'.") + + elif args.mode == "search": + print(f"Running in 'search' mode. Searching memories based on '{args.input}'...") + print(f"Retrieving top {args.top_k} memories for each query.") + print(f"Using {args.max_workers} workers.") + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Searching memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print( + f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'." + ) + + elif args.mode == "response": + print(f"Running in 'response' mode. Generating responses based on '{args.input}'...") + print(f"Using {args.max_workers} workers.") + openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(generate_response_for_line, (i, line), openai_client) + for i, line in enumerate(lines) + ] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Generating responses...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'response' mode complete! Final results written to '{args.output}'.") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/locomo/locomo_ingestion.py b/evaluation/scripts/locomo/locomo_ingestion.py index 2a177a52a..edb451dc0 100644 --- a/evaluation/scripts/locomo/locomo_ingestion.py +++ b/evaluation/scripts/locomo/locomo_ingestion.py @@ -103,12 +103,8 @@ def process_user(conv_idx, frame, locomo_df, version): from utils.client import MemobaseClient client = MemobaseClient() - all_users = client.client.get_all_users(limit=5000) - for user in all_users: - if user["additional_fields"]["user_id"] in [speaker_a_user_id, speaker_b_user_id]: - client.client.delete_user(user["id"]) - speaker_a_user_id = client.client.add_user({"user_id": speaker_a_user_id}) - speaker_b_user_id = client.client.add_user({"user_id": speaker_b_user_id}) + client.delete_user(speaker_a_user_id) + client.delete_user(speaker_b_user_id) elif frame == "memu": from utils.client import MemuClient @@ -193,7 +189,7 @@ def main(frame, version="default", num_workers=4): parser.add_argument( "--version", type=str, - default="default1", + default="default", help="Version identifier for saving results (e.g., 1010)", ) parser.add_argument( diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index d976b8f67..452fb4762 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -254,12 +254,6 @@ def process_user(conv_idx, locomo_df, frame, version, top_k=20, num_workers=1): from utils.client import MemobaseClient client = MemobaseClient() - users = client.client.get_all_users(limit=5000) - for u in users: - if u["additional_fields"]["user_id"] == speaker_a_user_id: - speaker_a_user_id = u["id"] - if u["additional_fields"]["user_id"] == speaker_b_user_id: - speaker_b_user_id = u["id"] elif frame == "memu": from utils.client import MemuClient @@ -348,7 +342,7 @@ def main(frame, version="default", num_workers=1, top_k=20): "--workers", type=int, default=5, help="Number of parallel workers to process users" ) parser.add_argument( - "--top_k", type=int, default=20, help="Number of results to retrieve in search queries" + "--top_k", type=int, default=15, help="Number of results to retrieve in search queries" ) args = parser.parse_args() lib = args.lib diff --git a/evaluation/scripts/longmemeval/lme_ingestion.py b/evaluation/scripts/longmemeval/lme_ingestion.py index 6e9bd5ab4..a1849757d 100644 --- a/evaluation/scripts/longmemeval/lme_ingestion.py +++ b/evaluation/scripts/longmemeval/lme_ingestion.py @@ -80,11 +80,7 @@ def ingest_conv(lme_df, version, conv_idx, frame, success_records, f): from utils.client import MemobaseClient client = MemobaseClient() - all_users = client.client.get_all_users(limit=5000) - for user in all_users: - if user["additional_fields"]["user_id"] == user_id: - client.client.delete_user(user["id"]) - user_id = client.client.add_user({"user_id": user_id}) + client.delete_user(user_id) elif frame == "memu": from utils.client import MemuClient diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index a24c0eaf5..67d2f1b04 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -114,10 +114,6 @@ def process_user(lme_df, conv_idx, frame, version, top_k=20): from utils.client import MemobaseClient client = MemobaseClient() - users = client.client.get_all_users(limit=5000) - for u in users: - if u["additional_fields"]["user_id"] == user_id: - user_id = u["id"] context, duration_ms = memobase_search(client, question, user_id, top_k) elif frame == "memos-api": from utils.client import MemosApiClient diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index 5cd9d38a6..8de23937c 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -1,48 +1,64 @@ import argparse -import os -import sys import csv import json - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import os +import sys from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime - from tqdm import tqdm -from utils.client import mem0_client,zep_client,memos_api_client -from zep_cloud.types import Message +import time + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def ingest_session(session, user_id, session_id, frame, client): messages = [] if frame == "zep": pass + elif "mem0" in frame: for idx, msg in enumerate(session): + messages.append({"role": msg["role"], "content": msg["content"][:8000]}) print( - f"[{frame}] 💬 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}...") - client.memory.add(messages=[Message(role=msg["role"], role_type=msg["role"], content=msg["content"], )], ) - elif frame == "mem0-local" or frame == "mem0-api": - for idx, msg in enumerate(session): - messages.append({"role": msg["role"], "content": msg["content"]}) - print( - f"[{frame}] 📝 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}...") - if frame == "mem0-local": - client.add(messages=messages, user_id=user_id) - elif frame == "mem0-api": - client.add(messages=messages, - user_id=user_id, - session_id=session_id, - version="v2", ) + f"[{frame}] 📝 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}..." + ) + timestamp_add = int(time.time() * 100) + client.add(messages=messages, user_id=user_id, timestamp=timestamp_add) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") - elif frame == "memos-local" or frame == "memos-api": - if os.getenv("PRE_SPLIT_CHUNK")=="true": + elif frame == "memos-api": + if os.getenv("PRE_SPLIT_CHUNK") == "true": for i in range(0, len(session), 10): - messages = session[i: i + 10] + messages = session[i : i + 10] client.add(messages=messages, user_id=user_id, conv_id=session_id) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") else: client.add(messages=session, user_id=user_id, conv_id=session_id) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(session)} messages") + elif frame == "memobase": + for idx, msg in enumerate(session): + if msg["role"] != "system": + messages.append( + { + "role": msg["role"], + "content": msg["content"][:8000], + "created_at": datetime.now().isoformat(), + } + ) + client.add(messages, user_id) + print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") + elif frame == "supermemory": + for _idx, msg in enumerate(session): + messages.append( + { + "role": msg["role"], + "content": msg["content"][:8000], + "chat_time": datetime.now().astimezone().isoformat(), + } + ) + client.add(messages, user_id) + elif frame == "memu": + for _idx, msg in enumerate(session): + messages.append({"role": msg["role"], "content": msg["content"]}) + client.add(messages, user_id, datetime.now().astimezone().isoformat()) def build_jsonl_index(jsonl_path): @@ -51,7 +67,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, "r", encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -63,14 +79,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, "r", encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -82,7 +98,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -102,13 +118,13 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as f: + with open(csv_path, mode="r", newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 def ingest_conv(row_data, context, version, conv_idx, frame): end_index_in_shared_context = row_data["end_index_in_shared_context"] - context = context[:int(end_index_in_shared_context)] + context = context[: int(end_index_in_shared_context)] user_id = f"pm_exper_user_{conv_idx}_{version}" print(f"👤 User ID: {user_id}") print("\n" + "=" * 80) @@ -116,38 +132,45 @@ def ingest_conv(row_data, context, version, conv_idx, frame): print("=" * 80) if frame == "zep": - client = zep_client() + from utils.client import ZepClient + + client = ZepClient() print("🔌 Using Zep client for ingestion...") client.user.delete(user_id) print(f"🗑️ Deleted existing user {user_id} from Zep memory...") client.user.add(user_id=user_id) print(f"➕ Added user {user_id} to Zep memory...") - elif frame == "mem0-local": - client = mem0_client(mode="local") - print("🔌 Using Mem0 Local client for ingestion...") - client.delete_all(user_id=user_id) + elif frame == "mem0" or frame == "mem0_graph": + from utils.client import Mem0Client + + client = Mem0Client(enable_graph="graph" in frame) + print("🔌 Using Mem0 client for ingestion...") + client.client.delete_all(user_id=user_id) print(f"🗑️ Deleted existing memories for user {user_id}...") - elif frame == "mem0-api": - client = mem0_client(mode="api") - print("🔌 Using Mem0 API client for ingestion...") - client.delete_all(user_id=user_id) + print(f"🗑️ Deleted existing memories for user {user_id}...") - elif frame == "memos-local": - client = memos_client( - mode="local", - db_name=f"pm_{frame}-{version}", - user_id=user_id, - top_k=20, - mem_cube_path=f"results/pm/{frame}-{version}/storages/{user_id}", - mem_cube_config_path="configs/mu_mem_cube_config.json", - mem_os_config_path="configs/mos_memos_config.json", - addorsearch="add", - ) - print("🔌 Using Memos Local client for ingestion...") elif frame == "memos-api": - client = memos_api_client() + from utils.client import MemosApiClient + + client = MemosApiClient() + elif frame == "memobase": + from utils.client import MemobaseClient + + client = MemobaseClient() + print("🔌 Using Memobase client for ingestion...") + client.delte_user(user_id) + elif frame == "supermemory": + from utils.client import SupermemoryClient - ingest_session(session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client) + client = SupermemoryClient() + elif frame == "memu": + from utils.client import MemuClient + + client = MemuClient() + + ingest_session( + session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client + ) print(f"✅ Ingestion of conversation {conv_idx} completed") print("=" * 80) @@ -170,16 +193,25 @@ def main(frame, version, num_workers=2): with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_idx = { - executor.submit(ingest_conv, row_data=row_data, context=context, version=version, conv_idx=idx, - frame=frame, ): idx - for idx, (row_data, context) in enumerate(all_data)} - - for future in tqdm(as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations"): + executor.submit( + ingest_conv, + row_data=row_data, + context=context, + version=version, + conv_idx=idx, + frame=frame, + ): idx + for idx, (row_data, context) in enumerate(all_data) + } + + for future in tqdm( + as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" + ): idx = future_to_idx[future] try: future.result() except Exception as exc: - print(f'\n❌ Conversation {idx} generated an exception: {exc}') + print(f"\n❌ Conversation {idx} generated an exception: {exc}") end_time = datetime.now() elapsed_time = end_time - start_time @@ -195,10 +227,18 @@ def main(frame, version, num_workers=2): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Ingestion Script") - parser.add_argument("--lib", type=str, choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], - default='memos-api') - parser.add_argument("--version", type=str, default="0925-1", help="Version of the evaluation framework.") - parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.") + parser.add_argument( + "--lib", + type=str, + choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory", "zep"], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="0925-1", help="Version of the evaluation framework." + ) + parser.add_argument( + "--workers", type=int, default=3, help="Number of parallel workers for processing users." + ) args = parser.parse_args() main(frame=args.lib, version=args.version, num_workers=args.workers) diff --git a/evaluation/scripts/personamem/pm_metric.py b/evaluation/scripts/personamem/pm_metric.py index 0f6a1e138..653c5fc10 100644 --- a/evaluation/scripts/personamem/pm_metric.py +++ b/evaluation/scripts/personamem/pm_metric.py @@ -8,40 +8,48 @@ def save_to_excel(results, output_path): """Save results to Excel file""" combined_data = [] - + # Add overall statistics row - overall_row = {"category": "overall", "accuracy": results["metrics"]["accuracy"], - "accuracy_std": results["metrics"]["accuracy_std"], - "total_questions": results["metrics"]["total_questions"], - "total_runs": results["metrics"]["total_runs"]} + overall_row = { + "category": "overall", + "accuracy": results["metrics"]["accuracy"], + "accuracy_std": results["metrics"]["accuracy_std"], + "total_questions": results["metrics"]["total_questions"], + "total_runs": results["metrics"]["total_runs"], + } # Add response duration metrics for metric, value in results["metrics"]["response_duration"].items(): overall_row[f"response_{metric}"] = value - + # Add search duration metrics (if exists) if "search_duration" in results["metrics"] and results["metrics"]["search_duration"]: for metric, value in results["metrics"]["search_duration"].items(): overall_row[f"search_{metric}"] = value - + combined_data.append(overall_row) - + # Add category statistics rows for category, scores in results["category_scores"].items(): - category_row = {"category": category, "accuracy": scores["accuracy"], "accuracy_std": scores["accuracy_std"], - "total_questions": scores["total_questions"], "total_runs": scores["total_runs"]} + category_row = { + "category": category, + "accuracy": scores["accuracy"], + "accuracy_std": scores["accuracy_std"], + "total_questions": scores["total_questions"], + "total_runs": scores["total_runs"], + } # Add response duration metrics for metric, value in scores["response_duration"].items(): category_row[f"response_{metric}"] = value - + # Add search duration metrics (if exists) if "search_duration" in scores and scores["search_duration"]: for metric, value in scores["search_duration"].items(): category_row[f"search_{metric}"] = value - + combined_data.append(category_row) - + # Save to Excel df = pd.DataFrame(combined_data) df.to_excel(output_path, sheet_name="PersonaMem_Metrics", index=False) @@ -50,62 +58,62 @@ def save_to_excel(results, output_path): def calculate_scores(data, grade_path, output_path): """Calculate PersonaMem evaluation metrics""" - + # Initialize statistics variables category_scores = {} user_metrics = {} - + # Overall metrics - collect accuracy for each run all_response_durations = [] all_search_durations = [] total_questions = 0 - + # For calculating accuracy across multiple runs num_runs = None # Will be determined from first user's data run_accuracies = [] # List to store accuracy for each run across all users - + # Category-wise statistics category_response_durations = {} category_search_durations = {} category_run_accuracies = {} # Store accuracy for each run by category - + print(f"📋 Processing response data for {len(data)} users...") - + # First pass: determine number of runs and initialize run accuracy arrays for user_id, user_data in data.items(): # Skip incomplete data (users with only topic field) if len(user_data) <= 2 and "topic" in user_data: continue - + results = user_data.get("results", []) if not results: continue - + if num_runs is None: num_runs = len(results) run_accuracies = [[] for _ in range(num_runs)] # Initialize for each run print(f"📊 Detected {num_runs} runs per user") break - + if num_runs is None: print("❌ Error: Could not determine number of runs from data") return - + # Iterate through all user data for user_id, user_data in data.items(): # Skip incomplete data (users with only topic field) if len(user_data) <= 2 and "topic" in user_data: print(f"⚠️ Skipping incomplete data for user {user_id}") continue - + # Get category and results category = user_data.get("category", "unknown") results = user_data.get("results", []) - + if not results: print(f"⚠️ No results found for user {user_id}") continue - + # Initialize category if not exists if category not in category_scores: category_scores[category] = { @@ -115,39 +123,39 @@ def calculate_scores(data, grade_path, output_path): "accuracy": 0.0, "accuracy_std": 0.0, "response_duration": {}, - "search_duration": {} + "search_duration": {}, } category_response_durations[category] = [] category_search_durations[category] = [] category_run_accuracies[category] = [[] for _ in range(num_runs)] - + # Process each run for this user user_response_durations = [] for run_idx, result in enumerate(results): is_correct = result.get("is_correct", False) - + # Collect accuracy for each run (1 if correct, 0 if not) if run_idx < num_runs: run_accuracies[run_idx].append(1.0 if is_correct else 0.0) category_run_accuracies[category][run_idx].append(1.0 if is_correct else 0.0) - + # Collect response duration response_duration = result.get("response_duration_ms", 0) if response_duration > 0: user_response_durations.append(response_duration) all_response_durations.append(response_duration) category_response_durations[category].append(response_duration) - + # Get search duration (usually same for all runs) search_duration = user_data.get("search_duration_ms", 0) if search_duration > 0: all_search_durations.append(search_duration) category_search_durations[category].append(search_duration) - + # Calculate user-level accuracy (average across runs) user_correct_count = sum(1 for result in results if result.get("is_correct", False)) user_accuracy = user_correct_count / len(results) if results else 0.0 - + # Store user-level metrics user_metrics[user_id] = { "user_id": user_id, @@ -156,22 +164,26 @@ def calculate_scores(data, grade_path, output_path): "accuracy": user_accuracy, "total_runs": len(results), "correct_runs": user_correct_count, - "avg_response_duration_ms": np.mean(user_response_durations) if user_response_durations else 0.0, + "avg_response_duration_ms": np.mean(user_response_durations) + if user_response_durations + else 0.0, "search_duration_ms": search_duration, "golden_answer": user_data.get("golden_answer", ""), - "topic": user_data.get("topic", "") + "topic": user_data.get("topic", ""), } - + # Count statistics total_questions += 1 category_scores[category]["total_questions"] += 1 category_scores[category]["total_runs"] += len(results) - + # Calculate overall accuracy and std across runs overall_run_accuracies = [np.mean(run_acc) for run_acc in run_accuracies if run_acc] overall_accuracy = np.mean(overall_run_accuracies) if overall_run_accuracies else 0.0 - overall_accuracy_std = np.std(overall_run_accuracies) if len(overall_run_accuracies) > 1 else 0.0 - + overall_accuracy_std = ( + np.std(overall_run_accuracies) if len(overall_run_accuracies) > 1 else 0.0 + ) + # Calculate response duration statistics response_duration_stats = {} if all_response_durations: @@ -182,9 +194,9 @@ def calculate_scores(data, grade_path, output_path): "p95": np.percentile(all_response_durations, 95), "std": np.std(all_response_durations), "min": np.min(all_response_durations), - "max": np.max(all_response_durations) + "max": np.max(all_response_durations), } - + # Calculate search duration statistics search_duration_stats = {} if all_search_durations: @@ -195,16 +207,22 @@ def calculate_scores(data, grade_path, output_path): "p95": np.percentile(all_search_durations, 95), "std": np.std(all_search_durations), "min": np.min(all_search_durations), - "max": np.max(all_search_durations) + "max": np.max(all_search_durations), } - + # Calculate category-wise metrics for category in category_scores: # Calculate accuracy mean and std across runs for this category - cat_run_accuracies = [np.mean(run_acc) for run_acc in category_run_accuracies[category] if run_acc] - category_scores[category]["accuracy"] = np.mean(cat_run_accuracies) if cat_run_accuracies else 0.0 - category_scores[category]["accuracy_std"] = np.std(cat_run_accuracies) if len(cat_run_accuracies) > 1 else 0.0 - + cat_run_accuracies = [ + np.mean(run_acc) for run_acc in category_run_accuracies[category] if run_acc + ] + category_scores[category]["accuracy"] = ( + np.mean(cat_run_accuracies) if cat_run_accuracies else 0.0 + ) + category_scores[category]["accuracy_std"] = ( + np.std(cat_run_accuracies) if len(cat_run_accuracies) > 1 else 0.0 + ) + # Response duration statistics for this category if category_response_durations[category]: durations = category_response_durations[category] @@ -215,14 +233,19 @@ def calculate_scores(data, grade_path, output_path): "p95": np.percentile(durations, 95), "std": np.std(durations), "min": np.min(durations), - "max": np.max(durations) + "max": np.max(durations), } else: category_scores[category]["response_duration"] = { - "mean": 0.0, "median": 0.0, "p50": 0.0, "p95": 0.0, - "std": 0.0, "min": 0.0, "max": 0.0 + "mean": 0.0, + "median": 0.0, + "p50": 0.0, + "p95": 0.0, + "std": 0.0, + "min": 0.0, + "max": 0.0, } - + # Search duration statistics for this category if category_search_durations[category]: durations = category_search_durations[category] @@ -233,14 +256,19 @@ def calculate_scores(data, grade_path, output_path): "p95": np.percentile(durations, 95), "std": np.std(durations), "min": np.min(durations), - "max": np.max(durations) + "max": np.max(durations), } else: category_scores[category]["search_duration"] = { - "mean": 0.0, "median": 0.0, "p50": 0.0, "p95": 0.0, - "std": 0.0, "min": 0.0, "max": 0.0 + "mean": 0.0, + "median": 0.0, + "p50": 0.0, + "p95": 0.0, + "std": 0.0, + "min": 0.0, + "max": 0.0, } - + # Build final results results = { "metrics": { @@ -249,22 +277,22 @@ def calculate_scores(data, grade_path, output_path): "total_questions": total_questions, "total_runs": total_questions * num_runs if num_runs else 0, "response_duration": response_duration_stats, - "search_duration": search_duration_stats + "search_duration": search_duration_stats, }, "category_scores": category_scores, - "user_scores": user_metrics + "user_scores": user_metrics, } - + # Save results to JSON file with open(grade_path, "w") as outfile: json.dump(results, outfile, indent=4, ensure_ascii=False) - + # Save to Excel save_to_excel(results, output_path) - + # Print summary print_summary(results) - + return results @@ -273,19 +301,19 @@ def print_summary(results): print("\n" + "=" * 80) print("📊 PERSONAMEM EVALUATION SUMMARY".center(80)) print("=" * 80) - + # Overall accuracy accuracy = results["metrics"]["accuracy"] accuracy_std = results["metrics"]["accuracy_std"] total_questions = results["metrics"]["total_questions"] total_runs = results["metrics"]["total_runs"] - + print(f"🎯 Overall Accuracy: {accuracy:.4f} ± {accuracy_std:.4f}") print(f"📋 Total Questions: {total_questions}") print(f"🔄 Total Runs: {total_runs}") - + print("-" * 80) - + # Response duration statistics if results["metrics"]["response_duration"]: rd = results["metrics"]["response_duration"] @@ -294,7 +322,7 @@ def print_summary(results): print(f" P50: \033[96m{rd['p50']:.2f}") print(f" P95: \033[91m{rd['p95']:.2f}") print(f" Std Dev: {rd['std']:.2f}") - + # Search duration statistics if results["metrics"]["search_duration"]: sd = results["metrics"]["search_duration"] @@ -303,9 +331,9 @@ def print_summary(results): print(f" P50: \033[96m{sd['p50']:.2f}") print(f" P95: \033[91m{sd['p95']:.2f}") print(f" Std Dev: {sd['std']:.2f}") - + print("-" * 80) - + # Category-wise accuracy print("📂 Category-wise Accuracy:") for category, scores in results["category_scores"].items(): @@ -313,50 +341,47 @@ def print_summary(results): acc_std = scores["accuracy_std"] total_cat = scores["total_questions"] total_runs_cat = scores["total_runs"] - print(f" {category:<35}: {acc:.4f} ± {acc_std:.4f} ({total_cat} questions, {total_runs_cat} runs)") - + print( + f" {category:<35}: {acc:.4f} ± {acc_std:.4f} ({total_cat} questions, {total_runs_cat} runs)" + ) + print("=" * 80 + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem evaluation metrics calculation script") parser.add_argument( - "--lib", - type=str, - choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], + "--lib", + type=str, + choices=["zep", "mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], required=True, help="Memory library to evaluate", - default='memos-api' - ) - parser.add_argument( - "--version", - type=str, - default="0925", - help="Evaluation framework version" + default="memos-api", ) - + parser.add_argument("--version", type=str, default="0925", help="Evaluation framework version") + args = parser.parse_args() lib, version = args.lib, args.version - + # Define file paths responses_path = f"results/pm/{lib}-{version}/{lib}_pm_responses.json" grade_path = f"results/pm/{lib}-{version}/{lib}_pm_grades.json" output_path = f"results/pm/{lib}-{version}/{lib}_pm_results.xlsx" - + print(f"📂 Loading response data from: {responses_path}") - + try: - with open(responses_path, 'r', encoding='utf-8') as file: + with open(responses_path, "r", encoding="utf-8") as file: data = json.load(file) - + # Calculate metrics results = calculate_scores(data, grade_path, output_path) - + print(f"📁 Results saved to: {grade_path}") print(f"📊 Excel report saved to: {output_path}") - + except FileNotFoundError: print(f"❌ Error: File not found {responses_path}") print("Please make sure to run pm_responses.py first to generate response data") except Exception as e: - print(f"❌ Error occurred during processing: {e}") \ No newline at end of file + print(f"❌ Error occurred during processing: {e}") diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index c48933c11..8bfeaf5f6 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -19,11 +19,11 @@ def extract_choice_answer(predicted_answer, correct_answer): def _extract_only_options(text): text = text.lower() - in_parens = re.findall(r'\(([a-d])\)', text) + in_parens = re.findall(r"\(([a-d])\)", text) if in_parens: return set(in_parens) else: - return set(re.findall(r'\b([a-d])\b', text)) + return set(re.findall(r"\b([a-d])\b", text)) correct = correct_answer.lower().strip("() ") @@ -33,7 +33,7 @@ def _extract_only_options(text): if "" in predicted_answer: predicted_answer = predicted_answer.split("")[-1].strip() if predicted_answer.endswith(""): - predicted_answer = predicted_answer[:-len("")].strip() + predicted_answer = predicted_answer[: -len("")].strip() pred_options = _extract_only_options(predicted_answer) @@ -79,12 +79,14 @@ def process_qa(user_id, search_result, num_runs, llm_client): is_correct, answer = extract_choice_answer(answer, search_result.get("golden_answer", "")) response_duration_ms = (time() - start) * 1000 - run_results.append({ - "run_id": idx + 1, - "answer": answer, - "is_correct": is_correct, - "response_duration_ms": response_duration_ms, - }) + run_results.append( + { + "run_id": idx + 1, + "answer": answer, + "is_correct": is_correct, + "response_duration_ms": response_duration_ms, + } + ) response_duration_ms = sum(result["response_duration_ms"] for result in run_results) / num_runs @@ -95,8 +97,11 @@ def process_qa(user_id, search_result, num_runs, llm_client): print(f"💡 Golden Answer: {search_result.get('golden_answer', 'N/A')}") for idx, result in enumerate(run_results, start=1): print(f"\n🔄 Run {idx}/{num_runs}:") - print(f"💬 Run Answer: {result['answer'][:150]}..." if len( - result['answer']) > 150 else f"💬 Run Answer: {result['answer']}") + print( + f"💬 Run Answer: {result['answer'][:150]}..." + if len(result["answer"]) > 150 + else f"💬 Run Answer: {result['answer']}" + ) print(f"✅ Run Is Correct: {result['is_correct']}") print(f"⏱️ Run Duration: {result['response_duration_ms']:.2f} ms") print("-" * 80) @@ -122,7 +127,9 @@ def main(frame, version, num_runs=3, num_workers=4): load_dotenv() - oai_client = OpenAI(api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL")) + oai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL") + ) print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") search_path = f"results/pm/{frame}-{version}/{frame}_pm_search_results.json" @@ -146,9 +153,9 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id[future] = user_id for future in tqdm( - as_completed(future_to_user_id), - total=len(future_to_user_id), - desc="📝 Generating responses", + as_completed(future_to_user_id), + total=len(future_to_user_id), + desc="📝 Generating responses", ): user_id = future_to_user_id[future] try: @@ -177,10 +184,21 @@ def main(frame, version, num_runs=3, num_workers=4): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Response Generation Script") - parser.add_argument("--lib", type=str, choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], default='memos-api') - parser.add_argument("--version", type=str, default="0925", help="Version of the evaluation framework.") - parser.add_argument("--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation.") - parser.add_argument("--workers", type=int, default=3, help="Number of worker threads to use for processing.") + parser.add_argument( + "--lib", + type=str, + choices=["zep", "mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="0925", help="Version of the evaluation framework." + ) + parser.add_argument( + "--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation." + ) + parser.add_argument( + "--workers", type=int, default=3, help="Number of worker threads to use for processing." + ) args = parser.parse_args() main(frame=args.lib, version=args.version, num_runs=args.num_runs, num_workers=args.workers) diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index 50f46f692..2e1a268fc 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -2,17 +2,15 @@ import json import os import sys - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from time import time - +from tqdm import tqdm import csv -from tqdm import tqdm -from utils.client import mem0_client,zep_client,memos_api_client +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, @@ -50,93 +48,68 @@ def zep_search(client, user_id, query, top_k=20): return context, duration_ms -def mem0_search(client, user_id, query, top_k=20, enable_graph=False, frame="mem0-api"): +def mem0_search(client, query, user_id, top_k): start = time() - - if frame == "mem0-local": - results = client.search( - query=query, - user_id=user_id, - top_k=top_k, - ) - search_memories = "\n".join( + results = client.search(query, user_id, top_k) + memory = [f"{memory['created_at']}: {memory['memory']}" for memory in results["results"]] + if client.enable_graph: + graph = "\n".join( [ - f" - {item['memory']} (date: {item['metadata']['timestamp']})" - for item in results["results"] + f" - 'source': {item.get('source', '?')} -> 'target': {item.get('target', '?')} " + f"(relationship: {item.get('relationship', '?')})" + for item in results.get("relations", []) ] ) - search_graph = ( - "\n".join( - [ - f" - 'source': {item.get('source', '?')} -> 'target': {item.get('destination', '?')} (relationship: {item.get('relationship', '?')})" - for item in results.get("relations", []) - ] - ) - if enable_graph - else "" - ) - - elif frame == "mem0-api": - results = client.search( - query=query, - user_id=user_id, - top_k=top_k, - version="v2", - output_format="v1.1", - enable_graph=enable_graph, - filters={"AND": [{"user_id": user_id}, {"run_id": "*"}]}, - ) - search_memories = "\n".join( - [f" - {item['memory']} (date: {item['created_at']})" for item in results["results"]] - ) - search_graph = ( - "\n".join( - [ - f" - 'source': {item.get('source', '?')} -> 'target': {item.get('target', '?')} (relationship: {item.get('relationship', '?')})" - for item in results.get("relations", []) - ] - ) - if enable_graph - else "" - ) - if enable_graph: context = MEM0_GRAPH_CONTEXT_TEMPLATE.format( - user_id=user_id, memories=search_memories, relations=search_graph + user_id=user_id, memories=memory, relations=graph ) else: - context = MEM0_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) + context = MEM0_CONTEXT_TEMPLATE.format(user_id=user_id, memories=memory) duration_ms = (time() - start) * 1000 return context, duration_ms -def memos_search(client, user_id, query, top_k, frame="memos-local"): +def memobase_search(client, query, user_id, top_k): start = time() - if frame == "memos-local": - results = client.search( - query=query, - user_id=user_id, - ) + context = client.search(query=query, user_id=user_id, top_k=top_k) + duration_ms = (time() - start) * 1000 + return context, duration_ms - results = filter_memory_data(results)["text_mem"][0]["memories"] - search_memories = "\n".join([f" - {item['memory']}" for item in results]) - elif frame == "memos-api": - results = client.search(query=query, user_id=user_id, top_k=top_k) - search_memories = "\n".join(f"- {entry.get('memory_value', '')}" - for entry in results.get("memory_detail_list", [])) +def memos_search(client, user_id, query, top_k): + start = time() + results = client.search(query=query, user_id=user_id, top_k=top_k) + search_memories = "\n".join( + item["memory"] for cube in results["text_mem"] for item in cube["memories"] + ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) duration_ms = (time() - start) * 1000 return context, duration_ms +def supermemory_search(client, query, user_id, top_k): + start = time() + context = client.search(query, user_id, top_k) + duration_ms = (time() - start) * 1000 + return context, duration_ms + + +def memu_search(client, query, user_id, top_k): + start = time() + results = client.search(query, user_id, top_k) + context = "\n".join(results) + duration_ms = (time() - start) * 1000 + return context, duration_ms + + def build_jsonl_index(jsonl_path): """ Scan the JSONL file once to build a mapping: {key: file_offset}. Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, "r", encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -148,14 +121,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, "r", encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -167,7 +140,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -190,7 +163,7 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as f: + with open(csv_path, mode="r", newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 @@ -219,35 +192,36 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): return existing_results if frame == "zep": - client = zep_client() + from utils.client import ZepClient + + client = ZepClient() print("🔌 Using Zep client for search...") context, duration_ms = zep_search(client, user_id, question) - elif frame == "mem0-local": - client = mem0_client(mode="local") - print("🔌 Using Mem0 Local client for search...") - context, duration_ms = mem0_search(client, user_id, question, top_k=top_k, frame=frame) - elif frame == "mem0-api": - client = mem0_client(mode="api") + elif frame == "mem0" or frame == "mem0-graph": + from utils.client import Mem0Client + + client = Mem0Client(enable_graph="graph" in frame) print("🔌 Using Mem0 API client for search...") - context, duration_ms = mem0_search(client, user_id, question, top_k=top_k, frame=frame) - elif frame == "memos-local": - client = memos_client( - mode="local", - db_name=f"pm_{frame}-{version}", - user_id=user_id, - top_k=top_k, - mem_cube_path=f"results/pm/{frame}-{version}/storages/{user_id}", - mem_cube_config_path="configs/mu_mem_cube_config.json", - mem_os_config_path="configs/mos_memos_config.json", - addorsearch="search", - ) - print("🔌 Using Memos Local client for search...") - context, duration_ms = memos_search(client, user_id, question, frame=frame) + context, duration_ms = mem0_search(client, question, user_id, top_k) elif frame == "memos-api": - client = memos_api_client() + from utils.client import MemosApiClient + + client = MemosApiClient() print("🔌 Using Memos API client for search...") - context, duration_ms = memos_search(client, user_id, question, top_k=top_k, frame=frame) + context, duration_ms = memos_search(client, user_id, question, top_k=top_k) + elif frame == "supermemory": + from utils.client import SupermemoryClient + + client = SupermemoryClient() + print("🔌 Using supermemory client for search...") + context, duration_ms = supermemory_search(client, question, user_id, top_k) + elif frame == "memu": + from utils.client import MemuClient + + client = MemuClient() + print("🔌 Using memu client for search...") + context, duration_ms = memu_search(client, question, user_id, top_k) search_results[user_id].append( { @@ -266,25 +240,23 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): os.makedirs(f"results/pm/{frame}-{version}/tmp", exist_ok=True) with open( - f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{conv_idx}.json", "w" + f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{conv_idx}.json", "w" ) as f: json.dump(search_results, f, indent=4) - print(f"💾 \033[92mSearch results for conversation {conv_idx} saved...") + print(f"💾 Search results for conversation {conv_idx} saved...") print("-" * 80) return search_results def load_existing_results(frame, version, group_idx): - result_path = ( - f"results/locomo/{frame}-{version}/tmp/{frame}_locomo_search_results_{group_idx}.json" - ) + result_path = f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{group_idx}.json" if os.path.exists(result_path): try: with open(result_path) as f: return json.load(f), True except Exception as e: - print(f"\033[91m❌ Error loading existing results for group {group_idx}: {e}") + print(f"❌ Error loading existing results for group {group_idx}: {e}") return {}, False @@ -299,9 +271,7 @@ def main(frame, version, top_k=20, num_workers=2): print(f"📚 Loaded PersonaMem dataset from {question_csv_path} and {context_jsonl_path}") print(f"📊 Total conversations: {total_rows}") - print( - f"⚙️ Search parameters: top_k={top_k}, workers={num_workers}" - ) + print(f"⚙️ Search parameters: top_k={top_k}, workers={num_workers}") print("-" * 80) all_search_results = defaultdict(list) @@ -320,7 +290,9 @@ def main(frame, version, top_k=20, num_workers=2): for idx, (row_data, _) in enumerate(all_data) } - for future in tqdm(as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations"): + for future in tqdm( + as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" + ): idx = future_to_idx[future] try: search_results = future.result() @@ -328,37 +300,41 @@ def main(frame, version, top_k=20, num_workers=2): all_search_results[user_id].extend(results) print(f"✅ Conversation {idx} processed successfully.") except Exception as exc: - print(f'\n❌ Conversation {idx} generated an exception: {exc}') + print(f"\n❌ Conversation {idx} generated an exception: {exc}") end_time = datetime.now() elapsed_time = end_time - start_time elapsed_time_str = str(elapsed_time).split(".")[0] print("\n" + "=" * 80) - print("✅ \033[1;32mSEARCH COMPLETE".center(80)) + print("✅ SEARCH COMPLETE".center(80)) print("=" * 80) - print( - f"⏱️ Total time taken to search {total_rows} users: \033[92m{elapsed_time_str}" - ) - print( - f"🔄 Framework: {frame} | Version: {version} | Workers: {num_workers}" - ) + print(f"⏱️ Total time taken to search {total_rows} users: {elapsed_time_str}") + print(f"🔄 Framework: {frame} | Version: {version} | Workers: {num_workers}") with open(f"results/pm/{frame}-{version}/{frame}_pm_search_results.json", "w") as f: json.dump(dict(all_search_results), f, indent=4) - print( - f"📁 Results saved to: \033[1;94mresults/pm/{frame}-{version}/{frame}_pm_search_results.json" - ) + print(f"📁 Results saved to: mresults/pm/{frame}-{version}/{frame}_pm_search_results.json") print("=" * 80 + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Search Script") - parser.add_argument("--lib", type=str, choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], - default='memos-api') - parser.add_argument("--version", type=str, default="0925", help="Version of the evaluation framework.") - parser.add_argument("--top_k", type=int, default=20, help="Number of top results to retrieve from the search.") - parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.") + parser.add_argument( + "--lib", + type=str, + choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="default", help="Version of the evaluation framework." + ) + parser.add_argument( + "--top_k", type=int, default=20, help="Number of top results to retrieve from the search." + ) + parser.add_argument( + "--workers", type=int, default=3, help="Number of parallel workers for processing users." + ) args = parser.parse_args() diff --git a/evaluation/scripts/run_pm_eval.sh b/evaluation/scripts/run_pm_eval.sh index 89484616b..f83893fed 100755 --- a/evaluation/scripts/run_pm_eval.sh +++ b/evaluation/scripts/run_pm_eval.sh @@ -1,41 +1,65 @@ #!/bin/bash # Common parameters for all scripts -LIB="memos-api" -VERSION="072201" +LIB="memu" +VERSION="072202" WORKERS=10 TOPK=20 -# echo "downloading data..." -# export HF_ENDPOINT=https://hf-mirror.com -# huggingface-cli download --repo-type dataset bowen-upenn/PersonaMem --local-dir /mnt/afs/codes/ljl/MemOS/evaluation/data/personamem +if [ "$LIB" = "mirix" ]; then + echo "Running pm_mirix.py 100 times..." + for i in {1..100}; do + echo "Iteration $i/100" + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_mirix.py --version $VERSION --workers 1 + if [ $? -ne 0 ]; then + echo "Error running xx.py on iteration $i" + exit 1 + fi + done +elif ["$LIB" = "zep"]; then + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_ingestion_zep.py --version $VERSION --workers $WORKERS + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_search_zep.py --version $VERSION --top_k $TOPK --workers $WORKERS + echo "Running pm_responses.py..." + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_responses.py --lib $LIB --version $VERSION --workers $WORKERS + if [ $? -ne 0 ]; then + echo "Error running pm_responses.py" + exit 1 + fi -echo "Running pm_ingestion.py..." -CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_ingestion.py --lib $LIB --version $VERSION --workers $WORKERS -if [ $? -ne 0 ]; then - echo "Error running pm_ingestion.py" - exit 1 -fi + echo "Running pm_metric.py..." + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_metric.py --lib $LIB --version $VERSION + if [ $? -ne 0 ]; then + echo "Error running pm_metric.py" + exit 1 + fi +else + echo "Running pm_ingestion.py..." + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_ingestion.py --lib $LIB --version $VERSION --workers $WORKERS + if [ $? -ne 0 ]; then + echo "Error running pm_ingestion.py" + exit 1 + fi -echo "Running pm_search.py..." -CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_search.py --lib $LIB --version $VERSION --top_k $TOPK --workers $WORKERS -if [ $? -ne 0 ]; then - echo "Error running pm_search.py" - exit 1 -fi + echo "Running pm_search.py..." + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_search.py --lib $LIB --version $VERSION --top_k $TOPK --workers $WORKERS + if [ $? -ne 0 ]; then + echo "Error running pm_search.py" + exit 1 + fi -echo "Running pm_responses.py..." -CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_responses.py --lib $LIB --version $VERSION --workers $WORKERS -if [ $? -ne 0 ]; then - echo "Error running pm_responses.py" - exit 1 -fi + echo "Running pm_responses.py..." + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_responses.py --lib $LIB --version $VERSION --workers $WORKERS + if [ $? -ne 0 ]; then + echo "Error running pm_responses.py" + exit 1 + fi -echo "Running pm_metric.py..." -CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_metric.py --lib $LIB --version $VERSION -if [ $? -ne 0 ]; then - echo "Error running pm_metric.py" - exit 1 + echo "Running pm_metric.py..." + CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_metric.py --lib $LIB --version $VERSION + if [ $? -ne 0 ]; then + echo "Error running pm_metric.py" + exit 1 + fi fi -echo "All scripts completed successfully!" +echo "All scripts completed successfully!" \ No newline at end of file diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh index 8e718192a..001f8299d 100644 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -9,21 +9,43 @@ WORKERS=10 # Parameters for pref_memos.py -TOP_K=10 +TOP_K=6 ADD_TURN=0 # Options: 0, 10, or 300 -LIB="memos-api" -VERSION="1021-5" +LIB="memos-api" +VERSION="1022-0" # --- File Paths --- # You may need to adjust these paths based on your project structure. -# Assumes Step 1 (preprocess) outputs this file: +# Step 1 (preprocess) outputs this file: PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" -# Intermediate file (output of 'add' mode, input for 'process' mode) -IDS_FILE="results/prefeval/pref_memos_add.jsonl" +# Create a directory name based on the *specific* LIB (e.g., "memos") +OUTPUT_DIR="results/prefeval/${LIB}_${VERSION}" + + +if [[ "$LIB" == *"mem0"* ]]; then + SCRIPT_NAME_BASE="mem0" +elif [[ "$LIB" == *"memos"* ]]; then + SCRIPT_NAME_BASE="memos" +elif [[ "$LIB" == *"memobase"* ]]; then + SCRIPT_NAME_BASE="memobase" +elif [[ "$LIB" == *"supermemory"* ]]; then + SCRIPT_NAME_BASE="supermemory" +elif [[ "$LIB" == *"memu"* ]]; then + SCRIPT_NAME_BASE="memu" +elif [[ "$LIB" == *"zep"* ]]; then + SCRIPT_NAME_BASE="zep" +else + SCRIPT_NAME_BASE=$LIB +fi + +# The script to be executed (e.g., pref_mem0.py) +LIB_SCRIPT="scripts/PrefEval/pref_${SCRIPT_NAME_BASE}.py" -# Final response file (output of 'process' mode, input for Step 3) -RESPONSE_FILE="results/prefeval/pref_memos_process.jsonl" +# Output files will be unique to the $LIB (e.g., pref_memos-api_add.jsonl) +IDS_FILE="${OUTPUT_DIR}/pref_${LIB}_add.jsonl" +SEARCH_FILE="${OUTPUT_DIR}/pref_${LIB}_search.jsonl" +RESPONSE_FILE="${OUTPUT_DIR}/pref_${LIB}_response.jsonl" # Set the Hugging Face mirror endpoint @@ -31,6 +53,8 @@ export HF_ENDPOINT="https://hf-mirror.com" echo "--- Starting PrefEval Pipeline ---" echo "Configuration: WORKERS=$WORKERS, TOP_K=$TOP_K, ADD_TURN=$ADD_TURN, LIB=$LIB, VERSION=$VERSION, HF_ENDPOINT=$HF_ENDPOINT" +echo "Results will be saved to: $OUTPUT_DIR" +echo "Using script: $LIB_SCRIPT (mapped from LIB=$LIB)" echo "" # --- Step 1: Preprocess the data --- @@ -42,11 +66,29 @@ if [ $? -ne 0 ]; then exit 1 fi -# --- Step 2: Generate responses using MemOS (split into 'add' and 'process') --- +# --- Create output directory --- +echo "" +echo "Creating output directory: $OUTPUT_DIR" +mkdir -p $OUTPUT_DIR +if [ $? -ne 0 ]; then + echo "Error: Could not create output directory '$OUTPUT_DIR'." + exit 1 +fi + +# Check if the *mapped* script exists +if [ ! -f "$LIB_SCRIPT" ]; then + echo "Error: Script not found for library '$LIB' (mapped to $LIB_SCRIPT)" + exit 1 +fi + +# --- Step 2: Generate responses based on LIB --- +echo "" +echo "--- Step 2: Generate responses using $LIB (3-Step Process) ---" + echo "" -echo "Running pref_memos.py in 'add' mode..." +echo "Running $LIB_SCRIPT in 'add' mode..." # Step 2a: Ingest conversations into memory and generate user_ids -python scripts/PrefEval/pref_memos.py add \ +python $LIB_SCRIPT add \ --input $PREPROCESSED_FILE \ --output $IDS_FILE \ --add-turn $ADD_TURN \ @@ -55,35 +97,49 @@ python scripts/PrefEval/pref_memos.py add \ --version $VERSION if [ $? -ne 0 ]; then - echo "Error: pref_memos.py 'add' mode failed." + echo "Error: $LIB_SCRIPT 'add' mode failed." exit 1 fi echo "" -echo "Running pref_memos.py in 'process' mode..." -# Step 2b: Search memories using user_ids and generate responses -python scripts/PrefEval/pref_memos.py process \ +echo "Running $LIB_SCRIPT in 'search' mode..." +# Step 2b: Search memories using user_ids +python $LIB_SCRIPT search \ --input $IDS_FILE \ - --output $RESPONSE_FILE \ + --output $SEARCH_FILE \ --top-k $TOP_K \ - --max-workers $WORKERS \ - --lib $LIB \ - --version $VERSION + --max-workers $WORKERS + +if [ $? -ne 0 ]; then + echo "Error: $LIB_SCRIPT 'search' mode failed." + exit 1 +fi + +echo "" +echo "Running $LIB_SCRIPT in 'response' mode..." +# Step 2c: Generate responses based on searched memories +python $LIB_SCRIPT response \ + --input $SEARCH_FILE \ + --output $RESPONSE_FILE \ + --max-workers $WORKERS if [ $? -ne 0 ]; then - echo "Error: pref_memos.py 'process' mode failed." + echo "Error: $LIB_SCRIPT 'response' mode failed." exit 1 fi # --- Step 3: Evaluate the generated responses --- echo "" echo "Running pref_eval.py..." -# Pass the WORKERS variable to the script's --concurrency-limit argument -python scripts/PrefEval/pref_eval.py --concurrency-limit $WORKERS +python scripts/PrefEval/pref_eval.py \ + --input $RESPONSE_FILE \ + --concurrency-limit $WORKERS + if [ $? -ne 0 ]; then echo "Error: Evaluation script failed." exit 1 fi echo "" -echo "--- PrefEval Pipeline completed successfully! ---" \ No newline at end of file +echo "--- PrefEval Pipeline completed successfully! ---" +echo "Final results are in $RESPONSE_FILE" diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 87b863e86..2efb0493d 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -2,6 +2,8 @@ import os import sys import time +import uuid +from contextlib import suppress from datetime import datetime from dotenv import load_dotenv import requests @@ -17,7 +19,7 @@ def __init__(self): api_key = os.getenv("ZEP_API_KEY") self.client = Zep(api_key=api_key) - def add(self, messages, user_id, conv_id, timestamp): + def add(self, messages, user_id, timestamp): iso_date = datetime.fromtimestamp(timestamp).isoformat() for msg in messages: self.client.graph.add( @@ -49,18 +51,31 @@ def __init__(self, enable_graph=False): self.client = MemoryClient(api_key=os.getenv("MEM0_API_KEY")) self.enable_graph = enable_graph - def add(self, messages, user_id, timestamp): - if self.enable_graph: - self.client.add( - messages=messages, - timestamp=timestamp, - user_id=user_id, - output_format="v1.1", - version="v2", - enable_graph=True, - ) - else: - self.client.add(messages=messages, timestamp=timestamp, user_id=user_id, version="v2") + def add(self, messages, user_id, timestamp, batch_size=2): + max_retries = 5 + for i in range(0, len(messages), batch_size): + batch_messages = messages[i : i + batch_size] + for attempt in range(max_retries): + try: + if self.enable_graph: + self.client.add( + messages=batch_messages, + timestamp=timestamp, + user_id=user_id, + enable_graph=True, + ) + else: + self.client.add( + messages=batch_messages, + timestamp=timestamp, + user_id=user_id, + ) + break + except Exception as e: + if attempt < max_retries - 1: + time.sleep(2**attempt) + else: + raise e def search(self, query, user_id, top_k): if self.enable_graph: @@ -68,19 +83,15 @@ def search(self, query, user_id, top_k): query=query, top_k=top_k, user_id=user_id, - output_format="v1.1", - version="v2", enable_graph=True, - filters={"AND": [{"user_id": f"{user_id}"}, {"run_id": "*"}]}, + filters={"AND": [{"user_id": f"{user_id}"}]}, ) else: res = self.client.search( query=query, top_k=top_k, user_id=user_id, - output_format="v1.1", - version="v2", - filters={"AND": [{"user_id": f"{user_id}"}, {"run_id": "*"}]}, + filters={"AND": [{"user_id": f"{user_id}"}]}, ) return res @@ -93,18 +104,29 @@ def __init__(self): project_url=os.getenv("MEMOBASE_PROJECT_URL"), api_key=os.getenv("MEMOBASE_API_KEY") ) - def add(self, messages, user_id): - from memobase import ChatBlob - + def add(self, messages, user_id, batch_size=2): """ - user_id: memobase user_id messages = [{"role": "assistant", "content": data, "created_at": iso_date}] """ - user = self.client.get_user(user_id, no_get=True) - user.insert(ChatBlob(messages=messages), sync=True) + from memobase import ChatBlob + + real_uid = self.string_to_uuid(user_id) + user = self.client.get_or_create_user(real_uid) + for i in range(0, len(messages), batch_size): + batch_messages = messages[i : i + batch_size] + max_retries = 5 + for attempt in range(max_retries): + try: + _ = user.insert(ChatBlob(messages=batch_messages), sync=True) + except Exception as e: + if attempt < max_retries - 1: + time.sleep(2**attempt) + else: + raise e def search(self, query, user_id, top_k): - user = self.client.get_user(user_id, no_get=True) + real_uid = self.string_to_uuid(user_id) + user = self.client.get_user(real_uid, no_get=True) memories = user.context( max_token_size=top_k * 100, chats=[{"role": "user", "content": query}], @@ -113,6 +135,16 @@ def search(self, query, user_id, top_k): ) return memories + def delete_user(self, user_id): + from memobase.error import ServerError + + real_uid = self.string_to_uuid(user_id) + with suppress(ServerError): + self.client.delete_user(real_uid) + + def string_to_uuid(self, s: str, salt="memobase_client"): + return str(uuid.uuid5(uuid.NAMESPACE_DNS, s + salt)) + class MemosApiClient: def __init__(self): @@ -120,6 +152,9 @@ def __init__(self): self.headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")} def add(self, messages, user_id, conv_id): + """ + messages = [{"role": "assistant", "content": data, "chat_time": date_str}] + """ url = f"{self.memos_url}/product/add" payload = json.dumps( { @@ -155,6 +190,62 @@ def search(self, query, user_id, top_k): return json.loads(response.text)["data"] +class MemosApiOnlineClient: + def __init__(self): + self.memos_url = os.getenv("MEMOS_ONLINE_URL") + self.headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")} + + def add(self, messages, user_id, conv_id=None): + url = f"{self.memos_url}/add/message" + payload = json.dumps( + { + "messages": messages, + "user_id": user_id, + "conversation_id": conv_id, + } + ) + + max_retries = 5 + for attempt in range(max_retries): + try: + response = requests.request("POST", url, data=payload, headers=self.headers) + assert response.status_code == 200, response.text + assert json.loads(response.text)["message"] == "ok", response.text + return response.text + except Exception as e: + if attempt < max_retries - 1: + time.sleep(2**attempt) + else: + raise e + + def search(self, query, user_id, top_k): + """Search memories.""" + url = f"{self.memos_url}/search/memory" + payload = json.dumps( + { + "query": query, + "user_id": user_id, + "memory_limit_number": top_k, + } + ) + + max_retries = 5 + for attempt in range(max_retries): + try: + response = requests.request("POST", url, data=payload, headers=self.headers) + assert response.status_code == 200, response.text + assert json.loads(response.text)["message"] == "ok", response.text + res = json.loads(response.text)["data"]["memory_detail_list"] + for i in res: + i.update({"memory": i.pop("memory_value")}) + return {"text_mem": [{"memories": res}]} + except Exception as e: + if attempt < max_retries - 1: + time.sleep(2**attempt) + else: + raise e + + class SupermemoryClient: def __init__(self): from supermemory import Supermemory @@ -172,7 +263,7 @@ def add(self, messages, user_id): break except Exception as e: if attempt < max_retries - 1: - time.sleep(2**attempt) # 指数退避 + time.sleep(2**attempt) else: raise e @@ -192,7 +283,7 @@ def search(self, query, user_id, top_k): return context except Exception as e: if attempt < max_retries - 1: - time.sleep(2**attempt) # 指数退避 + time.sleep(2**attempt) else: raise e @@ -245,3 +336,10 @@ def wait_for_completion(self, task_id): timestamp = 1682899200 query = "杭州西湖有什么" top_k = 5 + + # MEMOBASE + client = MemobaseClient() + for m in messages: + m["created_at"] = iso_date + client.add(messages, user_id) + memories = client.search(query, user_id, top_k) diff --git a/evaluation/scripts/utils/mirix_utils.py b/evaluation/scripts/utils/mirix_utils.py new file mode 100644 index 000000000..e1b5f3de6 --- /dev/null +++ b/evaluation/scripts/utils/mirix_utils.py @@ -0,0 +1,81 @@ +import os +import yaml +from tqdm import tqdm + + +def get_mirix_client(config_path, load_from=None): + if os.path.exists(os.path.expanduser(f"~/.mirix")): + os.system(f"rm -rf ~/.mirix/*") + + with open(config_path, "r") as f: + agent_config = yaml.safe_load(f) + + os.environ["OPENAI_API_KEY"] = agent_config["api_key"] + import mirix + from mirix import Mirix, EmbeddingConfig, LLMConfig + + embedding_default_config = EmbeddingConfig( + embedding_model=agent_config["embedding_model_name"], + embedding_endpoint_type="openai", + embedding_endpoint=agent_config["model_endpoint"], + embedding_dim=1536, + embedding_chunk_size=8191, + ) + + llm_default_config = LLMConfig( + model=agent_config["model_name"], + model_endpoint_type="openai", + model_endpoint=agent_config["model_endpoint"], + api_key=agent_config["api_key"], + model_wrapper=None, + context_window=128000, + ) + + def embedding_default_config_func(cls, model_name=None, provider=None): + return embedding_default_config + + def llm_default_config_func(cls, model_name=None, provider=None): + return llm_default_config + + mirix.EmbeddingConfig.default_config = embedding_default_config_func + mirix.LLMConfig.default_config = llm_default_config_func + + assistant = Mirix( + api_key=agent_config["api_key"], + config_path=config_path, + model=agent_config["model_name"], + load_from=load_from, + ) + return assistant + + +if __name__ == "__main__": + config_path = "configs-example/mirix_config.yaml" + out_dir = "results/mirix-test" + + assistant = get_mirix_client(config_path) + + chunks = [ + "I prefer coffee over tea", + "My work hours are 9 AM to 5 PM", + "Important meeting with client on Friday at 2 PM", + ] + + for _idx, chunk in tqdm(enumerate(chunks), total=len(chunks)): + response = assistant.add(chunk) + + assistant.save(out_dir) + + assistant = get_mirix_client(config_path, load_from=out_dir) + response = assistant.chat("What's my schedule like this week?") + + print(response) + assistant.create_user(user_name="user1") + assistant.create_user(user_name="user2") + user1 = assistant.get_user_by_name(user_name="user1") + user2 = assistant.get_user_by_name(user_name="user2") + assistant.add("i prefer tea over coffee", user_id=user1.id) + assistant.add("my favourite drink is coke", user_id=user2.id) + response1 = assistant.chat("What drink do I prefer?", user_id=user1.id) + response2 = assistant.chat("What drink do I prefer?", user_id=user2.id) + print(response1, response2) From 651e8df274ac1df0c5c32cbd4c636fcb32c09307 Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Sat, 25 Oct 2025 15:53:41 +0800 Subject: [PATCH 04/64] Meger update about scheduler and new api to Dev (#386) * debug an error function name * feat: Add DynamicCache compatibility for different transformers versions - Fix build_kv_cache method in hf.py to handle both old and new DynamicCache structures - Support new 'layers' attribute with key_cache/value_cache or keys/values - Maintain backward compatibility with direct key_cache/value_cache attributes - Add comprehensive error handling and logging for unsupported structures - Update move_dynamic_cache_htod function in kv.py for cross-version compatibility - Handle layers-based structure in newer transformers versions - Support alternative attribute names (keys/values vs key_cache/value_cache) - Preserve original functionality for older transformers versions - Add comprehensive tests for DynamicCache compatibility - Test activation memory update with mock DynamicCache layers - Verify layers attribute access across different transformers versions - Fix scheduler logger mock to include memory_manager attribute This resolves AttributeError issues when using different versions of the transformers library and ensures robust handling of DynamicCache objects. debug * feat: implement APIAnalyzerForScheduler for memory operations - Add APIAnalyzerForScheduler class with search/add operations - Support requests and http.client with connection reuse - Include comprehensive error handling and dynamic configuration - Add English test suite with real-world conversation scenarios * feat: Add search_ws API endpoint and enhance API analyzer functionality - Add search_ws endpoint in server_router.py for scheduler-enabled search - Fix missing imports: time module, SearchRequest class, and get_mos_product_instance function - Implement search_ws method in api_analyzer.py with HTTP client support - Add _search_ws_with_requests and _search_ws_with_http_client private methods - Include search_ws usage example in demonstration code - Enhance scheduler and dispatcher capabilities for improved memory management - Expand test coverage to ensure functionality stability This update primarily strengthens the memory scheduling system's search capabilities, providing users with more flexible API interface options. * fix: resolve test failures and warnings in test suite - Fix Pydantic serialization warning in test_memos_chen_tang_hello_world * Add warnings filter to suppress UserWarning from Pydantic serialization - Fix KeyError: 'past_key_values' in test_build_kv_cache_and_generation * Update mock configuration to properly return forward_output with past_key_values * Add DynamicCache version compatibility handling in test mocks * Support both old and new transformers versions with layers/key_cache attributes * Improve assertion logic to check all model calls for required parameters - Update base_scheduler.py to use centralized DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE constant * Add import for DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE from general_schemas * Replace hardcoded value 100 with configurable constant (1000) All tests now pass successfully with proper version compatibility handling. * feat: add a test_robustness execution to test thread pool execution * feat: optimize scheduler configuration and API search functionality - Add DEFAULT_TOP_K and DEFAULT_CONTEXT_WINDOW_SIZE global constants in general_schemas.py - Update base_scheduler.py to use global default values instead of hardcoded numbers - Fix SchedulerConfigFactory initialization issue by using keyword argument expansion - Resolve UnboundLocalError variable conflict in search_memories_ws function - Fix indentation and parameter issues in OptimizedScheduler search_for_api method - Improve code standardization and maintainability * feat: Add Redis auto-initialization with fallback strategies - Add auto_initialize_redis() with config/env/local fallback - Move Redis logic from dispatcher_monitor to redis_service - Update base_scheduler to use auto initialization - Add proper resource cleanup and error handling * feat: add database connection management to ORM module - Add MySQL engine loading from environment variables in BaseDBManager - Add Redis connection loading from environment variables in BaseDBManager - Enhance database configuration validation and error handling - Complete database adapter infrastructure for ORM module - Provide unified database connection management interface This update provides comprehensive database connection management capabilities for the mem_scheduler module, supporting dynamic MySQL and Redis configuration loading from environment variables, establishing reliable data persistence foundation for scheduling services and API services. * remove part of test * feat: add Redis-based ORM with multiprocess synchronization - Add RedisDBManager and RedisLockableORM classes - Implement atomic locking mechanism for concurrent access - Add merge functionality for different object types - Include comprehensive test suite and examples - Fix Redis key type conflicts in lock operations * fix: resolve scheduler module import and Redis integration issues * revise naive memcube creation in server router * remove long-time tests in test_scheduler * remove redis test which needs .env --- examples/mem_scheduler/orm_examples.py | 374 ++++++++++ src/memos/api/product_models.py | 3 +- src/memos/api/routers/server_router.py | 230 +++++- src/memos/configs/mem_scheduler.py | 41 +- src/memos/llms/hf.py | 54 +- src/memos/mem_os/core.py | 26 +- src/memos/mem_os/main.py | 36 +- .../mem_scheduler/analyzer/api_analyzer.py | 619 ++++++++++++++++ .../analyzer/mos_for_test_scheduler.py | 26 +- src/memos/mem_scheduler/base_scheduler.py | 218 ++++-- .../mem_scheduler/general_modules/api_misc.py | 115 +++ .../general_modules/dispatcher.py | 34 +- src/memos/mem_scheduler/general_scheduler.py | 4 +- .../monitors/dispatcher_monitor.py | 127 ++-- .../mem_scheduler/monitors/general_monitor.py | 5 +- .../mem_scheduler/optimized_scheduler.py | 117 ++- .../mem_scheduler/orm_modules/base_model.py | 217 +++++- .../mem_scheduler/orm_modules/redis_model.py | 699 ++++++++++++++++++ .../mem_scheduler/schemas/general_schemas.py | 15 + .../mem_scheduler/schemas/message_schemas.py | 9 +- .../mem_scheduler/schemas/task_schemas.py | 7 +- src/memos/mem_scheduler/utils/db_utils.py | 17 + .../webservice_modules/rabbitmq_service.py | 65 +- .../webservice_modules/redis_service.py | 225 +++++- src/memos/memories/activation/kv.py | 36 +- tests/llms/test_hf.py | 41 +- tests/mem_scheduler/test_dispatcher.py | 230 ++++-- tests/mem_scheduler/test_orm.py | 148 ++++ tests/mem_scheduler/test_scheduler.py | 366 ++++++++- tests/test_hello_world.py | 13 +- 30 files changed, 3799 insertions(+), 318 deletions(-) create mode 100644 examples/mem_scheduler/orm_examples.py create mode 100644 src/memos/mem_scheduler/general_modules/api_misc.py create mode 100644 src/memos/mem_scheduler/orm_modules/redis_model.py diff --git a/examples/mem_scheduler/orm_examples.py b/examples/mem_scheduler/orm_examples.py new file mode 100644 index 000000000..bbb57b4ab --- /dev/null +++ b/examples/mem_scheduler/orm_examples.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +ORM Examples for MemScheduler + +This script demonstrates how to use the BaseDBManager's new environment variable loading methods +for MySQL and Redis connections. +""" + +import multiprocessing +import os +import sys + +from pathlib import Path + + +# Add the src directory to the Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + +from memos.log import get_logger +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager, DatabaseError +from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager + + +logger = get_logger(__name__) + + +def test_mysql_engine_from_env(): + """Test loading MySQL engine from environment variables""" + print("\n" + "=" * 60) + print("Testing MySQL Engine from Environment Variables") + print("=" * 60) + + try: + # Test loading MySQL engine from current environment variables + mysql_engine = BaseDBManager.load_mysql_engine_from_env() + if mysql_engine is None: + print("❌ Failed to create MySQL engine - check environment variables") + return + + print(f"✅ Successfully created MySQL engine: {mysql_engine}") + print(f" Engine URL: {mysql_engine.url}") + + # Test connection + with mysql_engine.connect() as conn: + from sqlalchemy import text + + result = conn.execute(text("SELECT 'MySQL connection test successful' as message")) + message = result.fetchone()[0] + print(f" Connection test: {message}") + + mysql_engine.dispose() + print(" MySQL engine disposed successfully") + + except DatabaseError as e: + print(f"❌ DatabaseError: {e}") + except Exception as e: + print(f"❌ Unexpected error: {e}") + + +def test_redis_connection_from_env(): + """Test loading Redis connection from environment variables""" + print("\n" + "=" * 60) + print("Testing Redis Connection from Environment Variables") + print("=" * 60) + + try: + # Test loading Redis connection from current environment variables + redis_client = BaseDBManager.load_redis_engine_from_env() + if redis_client is None: + print("❌ Failed to create Redis connection - check environment variables") + return + + print(f"✅ Successfully created Redis connection: {redis_client}") + + # Test basic Redis operations + redis_client.set("test_key", "Hello from ORM Examples!") + value = redis_client.get("test_key") + print(f" Redis test - Set/Get: {value}") + + # Test Redis info + info = redis_client.info("server") + redis_version = info.get("redis_version", "unknown") + print(f" Redis server version: {redis_version}") + + # Clean up test key + redis_client.delete("test_key") + print(" Test key cleaned up") + + redis_client.close() + print(" Redis connection closed successfully") + + except DatabaseError as e: + print(f"❌ DatabaseError: {e}") + except Exception as e: + print(f"❌ Unexpected error: {e}") + + +def test_environment_variables(): + """Test and display current environment variables""" + print("\n" + "=" * 60) + print("Current Environment Variables") + print("=" * 60) + + # MySQL environment variables + mysql_vars = [ + "MYSQL_HOST", + "MYSQL_PORT", + "MYSQL_USERNAME", + "MYSQL_PASSWORD", + "MYSQL_DATABASE", + "MYSQL_CHARSET", + ] + + print("\nMySQL Environment Variables:") + for var in mysql_vars: + value = os.getenv(var, "Not set") + # Mask password for security + if "PASSWORD" in var and value != "Not set": + value = "*" * len(value) + print(f" {var}: {value}") + + # Redis environment variables + redis_vars = [ + "REDIS_HOST", + "REDIS_PORT", + "REDIS_DB", + "REDIS_PASSWORD", + "MEMSCHEDULER_REDIS_HOST", + "MEMSCHEDULER_REDIS_PORT", + "MEMSCHEDULER_REDIS_DB", + "MEMSCHEDULER_REDIS_PASSWORD", + ] + + print("\nRedis Environment Variables:") + for var in redis_vars: + value = os.getenv(var, "Not set") + # Mask password for security + if "PASSWORD" in var and value != "Not set": + value = "*" * len(value) + print(f" {var}: {value}") + + +def test_manual_env_loading(): + """Test loading environment variables manually from .env file""" + print("\n" + "=" * 60) + print("Testing Manual Environment Loading") + print("=" * 60) + + env_file_path = "/Users/travistang/Documents/codes/memos/.env" + + if not os.path.exists(env_file_path): + print(f"❌ Environment file not found: {env_file_path}") + return + + try: + from dotenv import load_dotenv + + # Load environment variables + load_dotenv(env_file_path) + print(f"✅ Successfully loaded environment variables from {env_file_path}") + + # Test some key variables + test_vars = ["OPENAI_API_KEY", "MOS_CHAT_MODEL", "TZ"] + for var in test_vars: + value = os.getenv(var, "Not set") + if "KEY" in var and value != "Not set": + value = f"{value[:10]}..." if len(value) > 10 else value + print(f" {var}: {value}") + + except ImportError: + print("❌ python-dotenv not installed. Install with: pip install python-dotenv") + except Exception as e: + print(f"❌ Error loading environment file: {e}") + + +def test_redis_lockable_orm_with_list(): + """Test RedisDBManager with list[str] type synchronization""" + print("\n" + "=" * 60) + print("Testing RedisDBManager with list[str]") + print("=" * 60) + + try: + from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager + + # Create a simple list manager instance + list_manager = SimpleListManager(["apple", "banana", "cherry"]) + print(f"Original list manager: {list_manager}") + + # Create RedisDBManager instance + redis_client = BaseDBManager.load_redis_engine_from_env() + if redis_client is None: + print("❌ Failed to create Redis connection - check environment variables") + return + + db_manager = RedisDBManager( + redis_client=redis_client, + user_id="test_user", + mem_cube_id="test_list_cube", + obj=list_manager, + ) + + # Save to Redis + db_manager.save_to_db(list_manager) + print("✅ List manager saved to Redis") + + # Load from Redis + loaded_manager = db_manager.load_from_db() + if loaded_manager: + print(f"Loaded list manager: {loaded_manager}") + print(f"Items match: {list_manager.items == loaded_manager.items}") + else: + print("❌ Failed to load list manager from Redis") + + # Clean up + redis_client.delete("lockable_orm:test_user:test_list_cube:data") + redis_client.delete("lockable_orm:test_user:test_list_cube:lock") + redis_client.delete("lockable_orm:test_user:test_list_cube:version") + redis_client.close() + + except Exception as e: + print(f"❌ Error in RedisDBManager test: {e}") + + +def modify_list_process(process_id: int, items_to_add: list[str]): + """Function to be run in separate processes to modify the list using merge_items""" + try: + from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager + + # Create Redis connection + redis_client = BaseDBManager.load_redis_engine_from_env() + if redis_client is None: + print(f"Process {process_id}: Failed to create Redis connection") + return + + # Create a temporary list manager for this process with items to add + temp_manager = SimpleListManager() + + db_manager = RedisDBManager( + redis_client=redis_client, + user_id="test_user", + mem_cube_id="multiprocess_list", + obj=temp_manager, + ) + + print(f"Process {process_id}: Starting modification with items: {items_to_add}") + for item in items_to_add: + db_manager.obj.add_item(item) + # Use sync_with_orm which internally uses merge_items + db_manager.sync_with_orm(size_limit=None) + + print(f"Process {process_id}: Successfully synchronized with Redis") + + redis_client.close() + + except Exception as e: + print(f"Process {process_id}: Error - {e}") + import traceback + + traceback.print_exc() + + +def test_multiprocess_synchronization(): + """Test multiprocess synchronization with RedisDBManager""" + print("\n" + "=" * 60) + print("Testing Multiprocess Synchronization") + print("=" * 60) + + try: + # Initialize Redis with empty list + redis_client = BaseDBManager.load_redis_engine_from_env() + if redis_client is None: + print("❌ Failed to create Redis connection") + return + + # Initialize with empty list + initial_manager = SimpleListManager([]) + db_manager = RedisDBManager( + redis_client=redis_client, + user_id="test_user", + mem_cube_id="multiprocess_list", + obj=initial_manager, + ) + db_manager.save_to_db(initial_manager) + print("✅ Initialized empty list manager in Redis") + + # Define items for each process to add + process_items = [ + ["item1", "item2"], + ["item3", "item4"], + ["item5", "item6"], + ["item1", "item7"], # item1 is duplicate, should not be added twice + ] + + # Create and start processes + processes = [] + for i, items in enumerate(process_items): + p = multiprocessing.Process(target=modify_list_process, args=(i + 1, items)) + processes.append(p) + p.start() + + # Wait for all processes to complete + for p in processes: + p.join() + + print("\n" + "-" * 40) + print("All processes completed. Checking final result...") + + # Load final result + final_db_manager = RedisDBManager( + redis_client=redis_client, + user_id="test_user", + mem_cube_id="multiprocess_list", + obj=SimpleListManager([]), + ) + final_manager = final_db_manager.load_from_db() + + if final_manager: + print(f"Final synchronized list manager: {final_manager}") + print(f"Final list length: {len(final_manager)}") + print("Expected items: {'item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7'}") + print(f"Actual items: {set(final_manager.items)}") + + # Check if all unique items are present + expected_items = {"item1", "item2", "item3", "item4", "item5", "item6", "item7"} + actual_items = set(final_manager.items) + + if expected_items == actual_items: + print("✅ All processes contributed correctly - synchronization successful!") + else: + print(f"❌ Expected items: {expected_items}") + print(f" Actual items: {actual_items}") + else: + print("❌ Failed to load final result") + + # Clean up + redis_client.delete("lockable_orm:test_user:multiprocess_list:data") + redis_client.delete("lockable_orm:test_user:multiprocess_list:lock") + redis_client.delete("lockable_orm:test_user:multiprocess_list:version") + redis_client.close() + + except Exception as e: + print(f"❌ Error in multiprocess synchronization test: {e}") + + +def main(): + """Main function to run all tests""" + print("ORM Examples - Environment Variable Loading Tests") + print("=" * 80) + + # Test environment variables display + test_environment_variables() + + # Test manual environment loading + test_manual_env_loading() + + # Test MySQL engine loading + test_mysql_engine_from_env() + + # Test Redis connection loading + test_redis_connection_from_env() + + # Test RedisLockableORM with list[str] + test_redis_lockable_orm_with_list() + + # Test multiprocess synchronization + test_multiprocess_synchronization() + + print("\n" + "=" * 80) + print("All tests completed!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 86751b008..d14c05993 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field # Import message types from core types module +from memos.mem_scheduler.schemas.general_schemas import SearchMode from memos.types import MessageDict, PermissionDict @@ -170,7 +171,7 @@ class APISearchRequest(BaseRequest): query: str = Field(..., description="Search query") user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") - mode: str = Field("fast", description="search mode fast or fine") + mode: SearchMode = Field(SearchMode.FINE, description="search mode: fast, fine, or mixture") internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") top_k: int = Field(10, description="Number of results to return") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index a332de583..9f982ddd3 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,3 +1,4 @@ +import json import os import traceback @@ -18,6 +19,7 @@ from memos.configs.internet_retriever import InternetRetrieverConfigFactory from memos.configs.llm import LLMConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.configs.reranker import RerankerConfigFactory from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory @@ -26,6 +28,14 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_os.product_server import MOSServer from memos.mem_reader.factory import MemReaderFactory +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.general_schemas import ( + API_MIX_SEARCH_LABEL, + SearchMode, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, @@ -134,6 +144,34 @@ def init_server(): llm=llm, online_bot=False, ) + + # Initialize Scheduler + scheduler_config_dict = APIConfig.get_scheduler_config() + scheduler_config = SchedulerConfigFactory( + backend="optimized_scheduler", config=scheduler_config_dict + ) + mem_scheduler = SchedulerFactory.from_config(scheduler_config) + mem_scheduler.initialize_modules( + chat_llm=llm, + process_llm=mem_reader.llm, + db_engine=BaseDBManager.create_default_sqlite_engine(), + ) + mem_scheduler.start() + + # Initialize SchedulerAPIModule + api_module = mem_scheduler.api_module + + naive_mem_cube = NaiveMemCube( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + internet_retriever=internet_retriever, + memory_manager=memory_manager, + default_cube_config=default_cube_config, + ) + return ( graph_db, mem_reader, @@ -144,6 +182,9 @@ def init_server(): memory_manager, default_cube_config, mos_server, + mem_scheduler, + naive_mem_cube, + api_module, ) @@ -158,24 +199,12 @@ def init_server(): memory_manager, default_cube_config, mos_server, + mem_scheduler, + naive_mem_cube, + api_module, ) = init_server() -def _create_naive_mem_cube() -> NaiveMemCube: - """Create a NaiveMemCube instance with initialized components.""" - naive_mem_cube = NaiveMemCube( - llm=llm, - embedder=embedder, - mem_reader=mem_reader, - graph_db=graph_db, - reranker=reranker, - internet_retriever=internet_retriever, - memory_manager=memory_manager, - default_cube_config=default_cube_config, - ) - return naive_mem_cube - - def _format_memory_item(memory_data: Any) -> dict[str, Any]: """Format a single memory item for API response.""" memory = memory_data.model_dump() @@ -207,18 +236,146 @@ def search_memories(search_req: APISearchRequest): "act_mem": [], "para_mem": [], } + + search_mode = search_req.mode + + if search_mode == SearchMode.FAST: + formatted_memories = fast_search_memories(search_req=search_req, user_context=user_context) + elif search_mode == SearchMode.FINE: + formatted_memories = fine_search_memories(search_req=search_req, user_context=user_context) + elif search_mode == SearchMode.MIXTURE: + formatted_memories = mix_search_memories(search_req=search_req, user_context=user_context) + else: + logger.error(f"Unsupported search mode: {search_mode}") + raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") + + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": formatted_memories, + } + ) + + return SearchResponse( + message="Search completed successfully", + data=memories_result, + ) + + +def mix_search_memories( + search_req: APISearchRequest, + user_context: UserContext, +): + """ + Mix search memories: fast search + async fine search + """ + # Get fast memories first + fast_memories = fast_search_memories(search_req, user_context) + + # Check if scheduler and dispatcher are available for async execution + if mem_scheduler and hasattr(mem_scheduler, "dispatcher") and mem_scheduler.dispatcher: + try: + # Create message for async fine search + message_content = { + "search_req": { + "query": search_req.query, + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "top_k": search_req.top_k, + "internet_search": search_req.internet_search, + "moscube": search_req.moscube, + "chat_history": search_req.chat_history, + }, + "user_context": {"mem_cube_id": user_context.mem_cube_id}, + } + + message = ScheduleMessageItem( + item_id=f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}", + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + label=API_MIX_SEARCH_LABEL, + mem_cube=naive_mem_cube, + content=json.dumps(message_content), + timestamp=get_utc_now(), + ) + + # Submit async task + mem_scheduler.dispatcher.submit_message(message) + logger.info(f"Submitted async fine search task for user {search_req.user_id}") + + # Try to get pre-computed fine memories if available + try: + pre_fine_memories = api_module.get_pre_fine_memories( + user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id + ) + if pre_fine_memories: + # Merge fast and pre-computed fine memories + all_memories = fast_memories + pre_fine_memories + # Remove duplicates based on content + seen_contents = set() + unique_memories = [] + for memory in all_memories: + content_key = memory.get("content", "") + if content_key not in seen_contents: + seen_contents.add(content_key) + unique_memories.append(memory) + return unique_memories + except Exception as e: + logger.warning(f"Failed to get pre-computed fine memories: {e}") + + except Exception as e: + logger.error(f"Failed to submit async fine search task: {e}") + # Fall back to synchronous execution + + # Fallback: synchronous fine search + try: + fine_memories = fine_search_memories(search_req, user_context) + + # Merge fast and fine memories + all_memories = fast_memories + fine_memories + + # Remove duplicates based on content + seen_contents = set() + unique_memories = [] + for memory in all_memories: + content_key = memory.get("content", "") + if content_key not in seen_contents: + seen_contents.add(content_key) + unique_memories.append(memory) + + # Sync search data to Redis + try: + api_module.sync_search_data( + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + query=search_req.query, + formatted_memories=unique_memories, + ) + except Exception as e: + logger.error(f"Failed to sync search data: {e}") + + return unique_memories + + except Exception as e: + logger.error(f"Fine search failed: {e}") + return fast_memories + + +def fine_search_memories( + search_req: APISearchRequest, + user_context: UserContext, +): target_session_id = search_req.session_id if not target_session_id: target_session_id = "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None # Create MemCube and perform search - naive_mem_cube = _create_naive_mem_cube() search_results = naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=search_req.mode, + mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, @@ -230,17 +387,36 @@ def search_memories(search_req: APISearchRequest): ) formatted_memories = [_format_memory_item(data) for data in search_results] - memories_result["text_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": formatted_memories, - } - ) + return formatted_memories - return SearchResponse( - message="Search completed successfully", - data=memories_result, + +def fast_search_memories( + search_req: APISearchRequest, + user_context: UserContext, +): + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + search_results = naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, ) + formatted_memories = [_format_memory_item(data) for data in search_results] + + return formatted_memories @router.post("/add", summary="Add memories", response_model=MemoryResponse) @@ -252,7 +428,6 @@ def add_memories(add_req: APIADDRequest): mem_cube_id=add_req.mem_cube_id, session_id=add_req.session_id or "default_session", ) - naive_mem_cube = _create_naive_mem_cube() target_session_id = add_req.session_id if not target_session_id: target_session_id = "default_session" @@ -296,7 +471,6 @@ def chat_complete(chat_req: APIChatCompleteRequest): """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" try: # Collect all responses from the generator - naive_mem_cube = _create_naive_mem_cube() content, references = mos_server.chat( query=chat_req.query, user_id=chat_req.user_id, diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 2d6155ec2..bc22cfb63 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -11,8 +11,14 @@ from memos.mem_scheduler.schemas.general_schemas import ( BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, + DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, DEFAULT_CONSUME_INTERVAL_SECONDS, + DEFAULT_CONTEXT_WINDOW_SIZE, + DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, DEFAULT_THREAD_POOL_MAX_WORKERS, + DEFAULT_TOP_K, + DEFAULT_USE_REDIS_QUEUE, + DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, ) @@ -20,7 +26,8 @@ class BaseSchedulerConfig(BaseConfig): """Base configuration class for mem_scheduler.""" top_k: int = Field( - default=10, description="Number of top candidates to consider in initial retrieval" + default=DEFAULT_TOP_K, + description="Number of top candidates to consider in initial retrieval", ) enable_parallel_dispatch: bool = Field( default=True, description="Whether to enable parallel message processing using thread pool" @@ -39,6 +46,19 @@ class BaseSchedulerConfig(BaseConfig): default=None, description="Path to the authentication configuration file containing private credentials", ) + # Redis queue configuration + use_redis_queue: bool = Field( + default=DEFAULT_USE_REDIS_QUEUE, + description="Whether to use Redis queue instead of local memory queue", + ) + redis_config: dict[str, Any] = Field( + default_factory=lambda: {"host": "localhost", "port": 6379, "db": 0}, + description="Redis connection configuration", + ) + max_internal_message_queue_size: int = Field( + default=DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, + description="Maximum size of internal message queue when not using Redis", + ) class GeneralSchedulerConfig(BaseSchedulerConfig): @@ -47,7 +67,8 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): default=300, description="Interval in seconds for updating activation memory" ) context_window_size: int | None = Field( - default=10, description="Size of the context window for conversation history" + default=DEFAULT_CONTEXT_WINDOW_SIZE, + description="Size of the context window for conversation history", ) act_mem_dump_path: str | None = Field( default=DEFAULT_ACT_MEM_DUMP_PATH, # Replace with DEFAULT_ACT_MEM_DUMP_PATH @@ -57,10 +78,12 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): default=False, description="Whether to enable automatic activation memory updates" ) working_mem_monitor_capacity: int = Field( - default=30, description="Capacity of the working memory monitor" + default=DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, + description="Capacity of the working memory monitor", ) activation_mem_monitor_capacity: int = Field( - default=20, description="Capacity of the activation memory monitor" + default=DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, + description="Capacity of the activation memory monitor", ) # Database configuration for ORM persistence @@ -77,6 +100,14 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): ) +class OptimizedSchedulerConfig(GeneralSchedulerConfig): + """Configuration for the optimized scheduler. + + This class inherits all fields from `GeneralSchedulerConfig` + and is used to distinguish optimized scheduling logic via type. + """ + + class SchedulerConfigFactory(BaseConfig): """Factory class for creating scheduler configurations.""" @@ -86,7 +117,7 @@ class SchedulerConfigFactory(BaseConfig): model_config = ConfigDict(extra="forbid", strict=True) backend_to_class: ClassVar[dict[str, Any]] = { "general_scheduler": GeneralSchedulerConfig, - "optimized_scheduler": GeneralSchedulerConfig, # optimized_scheduler uses same config as general_scheduler + "optimized_scheduler": OptimizedSchedulerConfig, # optimized_scheduler uses same config as general_scheduler } @field_validator("backend") diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index 00081b581..be0d1d95f 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -379,10 +379,52 @@ def build_kv_cache(self, messages) -> DynamicCache: raise ValueError( "Prompt after chat template is empty, cannot build KV cache. Check your messages input." ) - kv = DynamicCache() + # Create cache and perform forward pass without pre-existing cache with torch.no_grad(): - self.model(**inputs, use_cache=True, past_key_values=kv) - for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache, strict=False)): - kv.key_cache[i] = k[:, :, :seq_len, :] - kv.value_cache[i] = v[:, :, :seq_len, :] - return kv + outputs = self.model(**inputs, use_cache=True) + + # Get the cache from model outputs + if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None: + kv = outputs.past_key_values + + # Convert from legacy tuple format to DynamicCache if needed + if isinstance(kv, tuple): + kv = DynamicCache.from_legacy_cache(kv) + + # Handle compatibility between old and new transformers versions + # In newer versions, DynamicCache uses 'layers' attribute + # In older versions, it uses 'key_cache' and 'value_cache' attributes + if hasattr(kv, "layers"): + # New version: trim cache using layers attribute + for layer in kv.layers: + if hasattr(layer, "key_cache") and hasattr(layer, "value_cache"): + # Trim each layer's cache to the sequence length + if layer.key_cache is not None: + layer.key_cache = layer.key_cache[:, :, :seq_len, :] + if layer.value_cache is not None: + layer.value_cache = layer.value_cache[:, :, :seq_len, :] + elif hasattr(layer, "keys") and hasattr(layer, "values"): + # Alternative attribute names in some versions + if layer.keys is not None: + layer.keys = layer.keys[:, :, :seq_len, :] + if layer.values is not None: + layer.values = layer.values[:, :, :seq_len, :] + elif hasattr(kv, "key_cache") and hasattr(kv, "value_cache"): + # Old version: trim cache using key_cache and value_cache attributes + for i in range(len(kv.key_cache)): + if kv.key_cache[i] is not None: + kv.key_cache[i] = kv.key_cache[i][:, :, :seq_len, :] + if kv.value_cache[i] is not None: + kv.value_cache[i] = kv.value_cache[i][:, :, :seq_len, :] + else: + # Fallback: log warning but continue without trimming + logger.warning( + f"DynamicCache object of type {type(kv)} has unexpected structure. " + f"Cache trimming skipped. Available attributes: {dir(kv)}" + ) + + return kv + else: + raise RuntimeError( + "Failed to build KV cache: no cache data available from model outputs" + ) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 0010897c0..cedffd6fb 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -310,18 +310,20 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = past_key_values = None if self.config.enable_activation_memory: - assert self.config.chat_model.backend == "huggingface", ( - "Activation memory only used for huggingface backend." - ) - # TODO this only one cubes - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube_id not in user_cube_ids: - continue - if mem_cube.act_mem: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) + if self.config.chat_model.backend != "huggingface": + logger.error( + "Activation memory only used for huggingface backend. Skipping activation memory." + ) + else: + # TODO this only one cubes + for mem_cube_id, mem_cube in self.mem_cubes.items(): + if mem_cube_id not in user_cube_ids: + continue + if mem_cube.act_mem: + kv_cache = next(iter(mem_cube.act_mem.get_all()), None) + past_key_values = ( + kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None + ) break # Generate response response = self.chat_llm.generate(current_messages, past_key_values=past_key_values) diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 2e5b32548..6fc64c5e3 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -312,23 +312,25 @@ def _generate_enhanced_response_with_context( # Handle activation memory if enabled (same as core method) past_key_values = None if self.config.enable_activation_memory: - assert self.config.chat_model.backend == "huggingface", ( - "Activation memory only used for huggingface backend." - ) - # Get accessible cubes for the user - target_user_id = user_id if user_id is not None else self.user_id - accessible_cubes = self.user_manager.get_user_cubes(target_user_id) - user_cube_ids = [cube.cube_id for cube in accessible_cubes] - - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube_id not in user_cube_ids: - continue - if mem_cube.act_mem: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) - break + if self.config.chat_model.backend != "huggingface": + logger.error( + "Activation memory only used for huggingface backend. Skipping activation memory." + ) + else: + # Get accessible cubes for the user + target_user_id = user_id if user_id is not None else self.user_id + accessible_cubes = self.user_manager.get_user_cubes(target_user_id) + user_cube_ids = [cube.cube_id for cube in accessible_cubes] + + for mem_cube_id, mem_cube in self.mem_cubes.items(): + if mem_cube_id not in user_cube_ids: + continue + if mem_cube.act_mem: + kv_cache = next(iter(mem_cube.act_mem.get_all()), None) + past_key_values = ( + kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None + ) + break try: # Generate the enhanced response using the chat LLM with same parameters as core diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py index e69de29bb..45a39e0de 100644 --- a/src/memos/mem_scheduler/analyzer/api_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py @@ -0,0 +1,619 @@ +""" +API Analyzer for Scheduler + +This module provides the APIAnalyzerForScheduler class that handles API requests +for search and add operations with reusable instance variables. +""" + +import http.client +import json + +from typing import Any +from urllib.parse import urlparse + +import requests + +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class APIAnalyzerForScheduler: + """ + API Analyzer class for scheduler operations. + + This class provides methods to interact with APIs for search and add operations, + with reusable instance variables for better performance and configuration management. + """ + + def __init__( + self, + base_url: str = "http://127.0.0.1:8002", + default_headers: dict[str, str] | None = None, + timeout: int = 30, + ): + """ + Initialize the APIAnalyzerForScheduler. + + Args: + base_url: Base URL for API requests + default_headers: Default headers to use for all requests + timeout: Request timeout in seconds + """ + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + # Default headers + self.default_headers = default_headers or {"Content-Type": "application/json"} + + # Parse URL for http.client usage + parsed_url = urlparse(self.base_url) + self.host = parsed_url.hostname + self.port = parsed_url.port or 8002 + self.is_https = parsed_url.scheme == "https" + + # Reusable connection for http.client + self._connection = None + + # Attributes + self.user_id = "test_user_id" + self.mem_cube_id = "test_mem_cube_id" + + logger.info(f"APIAnalyzerForScheduler initialized with base_url: {self.base_url}") + + def _get_connection(self) -> http.client.HTTPConnection | http.client.HTTPSConnection: + """ + Get or create a reusable HTTP connection. + + Returns: + HTTP connection object + """ + if self._connection is None: + if self.is_https: + self._connection = http.client.HTTPSConnection(self.host, self.port) + else: + self._connection = http.client.HTTPConnection(self.host, self.port) + return self._connection + + def _close_connection(self): + """Close the HTTP connection if it exists.""" + if self._connection: + self._connection.close() + self._connection = None + + def search( + self, user_id: str, mem_cube_id: str, query: str, top: int = 50, use_requests: bool = True + ) -> dict[str, Any]: + """ + Search for memories using the product/search API endpoint. + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + query: Search query string + top: Number of top results to return + use_requests: Whether to use requests library (True) or http.client (False) + + Returns: + Dictionary containing the API response + """ + payload = {"user_id": user_id, "mem_cube_id": mem_cube_id, "query": query, "top": top} + + try: + if use_requests: + return self._search_with_requests(payload) + else: + return self._search_with_http_client(payload) + except Exception as e: + logger.error(f"Error in search operation: {e}") + return {"error": str(e), "success": False} + + def _search_with_requests(self, payload: dict[str, Any]) -> dict[str, Any]: + """ + Perform search using requests library. + + Args: + payload: Request payload + + Returns: + Dictionary containing the API response + """ + url = f"{self.base_url}/product/search" + + response = requests.post( + url, headers=self.default_headers, data=json.dumps(payload), timeout=self.timeout + ) + + logger.info(f"Search request to {url} completed with status: {response.status_code}") + + try: + return { + "success": True, + "status_code": response.status_code, + "data": response.json() if response.content else {}, + "text": response.text, + } + except json.JSONDecodeError: + return { + "success": True, + "status_code": response.status_code, + "data": {}, + "text": response.text, + } + + def _search_with_http_client(self, payload: dict[str, Any]) -> dict[str, Any]: + """ + Perform search using http.client. + + Args: + payload: Request payload + + Returns: + Dictionary containing the API response + """ + conn = self._get_connection() + + try: + conn.request("POST", "/product/search", json.dumps(payload), self.default_headers) + + response = conn.getresponse() + data = response.read() + response_text = data.decode("utf-8") + + logger.info(f"Search request completed with status: {response.status}") + + try: + response_data = json.loads(response_text) if response_text else {} + except json.JSONDecodeError: + response_data = {} + + return { + "success": True, + "status_code": response.status, + "data": response_data, + "text": response_text, + } + except Exception as e: + logger.error(f"Error in http.client search: {e}") + return {"error": str(e), "success": False} + + def add( + self, messages: list, user_id: str, mem_cube_id: str, use_requests: bool = True + ) -> dict[str, Any]: + """ + Add memories using the product/add API endpoint. + + Args: + messages: List of message objects with role and content + user_id: User identifier + mem_cube_id: Memory cube identifier + use_requests: Whether to use requests library (True) or http.client (False) + + Returns: + Dictionary containing the API response + """ + payload = {"messages": messages, "user_id": user_id, "mem_cube_id": mem_cube_id} + + try: + if use_requests: + return self._add_with_requests(payload) + else: + return self._add_with_http_client(payload) + except Exception as e: + logger.error(f"Error in add operation: {e}") + return {"error": str(e), "success": False} + + def _add_with_requests(self, payload: dict[str, Any]) -> dict[str, Any]: + """ + Perform add using requests library. + + Args: + payload: Request payload + + Returns: + Dictionary containing the API response + """ + url = f"{self.base_url}/product/add" + + response = requests.post( + url, headers=self.default_headers, data=json.dumps(payload), timeout=self.timeout + ) + + logger.info(f"Add request to {url} completed with status: {response.status_code}") + + try: + return { + "success": True, + "status_code": response.status_code, + "data": response.json() if response.content else {}, + "text": response.text, + } + except json.JSONDecodeError: + return { + "success": True, + "status_code": response.status_code, + "data": {}, + "text": response.text, + } + + def _add_with_http_client(self, payload: dict[str, Any]) -> dict[str, Any]: + """ + Perform add using http.client. + + Args: + payload: Request payload + + Returns: + Dictionary containing the API response + """ + conn = self._get_connection() + + try: + conn.request("POST", "/product/add", json.dumps(payload), self.default_headers) + + response = conn.getresponse() + data = response.read() + response_text = data.decode("utf-8") + + logger.info(f"Add request completed with status: {response.status}") + + try: + response_data = json.loads(response_text) if response_text else {} + except json.JSONDecodeError: + response_data = {} + + return { + "success": True, + "status_code": response.status, + "data": response_data, + "text": response_text, + } + except Exception as e: + logger.error(f"Error in http.client add: {e}") + return {"error": str(e), "success": False} + + def update_base_url(self, new_base_url: str): + """ + Update the base URL and reinitialize connection parameters. + + Args: + new_base_url: New base URL for API requests + """ + self._close_connection() + self.base_url = new_base_url.rstrip("/") + + # Re-parse URL + parsed_url = urlparse(self.base_url) + self.host = parsed_url.hostname + self.port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80) + self.is_https = parsed_url.scheme == "https" + + logger.info(f"Base URL updated to: {self.base_url}") + + def update_headers(self, headers: dict[str, str]): + """ + Update default headers. + + Args: + headers: New headers to merge with existing ones + """ + self.default_headers.update(headers) + logger.info("Headers updated") + + def __del__(self): + """Cleanup method to close connection when object is destroyed.""" + self._close_connection() + + def analyze_service(self): + # Example add operation + messages = [ + {"role": "user", "content": "Where should I go for New Year's Eve in Shanghai?"}, + { + "role": "assistant", + "content": "You could head to the Bund for the countdown, attend a rooftop party, or enjoy the fireworks at Disneyland Shanghai.", + }, + ] + + add_result = self.add( + messages=messages, user_id="test_user_id", mem_cube_id="test_mem_cube_id" + ) + print("Add result:", add_result) + + # Example search operation + search_result = self.search( + user_id="test_user_id", + mem_cube_id="test_mem_cube_id", + query="What are some good places to celebrate New Year's Eve in Shanghai?", + top=50, + ) + print("Search result:", search_result) + + def analyze_features(self): + try: + # Test basic search functionality + search_result = self.search( + user_id="test_user_id", + mem_cube_id="test_mem_cube_id", + query="What are some good places to celebrate New Year's Eve in Shanghai?", + top=50, + ) + print("Search result:", search_result) + except Exception as e: + logger.error(f"Feature analysis failed: {e}") + + +class DirectSearchMemoriesAnalyzer: + """ + Direct analyzer for testing search_memories function + Used for debugging and analyzing search_memories function behavior without starting a full API server + """ + + def __init__(self): + """Initialize the analyzer""" + # Import necessary modules + try: + from memos.api.product_models import APIADDRequest, APISearchRequest + from memos.api.routers.server_router import add_memories, search_memories + from memos.types import MessageDict, UserContext + + self.APISearchRequest = APISearchRequest + self.APIADDRequest = APIADDRequest + self.search_memories = search_memories + self.add_memories = add_memories + self.UserContext = UserContext + self.MessageDict = MessageDict + + logger.info("DirectSearchMemoriesAnalyzer initialized successfully") + except ImportError as e: + logger.error(f"Failed to import modules: {e}") + raise + + def create_test_search_request( + self, + query="test query", + user_id="test_user", + mem_cube_id="test_cube", + mode="fast", + top_k=10, + chat_history=None, + session_id=None, + ): + """ + Create a test APISearchRequest object with the given parameters. + + Args: + query: Search query string + user_id: User ID for the request + mem_cube_id: Memory cube ID for the request + mode: Search mode ("fast" or "fine") + top_k: Number of results to return + chat_history: Chat history for context (optional) + session_id: Session ID for the request (optional) + + Returns: + APISearchRequest: A configured request object + """ + return self.APISearchRequest( + query=query, + user_id=user_id, + mem_cube_id=mem_cube_id, + mode=mode, + top_k=top_k, + chat_history=chat_history, + session_id=session_id, + ) + + def create_test_add_request( + self, + user_id="test_user", + mem_cube_id="test_cube", + messages=None, + memory_content=None, + session_id=None, + ): + """ + Create a test APIADDRequest object with the given parameters. + + Args: + user_id: User ID for the request + mem_cube_id: Memory cube ID for the request + messages: List of messages to add (optional) + memory_content: Direct memory content to add (optional) + session_id: Session ID for the request (optional) + + Returns: + APIADDRequest: A configured request object + """ + if messages is None and memory_content is None: + # Default test messages + messages = [ + {"role": "user", "content": "What's the weather like today?"}, + { + "role": "assistant", + "content": "I don't have access to real-time weather data, but you can check a weather app or website for current conditions.", + }, + ] + + # Ensure we have a valid session_id + if session_id is None: + session_id = "test_session_" + str(hash(user_id + mem_cube_id))[:8] + + return self.APIADDRequest( + user_id=user_id, + mem_cube_id=mem_cube_id, + messages=messages, + memory_content=memory_content, + session_id=session_id, + doc_path=None, + source="api_analyzer_test", + chat_history=None, + operation=None, + ) + + def test_add_memories_basic(self, user_id="test_user_add", mem_cube_id="test_cube_add"): + """Basic add_memories test""" + print("=" * 60) + print("Starting basic add_memories test") + print("=" * 60) + + try: + # Create test request with default messages + add_req = self.create_test_add_request(user_id=user_id, mem_cube_id=mem_cube_id) + + print("Test request created:") + print(f" User ID: {add_req.user_id}") + print(f" Mem Cube ID: {add_req.mem_cube_id}") + print(f" Messages: {add_req.messages}") + print(f" Session ID: {add_req.session_id}") + + # Call add_memories function + print("\nCalling add_memories function...") + result = self.add_memories(add_req) + + print(f"Add result: {result}") + print("Basic add_memories test completed successfully") + return result + + except Exception as e: + print(f"Basic add_memories test failed: {e}") + import traceback + + traceback.print_exc() + return None + + def test_search_memories_basic(self, query: str, mode: str, topk: int): + """Basic search_memories test""" + print("=" * 60) + print("Starting basic search_memories test") + print("=" * 60) + + try: + # Create test request + search_req = self.create_test_search_request( + query=query, + user_id="test_user_id", + mem_cube_id="test_mem_cube_id", + mode=mode, + top_k=topk, + ) + + print("Test request parameters:") + print(f" - query: {search_req.query}") + print(f" - user_id: {search_req.user_id}") + print(f" - mem_cube_id: {search_req.mem_cube_id}") + print(f" - mode: {search_req.mode}") + print(f" - top_k: {search_req.top_k}") + print(f" - internet_search: {search_req.internet_search}") + print(f" - moscube: {search_req.moscube}") + print() + + # Call search_memories function + print("Calling search_memories function...") + result = self.search_memories(search_req) + + print("✅ Function call successful!") + print(f"Return result type: {type(result)}") + print(f"Return result: {result}") + + # Analyze return result + if hasattr(result, "message"): + print(f"Message: {result.message}") + if hasattr(result, "data"): + print(f"Data type: {type(result.data)}") + if result.data and isinstance(result.data, dict): + for key, value in result.data.items(): + print(f" {key}: {len(value) if isinstance(value, list) else value}") + + return result + + except Exception as e: + print(f"❌ Test failed: {e}") + import traceback + + print("Detailed error information:") + traceback.print_exc() + return None + + def run_all_tests(self): + """Run all available tests""" + print("🚀 Starting comprehensive test suite") + print("=" * 80) + + # Test add_memories functions (more likely to have dependency issues) + print("\n\n📝 Testing ADD_MEMORIES functions:") + try: + print("\n" + "-" * 40) + self.test_add_memories_basic() + print("✅ Basic add memories test completed") + except Exception as e: + print(f"❌ Basic add memories test failed: {e}") + + # Test search_memories functions first (less likely to fail) + print("\n🔍 Testing SEARCH_MEMORIES functions:") + try: + self.test_search_memories_basic( + query="What are some good places to celebrate New Year's Eve in Shanghai?", + mode="fast", + topk=3, + ) + print("✅ Search memories test completed successfully") + except Exception as e: + print(f"❌ Search memories test failed: {e}") + + print("\n" + "=" * 80) + print("✅ All tests completed!") + + +# Example usage +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="API Analyzer for Memory Scheduler") + parser.add_argument( + "--mode", + choices=["direct", "api"], + default="direct", + help="Test mode: 'direct' for direct function testing, 'api' for API testing (default: direct)", + ) + + args = parser.parse_args() + + if args.mode == "direct": + # Direct test mode for search_memories and add_memories functions + print("Using direct test mode") + try: + direct_analyzer = DirectSearchMemoriesAnalyzer() + direct_analyzer.run_all_tests() + except Exception as e: + print(f"Direct test mode failed: {e}") + import traceback + + traceback.print_exc() + else: + # Original API test mode + print("Using API test mode") + analyzer = APIAnalyzerForScheduler() + + # Test add operation + messages = [ + {"role": "user", "content": "Where should I go for New Year's Eve in Shanghai?"}, + { + "role": "assistant", + "content": "You could head to the Bund for the countdown, attend a rooftop party, or enjoy the fireworks at Disneyland Shanghai.", + }, + ] + + add_result = analyzer.add( + messages=messages, user_id="test_user_id", mem_cube_id="test_mem_cube_id" + ) + print("Add result:", add_result) + + # Test search operation + search_result = analyzer.search( + user_id="test_user_id", + mem_cube_id="test_mem_cube_id", + query="What are some good places to celebrate New Year's Eve in Shanghai?", + top=50, + ) + print("Search result:", search_result) diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py index 7cd085ada..ace67eff6 100644 --- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py @@ -485,18 +485,20 @@ def chat(self, query: str, user_id: str | None = None) -> str: past_key_values = None if self.config.enable_activation_memory: - assert self.config.chat_model.backend == "huggingface", ( - "Activation memory only used for huggingface backend." - ) - # TODO this only one cubes - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube_id not in user_cube_ids: - continue - if mem_cube.act_mem: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) + if self.config.chat_model.backend != "huggingface": + logger.error( + "Activation memory only used for huggingface backend. Skipping activation memory." + ) + else: + # TODO this only one cubes + for mem_cube_id, mem_cube in self.mem_cubes.items(): + if mem_cube_id not in user_cube_ids: + continue + if mem_cube.act_mem: + kv_cache = next(iter(mem_cube.act_mem.get_all()), None) + past_key_values = ( + kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None + ) break # Generate response response = self.chat_llm.generate(current_messages, past_key_values=past_key_values) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 1e8b042b1..e475ea225 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -22,8 +22,12 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, + DEFAULT_CONTEXT_WINDOW_SIZE, + DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, DEFAULT_STARTUP_MODE, DEFAULT_THREAD_POOL_MAX_WORKERS, + DEFAULT_TOP_K, + DEFAULT_USE_REDIS_QUEUE, STARTUP_BY_PROCESS, MemCubeID, TreeTextMemory_SEARCH_METHOD, @@ -34,6 +38,7 @@ ScheduleMessageItem, ) from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, ) @@ -57,11 +62,13 @@ def __init__(self, config: BaseSchedulerConfig): self.config = config # hyper-parameters - self.top_k = self.config.get("top_k", 10) - self.context_window_size = self.config.get("context_window_size", 5) + self.top_k = self.config.get("top_k", DEFAULT_TOP_K) + self.context_window_size = self.config.get( + "context_window_size", DEFAULT_CONTEXT_WINDOW_SIZE + ) self.enable_activation_memory = self.config.get("enable_activation_memory", False) self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH) - self.search_method = TreeTextMemory_SEARCH_METHOD + self.search_method = self.config.get("search_method", TreeTextMemory_SEARCH_METHOD) self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", True) self.thread_pool_max_workers = self.config.get( "thread_pool_max_workers", DEFAULT_THREAD_POOL_MAX_WORKERS @@ -86,13 +93,22 @@ def __init__(self, config: BaseSchedulerConfig): # optional configs self.disable_handlers: list | None = self.config.get("disable_handlers", None) - # internal message queue + # message queue configuration + self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE) self.max_internal_message_queue_size = self.config.get( - "max_internal_message_queue_size", 10000 - ) - self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( - maxsize=self.max_internal_message_queue_size + "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE ) + + # Initialize message queue based on configuration + if self.use_redis_queue: + self.memos_message_queue = None # Will use Redis instead + # Initialize Redis if using Redis queue with auto-initialization + self.auto_initialize_redis() + else: + self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( + maxsize=self.max_internal_message_queue_size + ) + self.max_web_log_queue_size = self.config.get("max_web_log_queue_size", 50) self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( maxsize=self.max_web_log_queue_size @@ -390,7 +406,7 @@ def update_activation_memory( cache_item = act_mem.extract(new_text_memory) cache_item.records.text_memories = new_text_memories - cache_item.records.timestamp = datetime.utcnow() + cache_item.records.timestamp = get_utc_now() act_mem.add([cache_item]) act_mem.dump(self.act_mem_dump_path) @@ -471,7 +487,7 @@ def update_activation_memory_periodically( mem_cube=mem_cube, ) - self.monitor.last_activation_mem_update_time = datetime.utcnow() + self.monitor.last_activation_mem_update_time = get_utc_now() logger.debug( f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}" @@ -480,14 +496,14 @@ def update_activation_memory_periodically( else: logger.info( f"Skipping update - {interval_seconds} second interval not yet reached. " - f"Last update time is {self.monitor.last_activation_mem_update_time} and now is" - f"{datetime.utcnow()}" + f"Last update time is {self.monitor.last_activation_mem_update_time} and now is " + f"{get_utc_now()}" ) except Exception as e: logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) - def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): - """Submit multiple messages to the message queue.""" + async def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + """Submit messages to the message queue (either local queue or Redis).""" if isinstance(messages, ScheduleMessageItem): messages = [messages] # transform single message to list @@ -497,13 +513,20 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt logger.error(error_msg) raise TypeError(error_msg) - # Check if this handler is disabled if self.disable_handlers and message.label in self.disable_handlers: logger.info(f"Skipping disabled handler: {message.label} - {message.content}") continue - self.memos_message_queue.put(message) - logger.info(f"Submitted message: {message.label} - {message.content}") + if self.use_redis_queue: + # Use Redis stream for message queue + await self.redis_add_message_stream(message.to_dict()) + logger.info(f"Submitted message to Redis: {message.label} - {message.content}") + else: + # Use local queue + self.memos_message_queue.put(message) + logger.info( + f"Submitted message to local queue: {message.label} - {message.content}" + ) def _submit_web_logs( self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem] @@ -556,36 +579,64 @@ def _message_consumer(self) -> None: Continuously checks the queue for messages and dispatches them. Runs in a dedicated thread to process messages at regular intervals. + For Redis queue, this method starts the Redis listener. """ - while self._running: # Use a running flag for graceful shutdown - try: - # Get all available messages at once (thread-safe approach) - messages = [] - while True: - try: - # Use get_nowait() directly without empty() check to avoid race conditions - message = self.memos_message_queue.get_nowait() - messages.append(message) - except queue.Empty: - # No more messages available - break - - if messages: - try: - self.dispatcher.dispatch(messages) - except Exception as e: - logger.error(f"Error dispatching messages: {e!s}") - finally: - # Mark all messages as processed - for _ in messages: - self.memos_message_queue.task_done() - - # Sleep briefly to prevent busy waiting - time.sleep(self._consume_interval) # Adjust interval as needed - - except Exception as e: - logger.error(f"Unexpected error in message consumer: {e!s}") - time.sleep(self._consume_interval) # Prevent tight error loops + if self.use_redis_queue: + # For Redis queue, start the Redis listener + def redis_message_handler(message_data): + """Handler for Redis messages""" + try: + # Redis message data needs to be decoded from bytes to string + decoded_data = {} + for key, value in message_data.items(): + if isinstance(key, bytes): + key = key.decode("utf-8") + if isinstance(value, bytes): + value = value.decode("utf-8") + decoded_data[key] = value + + message = ScheduleMessageItem.from_dict(decoded_data) + self.dispatcher.dispatch([message]) + except Exception as e: + logger.error(f"Error processing Redis message: {e}") + logger.error(f"Message data: {message_data}") + + self.redis_start_listening(handler=redis_message_handler) + + # Keep the thread alive while Redis listener is running + while self._running: + time.sleep(self._consume_interval) + else: + # Original local queue logic + while self._running: # Use a running flag for graceful shutdown + try: + # Get all available messages at once (thread-safe approach) + messages = [] + while True: + try: + # Use get_nowait() directly without empty() check to avoid race conditions + message = self.memos_message_queue.get_nowait() + messages.append(message) + except queue.Empty: + # No more messages available + break + + if messages: + try: + self.dispatcher.dispatch(messages) + except Exception as e: + logger.error(f"Error dispatching messages: {e!s}") + finally: + # Mark all messages as processed + for _ in messages: + self.memos_message_queue.task_done() + + # Sleep briefly to prevent busy waiting + time.sleep(self._consume_interval) # Adjust interval as needed + + except Exception as e: + logger.error(f"Unexpected error in message consumer: {e!s}") + time.sleep(self._consume_interval) # Prevent tight error loops def start(self) -> None: """ @@ -722,14 +773,77 @@ def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: return self.dispatcher.unregister_handlers(labels) + def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]: + """ + Get currently running tasks, optionally filtered by a custom function. + + This method delegates to the dispatcher's get_running_tasks method. + + Args: + filter_func: Optional function to filter tasks. Should accept a RunningTaskItem + and return True if the task should be included in results. + + Returns: + dict[str, dict]: Dictionary mapping task IDs to task information dictionaries. + Each task dict contains: item_id, user_id, mem_cube_id, task_info, + task_name, start_time, end_time, status, result, error_message, messages + + Examples: + # Get all running tasks + all_tasks = scheduler.get_running_tasks() + + # Get tasks for specific user + user_tasks = scheduler.get_running_tasks( + filter_func=lambda task: task.user_id == "user123" + ) + + # Get tasks with specific status + active_tasks = scheduler.get_running_tasks( + filter_func=lambda task: task.status == "running" + ) + """ + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, returning empty tasks dict") + return {} + + running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func) + + # Convert RunningTaskItem objects to dictionaries for easier consumption + result = {} + for task_id, task_item in running_tasks.items(): + result[task_id] = { + "item_id": task_item.item_id, + "user_id": task_item.user_id, + "mem_cube_id": task_item.mem_cube_id, + "task_info": task_item.task_info, + "task_name": task_item.task_name, + "start_time": task_item.start_time, + "end_time": task_item.end_time, + "status": task_item.status, + "result": task_item.result, + "error_message": task_item.error_message, + "messages": task_item.messages, + } + + return result + def _cleanup_queues(self) -> None: """Ensure all queues are emptied and marked as closed.""" - try: - while not self.memos_message_queue.empty(): - self.memos_message_queue.get_nowait() - self.memos_message_queue.task_done() - except queue.Empty: - pass + if self.use_redis_queue: + # For Redis queue, stop the listener and close connection + try: + self.redis_stop_listening() + self.redis_close() + except Exception as e: + logger.error(f"Error cleaning up Redis connection: {e}") + else: + # Original local queue cleanup + try: + while not self.memos_message_queue.empty(): + self.memos_message_queue.get_nowait() + self.memos_message_queue.task_done() + except queue.Empty: + pass try: while not self._web_log_message_queue.empty(): diff --git a/src/memos/mem_scheduler/general_modules/api_misc.py b/src/memos/mem_scheduler/general_modules/api_misc.py new file mode 100644 index 000000000..6139a895a --- /dev/null +++ b/src/memos/mem_scheduler/general_modules/api_misc.py @@ -0,0 +1,115 @@ +import threading + +from typing import Any + +from memos.log import get_logger +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager + + +logger = get_logger(__name__) + + +class SchedulerAPIModule(BaseSchedulerModule): + def __init__(self): + super().__init__() + + self.search_history_managers: dict[str, RedisDBManager] = {} + + def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> RedisDBManager: + """Get or create a Redis manager for search history.""" + key = f"search_history:{user_id}:{mem_cube_id}" + if key not in self.search_history_managers: + self.search_history_managers[key] = RedisDBManager( + user_id=user_id, mem_cube_id=mem_cube_id + ) + return self.search_history_managers[key] + + def sync_search_data( + self, user_id: str, mem_cube_id: str, query: str, formatted_memories: Any + ) -> None: + """ + Sync search data to Redis, maintaining a list of size 5. + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + query: Search query string + formatted_memories: Formatted search results + """ + try: + # Get the search history manager + manager = self.get_search_history_manager(user_id, mem_cube_id) + + # Create search data entry + search_entry = { + "query": query, + "formatted_memories": formatted_memories, + "timestamp": threading.current_thread().ident, # Use thread ID as simple timestamp + } + + # Load existing search history + existing_data = manager.load_from_db() + + if existing_data is None: + search_history = SimpleListManager([]) + else: + # If existing data is a SimpleListManager, use it; otherwise create new one + if isinstance(existing_data, SimpleListManager): + search_history = existing_data + else: + search_history = SimpleListManager([]) + + # Add new entry and keep only latest 5 + search_history.add_item(str(search_entry)) + if len(search_history) > 5: + # Keep only the latest 5 items + search_history.items = search_history.items[-5:] + + # Save back to Redis + manager.save_to_db(search_history) + + logger.info( + f"Synced search data for user {user_id}, mem_cube {mem_cube_id}. History size: {len(search_history)}" + ) + + except Exception as e: + logger.error(f"Failed to sync search data: {e}", exc_info=True) + + def get_pre_fine_memories(self, user_id: str, mem_cube_id: str) -> list: + """ + Get the most recent pre-computed fine memories from search history. + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + + Returns: + List of formatted memories from the most recent search, or empty list if none found + """ + try: + manager = self.get_search_history_manager(user_id, mem_cube_id) + search_history_key = "search_history_list" + existing_data = manager.load_from_db(search_history_key) + + if existing_data is None: + return [] + + search_history = ( + existing_data.obj_instance + if hasattr(existing_data, "obj_instance") + else existing_data + ) + + if not search_history or len(search_history) == 0: + return [] + + # Return the formatted_memories from the most recent search + latest_entry = search_history[-1] + return ( + latest_entry.get("formatted_memories", []) if isinstance(latest_entry, dict) else [] + ) + + except Exception as e: + logger.error(f"Failed to get pre-computed fine memories: {e}", exc_info=True) + return [] diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index 4584beb96..c357e31b5 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -101,15 +101,43 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): return wrapped_handler - def get_running_tasks(self) -> dict[str, RunningTaskItem]: + def get_running_tasks( + self, filter_func: Callable[[RunningTaskItem], bool] | None = None + ) -> dict[str, RunningTaskItem]: """ - Get a copy of currently running tasks. + Get a copy of currently running tasks, optionally filtered by a custom function. + + Args: + filter_func: Optional function that takes a RunningTaskItem and returns True if it should be included. + Common filters can be created using helper methods like filter_by_user_id, filter_by_task_name, etc. Returns: Dictionary of running tasks keyed by task ID + + Examples: + # Get all running tasks + all_tasks = dispatcher.get_running_tasks() + + # Get tasks for specific user + user_tasks = dispatcher.get_running_tasks(lambda task: task.user_id == "user123") + + # Get tasks for specific task name + handler_tasks = dispatcher.get_running_tasks(lambda task: task.task_name == "test_handler") + + # Get tasks with multiple conditions + filtered_tasks = dispatcher.get_running_tasks( + lambda task: task.user_id == "user123" and task.status == "running" + ) """ with self._task_lock: - return self._running_tasks.copy() + if filter_func is None: + return self._running_tasks.copy() + + return { + task_id: task_item + for task_id, task_item in self._running_tasks.items() + if filter_func(task_item) + } def get_running_task_count(self) -> int: """ diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index f47cc0cc5..31bb9b3da 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -148,7 +148,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) + grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) @@ -170,7 +170,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) + grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) self.validate_schedule_messages(messages=messages, label=ANSWER_LABEL) diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 13fe07354..0ebb7da4f 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -1,7 +1,6 @@ import threading import time -from datetime import datetime from time import perf_counter from memos.configs.mem_scheduler import BaseSchedulerConfig @@ -14,6 +13,7 @@ DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES, DEFAULT_STUCK_THREAD_TOLERANCE, ) +from memos.mem_scheduler.utils.db_utils import get_utc_now logger = get_logger(__name__) @@ -84,7 +84,7 @@ def register_pool( "max_workers": max_workers, "restart": restart_on_failure, "failure_count": 0, - "last_active": datetime.utcnow(), + "last_active": get_utc_now(), "healthy": True, } logger.info(f"Registered thread pool '{name}' for monitoring") @@ -122,54 +122,6 @@ def _monitor_loop(self) -> None: logger.debug("Monitor loop exiting") - def start(self) -> bool: - """ - Start the monitoring thread. - - Returns: - bool: True if monitor started successfully, False if already running - """ - if self._running: - logger.warning("Dispatcher Monitor is already running") - return False - - self._running = True - self._monitor_thread = threading.Thread( - target=self._monitor_loop, name="threadpool_monitor", daemon=True - ) - self._monitor_thread.start() - logger.info("Dispatcher Monitor monitor started") - return True - - def stop(self) -> None: - """ - Stop the monitoring thread and clean up all managed thread pools. - Ensures proper shutdown of all monitored executors. - """ - if not self._running: - return - - # Stop the monitoring loop - self._running = False - if self._monitor_thread and self._monitor_thread.is_alive(): - self._monitor_thread.join(timeout=5) - - # Shutdown all registered pools - with self._pool_lock: - for name, pool_info in self._pools.items(): - executor = pool_info["executor"] - if not executor._shutdown: # pylint: disable=protected-access - try: - logger.info(f"Shutting down thread pool '{name}'") - executor.shutdown(wait=True, cancel_futures=True) - logger.info(f"Successfully shut down thread pool '{name}'") - except Exception as e: - logger.error(f"Error shutting down pool '{name}': {e!s}", exc_info=True) - - # Clear the pool registry - self._pools.clear() - logger.info("Thread pool monitor and all pools stopped") - def _check_pools_health(self) -> None: """Check health of all registered thread pools.""" for name, pool_info in list(self._pools.items()): @@ -182,7 +134,6 @@ def _check_pools_health(self) -> None: if is_healthy: pool_info["failure_count"] = 0 pool_info["healthy"] = True - return else: pool_info["failure_count"] += 1 pool_info["healthy"] = False @@ -269,27 +220,24 @@ def _check_pool_health( f"Found {len(stuck_tasks)} stuck tasks (tolerance: {effective_tolerance})", ) - # Check thread activity - active_threads = sum( - 1 - for t in threading.enumerate() - if t.name.startswith(executor._thread_name_prefix) # pylint: disable=protected-access - ) - - # Check if no threads are active but should be - if active_threads == 0 and pool_info["max_workers"] > 0: - return False, "No active worker threads" - + # Only check for stuck threads, not inactive threads # Check if threads are stuck (no activity for specified intervals) - time_delta = (datetime.utcnow() - pool_info["last_active"]).total_seconds() + time_delta = (get_utc_now() - pool_info["last_active"]).total_seconds() if time_delta >= self.check_interval * stuck_max_interval: return False, f"No recent activity for {time_delta:.1f} seconds" # If we got here, pool appears healthy - pool_info["last_active"] = datetime.utcnow() + pool_info["last_active"] = get_utc_now() # Log health status with comprehensive information if self.dispatcher: + # Check thread activity + active_threads = sum( + 1 + for t in threading.enumerate() + if t.name.startswith(executor._thread_name_prefix) # pylint: disable=protected-access + ) + task_count = self.dispatcher.get_running_task_count() max_workers = pool_info.get("max_workers", 0) stuck_count = len(stuck_tasks) @@ -338,7 +286,7 @@ def _restart_pool(self, name: str, pool_info: dict) -> None: pool_info["executor"] = new_executor pool_info["failure_count"] = 0 pool_info["healthy"] = True - pool_info["last_active"] = datetime.utcnow() + pool_info["last_active"] = get_utc_now() elapsed_time = perf_counter() - start_time if elapsed_time > 1: @@ -379,3 +327,52 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit point.""" self.stop() + + def start(self) -> bool: + """ + Start the monitoring thread. + + Returns: + bool: True if monitor started successfully, False if already running + """ + if self._running: + logger.warning("Dispatcher Monitor is already running") + return False + + self._running = True + self._monitor_thread = threading.Thread( + target=self._monitor_loop, name="threadpool_monitor", daemon=True + ) + self._monitor_thread.start() + logger.info("Dispatcher Monitor monitor started") + return True + + def stop(self) -> None: + """ + Stop the monitoring thread and clean up all managed thread pools. + Ensures proper shutdown of all monitored executors. + """ + if not self._running: + return + + # Stop the monitoring loop + self._running = False + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(timeout=5) + + # Shutdown all registered pools + with self._pool_lock: + for name, pool_info in self._pools.items(): + executor = pool_info["executor"] + if not executor._shutdown: # pylint: disable=protected-access + try: + logger.info(f"Shutting down thread pool '{name}'") + executor.shutdown(wait=True, cancel_futures=True) + logger.info(f"Successfully shut down thread pool '{name}'") + except Exception as e: + logger.error(f"Error shutting down pool '{name}': {e!s}", exc_info=True) + + # Clear the pool registry + self._pools.clear() + + logger.info("Thread pool monitor and all pools stopped") diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index 87d996549..22fb78445 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -28,6 +28,7 @@ MemoryMonitorManager, QueryMonitorQueue, ) +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import extract_json_dict from memos.memories.textual.tree import TreeTextMemory @@ -64,7 +65,7 @@ def __init__( "No database engine provided; falling back to default temporary SQLite engine. " "This is intended for testing only. Consider providing a configured engine for production use." ) - self.db_engine = BaseDBManager.create_default_engine() + self.db_engine = BaseDBManager.create_default_sqlite_engine() self.query_monitors: dict[UserID, dict[MemCubeID, DBManagerForQueryMonitorQueue]] = {} self.working_memory_monitors: dict[ @@ -256,7 +257,7 @@ def update_activation_memory_monitors( activation_db_manager.sync_with_orm(size_limit=self.activation_mem_monitor_capacity) def timed_trigger(self, last_time: datetime, interval_seconds: float) -> bool: - now = datetime.utcnow() + now = get_utc_now() elapsed = (now - last_time).total_seconds() if elapsed >= interval_seconds: return True diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index dd08954a9..fb5f4ce7c 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -1,14 +1,21 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from memos.api.product_models import APISearchRequest from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.general_modules.api_misc import SchedulerAPIModule from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( + API_MIX_SEARCH_LABEL, + QUERY_LABEL, MemCubeID, + SearchMode, UserID, ) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.types import UserContext if TYPE_CHECKING: @@ -19,10 +26,116 @@ class OptimizedScheduler(GeneralScheduler): - """Optimized scheduler with improved working memory management""" + """Optimized scheduler with improved working memory management and support for api""" def __init__(self, config: GeneralSchedulerConfig): super().__init__(config) + self.api_module = SchedulerAPIModule() + self.message_consumers = { + API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, + } + + def _format_memory_item(self, memory_data: Any) -> dict[str, Any]: + """Format a single memory item for API response.""" + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + def fine_search_memories( + self, + search_req: APISearchRequest, + user_context: UserContext, + mem_cube: GeneralMemCube, + ): + """Fine search memories function copied from server_router to avoid circular import""" + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + search_results = mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FINE, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + formatted_memories = [self._format_memory_item(data) for data in search_results] + + return formatted_memories + + def update_search_memories_to_redis( + self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] + ): + mem_cube = messages[0].mem_cube + + # for status update + self._set_current_context_from_message(msg=messages[0]) + + # update query monitors + for msg in messages: + self.monitor.register_query_monitor_if_not_exists( + user_id=user_id, mem_cube_id=mem_cube_id + ) + + content_dict = msg.content + search_req = content_dict["search_req"] + user_context = content_dict["user_context"] + + formatted_memories = self.fine_search_memories( + search_req=search_req, user_context=user_context, mem_cube=mem_cube + ) + + # Sync search data to Redis + try: + self.api_module.sync_search_data( + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + query=search_req.query, + formatted_memories=formatted_memories, + ) + except Exception as e: + logger.error(f"Failed to sync search data: {e}") + + def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + """ + Process and handle query trigger messages from the queue. + + Args: + messages: List of query messages to process + """ + logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") + + # Process the query in a session turn + grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + + self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) + + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + messages = grouped_messages[user_id][mem_cube_id] + if len(messages) == 0: + return + self.update_search_memories_to_redis( + user_id=user_id, mem_cube_id=mem_cube_id, messages=messages + ) def replace_working_memory( self, diff --git a/src/memos/mem_scheduler/orm_modules/base_model.py b/src/memos/mem_scheduler/orm_modules/base_model.py index 9d75a12bd..cf3fc904c 100644 --- a/src/memos/mem_scheduler/orm_modules/base_model.py +++ b/src/memos/mem_scheduler/orm_modules/base_model.py @@ -10,13 +10,16 @@ from sqlalchemy import Boolean, Column, DateTime, String, Text, and_, create_engine from sqlalchemy.engine import Engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session, declarative_base, sessionmaker from memos.log import get_logger from memos.mem_user.user_manager import UserManager +class DatabaseError(Exception): + """Exception raised for database-related errors""" + + T = TypeVar("T") # The model type (MemoryMonitorManager, QueryMonitorManager, etc.) ORM = TypeVar("ORM") # The ORM model type @@ -561,7 +564,7 @@ def close(self): logger.error(f"Error during close operation: {e}") @staticmethod - def create_default_engine() -> Engine: + def create_default_sqlite_engine() -> Engine: """Create SQLAlchemy engine with default database path Returns: @@ -633,3 +636,211 @@ def create_mysql_db_path( else: db_path = f"mysql+pymysql://{username}@{host}:{port}/{database}?charset={charset}" return db_path + + @staticmethod + def load_mysql_engine_from_env(env_file_path: str | None = None) -> Engine | None: + """Load MySQL engine from environment variables + + Args: + env_file_path: Path to .env file (optional, defaults to loading from current environment) + + Returns: + SQLAlchemy Engine instance configured for MySQL + + Raises: + DatabaseError: If required environment variables are missing or connection fails + """ + # Load environment variables from file if provided + if env_file_path: + if os.path.exists(env_file_path): + from dotenv import load_dotenv + + load_dotenv(env_file_path) + logger.info(f"Loaded environment variables from {env_file_path}") + else: + logger.warning( + f"Environment file not found: {env_file_path}, using current environment variables" + ) + else: + logger.info("Using current environment variables (no env_file_path provided)") + + # Get MySQL configuration from environment variables + mysql_host = os.getenv("MYSQL_HOST") + mysql_port_str = os.getenv("MYSQL_PORT") + mysql_username = os.getenv("MYSQL_USERNAME") + mysql_password = os.getenv("MYSQL_PASSWORD") + mysql_database = os.getenv("MYSQL_DATABASE") + mysql_charset = os.getenv("MYSQL_CHARSET") + + # Check required environment variables + required_vars = { + "MYSQL_HOST": mysql_host, + "MYSQL_USERNAME": mysql_username, + "MYSQL_PASSWORD": mysql_password, + "MYSQL_DATABASE": mysql_database, + } + + missing_vars = [var for var, value in required_vars.items() if not value] + if missing_vars: + error_msg = f"Missing required MySQL environment variables: {', '.join(missing_vars)}" + logger.error(error_msg) + return None + + # Parse port with validation + try: + mysql_port = int(mysql_port_str) if mysql_port_str else 3306 + except ValueError: + error_msg = f"Invalid MYSQL_PORT value: {mysql_port_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Set default charset if not provided + if not mysql_charset: + mysql_charset = "utf8mb4" + + # Create MySQL connection URL + db_url = BaseDBManager.create_mysql_db_path( + host=mysql_host, + port=mysql_port, + username=mysql_username, + password=mysql_password, + database=mysql_database, + charset=mysql_charset, + ) + + try: + # Create and test the engine + engine = create_engine(db_url, echo=False) + + # Test connection + with engine.connect() as conn: + from sqlalchemy import text + + conn.execute(text("SELECT 1")) + + logger.info( + f"Successfully created MySQL engine: {mysql_host}:{mysql_port}/{mysql_database}" + ) + return engine + + except Exception as e: + error_msg = f"Failed to create MySQL engine from environment variables: {e}" + logger.error(error_msg) + raise DatabaseError(error_msg) from e + + @staticmethod + def load_redis_engine_from_env(env_file_path: str | None = None) -> Any: + """Load Redis connection from environment variables + + Args: + env_file_path: Path to .env file (optional, defaults to loading from current environment) + + Returns: + Redis connection instance + + Raises: + DatabaseError: If required environment variables are missing or connection fails + """ + try: + import redis + except ImportError as e: + error_msg = "Redis package not installed. Install with: pip install redis" + logger.error(error_msg) + raise DatabaseError(error_msg) from e + + # Load environment variables from file if provided + if env_file_path: + if os.path.exists(env_file_path): + from dotenv import load_dotenv + + load_dotenv(env_file_path) + logger.info(f"Loaded environment variables from {env_file_path}") + else: + logger.warning( + f"Environment file not found: {env_file_path}, using current environment variables" + ) + else: + logger.info("Using current environment variables (no env_file_path provided)") + + # Get Redis configuration from environment variables + redis_host = os.getenv("REDIS_HOST") or os.getenv("MEMSCHEDULER_REDIS_HOST") + redis_port_str = os.getenv("REDIS_PORT") or os.getenv("MEMSCHEDULER_REDIS_PORT") + redis_db_str = os.getenv("REDIS_DB") or os.getenv("MEMSCHEDULER_REDIS_DB") + redis_password = os.getenv("REDIS_PASSWORD") or os.getenv("MEMSCHEDULER_REDIS_PASSWORD") + + # Check required environment variables + if not redis_host: + error_msg = ( + "Missing required Redis environment variable: REDIS_HOST or MEMSCHEDULER_REDIS_HOST" + ) + logger.error(error_msg) + return None + + # Parse port with validation + try: + redis_port = int(redis_port_str) if redis_port_str else 6379 + except ValueError: + error_msg = f"Invalid REDIS_PORT value: {redis_port_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Parse database with validation + try: + redis_db = int(redis_db_str) if redis_db_str else 0 + except ValueError: + error_msg = f"Invalid REDIS_DB value: {redis_db_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Optional timeout settings + socket_timeout = os.getenv( + "REDIS_SOCKET_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_TIMEOUT", None) + ) + socket_connect_timeout = os.getenv( + "REDIS_SOCKET_CONNECT_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_CONNECT_TIMEOUT", None) + ) + + try: + # Build Redis connection parameters + redis_kwargs = { + "host": redis_host, + "port": redis_port, + "db": redis_db, + "decode_responses": True, + } + + if redis_password: + redis_kwargs["password"] = redis_password + + if socket_timeout: + try: + redis_kwargs["socket_timeout"] = float(socket_timeout) + except ValueError: + logger.warning( + f"Invalid REDIS_SOCKET_TIMEOUT value: {socket_timeout}, ignoring" + ) + + if socket_connect_timeout: + try: + redis_kwargs["socket_connect_timeout"] = float(socket_connect_timeout) + except ValueError: + logger.warning( + f"Invalid REDIS_SOCKET_CONNECT_TIMEOUT value: {socket_connect_timeout}, ignoring" + ) + + # Create Redis connection + redis_client = redis.Redis(**redis_kwargs) + + # Test connection + if not redis_client.ping(): + raise ConnectionError("Redis ping failed") + + logger.info( + f"Successfully created Redis connection: {redis_host}:{redis_port}/{redis_db}" + ) + return redis_client + + except Exception as e: + error_msg = f"Failed to create Redis connection from environment variables: {e}" + logger.error(error_msg) + raise DatabaseError(error_msg) from e diff --git a/src/memos/mem_scheduler/orm_modules/redis_model.py b/src/memos/mem_scheduler/orm_modules/redis_model.py new file mode 100644 index 000000000..ccfe1b1c8 --- /dev/null +++ b/src/memos/mem_scheduler/orm_modules/redis_model.py @@ -0,0 +1,699 @@ +import json +import time + +from typing import Any, TypeVar + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import declarative_base + +from memos.log import get_logger +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager +from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorManager +from memos.mem_scheduler.utils.db_utils import get_utc_now + + +T = TypeVar("T") # The model type (MemoryMonitorManager, QueryMonitorManager, etc.) +ORM = TypeVar("ORM") # The ORM model type + +logger = get_logger(__name__) + +Base = declarative_base() + + +class SimpleListManager: + """Simple wrapper class for list[str] to work with RedisDBManager""" + + def __init__(self, items: list[str] | None = None): + self.items = items or [] + + def to_json(self) -> str: + """Serialize to JSON string""" + return json.dumps({"items": self.items}) + + @classmethod + def from_json(cls, json_str: str) -> "SimpleListManager": + """Deserialize from JSON string""" + data = json.loads(json_str) + return cls(items=data.get("items", [])) + + def add_item(self, item: str): + """Add an item to the list""" + self.items.append(item) + + def __len__(self): + return len(self.items) + + def __str__(self): + return f"SimpleListManager(items={self.items})" + + +class RedisLockableORM: + """Redis-based implementation of LockableORM interface + + This class provides Redis-based storage for lockable ORM objects, + mimicking the SQLAlchemy LockableORM interface but using Redis as the backend. + """ + + def __init__(self, redis_client, user_id: str, mem_cube_id: str): + self.redis_client = redis_client + self.user_id = user_id + self.mem_cube_id = mem_cube_id + self.serialized_data = None + self.lock_acquired = False + self.lock_expiry = None + self.version_control = "0" + + def _get_key_prefix(self) -> str: + """Generate Redis key prefix for this ORM instance""" + return f"lockable_orm:{self.user_id}:{self.mem_cube_id}" + + def _get_data_key(self) -> str: + """Get Redis key for serialized data""" + return f"{self._get_key_prefix()}:data" + + def _get_lock_key(self) -> str: + """Get Redis key for lock information""" + return f"{self._get_key_prefix()}:lock" + + def _get_version_key(self) -> str: + """Get Redis key for version control""" + return f"{self._get_key_prefix()}:version" + + def save(self): + """Save this ORM instance to Redis""" + try: + # Save serialized data + if self.serialized_data: + self.redis_client.set(self._get_data_key(), self.serialized_data) + + # Note: Lock information is now managed by acquire_lock/release_locks methods + # We don't save lock info here to avoid conflicts with atomic lock operations + + # Save version control + self.redis_client.set(self._get_version_key(), self.version_control) + + logger.debug(f"Saved RedisLockableORM to Redis: {self._get_key_prefix()}") + + except Exception as e: + logger.error(f"Failed to save RedisLockableORM to Redis: {e}") + raise + + def load(self): + """Load this ORM instance from Redis""" + try: + # Load serialized data + data = self.redis_client.get(self._get_data_key()) + if data: + self.serialized_data = data.decode() if isinstance(data, bytes) else data + else: + self.serialized_data = None + + # Note: Lock information is now managed by acquire_lock/release_locks methods + # We don't load lock info here to avoid conflicts with atomic lock operations + self.lock_acquired = False + self.lock_expiry = None + + # Load version control + version = self.redis_client.get(self._get_version_key()) + if version: + self.version_control = version.decode() if isinstance(version, bytes) else version + else: + self.version_control = "0" + + logger.debug(f"Loaded RedisLockableORM from Redis: {self._get_key_prefix()}") + # Return True if we found any data, False otherwise + return self.serialized_data is not None + + except Exception as e: + logger.error(f"Failed to load RedisLockableORM from Redis: {e}") + return False + + def delete(self): + """Delete this ORM instance from Redis""" + try: + keys_to_delete = [self._get_data_key(), self._get_lock_key(), self._get_version_key()] + self.redis_client.delete(*keys_to_delete) + logger.debug(f"Deleted RedisLockableORM from Redis: {self._get_key_prefix()}") + except Exception as e: + logger.error(f"Failed to delete RedisLockableORM from Redis: {e}") + raise + + +class RedisDBManager(BaseDBManager): + """Redis-based database manager for any serializable object + + This class handles persistence, synchronization, and locking + for any object that implements to_json/from_json methods using Redis as the backend storage. + """ + + def __init__( + self, + engine: Engine | None = None, + user_id: str | None = None, + mem_cube_id: str | None = None, + obj: Any | None = None, + lock_timeout: int = 10, + redis_client=None, + redis_config: dict | None = None, + ): + """Initialize the Redis database manager + + Args: + engine: SQLAlchemy engine (not used for Redis, kept for compatibility) + user_id: Unique identifier for the user + mem_cube_id: Unique identifier for the memory cube + obj: Optional object instance to manage (must have to_json/from_json methods) + lock_timeout: Timeout in seconds for lock acquisition + redis_client: Redis client instance (optional) + redis_config: Redis configuration dictionary (optional) + """ + # Initialize Redis client + self.redis_client = redis_client + self.redis_config = redis_config or {} + + if self.redis_client is None: + self._init_redis_client() + + # Initialize base attributes without calling parent's init_manager + self.user_id = user_id + self.mem_cube_id = mem_cube_id + self.obj = obj + self.obj_type = type(obj) if obj is not None else None # Store the actual object type + self.lock_timeout = lock_timeout + self.engine = engine # Keep for compatibility but not used + self.SessionLocal = None # Not used for Redis + self.last_version_control = None + + logger.info( + f"RedisDBManager initialized for user_id: {user_id}, mem_cube_id: {mem_cube_id}" + ) + logger.info(f"Redis client: {type(self.redis_client).__name__}") + + # Test Redis connection + try: + self.redis_client.ping() + logger.info("Redis connection successful") + except Exception as e: + logger.warning(f"Redis ping failed: {e}") + # Don't raise error here as it might be a mock client in tests + + def _init_redis_client(self): + """Initialize Redis client from config or environment""" + try: + import redis + + # Try to get Redis client from environment first + if not self.redis_client: + self.redis_client = self.load_redis_engine_from_env() + + # If still no client, try from config + if not self.redis_client and self.redis_config: + redis_kwargs = { + "host": self.redis_config.get("host", "localhost"), + "port": self.redis_config.get("port", 6379), + "db": self.redis_config.get("db", 0), + "decode_responses": True, + } + + if self.redis_config.get("password"): + redis_kwargs["password"] = self.redis_config["password"] + + self.redis_client = redis.Redis(**redis_kwargs) + + # Final fallback to localhost + if not self.redis_client: + logger.warning("No Redis configuration found, using localhost defaults") + self.redis_client = redis.Redis( + host="localhost", port=6379, db=0, decode_responses=True + ) + + # Test connection + if not self.redis_client.ping(): + raise ConnectionError("Redis ping failed") + + logger.info("Redis client initialized successfully") + + except ImportError: + logger.error("Redis package not installed. Install with: pip install redis") + raise + except Exception as e: + logger.error(f"Failed to initialize Redis client: {e}") + raise + + @property + def orm_class(self) -> type[RedisLockableORM]: + """Return the Redis-based ORM class""" + return RedisLockableORM + + @property + def obj_class(self) -> type: + """Return the actual object class""" + return self.obj_type if self.obj_type is not None else MemoryMonitorManager + + def merge_items( + self, + orm_instance: RedisLockableORM, + obj_instance: Any, + size_limit: int, + ): + """Merge items from Redis with current object instance + + This method provides a generic way to merge data from Redis with the current + object instance. It handles different object types and their specific merge logic. + + Args: + orm_instance: Redis ORM instance from database + obj_instance: Current object instance (any type with to_json/from_json methods) + size_limit: Maximum number of items to keep after merge + """ + logger.debug(f"Starting merge_items with size_limit={size_limit}") + + try: + if not orm_instance.serialized_data: + logger.warning("No serialized data in Redis ORM instance to merge") + return obj_instance + + # Deserialize the database object using the actual object type + if self.obj_type is not None: + db_obj = self.obj_type.from_json(orm_instance.serialized_data) + else: + db_obj = MemoryMonitorManager.from_json(orm_instance.serialized_data) + + # Handle different object types with specific merge logic based on type + obj_type = type(obj_instance) + if obj_type.__name__ == "MemoryMonitorManager" or hasattr(obj_instance, "memories"): + # MemoryMonitorManager-like objects + return self._merge_memory_monitor_items(obj_instance, db_obj, size_limit) + elif obj_type.__name__ == "SimpleListManager" or hasattr(obj_instance, "items"): + # SimpleListManager-like objects + return self._merge_list_items(obj_instance, db_obj, size_limit) + else: + # Generic objects - just return the current instance + logger.info( + f"No specific merge logic for object type {obj_type.__name__}, returning current instance" + ) + return obj_instance + + except Exception as e: + logger.error(f"Failed to deserialize database instance: {e}", exc_info=True) + logger.warning("Skipping merge due to deserialization error, using current object only") + return obj_instance + + def _merge_memory_monitor_items(self, obj_instance, db_obj, size_limit: int): + """Merge MemoryMonitorManager items""" + # Create a mapping of existing memories by their mapping key + current_memories_dict = obj_instance.memories_mapping_dict + + # Add memories from database that don't exist in current object + for db_memory in db_obj.memories: + if db_memory.tree_memory_item_mapping_key not in current_memories_dict: + obj_instance.memories.append(db_memory) + + # Apply size limit if specified + if size_limit and len(obj_instance.memories) > size_limit: + # Sort by recording_count and keep the most recorded ones + obj_instance.memories.sort(key=lambda x: x.recording_count, reverse=True) + obj_instance.memories = obj_instance.memories[:size_limit] + logger.info( + f"Applied size limit {size_limit}, kept {len(obj_instance.memories)} memories" + ) + + logger.info(f"Merged {len(obj_instance.memories)} memory items") + return obj_instance + + def _merge_list_items(self, obj_instance, db_obj, size_limit: int): + """Merge SimpleListManager-like items""" + merged_items = [] + seen_items = set() + + # First, add all items from current object (higher priority) + for item in obj_instance.items: + if item not in seen_items: + merged_items.append(item) + seen_items.add(item) + + # Then, add items from database that aren't in current object + for item in db_obj.items: + if item not in seen_items: + merged_items.append(item) + seen_items.add(item) + + # Apply size limit if specified (keep most recent items) + if size_limit is not None and size_limit > 0 and len(merged_items) > size_limit: + merged_items = merged_items[:size_limit] + logger.debug(f"Applied size limit of {size_limit}, kept {len(merged_items)} items") + + # Update the object with merged items + obj_instance.items = merged_items + + logger.info(f"Merged {len(merged_items)} list items (size_limit: {size_limit})") + return obj_instance + + def _get_redis_orm_instance(self) -> RedisLockableORM: + """Get or create a Redis ORM instance""" + orm_instance = RedisLockableORM( + redis_client=self.redis_client, user_id=self.user_id, mem_cube_id=self.mem_cube_id + ) + return orm_instance + + def _get_key_prefix(self) -> str: + """Generate Redis key prefix for this ORM instance""" + return f"lockable_orm:{self.user_id}:{self.mem_cube_id}" + + def acquire_lock(self, block: bool = True, **kwargs) -> bool: + """Acquire a distributed lock using Redis with atomic operations + + Args: + block: Whether to block until lock is acquired + **kwargs: Additional filter criteria (ignored for Redis) + + Returns: + True if lock was acquired, False otherwise + """ + try: + lock_key = f"{self._get_key_prefix()}:lock" + now = get_utc_now() + + # Use Redis SET with NX (only if not exists) and EX (expiry) for atomic lock acquisition + lock_value = f"{self.user_id}:{self.mem_cube_id}:{now.timestamp()}" + + while True: + # Try to acquire lock atomically + result = self.redis_client.set( + lock_key, + lock_value, + nx=True, # Only set if key doesn't exist + ex=self.lock_timeout, # Set expiry in seconds + ) + + if result: + # Successfully acquired lock + logger.info(f"Redis lock acquired for {self.user_id}/{self.mem_cube_id}") + return True + + if not block: + logger.warning( + f"Redis lock is held for {self.user_id}/{self.mem_cube_id}, cannot acquire" + ) + return False + + # Wait a bit before retrying + logger.info( + f"Waiting for Redis lock to be released for {self.user_id}/{self.mem_cube_id}" + ) + time.sleep(0.1) + + except Exception as e: + logger.error(f"Failed to acquire Redis lock for {self.user_id}/{self.mem_cube_id}: {e}") + return False + + def release_locks(self, user_id: str, mem_cube_id: str, **kwargs): + """Release Redis locks for the specified user and memory cube + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + **kwargs: Additional filter criteria (ignored for Redis) + """ + try: + lock_key = f"lockable_orm:{user_id}:{mem_cube_id}:lock" + + # Delete the lock key to release the lock + result = self.redis_client.delete(lock_key) + + if result: + logger.info(f"Redis lock released for {user_id}/{mem_cube_id}") + else: + logger.warning(f"No Redis lock found to release for {user_id}/{mem_cube_id}") + + except Exception as e: + logger.error(f"Failed to release Redis lock for {user_id}/{mem_cube_id}: {e}") + + def sync_with_orm(self, size_limit: int | None = None) -> None: + """Synchronize data between Redis and the business object + + Args: + size_limit: Optional maximum number of items to keep after synchronization + """ + logger.info( + f"Starting Redis sync_with_orm for {self.user_id}/{self.mem_cube_id} with size_limit={size_limit}" + ) + + try: + # Acquire lock before any operations + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire Redis lock for synchronization") + return + + # Get existing data from Redis + orm_instance = self._get_redis_orm_instance() + exists = orm_instance.load() + + # If no existing record, create a new one + if not exists: + if self.obj is None: + logger.warning("No object to synchronize and no existing Redis record") + return + + orm_instance.serialized_data = self.obj.to_json() + orm_instance.version_control = "0" + orm_instance.save() + + logger.info("No existing Redis record found. Created a new one.") + self.last_version_control = "0" + return + + # Check version control and merge data + if self.obj is not None: + current_redis_tag = orm_instance.version_control + new_tag = self._increment_version_control(current_redis_tag) + + # Check if this is the first sync or if we need to merge + if self.last_version_control is None: + logger.info("First Redis sync, merging data from Redis") + # Always merge on first sync to load data from Redis + try: + self.merge_items( + orm_instance=orm_instance, obj_instance=self.obj, size_limit=size_limit + ) + except Exception as merge_error: + logger.error( + f"Error during Redis merge_items: {merge_error}", exc_info=True + ) + logger.warning("Continuing with current object data without merge") + elif current_redis_tag == self.last_version_control: + logger.info( + f"Redis version control unchanged ({current_redis_tag}), directly update" + ) + else: + logger.info( + f"Redis version control changed from {self.last_version_control} to {current_redis_tag}, merging data" + ) + try: + self.merge_items( + orm_instance=orm_instance, obj_instance=self.obj, size_limit=size_limit + ) + except Exception as merge_error: + logger.error( + f"Error during Redis merge_items: {merge_error}", exc_info=True + ) + logger.warning("Continuing with current object data without merge") + + # Write merged data back to Redis + orm_instance.serialized_data = self.obj.to_json() + orm_instance.version_control = new_tag + orm_instance.save() + + logger.info(f"Updated Redis serialized_data for {self.user_id}/{self.mem_cube_id}") + self.last_version_control = orm_instance.version_control + else: + logger.warning("No current object to merge with Redis data") + + logger.info(f"Redis synchronization completed for {self.user_id}/{self.mem_cube_id}") + + except Exception as e: + logger.error( + f"Error during Redis synchronization for {self.user_id}/{self.mem_cube_id}: {e}", + exc_info=True, + ) + finally: + # Always release locks + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + + def save_to_db(self, obj_instance: Any) -> None: + """Save the current state of the business object to Redis + + Args: + obj_instance: The object instance to save (must have to_json method) + """ + try: + # Acquire lock before operations + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire Redis lock for saving") + return + + # Get or create Redis ORM instance + orm_instance = self._get_redis_orm_instance() + exists = orm_instance.load() + + if not exists: + # Create new record + orm_instance.serialized_data = obj_instance.to_json() + orm_instance.version_control = "0" + orm_instance.save() + + logger.info(f"Created new Redis record for {self.user_id}/{self.mem_cube_id}") + self.last_version_control = "0" + else: + # Update existing record with version control + current_version = orm_instance.version_control + new_version = self._increment_version_control(current_version) + + orm_instance.serialized_data = obj_instance.to_json() + orm_instance.version_control = new_version + orm_instance.save() + + logger.info( + f"Updated existing Redis record for {self.user_id}/{self.mem_cube_id} with version {new_version}" + ) + self.last_version_control = new_version + + except Exception as e: + logger.error(f"Error saving to Redis for {self.user_id}/{self.mem_cube_id}: {e}") + finally: + # Always release locks + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + + def load_from_db(self, acquire_lock: bool = False) -> Any | None: + """Load the business object from Redis + + Args: + acquire_lock: Whether to acquire a lock during the load operation + + Returns: + The deserialized object instance, or None if not found + """ + try: + if acquire_lock: + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire Redis lock for loading") + return None + + # Load from Redis + orm_instance = self._get_redis_orm_instance() + exists = orm_instance.load() + + if not exists or not orm_instance.serialized_data: + logger.info(f"No Redis record found for {self.user_id}/{self.mem_cube_id}") + return None + + # Deserialize the business object using the actual object type + if self.obj_type is not None: + db_instance = self.obj_type.from_json(orm_instance.serialized_data) + else: + db_instance = MemoryMonitorManager.from_json(orm_instance.serialized_data) + self.last_version_control = orm_instance.version_control + + logger.info( + f"Successfully loaded object from Redis for {self.user_id}/{self.mem_cube_id} with version {orm_instance.version_control}" + ) + return db_instance + + except Exception as e: + logger.error(f"Error loading from Redis for {self.user_id}/{self.mem_cube_id}: {e}") + return None + finally: + if acquire_lock: + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + + def close(self): + """Close the Redis manager and clean up resources""" + try: + # Release any locks held by this manager instance + if self.user_id and self.mem_cube_id: + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + logger.info(f"Released Redis locks for {self.user_id}/{self.mem_cube_id}") + + # Close Redis connection + if self.redis_client: + self.redis_client.close() + logger.info("Redis connection closed") + + # Call parent close method for any additional cleanup + super().close() + + except Exception as e: + logger.error(f"Error during Redis close operation: {e}") + + @classmethod + def from_env( + cls, + user_id: str, + mem_cube_id: str, + obj: Any | None = None, + lock_timeout: int = 10, + env_file_path: str | None = None, + ) -> "RedisDBManager": + """Create RedisDBManager from environment variables + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + obj: Optional MemoryMonitorManager instance + lock_timeout: Lock timeout in seconds + env_file_path: Optional path to .env file + + Returns: + RedisDBManager instance + """ + try: + redis_client = cls.load_redis_engine_from_env(env_file_path) + return cls( + user_id=user_id, + mem_cube_id=mem_cube_id, + obj=obj, + lock_timeout=lock_timeout, + redis_client=redis_client, + ) + except Exception as e: + logger.error(f"Failed to create RedisDBManager from environment: {e}") + raise + + def list_keys(self, pattern: str | None = None) -> list[str]: + """List all Redis keys for this manager's data + + Args: + pattern: Optional pattern to filter keys + + Returns: + List of Redis keys + """ + try: + if pattern is None: + pattern = f"lockable_orm:{self.user_id}:{self.mem_cube_id}:*" + + keys = self.redis_client.keys(pattern) + return [key.decode() if isinstance(key, bytes) else key for key in keys] + + except Exception as e: + logger.error(f"Error listing Redis keys: {e}") + return [] + + def health_check(self) -> dict[str, bool]: + """Check the health of Redis connection + + Returns: + Dictionary with health status + """ + try: + redis_healthy = self.redis_client.ping() + return { + "redis": redis_healthy, + "mysql": False, # Not applicable for Redis manager + } + except Exception as e: + logger.error(f"Redis health check failed: {e}") + return {"redis": False, "mysql": False} diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 248c42e80..f0868e8df 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,7 +1,16 @@ +from enum import Enum from pathlib import Path from typing import NewType +class SearchMode(str, Enum): + """Enumeration for search modes.""" + + FAST = "fast" + FINE = "fine" + MIXTURE = "mixture" + + FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent @@ -10,6 +19,8 @@ ADD_LABEL = "add" MEM_READ_LABEL = "mem_read" MEM_ORGANIZE_LABEL = "mem_organize" +API_MIX_SEARCH_LABEL = "api_mix_search" + TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" @@ -24,6 +35,10 @@ DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 +DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 100000 +DEFAULT_TOP_K = 10 +DEFAULT_CONTEXT_WINDOW_SIZE = 5 +DEFAULT_USE_REDIS_QUEUE = False # startup mode configuration STARTUP_BY_THREAD = "thread" diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 9b5bd5d81..efdaa44ef 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -8,6 +8,7 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.misc import DictConversionMixin +from memos.mem_scheduler.utils.db_utils import get_utc_now from .general_schemas import NOT_INITIALIZED @@ -39,7 +40,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") content: str = Field(..., description="Content of the schedule message") timestamp: datetime = Field( - default_factory=lambda: datetime.utcnow(), description="submit time for schedule_messages" + default_factory=get_utc_now, description="submit time for schedule_messages" ) # Pydantic V2 model configuration @@ -88,9 +89,9 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": return cls( item_id=data.get("item_id", str(uuid4())), user_id=data["user_id"], - cube_id=data["cube_id"], + mem_cube_id=data["cube_id"], label=data["label"], - cube="Not Applicable", # Custom cube deserialization + mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), ) @@ -131,7 +132,7 @@ class ScheduleLogForWebItem(BaseModel, DictConversionMixin): description="Maximum capacities of memory partitions", ) timestamp: datetime = Field( - default_factory=lambda: datetime.utcnow(), + default_factory=get_utc_now, description="Timestamp indicating when the log entry was created", ) diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index d189797ae..168a25b5d 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -7,6 +7,7 @@ from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import DictConversionMixin +from memos.mem_scheduler.utils.db_utils import get_utc_now logger = get_logger(__name__) @@ -26,7 +27,7 @@ class RunningTaskItem(BaseModel, DictConversionMixin): mem_cube_id: str = Field(..., description="Required memory cube identifier", min_length=1) task_info: str = Field(..., description="Information about the task being executed") task_name: str = Field(..., description="Name/type of the task handler") - start_time: datetime = Field(description="Task start time", default_factory=datetime.utcnow) + start_time: datetime = Field(description="Task start time", default_factory=get_utc_now) end_time: datetime | None = Field(default=None, description="Task completion time") status: str = Field(default="running", description="Task status: running, completed, failed") result: Any | None = Field(default=None, description="Task execution result") @@ -37,13 +38,13 @@ class RunningTaskItem(BaseModel, DictConversionMixin): def mark_completed(self, result: Any | None = None) -> None: """Mark task as completed with optional result.""" - self.end_time = datetime.utcnow() + self.end_time = get_utc_now() self.status = "completed" self.result = result def mark_failed(self, error_message: str) -> None: """Mark task as failed with error message.""" - self.end_time = datetime.utcnow() + self.end_time = get_utc_now() self.status = "failed" self.error_message = error_message diff --git a/src/memos/mem_scheduler/utils/db_utils.py b/src/memos/mem_scheduler/utils/db_utils.py index 5d7cc52c3..4c7402a9d 100644 --- a/src/memos/mem_scheduler/utils/db_utils.py +++ b/src/memos/mem_scheduler/utils/db_utils.py @@ -1,5 +1,22 @@ import os import sqlite3 +import sys + +from datetime import datetime, timezone + + +# Compatibility handling: Python 3.11+ supports UTC, earlier versions use timezone.utc +if sys.version_info >= (3, 11): + from datetime import UTC + + def get_utc_now(): + """Get current UTC datetime with compatibility for different Python versions""" + return datetime.now(UTC) +else: + + def get_utc_now(): + """Get current UTC datetime with compatibility for different Python versions""" + return datetime.now(timezone.utc) def print_db_tables(db_path: str): diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 8865c2232..b240f4369 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -67,39 +67,42 @@ def initialize_rabbitmq( """ Establish connection to RabbitMQ using pika. """ - from pika.adapters.select_connection import SelectConnection - - if config is None: - if config_path is None and AuthConfig.default_config_exists(): - auth_config = AuthConfig.from_local_config() - elif Path(config_path).exists(): - auth_config = AuthConfig.from_local_config(config_path=config_path) + try: + from pika.adapters.select_connection import SelectConnection + + if config is None: + if config_path is None and AuthConfig.default_config_exists(): + auth_config = AuthConfig.from_local_config() + elif Path(config_path).exists(): + auth_config = AuthConfig.from_local_config(config_path=config_path) + else: + logger.error("Fail to initialize auth_config") + return + self.rabbitmq_config = auth_config.rabbitmq + elif isinstance(config, RabbitMQConfig): + self.rabbitmq_config = config + elif isinstance(config, dict): + self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq else: - logger.error("Fail to initialize auth_config") - return - self.rabbitmq_config = auth_config.rabbitmq - elif isinstance(config, RabbitMQConfig): - self.rabbitmq_config = config - elif isinstance(config, dict): - self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq - else: - logger.error("Not implemented") - - # Start connection process - parameters = self.get_rabbitmq_connection_param() - self.rabbitmq_connection = SelectConnection( - parameters, - on_open_callback=self.on_rabbitmq_connection_open, - on_open_error_callback=self.on_rabbitmq_connection_error, - on_close_callback=self.on_rabbitmq_connection_closed, - ) + logger.error("Not implemented") + + # Start connection process + parameters = self.get_rabbitmq_connection_param() + self.rabbitmq_connection = SelectConnection( + parameters, + on_open_callback=self.on_rabbitmq_connection_open, + on_open_error_callback=self.on_rabbitmq_connection_error, + on_close_callback=self.on_rabbitmq_connection_closed, + ) - # Start IOLoop in dedicated thread - self._io_loop_thread = threading.Thread( - target=self.rabbitmq_connection.ioloop.start, daemon=True - ) - self._io_loop_thread.start() - logger.info("RabbitMQ connection process started") + # Start IOLoop in dedicated thread + self._io_loop_thread = threading.Thread( + target=self.rabbitmq_connection.ioloop.start, daemon=True + ) + self._io_loop_thread.start() + logger.info("RabbitMQ connection process started") + except Exception: + logger.error("Fail to initialize auth_config", exc_info=True) def get_rabbitmq_queue_size(self) -> int: """Get the current number of messages in the queue. diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index 5b04ec280..239557bc9 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -1,5 +1,8 @@ import asyncio +import os +import subprocess import threading +import time from collections.abc import Callable from typing import Any @@ -27,10 +30,14 @@ def __init__(self): super().__init__() # settings for redis - self.redis_host: str = None - self.redis_port: int = None - self.redis_db: int = None + self.redis_host: str | None = None + self.redis_port: int | None = None + self.redis_db: int | None = None + self.redis_password: str | None = None + self.socket_timeout: float | None = None + self.socket_connect_timeout: float | None = None self._redis_conn = None + self._local_redis_process = None self.query_list_capacity = 1000 self._redis_listener_running = False @@ -46,19 +53,40 @@ def redis(self, value: Any) -> None: self._redis_conn = value def initialize_redis( - self, redis_host: str = "localhost", redis_port: int = 6379, redis_db: int = 0 + self, + redis_host: str = "localhost", + redis_port: int = 6379, + redis_db: int = 0, + redis_password: str | None = None, + socket_timeout: float | None = None, + socket_connect_timeout: float | None = None, ): import redis self.redis_host = redis_host self.redis_port = redis_port self.redis_db = redis_db + self.redis_password = redis_password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout try: logger.debug(f"Connecting to Redis at {redis_host}:{redis_port}/{redis_db}") - self._redis_conn = redis.Redis( - host=self.redis_host, port=self.redis_port, db=self.redis_db, decode_responses=True - ) + redis_kwargs = { + "host": self.redis_host, + "port": self.redis_port, + "db": self.redis_db, + "password": redis_password, + "decode_responses": True, + } + + # Add timeout parameters if provided + if socket_timeout is not None: + redis_kwargs["socket_timeout"] = socket_timeout + if socket_connect_timeout is not None: + redis_kwargs["socket_connect_timeout"] = socket_connect_timeout + + self._redis_conn = redis.Redis(**redis_kwargs) # test conn if not self._redis_conn.ping(): logger.error("Redis connection failed") @@ -68,6 +96,183 @@ def initialize_redis( self._redis_conn.xtrim("user:queries:stream", self.query_list_capacity) return self._redis_conn + @require_python_package( + import_name="redis", + install_command="pip install redis", + install_link="https://redis.readthedocs.io/en/stable/", + ) + def auto_initialize_redis(self) -> bool: + """ + Auto-initialize Redis with fallback strategies: + 1. Try to initialize from config + 2. Try to initialize from environment variables + 3. Try to start local Redis server as fallback + + Returns: + bool: True if Redis connection is successfully established, False otherwise + """ + import redis + + # Strategy 1: Try to initialize from config + if hasattr(self, "config") and hasattr(self.config, "redis_config"): + try: + redis_config = self.config.redis_config + logger.info("Attempting to initialize Redis from config") + + self._redis_conn = redis.Redis( + host=redis_config.get("host", "localhost"), + port=redis_config.get("port", 6379), + db=redis_config.get("db", 0), + password=redis_config.get("password", None), + decode_responses=True, + ) + + # Test connection + if self._redis_conn.ping(): + logger.info("Redis initialized successfully from config") + self.redis_host = redis_config.get("host", "localhost") + self.redis_port = redis_config.get("port", 6379) + self.redis_db = redis_config.get("db", 0) + self.redis_password = redis_config.get("password", None) + self.socket_timeout = redis_config.get("socket_timeout", None) + self.socket_connect_timeout = redis_config.get("socket_connect_timeout", None) + return True + else: + logger.warning("Redis config connection test failed") + self._redis_conn = None + except Exception as e: + logger.warning(f"Failed to initialize Redis from config: {e}") + self._redis_conn = None + + # Strategy 2: Try to initialize from environment variables + try: + redis_host = os.getenv("MEMSCHEDULER_REDIS_HOST", "localhost") + redis_port = int(os.getenv("MEMSCHEDULER_REDIS_PORT", "6379")) + redis_db = int(os.getenv("MEMSCHEDULER_REDIS_DB", "0")) + redis_password = os.getenv("MEMSCHEDULER_REDIS_PASSWORD", None) + socket_timeout = os.getenv("MEMSCHEDULER_REDIS_TIMEOUT", None) + socket_connect_timeout = os.getenv("MEMSCHEDULER_REDIS_CONNECT_TIMEOUT", None) + + logger.info( + f"Attempting to initialize Redis from environment variables: {redis_host}:{redis_port}" + ) + + redis_kwargs = { + "host": redis_host, + "port": redis_port, + "db": redis_db, + "password": redis_password, + "decode_responses": True, + } + + # Add timeout parameters if provided + if socket_timeout is not None: + try: + redis_kwargs["socket_timeout"] = float(socket_timeout) + except ValueError: + logger.warning( + f"Invalid MEMSCHEDULER_REDIS_TIMEOUT value: {socket_timeout}, ignoring" + ) + + if socket_connect_timeout is not None: + try: + redis_kwargs["socket_connect_timeout"] = float(socket_connect_timeout) + except ValueError: + logger.warning( + f"Invalid MEMSCHEDULER_REDIS_CONNECT_TIMEOUT value: {socket_connect_timeout}, ignoring" + ) + + self._redis_conn = redis.Redis(**redis_kwargs) + + # Test connection + if self._redis_conn.ping(): + logger.info("Redis initialized successfully from environment variables") + self.redis_host = redis_host + self.redis_port = redis_port + self.redis_db = redis_db + self.redis_password = redis_password + self.socket_timeout = float(socket_timeout) if socket_timeout is not None else None + self.socket_connect_timeout = ( + float(socket_connect_timeout) if socket_connect_timeout is not None else None + ) + return True + else: + logger.warning("Redis environment connection test failed") + self._redis_conn = None + except Exception as e: + logger.warning(f"Failed to initialize Redis from environment variables: {e}") + self._redis_conn = None + + # Strategy 3: Try to start local Redis server as fallback + try: + logger.warning( + "Attempting to start local Redis server as fallback (not recommended for production)" + ) + + # Try to start Redis server locally + self._local_redis_process = subprocess.Popen( + ["redis-server", "--port", "6379", "--daemonize", "no"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid if hasattr(os, "setsid") else None, + ) + + # Wait a moment for Redis to start + time.sleep(0.5) + + # Try to connect to local Redis + self._redis_conn = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True) + + # Test connection + if self._redis_conn.ping(): + logger.warning("Local Redis server started and connected successfully") + logger.warning("WARNING: Using local Redis server - not suitable for production!") + self.redis_host = "localhost" + self.redis_port = 6379 + self.redis_db = 0 + self.redis_password = None + self.socket_timeout = None + self.socket_connect_timeout = None + return True + else: + logger.error("Local Redis server connection test failed") + self._cleanup_local_redis() + return False + + except Exception as e: + logger.error(f"Failed to start local Redis server: {e}") + self._cleanup_local_redis() + return False + + def _cleanup_local_redis(self): + """Clean up local Redis process if it exists""" + if self._local_redis_process: + try: + self._local_redis_process.terminate() + self._local_redis_process.wait(timeout=5) + logger.info("Local Redis process terminated") + except subprocess.TimeoutExpired: + logger.warning("Local Redis process did not terminate gracefully, killing it") + self._local_redis_process.kill() + self._local_redis_process.wait() + except Exception as e: + logger.error(f"Error cleaning up local Redis process: {e}") + finally: + self._local_redis_process = None + + def _cleanup_redis_resources(self): + """Clean up Redis connection and local process""" + if self._redis_conn: + try: + self._redis_conn.close() + logger.info("Redis connection closed") + except Exception as e: + logger.error(f"Error closing Redis connection: {e}") + finally: + self._redis_conn = None + + self._cleanup_local_redis() + async def redis_add_message_stream(self, message: dict): logger.debug(f"add_message_stream: {message}") return self._redis_conn.xadd("user:queries:stream", message) @@ -150,7 +355,5 @@ def redis_stop_listening(self): logger.info("Redis stream listener stopped") def redis_close(self): - """Close Redis connection""" - if self._redis_conn is not None: - self._redis_conn.close() - self._redis_conn = None + """Close Redis connection and clean up resources""" + self._cleanup_redis_resources() diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 2fa08590f..98d611dbf 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -237,16 +237,36 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> DynamicCache: """ + Move DynamicCache from CPU to GPU device. + Compatible with both old and new transformers versions. + In SimpleMemChat.run(), if self.config.enable_activation_memory is enabled, we load serialized kv cache from a [class KVCacheMemory] object, which has a kv_cache_memories on CPU. So before inferring with DynamicCache, we should move it to GPU in-place first. """ - # Currently, we put this function outside [class KVCacheMemory] - for i in range(len(dynamic_cache.key_cache)): - if dynamic_cache.key_cache[i] is not None: - dynamic_cache.key_cache[i] = dynamic_cache.key_cache[i].to(device, non_blocking=True) - if dynamic_cache.value_cache[i] is not None: - dynamic_cache.value_cache[i] = dynamic_cache.value_cache[i].to( - device, non_blocking=True - ) + # Handle compatibility between old and new transformers versions + if hasattr(dynamic_cache, "layers"): + # New version: use layers attribute + for layer in dynamic_cache.layers: + if hasattr(layer, "key_cache") and layer.key_cache is not None: + layer.key_cache = layer.key_cache.to(device, non_blocking=True) + if hasattr(layer, "value_cache") and layer.value_cache is not None: + layer.value_cache = layer.value_cache.to(device, non_blocking=True) + elif hasattr(layer, "keys") and hasattr(layer, "values"): + # Alternative attribute names in some versions + if layer.keys is not None: + layer.keys = layer.keys.to(device, non_blocking=True) + if layer.values is not None: + layer.values = layer.values.to(device, non_blocking=True) + elif hasattr(dynamic_cache, "key_cache") and hasattr(dynamic_cache, "value_cache"): + # Old version: use key_cache and value_cache attributes + for i in range(len(dynamic_cache.key_cache)): + if dynamic_cache.key_cache[i] is not None: + dynamic_cache.key_cache[i] = dynamic_cache.key_cache[i].to( + device, non_blocking=True + ) + if dynamic_cache.value_cache[i] is not None: + dynamic_cache.value_cache[i] = dynamic_cache.value_cache[i].to( + device, non_blocking=True + ) return dynamic_cache diff --git a/tests/llms/test_hf.py b/tests/llms/test_hf.py index 8a266e58d..595995ad1 100644 --- a/tests/llms/test_hf.py +++ b/tests/llms/test_hf.py @@ -93,15 +93,50 @@ def test_build_kv_cache_and_generation(self): add_generation_prompt=True, ) llm = self._create_llm(config) + + # Ensure the mock model returns an object with past_key_values attribute + forward_output = MagicMock() + forward_output.logits = torch.ones(1, 1, 100) + + # Create a DynamicCache that's compatible with both old and new transformers versions + kv_cache = DynamicCache() + + # Mock the DynamicCache to have both old and new version attributes for compatibility + # New version uses 'layers' attribute + mock_layer = MagicMock() + mock_layer.key_cache = torch.tensor([[[[1.0, 2.0]]]]) + mock_layer.value_cache = torch.tensor([[[[3.0, 4.0]]]]) + kv_cache.layers = [mock_layer] + + # Old version uses 'key_cache' and 'value_cache' lists + kv_cache.key_cache = [torch.tensor([[[[1.0, 2.0]]]])] + kv_cache.value_cache = [torch.tensor([[[[3.0, 4.0]]]])] + + forward_output.past_key_values = kv_cache + # Make sure the mock model call returns the forward_output when called with **kwargs + self.mock_model.return_value = forward_output + kv_cache = llm.build_kv_cache("The capital of France is Paris.") self.assertIsInstance(kv_cache, DynamicCache) resp = llm.generate( [{"role": "user", "content": "What's its population?"}], past_key_values=kv_cache ) self.assertEqual(resp, self.standard_response) - first_kwargs = self.mock_model.call_args_list[0][1] - self.assertIs(first_kwargs["past_key_values"], kv_cache) - self.assertTrue(first_kwargs["use_cache"]) + # Check that the model was called with past_key_values during _prefill + # The model should be called multiple times during generation with cache + found_past_key_values = False + for call_args in self.mock_model.call_args_list: + if len(call_args) > 1 and "past_key_values" in call_args[1]: + found_past_key_values = True + break + self.assertTrue(found_past_key_values, "Model should be called with past_key_values") + # Check that use_cache was used + found_use_cache = False + for call_args in self.mock_model.call_args_list: + if len(call_args) > 1 and call_args[1].get("use_cache"): + found_use_cache = True + break + self.assertTrue(found_use_cache, "Model should be called with use_cache=True") def test_think_prefix_removal(self): config = HFLLMConfig( diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index ed2093dea..e3064660b 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -233,7 +233,7 @@ def test_dispatch_parallel(self): self.assertEqual(len(label2_messages), 1) self.assertEqual(label2_messages[0].item_id, "msg2") - def test_group_messages_by_user_and_cube(self): + def test_group_messages_by_user_and_mem_cube(self): """Test grouping messages by user and cube.""" # Check actual grouping logic with patch("memos.mem_scheduler.general_modules.dispatcher.logger.debug"): @@ -261,47 +261,6 @@ def test_group_messages_by_user_and_cube(self): for msg in expected[user_id][cube_id]: self.assertIn(msg.item_id, [m.item_id for m in result[user_id][cube_id]]) - def test_thread_race(self): - """Test the ThreadRace integration.""" - - # Define test tasks - def task1(stop_flag): - time.sleep(0.1) - return "result1" - - def task2(stop_flag): - time.sleep(0.2) - return "result2" - - # Run competitive tasks - tasks = { - "task1": task1, - "task2": task2, - } - - result = self.dispatcher.run_competitive_tasks(tasks, timeout=1.0) - - # Verify the result - self.assertIsNotNone(result) - self.assertEqual(result[0], "task1") # task1 should win - self.assertEqual(result[1], "result1") - - def test_thread_race_timeout(self): - """Test ThreadRace with timeout.""" - - # Define a task that takes longer than the timeout - def slow_task(stop_flag): - time.sleep(0.5) - return "slow_result" - - tasks = {"slow": slow_task} - - # Run with a short timeout - result = self.dispatcher.run_competitive_tasks(tasks, timeout=0.1) - - # Verify no result was returned due to timeout - self.assertIsNone(result) - def test_thread_race_cooperative_termination(self): """Test that ThreadRace properly terminates slower threads when one completes.""" @@ -459,3 +418,190 @@ def test_dispatcher_monitor_logs_stuck_task_messages(self): self.assertIn("Messages: 2 items", expected_log) self.assertIn("Stuck message 1", expected_log) self.assertIn("Stuck message 2", expected_log) + + def test_get_running_tasks_no_filter(self): + """Test get_running_tasks without filter returns all running tasks.""" + # Create test tasks manually + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="handler1", + ) + task2 = RunningTaskItem( + user_id="user2", + mem_cube_id="cube2", + task_info="Test task 2", + task_name="handler2", + ) + + # Add tasks to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + self.dispatcher._running_tasks[task2.item_id] = task2 + + # Get all running tasks + running_tasks = self.dispatcher.get_running_tasks() + + # Verify all tasks are returned + self.assertEqual(len(running_tasks), 2) + self.assertIn(task1.item_id, running_tasks) + self.assertIn(task2.item_id, running_tasks) + self.assertEqual(running_tasks[task1.item_id], task1) + self.assertEqual(running_tasks[task2.item_id], task2) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() + + def test_get_running_tasks_filter_by_user_id(self): + """Test get_running_tasks with user_id filter.""" + # Create test tasks with different user_ids + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="handler1", + ) + task2 = RunningTaskItem( + user_id="user2", + mem_cube_id="cube2", + task_info="Test task 2", + task_name="handler2", + ) + task3 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube3", + task_info="Test task 3", + task_name="handler3", + ) + + # Add tasks to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + self.dispatcher._running_tasks[task2.item_id] = task2 + self.dispatcher._running_tasks[task3.item_id] = task3 + + # Filter by user_id + user1_tasks = self.dispatcher.get_running_tasks(lambda task: task.user_id == "user1") + + # Verify only user1 tasks are returned + self.assertEqual(len(user1_tasks), 2) + self.assertIn(task1.item_id, user1_tasks) + self.assertIn(task3.item_id, user1_tasks) + self.assertNotIn(task2.item_id, user1_tasks) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() + + def test_get_running_tasks_filter_by_multiple_conditions(self): + """Test get_running_tasks with multiple filter conditions.""" + # Create test tasks with different attributes + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="test_handler", + ) + task2 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube2", + task_info="Test task 2", + task_name="other_handler", + ) + task3 = RunningTaskItem( + user_id="user2", + mem_cube_id="cube1", + task_info="Test task 3", + task_name="test_handler", + ) + + # Add tasks to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + self.dispatcher._running_tasks[task2.item_id] = task2 + self.dispatcher._running_tasks[task3.item_id] = task3 + + # Filter by multiple conditions: user_id == "user1" AND task_name == "test_handler" + filtered_tasks = self.dispatcher.get_running_tasks( + lambda task: task.user_id == "user1" and task.task_name == "test_handler" + ) + + # Verify only task1 matches both conditions + self.assertEqual(len(filtered_tasks), 1) + self.assertIn(task1.item_id, filtered_tasks) + self.assertNotIn(task2.item_id, filtered_tasks) + self.assertNotIn(task3.item_id, filtered_tasks) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() + + def test_get_running_tasks_filter_by_status(self): + """Test get_running_tasks with status filter.""" + # Create test tasks with different statuses + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="handler1", + ) + task2 = RunningTaskItem( + user_id="user2", + mem_cube_id="cube2", + task_info="Test task 2", + task_name="handler2", + ) + + # Manually set different statuses + task1.status = "running" + task2.status = "completed" + + # Add tasks to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + self.dispatcher._running_tasks[task2.item_id] = task2 + + # Filter by status + running_status_tasks = self.dispatcher.get_running_tasks( + lambda task: task.status == "running" + ) + + # Verify only running tasks are returned + self.assertEqual(len(running_status_tasks), 1) + self.assertIn(task1.item_id, running_status_tasks) + self.assertNotIn(task2.item_id, running_status_tasks) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() + + def test_get_running_tasks_thread_safety(self): + """Test get_running_tasks is thread-safe.""" + # Create test task + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="handler1", + ) + + # Add task to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + + # Get running tasks (should work without deadlock) + running_tasks = self.dispatcher.get_running_tasks() + + # Verify task is returned + self.assertEqual(len(running_tasks), 1) + self.assertIn(task1.item_id, running_tasks) + + # Test with filter (should also work without deadlock) + filtered_tasks = self.dispatcher.get_running_tasks(lambda task: task.user_id == "user1") + self.assertEqual(len(filtered_tasks), 1) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() diff --git a/tests/mem_scheduler/test_orm.py b/tests/mem_scheduler/test_orm.py index ddf4fea8b..a43231e4a 100644 --- a/tests/mem_scheduler/test_orm.py +++ b/tests/mem_scheduler/test_orm.py @@ -13,6 +13,7 @@ DBManagerForMemoryMonitorManager, DBManagerForQueryMonitorQueue, ) +from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager from memos.mem_scheduler.schemas.monitor_schemas import ( MemoryMonitorItem, MemoryMonitorManager, @@ -297,3 +298,150 @@ def test_concurrent_access(self, temp_db, query_queue_obj): manager1.close() manager2.close() + + +class TestRedisDBManager: + """Test class for RedisDBManager functionality""" + + @pytest.fixture + def memory_manager_obj(self): + """Create a MemoryMonitorManager object for testing""" + return MemoryMonitorManager( + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + memories=[ + MemoryMonitorItem( + item_id="redis-test-123", + memory_text="Redis test memory", + tree_memory_item=None, + tree_memory_item_mapping_key="redis_test_key", + keywords_score=0.8, + sorting_score=0.9, + importance_score=0.7, + recording_count=3, + ) + ], + ) + + @pytest.fixture + def mock_redis_client(self): + """Create a mock Redis client for testing""" + try: + from unittest.mock import MagicMock + + # Create a mock Redis client + mock_client = MagicMock() + + # Mock Redis data storage + mock_data = {} + + def mock_set(key, value, nx=False, ex=None, **kwargs): + if nx and key in mock_data: + # NX means "only set if not exists" + return False # Redis returns False when NX fails + mock_data[key] = value + return True + + def mock_get(key): + return mock_data.get(key) + + def mock_hset(key, mapping=None, **kwargs): + if key not in mock_data: + mock_data[key] = {} + if mapping: + mock_data[key].update(mapping) + if kwargs: + mock_data[key].update(kwargs) + return len(mapping) if mapping else len(kwargs) + + def mock_hgetall(key): + return mock_data.get(key, {}) + + def mock_delete(*keys): + deleted = 0 + for key in keys: + if key in mock_data: + del mock_data[key] + deleted += 1 + return deleted + + def mock_keys(pattern): + import fnmatch + + return [key for key in mock_data if fnmatch.fnmatch(key, pattern)] + + def mock_ping(): + return True + + def mock_close(): + pass + + # Configure mock methods + mock_client.set = mock_set + mock_client.get = mock_get + mock_client.hset = mock_hset + mock_client.hgetall = mock_hgetall + mock_client.delete = mock_delete + mock_client.keys = mock_keys + mock_client.ping = mock_ping + mock_client.close = mock_close + + return mock_client + + except ImportError: + pytest.skip("Redis package not available for testing") + + @pytest.fixture + def redis_manager(self, mock_redis_client, memory_manager_obj): + """Create RedisDBManager instance with mock Redis client""" + manager = RedisDBManager( + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + obj=memory_manager_obj, + lock_timeout=10, + redis_client=mock_redis_client, + ) + yield manager + manager.close() + + def test_redis_manager_initialization(self, mock_redis_client): + """Test RedisDBManager initialization""" + manager = RedisDBManager( + user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID, redis_client=mock_redis_client + ) + + assert manager.user_id == TEST_USER_ID + assert manager.mem_cube_id == TEST_MEM_CUBE_ID + assert manager.redis_client is mock_redis_client + assert manager.orm_class.__name__ == "RedisLockableORM" + assert manager.obj_class == MemoryMonitorManager + + manager.close() + + def test_redis_lockable_orm_save_load(self, mock_redis_client): + """Test RedisLockableORM save and load operations""" + from memos.mem_scheduler.orm_modules.redis_model import RedisLockableORM + + orm = RedisLockableORM( + redis_client=mock_redis_client, user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID + ) + + # Test save + orm.serialized_data = '{"test": "data"}' + orm.version_control = "1" + orm.lock_acquired = True + orm.lock_expiry = datetime.now() + + orm.save() + + # Test load + new_orm = RedisLockableORM( + redis_client=mock_redis_client, user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID + ) + + exists = new_orm.load() + assert exists + assert new_orm.serialized_data == '{"test": "data"}' + assert new_orm.version_control == "1" + # Note: lock_acquired is False after load by design - locks are managed separately + assert not new_orm.lock_acquired diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 15338006d..369b4a6f1 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -26,6 +26,7 @@ ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, + ScheduleMessageItem, ) from memos.memories.textual.tree import TreeTextMemory @@ -36,6 +37,9 @@ class TestGeneralScheduler(unittest.TestCase): + # Control whether to run activation memory tests that require GPU, default is False + RUN_ACTIVATION_MEMORY_TESTS = True + def _create_mock_auth_config(self): """Create a mock AuthConfig for testing purposes.""" # Create mock configs with valid test values @@ -68,6 +72,19 @@ def setUp(self): self.llm = MagicMock(spec=BaseLLM) self.mem_cube = MagicMock(spec=GeneralMemCube) self.tree_text_memory = MagicMock(spec=TreeTextMemory) + # Add memory_manager mock to prevent AttributeError in scheduler_logger + self.tree_text_memory.memory_manager = MagicMock() + self.tree_text_memory.memory_manager.memory_size = { + "LongTermMemory": 10000, + "UserMemory": 10000, + "WorkingMemory": 20, + } + # Mock get_current_memory_size method + self.tree_text_memory.get_current_memory_size.return_value = { + "LongTermMemory": 100, + "UserMemory": 50, + "WorkingMemory": 10, + } self.mem_cube.text_mem = self.tree_text_memory self.mem_cube.act_mem = MagicMock() @@ -185,8 +202,72 @@ def test_scheduler_startup_mode_thread(self): # Stop the scheduler self.scheduler.stop() - # Verify cleanup - self.assertFalse(self.scheduler._running) + def test_redis_message_queue(self): + """Test Redis message queue functionality for sending and receiving messages.""" + import asyncio + import time + + from unittest.mock import MagicMock, patch + + # Mock Redis connection and operations + mock_redis = MagicMock() + mock_redis.xadd = MagicMock(return_value=b"1234567890-0") + + # Track received messages + received_messages = [] + + def redis_handler(messages: list[ScheduleMessageItem]) -> None: + """Handler for Redis messages.""" + received_messages.extend(messages) + + # Register Redis handler + redis_label = "test_redis" + handlers = {redis_label: redis_handler} + self.scheduler.register_handlers(handlers) + + # Enable Redis queue for this test + with ( + patch.object(self.scheduler, "use_redis_queue", True), + patch.object(self.scheduler, "_redis_conn", mock_redis), + ): + # Start scheduler + self.scheduler.start() + + # Create test message for Redis + redis_message = ScheduleMessageItem( + label=redis_label, + content="Redis test message", + user_id="redis_user", + mem_cube_id="redis_cube", + mem_cube="redis_mem_cube_obj", + timestamp=datetime.now(), + ) + + # Submit message to Redis queue + asyncio.run(self.scheduler.submit_messages(redis_message)) + + # Verify Redis xadd was called + mock_redis.xadd.assert_called_once() + call_args = mock_redis.xadd.call_args + self.assertEqual(call_args[0][0], "user:queries:stream") + + # Verify message data was serialized correctly + message_data = call_args[0][1] + self.assertEqual(message_data["label"], redis_label) + self.assertEqual(message_data["content"], "Redis test message") + self.assertEqual(message_data["user_id"], "redis_user") + self.assertEqual(message_data["cube_id"], "redis_cube") # Note: to_dict uses cube_id + + # Simulate Redis message consumption + # This would normally be handled by the Redis consumer in the scheduler + time.sleep(0.1) # Brief wait for async operations + + # Stop scheduler + self.scheduler.stop() + + print("Redis message queue test completed successfully!") + + # Removed test_robustness method - was too time-consuming for CI/CD pipeline def test_scheduler_startup_mode_process(self): """Test scheduler with process startup mode.""" @@ -219,3 +300,284 @@ def test_scheduler_startup_mode_constants(self): """Test that startup mode constants are properly defined.""" self.assertEqual(STARTUP_BY_THREAD, "thread") self.assertEqual(STARTUP_BY_PROCESS, "process") + + def test_activation_memory_update(self): + """Test activation memory update functionality with DynamicCache handling.""" + if not self.RUN_ACTIVATION_MEMORY_TESTS: + self.skipTest( + "Skipping activation memory test. Set RUN_ACTIVATION_MEMORY_TESTS=True to enable." + ) + + from unittest.mock import Mock + + from transformers import DynamicCache + + from memos.memories.activation.kv import KVCacheMemory + + # Mock the mem_cube with activation memory + mock_kv_cache_memory = Mock(spec=KVCacheMemory) + self.mem_cube.act_mem = mock_kv_cache_memory + + # Mock get_all to return empty list (no existing cache items) + mock_kv_cache_memory.get_all.return_value = [] + + # Create a mock DynamicCache with layers attribute + mock_cache = Mock(spec=DynamicCache) + mock_cache.layers = [] + + # Create mock layers with key_cache and value_cache + for _ in range(2): # Simulate 2 layers + mock_layer = Mock() + mock_layer.key_cache = Mock() + mock_layer.value_cache = Mock() + mock_cache.layers.append(mock_layer) + + # Mock the extract method to return a KVCacheItem + mock_cache_item = Mock() + mock_cache_item.records = Mock() + mock_cache_item.records.text_memories = [] + mock_cache_item.records.timestamp = None + mock_kv_cache_memory.extract.return_value = mock_cache_item + + # Test data + test_memories = ["Test memory 1", "Test memory 2"] + user_id = "test_user" + mem_cube_id = "test_cube" + + # Call the method under test + try: + self.scheduler.update_activation_memory( + new_memories=test_memories, + label=QUERY_LABEL, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.mem_cube, + ) + + # Verify that extract was called + mock_kv_cache_memory.extract.assert_called_once() + + # Verify that add was called with the extracted cache item + mock_kv_cache_memory.add.assert_called_once() + + # Verify that dump was called + mock_kv_cache_memory.dump.assert_called_once() + + print("✅ Activation memory update test passed - DynamicCache layers handled correctly") + + except Exception as e: + self.fail(f"Activation memory update failed: {e}") + + def test_dynamic_cache_layers_access(self): + """Test DynamicCache layers attribute access for compatibility.""" + if not self.RUN_ACTIVATION_MEMORY_TESTS: + self.skipTest( + "Skipping activation memory test. Set RUN_ACTIVATION_MEMORY_TESTS=True to enable." + ) + + from unittest.mock import Mock + + from transformers import DynamicCache + + # Create a real DynamicCache instance + cache = DynamicCache() + + # Check if it has layers attribute (may vary by transformers version) + if hasattr(cache, "layers"): + self.assertIsInstance(cache.layers, list, "DynamicCache.layers should be a list") + + # Test with mock layers + mock_layer = Mock() + mock_layer.key_cache = Mock() + mock_layer.value_cache = Mock() + cache.layers.append(mock_layer) + + # Verify we can access layer attributes + self.assertEqual(len(cache.layers), 1) + self.assertTrue(hasattr(cache.layers[0], "key_cache")) + self.assertTrue(hasattr(cache.layers[0], "value_cache")) + + print("✅ DynamicCache layers access test passed") + else: + # If layers attribute doesn't exist, verify our fix handles this case + print("⚠️ DynamicCache doesn't have 'layers' attribute in this transformers version") + print("✅ Test passed - our code should handle this gracefully") + + def test_get_running_tasks_with_filter(self): + """Test get_running_tasks method with filter function.""" + # Mock dispatcher and its get_running_tasks method + mock_task_item1 = MagicMock() + mock_task_item1.item_id = "task_1" + mock_task_item1.user_id = "user_1" + mock_task_item1.mem_cube_id = "cube_1" + mock_task_item1.task_info = {"type": "query"} + mock_task_item1.task_name = "test_task_1" + mock_task_item1.start_time = datetime.now() + mock_task_item1.end_time = None + mock_task_item1.status = "running" + mock_task_item1.result = None + mock_task_item1.error_message = None + mock_task_item1.messages = [] + + # Define a filter function + def user_filter(task): + return task.user_id == "user_1" + + # Mock the filtered result (only task_1 matches the filter) + with patch.object( + self.scheduler.dispatcher, "get_running_tasks", return_value={"task_1": mock_task_item1} + ) as mock_get_running_tasks: + # Call get_running_tasks with filter + result = self.scheduler.get_running_tasks(filter_func=user_filter) + + # Verify result + self.assertIsInstance(result, dict) + self.assertIn("task_1", result) + self.assertEqual(len(result), 1) + + # Verify dispatcher method was called with filter + mock_get_running_tasks.assert_called_once_with(filter_func=user_filter) + + def test_get_running_tasks_empty_result(self): + """Test get_running_tasks method when no tasks are running.""" + # Mock dispatcher to return empty dict + with patch.object( + self.scheduler.dispatcher, "get_running_tasks", return_value={} + ) as mock_get_running_tasks: + # Call get_running_tasks + result = self.scheduler.get_running_tasks() + + # Verify empty result + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 0) + + # Verify dispatcher method was called + mock_get_running_tasks.assert_called_once_with(filter_func=None) + + def test_get_running_tasks_no_dispatcher(self): + """Test get_running_tasks method when dispatcher is None.""" + # Temporarily set dispatcher to None + original_dispatcher = self.scheduler.dispatcher + self.scheduler.dispatcher = None + + # Call get_running_tasks + result = self.scheduler.get_running_tasks() + + # Verify empty result and warning behavior + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 0) + + # Restore dispatcher + self.scheduler.dispatcher = original_dispatcher + + def test_get_running_tasks_multiple_tasks(self): + """Test get_running_tasks method with multiple tasks.""" + # Mock multiple task items + mock_task_item1 = MagicMock() + mock_task_item1.item_id = "task_1" + mock_task_item1.user_id = "user_1" + mock_task_item1.mem_cube_id = "cube_1" + mock_task_item1.task_info = {"type": "query"} + mock_task_item1.task_name = "test_task_1" + mock_task_item1.start_time = datetime.now() + mock_task_item1.end_time = None + mock_task_item1.status = "running" + mock_task_item1.result = None + mock_task_item1.error_message = None + mock_task_item1.messages = [] + + mock_task_item2 = MagicMock() + mock_task_item2.item_id = "task_2" + mock_task_item2.user_id = "user_2" + mock_task_item2.mem_cube_id = "cube_2" + mock_task_item2.task_info = {"type": "answer"} + mock_task_item2.task_name = "test_task_2" + mock_task_item2.start_time = datetime.now() + mock_task_item2.end_time = None + mock_task_item2.status = "completed" + mock_task_item2.result = "success" + mock_task_item2.error_message = None + mock_task_item2.messages = ["message1", "message2"] + + with patch.object( + self.scheduler.dispatcher, + "get_running_tasks", + return_value={"task_1": mock_task_item1, "task_2": mock_task_item2}, + ) as mock_get_running_tasks: + # Call get_running_tasks + result = self.scheduler.get_running_tasks() + + # Verify result structure + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 2) + self.assertIn("task_1", result) + self.assertIn("task_2", result) + + # Verify task_1 details + task1_dict = result["task_1"] + self.assertEqual(task1_dict["item_id"], "task_1") + self.assertEqual(task1_dict["user_id"], "user_1") + self.assertEqual(task1_dict["status"], "running") + + # Verify task_2 details + task2_dict = result["task_2"] + self.assertEqual(task2_dict["item_id"], "task_2") + self.assertEqual(task2_dict["user_id"], "user_2") + self.assertEqual(task2_dict["status"], "completed") + self.assertEqual(task2_dict["result"], "success") + self.assertEqual(task2_dict["messages"], ["message1", "message2"]) + + # Verify dispatcher method was called + mock_get_running_tasks.assert_called_once_with(filter_func=None) + + def test_message_handler_receives_submitted_message(self): + """Test that handlers receive messages after scheduler startup and message submission.""" + # Create a mock handler that tracks received messages + received_messages = [] + + def mock_handler(messages: list[ScheduleMessageItem]) -> None: + """Mock handler that records received messages.""" + received_messages.extend(messages) + + # Register the mock handler + test_label = "test_handler" + handlers = {test_label: mock_handler} + self.scheduler.register_handlers(handlers) + + # Verify handler is registered + self.assertIn(test_label, self.scheduler.handlers) + self.assertEqual(self.scheduler.handlers[test_label], mock_handler) + + # Start the scheduler + self.scheduler.start() + + # Create and submit a test message + test_message = ScheduleMessageItem( + label=test_label, + content="Test message content", + user_id="test_user", + mem_cube_id="test_mem_cube", + mem_cube="test_mem_cube_obj", # Required field - can be string or GeneralMemCube + timestamp=datetime.now(), + ) + + import asyncio + + asyncio.run(self.scheduler.submit_messages(test_message)) + + # Wait for message processing to complete + import time + + time.sleep(2.0) # Allow sufficient time for message processing + + # Verify the handler received the message + self.assertEqual( + len(received_messages), 1, f"Expected 1 message, got {len(received_messages)}" + ) + self.assertEqual(received_messages[0].label, test_label) + self.assertEqual(received_messages[0].content, "Test message content") + self.assertEqual(received_messages[0].user_id, "test_user") + self.assertEqual(received_messages[0].mem_cube_id, "test_mem_cube") + + # Stop the scheduler + self.scheduler.stop() diff --git a/tests/test_hello_world.py b/tests/test_hello_world.py index 986839bc9..e9c81c7f0 100644 --- a/tests/test_hello_world.py +++ b/tests/test_hello_world.py @@ -118,6 +118,8 @@ def test_memos_yuqingchen_hello_world_logger_called(): def test_memos_chen_tang_hello_world(): + import warnings + from memos.memories.textual.general import GeneralTextMemory # Define return values for os.getenv @@ -130,7 +132,10 @@ def mock_getenv(key, default=None): } return mock_values.get(key, default) - # Use patch to mock os.getenv - with patch("os.getenv", side_effect=mock_getenv): - memory = memos_chentang_hello_world() - assert isinstance(memory, GeneralTextMemory) + # Filter Pydantic serialization warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") + # Use patch to mock os.getenv + with patch("os.getenv", side_effect=mock_getenv): + memory = memos_chentang_hello_world() + assert isinstance(memory, GeneralTextMemory) From 0b2b6ed69d1d8da568b3db987a7fcd216c392e54 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Sat, 25 Oct 2025 17:34:06 +0800 Subject: [PATCH 05/64] Feat/merge inst cplt to dev (#388) * add preference text memory * finish milvus support * add new builder * finish prefer textual memory base level * modify code struct * modify pref module * implement remain preference function * modify preference.py * modify bug in milvus * finish debug * modify user pref user id code * modify bug in milvus * finish debug in core * repair bug in milvus get_all * add pref mem esarch time in core * modify search for pref mem in product..py * add simple pref memos example * modify bug in examples/mem_os/simple_prefs_memos_product.py * repair bug in user id related part * modify search * repair bug in slow update * modify define error in extractor -> extract_implicit_preferences * reapair define error in extractor and modify split func in spliter * modify code * modify adder * optimize the code * repair bug in adder and extractor * finish make test and make pre-commit * repair bug in preference * add memory field for milvusvecdbitem and modify related module * pref code clean * modify prompt of extractor * modify extractor * add reranker to pref mem * remove assember in pref mem * modify code * add op trace based update method in add * modify slow update in adder * modify implicit part code in extractor and add duplicate in utils * modify depulicate threshold * modify api config * reapir bug in adder about search relate * repair bug in core , dupicate search * add pref to new naive cube and server api * add async pref add by mem_schedular * modify * replace print to logger * repair bug from make pre-commit * inst cplt * align to liji cloud server * repair pkg problem * modify example of pref * pre_commit * fix api bug * merge inst_cplt to dev * fix pre commit * fix pre commit * fix pre commit error * modify code fllow reviewer * fix bug in make pre_commit * repair bug in server router * fix pre commit bug --------- Co-authored-by: yuan.wang --- docker/requirements.txt | 2 +- docs/openapi.json | 8 +- evaluation/.env-example | 12 +- evaluation/scripts/PrefEval/pref_eval.py | 44 +- evaluation/scripts/PrefEval/pref_mem0.py | 6 +- evaluation/scripts/PrefEval/pref_memobase.py | 10 +- evaluation/scripts/PrefEval/pref_memos.py | 26 +- evaluation/scripts/PrefEval/pref_memu.py | 10 +- .../scripts/PrefEval/pref_supermemory.py | 8 +- evaluation/scripts/PrefEval/pref_zep.py | 10 +- .../scripts/PrefEval/prefeval_preprocess.py | 5 +- evaluation/scripts/locomo/locomo_ingestion.py | 10 +- evaluation/scripts/locomo/locomo_responses.py | 22 +- evaluation/scripts/locomo/locomo_search.py | 18 +- evaluation/scripts/locomo/prompts.py | 24 +- evaluation/scripts/longmemeval/lme_eval.py | 1 + .../scripts/longmemeval/lme_ingestion.py | 6 +- .../scripts/longmemeval/lme_responses.py | 14 +- evaluation/scripts/longmemeval/lme_search.py | 7 +- evaluation/scripts/personamem/pm_ingestion.py | 17 +- evaluation/scripts/personamem/pm_metric.py | 6 +- evaluation/scripts/personamem/pm_responses.py | 21 +- evaluation/scripts/personamem/pm_search.py | 20 +- evaluation/scripts/run_pm_eval.sh | 4 +- evaluation/scripts/run_prefeval_eval.sh | 8 +- evaluation/scripts/utils/client.py | 6 +- evaluation/scripts/utils/mirix_utils.py | 11 +- evaluation/scripts/utils/pref_mem_utils.py | 43 ++ evaluation/scripts/utils/prompts.py | 21 +- examples/mem_os/simple_prefs_memos_product.py | 399 ++++++++++++++++++ poetry.lock | 70 ++- pyproject.toml | 8 + src/memos/api/config.py | 45 ++ src/memos/api/product_models.py | 1 + src/memos/api/routers/server_router.py | 236 +++++++++-- src/memos/configs/mem_cube.py | 16 + src/memos/configs/mem_os.py | 4 + src/memos/configs/memory.py | 45 ++ src/memos/mem_cube/base.py | 1 + src/memos/mem_cube/general.py | 47 ++- src/memos/mem_cube/navie.py | 50 ++- src/memos/mem_os/core.py | 228 ++++++---- src/memos/mem_os/product.py | 18 + src/memos/mem_scheduler/general_scheduler.py | 48 ++- .../mem_scheduler/schemas/general_schemas.py | 2 +- .../mem_scheduler/schemas/message_schemas.py | 5 + src/memos/memories/factory.py | 4 + src/memos/memories/textual/item.py | 20 +- .../textual/prefer_text_memory/__init__.py | 0 .../textual/prefer_text_memory/adder.py | 284 +++++++++++++ .../textual/prefer_text_memory/config.py | 106 +++++ .../textual/prefer_text_memory/extractor.py | 184 ++++++++ .../textual/prefer_text_memory/factory.py | 78 ++++ .../textual/prefer_text_memory/retrievers.py | 88 ++++ .../textual/prefer_text_memory/spliter.py | 132 ++++++ .../textual/prefer_text_memory/utils.py | 70 +++ src/memos/memories/textual/preference.py | 283 +++++++++++++ .../memories/textual/simple_preference.py | 156 +++++++ src/memos/templates/instruction_completion.py | 43 ++ src/memos/templates/prefer_complete_prompt.py | 250 +++++++++++ src/memos/vec_dbs/factory.py | 2 + src/memos/vec_dbs/item.py | 6 + src/memos/vec_dbs/milvus.py | 45 +- tests/configs/test_mem_cube.py | 2 +- 64 files changed, 3105 insertions(+), 271 deletions(-) mode change 100644 => 100755 evaluation/scripts/run_prefeval_eval.sh create mode 100644 evaluation/scripts/utils/pref_mem_utils.py create mode 100644 examples/mem_os/simple_prefs_memos_product.py create mode 100644 src/memos/memories/textual/prefer_text_memory/__init__.py create mode 100644 src/memos/memories/textual/prefer_text_memory/adder.py create mode 100644 src/memos/memories/textual/prefer_text_memory/config.py create mode 100644 src/memos/memories/textual/prefer_text_memory/extractor.py create mode 100644 src/memos/memories/textual/prefer_text_memory/factory.py create mode 100644 src/memos/memories/textual/prefer_text_memory/retrievers.py create mode 100644 src/memos/memories/textual/prefer_text_memory/spliter.py create mode 100644 src/memos/memories/textual/prefer_text_memory/utils.py create mode 100644 src/memos/memories/textual/preference.py create mode 100644 src/memos/memories/textual/simple_preference.py create mode 100644 src/memos/templates/instruction_completion.py create mode 100644 src/memos/templates/prefer_complete_prompt.py diff --git a/docker/requirements.txt b/docker/requirements.txt index d20c0b36e..4846f1832 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -157,4 +157,4 @@ volcengine-python-sdk==4.0.6 watchfiles==1.1.0 websockets==15.0.1 xlrd==2.0.2 -xlsxwriter==3.2.5 \ No newline at end of file +xlsxwriter==3.2.5 diff --git a/docs/openapi.json b/docs/openapi.json index 5a3471ac0..ee2ff1368 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -884,7 +884,7 @@ "type": "string", "title": "Session Id", "description": "Session ID for the MOS. This is used to distinguish between different dialogue", - "default": "0ce84b9c-0615-4b9d-83dd-fba50537d5d3" + "default": "41bb5e18-252d-4948-918c-07d82aa47086" }, "chat_model": { "$ref": "#/components/schemas/LLMConfigFactory", @@ -939,6 +939,12 @@ "description": "Enable parametric memory for the MemChat", "default": false }, + "enable_preference_memory": { + "type": "boolean", + "title": "Enable Preference Memory", + "description": "Enable preference memory for the MemChat", + "default": false + }, "enable_mem_scheduler": { "type": "boolean", "title": "Enable Mem Scheduler", diff --git a/evaluation/.env-example b/evaluation/.env-example index 4b2b9311f..bda935442 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -22,9 +22,13 @@ SUPERMEMORY_API_KEY="sm_xxx" MEMOBASE_API_KEY="xxx" MEMOBASE_PROJECT_URL="http://***.***.***.***:8019" -# eval settings -PRE_SPLIT_CHUNK=false - +# pref +PRE_SPLIT_CHUNK=false # pre split chunk in client end, for personamem and prefeval +# 1. text_mem + pref_mem + instruction_completion: set INSTRUCT_COMPLETE=true, ABLATION_PREF=false +# 2. text_mem + pref_mem: set INSTRUCT_COMPLETE=false, ABLATION_PREF=false +# 3. text_mem: set INSTRUCT_COMPLETE=false, ABLATION_PREF=true +INSTRUCT_COMPLETE=true # use instruct complete format or not +ABLATION_PREF=false # remove pref mem, only text mem # Configuration Only For Scheduler # RabbitMQ Configuration @@ -45,4 +49,4 @@ MEMSCHEDULER_GRAPHDBAUTH_URI=bolt://localhost:7687 MEMSCHEDULER_GRAPHDBAUTH_USER=neo4j MEMSCHEDULER_GRAPHDBAUTH_PASSWORD=*** MEMSCHEDULER_GRAPHDBAUTH_DB_NAME=neo4j -MEMSCHEDULER_GRAPHDBAUTH_AUTO_CREATE=true \ No newline at end of file +MEMSCHEDULER_GRAPHDBAUTH_AUTO_CREATE=true diff --git a/evaluation/scripts/PrefEval/pref_eval.py b/evaluation/scripts/PrefEval/pref_eval.py index 10cf41bf3..f1966b847 100644 --- a/evaluation/scripts/PrefEval/pref_eval.py +++ b/evaluation/scripts/PrefEval/pref_eval.py @@ -1,14 +1,18 @@ +import argparse import asyncio import json +import os import re -import argparse -from typing import Dict, Any + from collections import Counter -from tqdm.asyncio import tqdm -import os +from typing import Any + import pandas as pd + from dotenv import load_dotenv from openai import OpenAI +from tqdm.asyncio import tqdm + load_dotenv() @@ -30,7 +34,7 @@ async def call_gpt4o_mini_async(client: OpenAI, prompt: str) -> str: ) return response.choices[0].message.content except Exception as e: - print(f"API call failed: {str(e)}") + print(f"API call failed: {e!s}") return "" @@ -41,7 +45,7 @@ def parse_xml_response(response: str, tag: str) -> str: async def evaluate_violate_preference_async( client: OpenAI, preference: str, question: str, response: str -) -> Dict[str, str]: +) -> dict[str, str]: prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's response violates the user's stated preference. Evaluate the response based on these criteria: @@ -73,8 +77,8 @@ async def evaluate_violate_preference_async( async def evaluate_acknowledge_preference_async( client: OpenAI, question: str, response: str -) -> Dict[str, str]: - prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query. +) -> dict[str, str]: + prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query. Evaluate the response based on these stringent criteria: @@ -82,7 +86,7 @@ async def evaluate_acknowledge_preference_async( Answer "Yes" if: -1. The response explicitly mentions or refers to a user preference in answering the user's question. Examples include: "Based on your previous preference for xxxx"; "Based on your previous preference, I would suggest you xxxx"; "Since you mentioned you prefer/dislike xxxx"; "Since you are a xxxx"; "I will recommend the following given your preference for xxx", etc. You should extract the "xxxx" in your answer. If it only says "Based on our previous conversation, I recommend..." and does not explicitly mention any preference, you should answer 'No'. +1. The response explicitly mentions or refers to a user preference in answering the user's question. Examples include: "Based on your previous preference for xxxx"; "Based on your previous preference, I would suggest you xxxx"; "Since you mentioned you prefer/dislike xxxx"; "Since you are a xxxx"; "I will recommend the following given your preference for xxx", etc. You should extract the "xxxx" in your answer. If it only says "Based on our previous conversation, I recommend..." and does not explicitly mention any preference, you should answer 'No'. 2. The response assumes the user preference in answering the user's question implicitly. For example, when the user asks 'Can you recommend me cars to drive?', if the response is 'Based on your preference, I will recommend non-EV cars, ...', then this indicates the assistant assumes that the user's preference is a dislike of EV cars, and you should answer 'Yes'. Answer "No" if the response does not mention or refer to any user preference explicitly or implicitly. If you cannot extract the sentence stating what the preference is, answer 'No'. @@ -103,7 +107,7 @@ async def evaluate_acknowledge_preference_async( async def evaluate_hallucinate_preference_async( client: OpenAI, preference: str, restatement: str -) -> Dict[str, str]: +) -> dict[str, str]: if not restatement.strip(): return {"explanation": "No restatement provided by assistant", "answer": "No"} prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's restatement of the user's stated preference is the same preference. Evaluate the response based on these stringent criteria to answer if the assistant has hallucinated the preference or not: @@ -120,10 +124,10 @@ async def evaluate_hallucinate_preference_async( 2. The assistant's restatement is a minor paraphrase that fully preserves the meaning and intent of the original preference. 3. The restatement is just empty, no hallucination. -Here is the information: -Original user preference: "{preference}" -Assistant's restatement: "{restatement}" -Examine the original preference and the assistant's restatement meticulously and answer. Please answer in this exact XML format without any other additional text: +Here is the information: +Original user preference: "{preference}" +Assistant's restatement: "{restatement}" +Examine the original preference and the assistant's restatement meticulously and answer. Please answer in this exact XML format without any other additional text: [1 short sentence explanation] [Yes/No]""" @@ -136,7 +140,7 @@ async def evaluate_hallucinate_preference_async( async def evaluate_helpful_response_async( client: OpenAI, question: str, response: str -) -> Dict[str, str]: +) -> dict[str, str]: prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant provides any substantive response to the user's query. Evaluate the response based on these stringent criteria: @@ -174,7 +178,7 @@ async def evaluate_helpful_response_async( } -def classify_error_type(evaluation_results: Dict[str, Any]) -> str: +def classify_error_type(evaluation_results: dict[str, Any]) -> str: violate = evaluation_results["violate_preference"]["answer"] acknowledge = evaluation_results["acknowledge_preference"]["answer"] hallucinate = evaluation_results["hallucinate_preference"]["answer"] @@ -192,7 +196,7 @@ def classify_error_type(evaluation_results: Dict[str, Any]) -> str: return "Personalized Response" -async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> Dict[str, Any]: +async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> dict[str, Any]: async with semaphore: data = json.loads(line.strip()) preference = data["preference"] @@ -223,7 +227,7 @@ async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) return result -def log_summary(error_counter: Counter, total_samples: int) -> Dict[str, Dict[str, float]]: +def log_summary(error_counter: Counter, total_samples: int) -> dict[str, dict[str, float]]: summary_data = {} print("\n--- Error Type Summary ---") @@ -247,7 +251,7 @@ def log_summary(error_counter: Counter, total_samples: int) -> Dict[str, Dict[st def generate_excel_summary( - summary_results: Dict[str, Dict[str, float]], + summary_results: dict[str, dict[str, float]], avg_search_time: float, avg_context_tokens: float, avg_add_time: float, @@ -317,7 +321,7 @@ async def main(concurrency_limit: int, input_file: str, output_file: str, output client = OpenAI(api_key=API_KEY, base_url=API_URL) try: - with open(input_file, "r", encoding="utf-8") as f: + with open(input_file, encoding="utf-8") as f: lines = f.readlines() except FileNotFoundError: print(f"Error: Input file not found at '{input_file}'") diff --git a/evaluation/scripts/PrefEval/pref_mem0.py b/evaluation/scripts/PrefEval/pref_mem0.py index 416d8045f..4bbdb0fd8 100644 --- a/evaluation/scripts/PrefEval/pref_mem0.py +++ b/evaluation/scripts/PrefEval/pref_mem0.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from irrelevant_conv import irre_10, irre_300 ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -199,7 +201,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_memobase.py b/evaluation/scripts/PrefEval/pref_memobase.py index 34d3ea86f..4f6174d3d 100644 --- a/evaluation/scripts/PrefEval/pref_memobase.py +++ b/evaluation/scripts/PrefEval/pref_memobase.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -import time -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -49,7 +51,7 @@ def add_memory_for_line( if conversation: messages = [] - for chunk_start in range(0, len(conversation)): + for chunk_start in range(len(conversation)): chunk = conversation[chunk_start : chunk_start + 1] timestamp_add = str(int(time.time() * 100)) time.sleep(0.001) # Ensure unique timestamp @@ -210,7 +212,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 5ee064b1f..753a77d99 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from irrelevant_conv import irre_10, irre_300 ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -18,6 +20,8 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) + + load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") BASE_URL = os.getenv("OPENAI_BASE_URL") @@ -68,6 +72,8 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di """ Processes a single line of data, searching memory based on the question. """ + from utils.pref_mem_utils import create_mem_string + i, line = line_data try: original_data = json.loads(line) @@ -88,9 +94,7 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di start_time_search = time.monotonic() relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) search_memories_duration = time.monotonic() - start_time_search - memories_str = "\n".join( - f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] - ) + memories_str = create_mem_string(relevant_memories) memory_tokens_used = len(tokenizer.encode(memories_str)) @@ -111,10 +115,13 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di return None -def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: +def generate_response_for_line(line_data: tuple, openai_client: OpenAI, lib: str) -> dict: """ Generates a response for a single line of data using pre-fetched memories. """ + from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string + from utils.prompts import PREFEVAL_ANSWER_PROMPT + i, line = line_data try: original_data = json.loads(line) @@ -139,7 +146,10 @@ def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: ) return original_data - system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" + memories_str = remove_pref_mem_from_mem_string(memories_str, frame=lib) + + template = add_pref_instruction(PREFEVAL_ANSWER_PROMPT, frame=lib) + system_prompt = template.format(context=memories_str) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, @@ -201,7 +211,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") @@ -277,7 +287,7 @@ def main(): concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, ): futures = [ - executor.submit(generate_response_for_line, (i, line), openai_client) + executor.submit(generate_response_for_line, (i, line), openai_client, args.lib) for i, line in enumerate(lines) ] diff --git a/evaluation/scripts/PrefEval/pref_memu.py b/evaluation/scripts/PrefEval/pref_memu.py index 719f2b488..2b9f769a4 100644 --- a/evaluation/scripts/PrefEval/pref_memu.py +++ b/evaluation/scripts/PrefEval/pref_memu.py @@ -4,12 +4,16 @@ import os import sys import time + +from datetime import datetime + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -205,7 +209,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_supermemory.py b/evaluation/scripts/PrefEval/pref_supermemory.py index 85e84b6c9..88a64038b 100644 --- a/evaluation/scripts/PrefEval/pref_supermemory.py +++ b/evaluation/scripts/PrefEval/pref_supermemory.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -201,7 +203,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_zep.py b/evaluation/scripts/PrefEval/pref_zep.py index 699660787..91aef1492 100644 --- a/evaluation/scripts/PrefEval/pref_zep.py +++ b/evaluation/scripts/PrefEval/pref_zep.py @@ -4,12 +4,16 @@ import os import sys import time + +from datetime import datetime + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -211,7 +215,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/prefeval_preprocess.py b/evaluation/scripts/PrefEval/prefeval_preprocess.py index 004d5e505..9ace9dec9 100644 --- a/evaluation/scripts/PrefEval/prefeval_preprocess.py +++ b/evaluation/scripts/PrefEval/prefeval_preprocess.py @@ -1,7 +1,8 @@ -from datasets import load_dataset import json import os +from datasets import load_dataset + def convert_dataset_to_jsonl(dataset_name, output_dir="./scripts/PrefEval"): if not os.path.exists(output_dir): @@ -64,7 +65,7 @@ def process_jsonl_file(input_filepath, output_filepath): line_count = 0 print(f"Start processing file: {input_filepath}") with ( - open(input_filepath, "r", encoding="utf-8") as infile, + open(input_filepath, encoding="utf-8") as infile, open(output_filepath, "w", encoding="utf-8") as outfile, ): for line in infile: diff --git a/evaluation/scripts/locomo/locomo_ingestion.py b/evaluation/scripts/locomo/locomo_ingestion.py index edb451dc0..fe7aa86f7 100644 --- a/evaluation/scripts/locomo/locomo_ingestion.py +++ b/evaluation/scripts/locomo/locomo_ingestion.py @@ -1,12 +1,16 @@ -import os -import sys import argparse import concurrent.futures +import os +import sys import time + from datetime import datetime, timezone + import pandas as pd + from dotenv import load_dotenv + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -88,8 +92,8 @@ def process_user(conv_idx, frame, locomo_df, version): client = None if frame == "mem0" or frame == "mem0_graph": - from utils.client import Mem0Client from prompts import custom_instructions + from utils.client import Mem0Client client = Mem0Client(enable_graph="graph" in frame) client.client.update_project(custom_instructions=custom_instructions) diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index 4e3b966a3..2ae4dcb6e 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -2,6 +2,7 @@ import asyncio import json import os +import sys from time import time @@ -13,6 +14,15 @@ from tqdm import tqdm +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + async def locomo_response(frame, llm_client, context: str, question: str) -> str: if frame == "zep": prompt = ANSWER_PROMPT_ZEP.format( @@ -25,7 +35,10 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str question=question, ) else: - prompt = ANSWER_PROMPT_MEMOS.format( + from utils.pref_mem_utils import add_pref_instruction + + template = add_pref_instruction(ANSWER_PROMPT_MEMOS, frame=frame) + prompt = template.format( context=context, question=question, ) @@ -42,12 +55,17 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str async def process_qa(frame, qa, search_result, oai_client): + from utils.pref_mem_utils import remove_pref_mem_from_mem_string + start = time() query = qa.get("question") gold_answer = qa.get("answer") qa_category = qa.get("category") - answer = await locomo_response(frame, oai_client, search_result.get("context"), query) + context = search_result.get("context") + + context = remove_pref_mem_from_mem_string(context, frame) + answer = await locomo_response(frame, oai_client, context, query) response_duration_ms = (time() - start) * 1000 diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index 452fb4762..19efb5b92 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -1,14 +1,18 @@ -import os -import sys import argparse import json +import os +import sys + from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from time import time + import pandas as pd + from dotenv import load_dotenv from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -96,16 +100,14 @@ def memos_api_search( client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b ): from prompts import TEMPLATE_MEMOS + from utils.pref_mem_utils import create_mem_string start = time() search_a_results = client.search(query=query, user_id=speaker_a_user_id, top_k=top_k) search_b_results = client.search(query=query, user_id=speaker_b_user_id, top_k=top_k) - speaker_a_context = "\n".join( - [i["memory"] for i in search_a_results["text_mem"][0]["memories"]] - ) - speaker_b_context = "\n".join( - [i["memory"] for i in search_b_results["text_mem"][0]["memories"]] - ) + + speaker_a_context = create_mem_string(search_a_results) + speaker_b_context = create_mem_string(search_b_results) context = TEMPLATE_MEMOS.format( speaker_1=speaker_a, diff --git a/evaluation/scripts/locomo/prompts.py b/evaluation/scripts/locomo/prompts.py index 2827716a0..caf462f6a 100644 --- a/evaluation/scripts/locomo/prompts.py +++ b/evaluation/scripts/locomo/prompts.py @@ -1,3 +1,14 @@ +import os + + +PREF_INSTRUCTIONS = """ + # Note: + Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. + Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. + When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" + + ANSWER_PROMPT_MEM0 = """ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. @@ -49,12 +60,12 @@ 5. Always convert relative time references to specific dates, months, or years. 6. Be as specific as possible when talking about people, places, and events 7. Timestamps in memories represent the actual time the event occurred, not the time the event was mentioned in a message. - + Clarification: When interpreting memories, use the timestamp to determine when the described event happened, not when someone talked about the event. - + Example: - + Memory: (2023-03-15T16:33:00Z) I went to the vet yesterday. Question: What day did I go to the vet? Correct Answer: March 15, 2023 @@ -103,7 +114,7 @@ 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. 7. Ensure your final answer is specific and avoids vague time references. - + {pref_instructions} {context} Question: {question} @@ -111,6 +122,11 @@ Answer: """ +if os.getenv("INSTRUCT_COMPLETE") == "true": + ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", PREF_INSTRUCTIONS) +else: + ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", "") + custom_instructions = """ Generate personal memories that follow these guidelines: diff --git a/evaluation/scripts/longmemeval/lme_eval.py b/evaluation/scripts/longmemeval/lme_eval.py index 45c038a2b..73117b925 100644 --- a/evaluation/scripts/longmemeval/lme_eval.py +++ b/evaluation/scripts/longmemeval/lme_eval.py @@ -26,6 +26,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from utils.prompts import LME_JUDGE_MODEL_TEMPLATE + encoding = tiktoken.get_encoding("cl100k_base") logging.basicConfig(level=logging.CRITICAL) transformers.logging.set_verbosity_error() diff --git a/evaluation/scripts/longmemeval/lme_ingestion.py b/evaluation/scripts/longmemeval/lme_ingestion.py index a1849757d..325178292 100644 --- a/evaluation/scripts/longmemeval/lme_ingestion.py +++ b/evaluation/scripts/longmemeval/lme_ingestion.py @@ -1,11 +1,15 @@ import argparse import os import sys + from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone + import pandas as pd + from tqdm import tqdm + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -126,7 +130,7 @@ def main(frame, version, num_workers=2): success_records = [] record_file = f"results/lme/{frame}-{version}/success_records.txt" if os.path.exists(record_file): - with open(record_file, "r") as f: + with open(record_file) as f: for i in f.readlines(): success_records.append(i.strip()) diff --git a/evaluation/scripts/longmemeval/lme_responses.py b/evaluation/scripts/longmemeval/lme_responses.py index 3df3e2da4..22f17c304 100644 --- a/evaluation/scripts/longmemeval/lme_responses.py +++ b/evaluation/scripts/longmemeval/lme_responses.py @@ -12,16 +12,17 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string from utils.prompts import LME_ANSWER_PROMPT -def lme_response(llm_client, context, question, question_date): - prompt = LME_ANSWER_PROMPT.format( +def lme_response(llm_client, context, question, question_date, frame): + template = add_pref_instruction(LME_ANSWER_PROMPT, frame=frame) + prompt = template.format( question=question, question_date=question_date, context=context, ) - response = llm_client.chat.completions.create( model=os.getenv("CHAT_MODEL"), messages=[ @@ -34,13 +35,14 @@ def lme_response(llm_client, context, question, question_date): return result -def process_qa(user_id, search_result, llm_client): +def process_qa(user_id, search_result, llm_client, frame): start = time() search_result = search_result[0] question = search_result.get("question") question_date = search_result.get("date") context = search_result.get("search_context", "") - anwer = lme_response(llm_client, context, question, question_date) + context = remove_pref_mem_from_mem_string(context, frame=frame) + anwer = lme_response(llm_client, context, question, question_date, frame) response_duration_ms = (time() - start) * 1000 @@ -95,7 +97,7 @@ def main(frame, version, num_workers=4): future_to_user_id = {} for user_id, search_results in lme_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, oai_client) + future = executor.submit(process_qa, user_id, search_results, oai_client, frame) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index 67d2f1b04..d21795eef 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -3,6 +3,7 @@ import os import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed @@ -10,13 +11,13 @@ from time import time import pandas as pd + from tqdm import tqdm +from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, - MEMOBASE_CONTEXT_TEMPLATE, MEMOS_CONTEXT_TEMPLATE, - ZEP_CONTEXT_TEMPLATE, ) @@ -44,7 +45,7 @@ def mem0_search(client, query, user_id, top_k): def memos_search(client, query, user_id, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) - context = "\n".join([i["memory"] for i in results["text_mem"][0]["memories"]]) + context = create_mem_string(results) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context) duration_ms = (time() - start) * 1000 return context, duration_ms diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index 8de23937c..5204b5c2a 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -3,10 +3,13 @@ import json import os import sys +import time + from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime + from tqdm import tqdm -import time + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -34,7 +37,7 @@ def ingest_session(session, user_id, session_id, frame, client): client.add(messages=session, user_id=user_id, conv_id=session_id) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(session)} messages") elif frame == "memobase": - for idx, msg in enumerate(session): + for _idx, msg in enumerate(session): if msg["role"] != "system": messages.append( { @@ -67,7 +70,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -79,14 +82,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -98,7 +101,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -118,7 +121,7 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as f: + with open(csv_path, newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 diff --git a/evaluation/scripts/personamem/pm_metric.py b/evaluation/scripts/personamem/pm_metric.py index 653c5fc10..e88c538d4 100644 --- a/evaluation/scripts/personamem/pm_metric.py +++ b/evaluation/scripts/personamem/pm_metric.py @@ -44,7 +44,7 @@ def save_to_excel(results, output_path): category_row[f"response_{metric}"] = value # Add search duration metrics (if exists) - if "search_duration" in scores and scores["search_duration"]: + if scores.get("search_duration"): for metric, value in scores["search_duration"].items(): category_row[f"search_{metric}"] = value @@ -80,7 +80,7 @@ def calculate_scores(data, grade_path, output_path): print(f"📋 Processing response data for {len(data)} users...") # First pass: determine number of runs and initialize run accuracy arrays - for user_id, user_data in data.items(): + for _user_id, user_data in data.items(): # Skip incomplete data (users with only topic field) if len(user_data) <= 2 and "topic" in user_data: continue @@ -371,7 +371,7 @@ def print_summary(results): print(f"📂 Loading response data from: {responses_path}") try: - with open(responses_path, "r", encoding="utf-8") as file: + with open(responses_path, encoding="utf-8") as file: data = json.load(file) # Calculate metrics diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index 8bfeaf5f6..5b54f9bb8 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -10,11 +10,13 @@ from openai import OpenAI from tqdm import tqdm -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils.prompts import PM_ANSWER_PROMPT +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re +from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string +from utils.prompts import PM_ANSWER_PROMPT + def extract_choice_answer(predicted_answer, correct_answer): def _extract_only_options(text): @@ -47,8 +49,9 @@ def _extract_only_options(text): return False, predicted_answer -def pm_response(llm_client, context, question, options): - prompt = PM_ANSWER_PROMPT.format( +def pm_response(llm_client, context, question, options, frame): + template = add_pref_instruction(PM_ANSWER_PROMPT, frame=frame) + prompt = template.format( question=question, context=context, options=options, @@ -65,17 +68,19 @@ def pm_response(llm_client, context, question, options): return result -def process_qa(user_id, search_result, num_runs, llm_client): +def process_qa(user_id, search_result, num_runs, llm_client, frame): search_result = search_result[0] question = search_result.get("question") context = search_result.get("search_context", "") options = search_result.get("all_options", []) + context = remove_pref_mem_from_mem_string(context, frame=frame) + run_results = [] for idx in range(num_runs): start = time() - answer = pm_response(llm_client, context, question, options) + answer = pm_response(llm_client, context, question, options, frame) is_correct, answer = extract_choice_answer(answer, search_result.get("golden_answer", "")) response_duration_ms = (time() - start) * 1000 @@ -149,7 +154,9 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id = {} for user_id, search_results in pm_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, num_runs, oai_client) + future = executor.submit( + process_qa, user_id, search_results, num_runs, oai_client, frame + ) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index 2e1a268fc..243c64589 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -1,16 +1,20 @@ import argparse +import csv import json import os import sys + from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from time import time + from tqdm import tqdm -import csv + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, @@ -79,9 +83,7 @@ def memobase_search(client, query, user_id, top_k): def memos_search(client, user_id, query, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) - search_memories = "\n".join( - item["memory"] for cube in results["text_mem"] for item in cube["memories"] - ) + search_memories = create_mem_string(results) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) duration_ms = (time() - start) * 1000 @@ -109,7 +111,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -121,14 +123,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -140,7 +142,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -163,7 +165,7 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as f: + with open(csv_path, newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 diff --git a/evaluation/scripts/run_pm_eval.sh b/evaluation/scripts/run_pm_eval.sh index f83893fed..a46440bfc 100755 --- a/evaluation/scripts/run_pm_eval.sh +++ b/evaluation/scripts/run_pm_eval.sh @@ -1,7 +1,7 @@ #!/bin/bash # Common parameters for all scripts -LIB="memu" +LIB="memos-api" VERSION="072202" WORKERS=10 TOPK=20 @@ -62,4 +62,4 @@ else fi fi -echo "All scripts completed successfully!" \ No newline at end of file +echo "All scripts completed successfully!" diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh old mode 100644 new mode 100755 index 001f8299d..a79cefcc2 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -11,13 +11,13 @@ WORKERS=10 # Parameters for pref_memos.py TOP_K=6 ADD_TURN=0 # Options: 0, 10, or 300 -LIB="memos-api" +LIB="memos-api" VERSION="1022-0" # --- File Paths --- # You may need to adjust these paths based on your project structure. # Step 1 (preprocess) outputs this file: -PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" +PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" # Create a directory name based on the *specific* LIB (e.g., "memos") OUTPUT_DIR="results/prefeval/${LIB}_${VERSION}" @@ -54,7 +54,7 @@ export HF_ENDPOINT="https://hf-mirror.com" echo "--- Starting PrefEval Pipeline ---" echo "Configuration: WORKERS=$WORKERS, TOP_K=$TOP_K, ADD_TURN=$ADD_TURN, LIB=$LIB, VERSION=$VERSION, HF_ENDPOINT=$HF_ENDPOINT" echo "Results will be saved to: $OUTPUT_DIR" -echo "Using script: $LIB_SCRIPT (mapped from LIB=$LIB)" +echo "Using script: $LIB_SCRIPT (mapped from LIB=$LIB)" echo "" # --- Step 1: Preprocess the data --- @@ -134,7 +134,7 @@ echo "Running pref_eval.py..." python scripts/PrefEval/pref_eval.py \ --input $RESPONSE_FILE \ --concurrency-limit $WORKERS - + if [ $? -ne 0 ]; then echo "Error: Evaluation script failed." exit 1 diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 2efb0493d..ffc9dda12 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -3,11 +3,15 @@ import sys import time import uuid + from contextlib import suppress from datetime import datetime -from dotenv import load_dotenv + import requests +from dotenv import load_dotenv + + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) load_dotenv() diff --git a/evaluation/scripts/utils/mirix_utils.py b/evaluation/scripts/utils/mirix_utils.py index e1b5f3de6..63cd490df 100644 --- a/evaluation/scripts/utils/mirix_utils.py +++ b/evaluation/scripts/utils/mirix_utils.py @@ -1,18 +1,21 @@ import os + import yaml + from tqdm import tqdm def get_mirix_client(config_path, load_from=None): - if os.path.exists(os.path.expanduser(f"~/.mirix")): - os.system(f"rm -rf ~/.mirix/*") + if os.path.exists(os.path.expanduser("~/.mirix")): + os.system("rm -rf ~/.mirix/*") - with open(config_path, "r") as f: + with open(config_path) as f: agent_config = yaml.safe_load(f) os.environ["OPENAI_API_KEY"] = agent_config["api_key"] import mirix - from mirix import Mirix, EmbeddingConfig, LLMConfig + + from mirix import EmbeddingConfig, LLMConfig, Mirix embedding_default_config = EmbeddingConfig( embedding_model=agent_config["embedding_model_name"], diff --git a/evaluation/scripts/utils/pref_mem_utils.py b/evaluation/scripts/utils/pref_mem_utils.py new file mode 100644 index 000000000..22a5bb86c --- /dev/null +++ b/evaluation/scripts/utils/pref_mem_utils.py @@ -0,0 +1,43 @@ +import os +import sys + + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from prompts import PREF_INSTRUCTIONS + + +def create_mem_string(relevant_memories) -> str: + text_memories = [] + explicit = [] + implicit = [] + for item in relevant_memories["text_mem"]: + for mem in item["memories"]: + text_memories.append(mem["memory"]) + text_memories_text = "\n".join(f"{i + 1}. {mem}" for i, mem in enumerate(text_memories)).strip() + text_context = f"Plaintext Memory:\n{text_memories_text}\n" if text_memories_text else "" + + for item in relevant_memories.get("prefs", []): + for mem in item["memories"]: + if mem["metadata"]["preference_type"] == "explicit_preference": + explicit.append(mem["metadata"]["explicit_preference"]) + elif mem["metadata"]["preference_type"] == "implicit_preference": + implicit.append(mem["metadata"]["implicit_preference"]) + explicit_text = "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(explicit)).strip() + explicit_context = f"Explicit Preference:\n{explicit_text}\n" if explicit_text else "" + implicit_text = "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(implicit)).strip() + implicit_context = f"Implicit Preference:\n{implicit_text}\n" if implicit_text else "" + return text_context + explicit_context + implicit_context + + +def remove_pref_mem_from_mem_string(mem_string: str, frame: str) -> str: + if os.getenv("ABLATION_PREF", "false").lower() == "true" and frame == "memos-api": + tmp_list = mem_string.split("Plaintext Memory:") + if len(tmp_list) > 1: + return tmp_list[1].split("Explicit Preference:")[0] + return mem_string + + +def add_pref_instruction(template: str, frame: str): + if os.getenv("INSTRUCT_COMPLETE", "false").lower() == "true" and frame == "memos-api": + return template.replace("{pref_instructions}", PREF_INSTRUCTIONS) + return template.replace("{pref_instructions}", "") diff --git a/evaluation/scripts/utils/prompts.py b/evaluation/scripts/utils/prompts.py index bd418af54..902bbb1be 100644 --- a/evaluation/scripts/utils/prompts.py +++ b/evaluation/scripts/utils/prompts.py @@ -1,3 +1,11 @@ +PREF_INSTRUCTIONS = """ + # Note: + Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. + Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. + When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" + + LME_ANSWER_PROMPT = """ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. @@ -17,7 +25,7 @@ 5. Formulate a precise, concise answer based solely on the evidence in the memories. 6. Double-check that your answer directly addresses the question asked. 7. Ensure your final answer is specific and avoids vague time references. - + {pref_instructions} {context} Current Date: {question_date} @@ -27,6 +35,7 @@ Answer: """ + PM_ANSWER_PROMPT = """ You are a helpful assistant tasked with selecting the best answer to a user question, based solely on summarized conversation memories. @@ -46,7 +55,7 @@ - Your final answer **must use parentheses**, like (a) or (b). - Do NOT list multiple choices. Choose only one. - Do NOT include extra text after . Just output the answer. - + {pref_instructions} # QUESTION: {question} @@ -58,6 +67,14 @@ """ +PREFEVAL_ANSWER_PROMPT = """ + You are a helpful AI. Answer the question based on the query and the following memories: + User Memories: + {context} + {pref_instructions} +""" + + ZEP_CONTEXT_TEMPLATE = """ FACTS and ENTITIES represent relevant context to the current conversation. diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py new file mode 100644 index 000000000..40ec920f5 --- /dev/null +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -0,0 +1,399 @@ +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.product import MOSProduct + + +def get_config(user_id: str): + llm_config = { + "backend": "openai", + "config": { + "model_name_or_path": "gpt-4o-mini", + "api_key": "sk-xxxxx", + "api_base": "http://xxxx/v1", + "temperature": 0.1, + "remove_think_prefix": True, + "max_tokens": 4096, + }, + } + + embedder_config = { + "backend": "ollama", + "config": {"model_name_or_path": "nomic-embed-text:latest"}, + } + + # init MOS + mos_config = { + "user_id": user_id, + "chat_model": llm_config, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": llm_config, + "embedder": embedder_config, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, + "enable_preference_memory": True, + } + + cube_config = { + "model_schema": "memos.configs.mem_cube.GeneralMemCubeConfig", + "user_id": user_id, + "cube_id": f"{user_id}/mem_cube", + "text_mem": { + "backend": "tree_text", + "config": { + "cube_id": f"{user_id}/mem_cube", + "extractor_llm": llm_config, + "dispatcher_llm": llm_config, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": "bolt://localhost:7687", + "user": "neo4j", + "password": "12345678", + "db_name": "neo4j", + "user_name": "memosneo4j", + "embedding_dimension": 768, + "use_multi_db": False, + "auto_create": False, + }, + }, + "embedder": embedder_config, + }, + }, + "act_mem": {"backend": "uninitialized", "config": {}}, + "para_mem": {"backend": "uninitialized", "config": {}}, + "pref_mem": { + "backend": "pref_text", + "config": { + "cube_id": f"{user_id}/mem_cube", + "extractor_llm": llm_config, + "vector_db": { + "backend": "milvus", + "config": { + "collection_name": [ + "explicit_preference", + "implicit_preference", + ], + "vector_dimension": 768, + "distance_metric": "cosine", + "uri": "./milvus_demo.db", + }, + }, + "embedder": embedder_config, + "extractor": {"backend": "naive", "config": {}}, + "adder": {"backend": "naive", "config": {}}, + "retriever": {"backend": "naive", "config": {}}, + }, + }, + } + mos_config = MOSConfig(**mos_config) + cube_config = GeneralMemCubeConfig.model_validate(cube_config) + + return mos_config, cube_config + + +user_id = "superman" +mos_config, mem_cube_config = get_config(user_id) + +memory = MOSProduct(default_config=mos_config) + +# register user and mem cube +mem_cube = GeneralMemCube(mem_cube_config) + +register_result = memory.user_register( + user_id=user_id, user_name=user_id, default_mem_cube=mem_cube +) +print(f"register_result: {register_result}") + +# add memory for user +memory.add( + messages=[ + # Sports and hobbies + {"role": "user", "content": "I like playing football."}, + { + "role": "assistant", + "content": "I like playing football too. It's a great way to stay active and work as a team.", + }, + {"role": "user", "content": "Do you enjoy watching basketball games?"}, + { + "role": "assistant", + "content": "Yes, I find basketball exciting to watch, especially the fast-paced action and strategic plays.", + }, + {"role": "user", "content": "I'm learning to play tennis this summer."}, + { + "role": "assistant", + "content": "That's wonderful! Tennis is a fantastic sport that improves both physical fitness and mental focus.", + }, + # Food and cooking + {"role": "user", "content": "What's your favorite type of cuisine?"}, + { + "role": "assistant", + "content": "I really enjoy Italian cuisine, especially pasta dishes and authentic pizza.", + }, + {"role": "user", "content": "I love cooking Asian food at home."}, + { + "role": "assistant", + "content": "Asian cooking is amazing! The balance of flavors and techniques is so interesting to learn.", + }, + {"role": "user", "content": "Have you ever tried making homemade bread?"}, + { + "role": "assistant", + "content": "Yes, there's something magical about the process of making bread from scratch. The smell is incredible!", + }, + # Travel and places + {"role": "user", "content": "I'm planning a trip to Japan next year."}, + { + "role": "assistant", + "content": "Japan is an incredible destination! The culture, food, and technology blend is fascinating.", + }, + {"role": "user", "content": "What's the most beautiful place you've ever seen?"}, + { + "role": "assistant", + "content": "I find mountain landscapes particularly breathtaking, especially during sunrise or sunset.", + }, + {"role": "user", "content": "I prefer beach vacations over city trips."}, + { + "role": "assistant", + "content": "Beach vacations are so relaxing! The sound of waves and warm sand is perfect for unwinding.", + }, + # Technology and innovation + {"role": "user", "content": "What do you think about artificial intelligence?"}, + { + "role": "assistant", + "content": "AI is transforming many aspects of our lives. It's exciting to see how it can help solve complex problems.", + }, + {"role": "user", "content": "I'm learning to code in Python."}, + { + "role": "assistant", + "content": "Python is a great language to start with! It's versatile and has a wonderful community.", + }, + {"role": "user", "content": "Do you use any productivity apps?"}, + { + "role": "assistant", + "content": "I find task management apps really helpful for organizing daily activities and staying focused.", + }, + # Music and entertainment + {"role": "user", "content": "What kind of music do you enjoy?"}, + { + "role": "assistant", + "content": "I appreciate many genres, but I particularly enjoy jazz and classical music for their complexity.", + }, + {"role": "user", "content": "I play the guitar in my free time."}, + { + "role": "assistant", + "content": "That's fantastic! Playing an instrument is such a rewarding creative outlet.", + }, + {"role": "user", "content": "Have you seen any good movies lately?"}, + { + "role": "assistant", + "content": "I enjoy films that make you think deeply about human nature and society.", + }, + # Books and learning + {"role": "user", "content": "Do you read books regularly?"}, + { + "role": "assistant", + "content": "Yes, reading is one of my favorite ways to learn and explore new ideas.", + }, + {"role": "user", "content": "I'm currently reading a science fiction novel."}, + { + "role": "assistant", + "content": "Science fiction is great for exploring possibilities and imagining different futures.", + }, + {"role": "user", "content": "What's your favorite subject to learn about?"}, + { + "role": "assistant", + "content": "I'm fascinated by psychology and how the human mind works.", + }, + # Nature and environment + {"role": "user", "content": "I love hiking in the mountains."}, + { + "role": "assistant", + "content": "Hiking is wonderful! It's a great way to connect with nature and get exercise.", + }, + {"role": "user", "content": "Are you concerned about climate change?"}, + { + "role": "assistant", + "content": "Yes, it's one of the most important challenges of our time that requires collective action.", + }, + {"role": "user", "content": "I try to recycle and reduce waste at home."}, + { + "role": "assistant", + "content": "That's commendable! Every small action contributes to environmental protection.", + }, + # Work and career + {"role": "user", "content": "What do you do for work?"}, + { + "role": "assistant", + "content": "I work in technology, helping people solve problems and improve their daily lives.", + }, + {"role": "user", "content": "I'm considering a career change."}, + { + "role": "assistant", + "content": "Career changes can be challenging but also very rewarding when you find the right fit.", + }, + {"role": "user", "content": "Do you work from home or in an office?"}, + { + "role": "assistant", + "content": "I work remotely, which gives me flexibility in how I organize my day.", + }, + # Health and wellness + {"role": "user", "content": "I've started doing yoga in the morning."}, + { + "role": "assistant", + "content": "Yoga is excellent for both physical and mental well-being. Great way to start the day!", + }, + {"role": "user", "content": "How do you manage stress?"}, + { + "role": "assistant", + "content": "I find meditation and regular exercise help me stay balanced and focused.", + }, + {"role": "user", "content": "I'm trying to eat more vegetables."}, + { + "role": "assistant", + "content": "That's a great health goal! Vegetables provide so many essential nutrients.", + }, + # Art and creativity + {"role": "user", "content": "Do you enjoy painting or drawing?"}, + { + "role": "assistant", + "content": "I appreciate visual arts, though I'm more of an observer than a creator.", + }, + {"role": "user", "content": "I'm learning photography."}, + { + "role": "assistant", + "content": "Photography is a beautiful way to capture moments and express your perspective.", + }, + {"role": "user", "content": "What's your favorite type of art?"}, + { + "role": "assistant", + "content": "I'm drawn to abstract art because it allows for personal interpretation.", + }, + # Family and relationships + {"role": "user", "content": "I have two siblings."}, + { + "role": "assistant", + "content": "Having siblings can be wonderful - they're often our first friends and lifelong companions.", + }, + {"role": "user", "content": "Do you have any pets?"}, + { + "role": "assistant", + "content": "I don't have pets, but I think they bring so much joy and companionship to people's lives.", + }, + {"role": "user", "content": "I'm planning to visit my grandparents this weekend."}, + { + "role": "assistant", + "content": "That's lovely! Grandparents often have such interesting stories and wisdom to share.", + }, + # Education and growth + {"role": "user", "content": "I'm taking an online course in data science."}, + { + "role": "assistant", + "content": "Data science is a fascinating field with many applications across different industries.", + }, + {"role": "user", "content": "What's the best way to learn a new language?"}, + { + "role": "assistant", + "content": "Immersion and consistent practice are key, along with finding methods that work for your learning style.", + }, + {"role": "user", "content": "I enjoy attending workshops and conferences."}, + { + "role": "assistant", + "content": "Workshops are great for hands-on learning and networking with like-minded people.", + }, + # Hobbies and interests + {"role": "user", "content": "I collect vintage postcards."}, + { + "role": "assistant", + "content": "What a unique hobby! Postcards offer glimpses into different times and places.", + }, + {"role": "user", "content": "Do you garden?"}, + { + "role": "assistant", + "content": "I appreciate gardening, though I'm more of a plant observer than a gardener myself.", + }, + {"role": "user", "content": "I love solving puzzles and brain teasers."}, + { + "role": "assistant", + "content": "Puzzles are excellent for keeping the mind sharp and developing problem-solving skills.", + }, + # Future and goals + {"role": "user", "content": "What are your goals for next year?"}, + { + "role": "assistant", + "content": "I aim to continue learning and growing, helping more people with their challenges.", + }, + {"role": "user", "content": "I want to start my own business someday."}, + { + "role": "assistant", + "content": "Entrepreneurship is exciting! It requires courage, planning, and persistence.", + }, + {"role": "user", "content": "Do you believe in setting New Year's resolutions?"}, + { + "role": "assistant", + "content": "I think any time is good for setting positive goals and making positive changes.", + }, + # Philosophy and life + {"role": "user", "content": "What's the meaning of life to you?"}, + { + "role": "assistant", + "content": "I believe it's about learning, growing, and making positive connections with others.", + }, + {"role": "user", "content": "I try to practice gratitude daily."}, + { + "role": "assistant", + "content": "Gratitude is such a powerful practice that can transform your perspective on life.", + }, + {"role": "user", "content": "What's your biggest fear?"}, + { + "role": "assistant", + "content": "I think fear of not being able to help when someone really needs it would be challenging.", + }, + # Technology and daily life + {"role": "user", "content": "How has technology changed your daily routine?"}, + { + "role": "assistant", + "content": "Technology has made many tasks more efficient and connected me with people worldwide.", + }, + {"role": "user", "content": "I'm trying to reduce my screen time."}, + { + "role": "assistant", + "content": "That's a healthy goal! Finding balance with technology is important for well-being.", + }, + {"role": "user", "content": "Do you use social media?"}, + { + "role": "assistant", + "content": "I think social media can be a great tool for connection when used mindfully.", + }, + # Final thoughts + {"role": "user", "content": "This has been a great conversation!"}, + { + "role": "assistant", + "content": "I've really enjoyed our chat too! It's wonderful to explore so many different topics together.", + }, + ], + user_id=user_id, + mem_cube_id=register_result["default_cube_id"], +) + +retrieved_memories = memory.search(query="What do you like?", user_id=user_id) +print( + f"len_pref_memories: {len(retrieved_memories['pref_mem'][0]['memories'])}" + if retrieved_memories["pref_mem"] + else 0 +) diff --git a/poetry.lock b/poetry.lock index d34f964b6..44265bca8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -690,6 +690,30 @@ toml = ["tomli (>=2.0.0) ; python_version < \"3.11\""] trio = ["trio (>=0.10.0)"] yaml = ["pyyaml (>=6.0.1)"] +[[package]] +name = "datasketch" +version = "1.6.5" +description = "Probabilistic data structures for processing and searching very large datasets" +optional = true +python-versions = "*" +groups = ["main"] +markers = "extra == \"pref-mem\" or extra == \"all\"" +files = [ + {file = "datasketch-1.6.5-py3-none-any.whl", hash = "sha256:59311b2925b2f37536e9f7c2f46bbc25e8e54379c8635a3fa7ca55d2abb66d1b"}, + {file = "datasketch-1.6.5.tar.gz", hash = "sha256:ba2848cb74f23d6d3dd444cf24edcbc47b1c34a171b1803231793ed4d74d4fcf"}, +] + +[package.dependencies] +numpy = ">=1.11" +scipy = ">=1.0.0" + +[package.extras] +benchmark = ["SetSimilaritySearch (>=0.1.7)", "matplotlib (>=3.1.2)", "nltk (>=3.4.5)", "pandas (>=0.25.3)", "pyfarmhash (>=0.2.2)", "pyhash (>=0.9.3)", "scikit-learn (>=0.21.3)", "scipy (>=1.3.3)"] +cassandra = ["cassandra-driver (>=3.20)"] +experimental-aio = ["aiounittest ; python_version >= \"3.6\"", "motor ; python_version >= \"3.6\""] +redis = ["redis (>=2.10.0)"] +test = ["cassandra-driver (>=3.20)", "coverage", "mock (>=2.0.0)", "mockredispy", "nose (>=1.3.7)", "nose-exclude (>=0.5.0)", "pymongo (>=3.9.0)", "pytest", "redis (>=2.10.0)"] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1222,7 +1246,7 @@ files = [ {file = "grpcio-1.73.1-cp39-cp39-win_amd64.whl", hash = "sha256:42f0660bce31b745eb9d23f094a332d31f210dcadd0fc8e5be7e4c62a87ce86b"}, {file = "grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87"}, ] -markers = {main = "extra == \"all\""} +markers = {main = "extra == \"pref-mem\" or extra == \"all\""} [package.extras] protobuf = ["grpcio-tools (>=1.73.1)"] @@ -3241,7 +3265,7 @@ files = [ {file = "pandas-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:b4b0de34dc8499c2db34000ef8baad684cfa4cbd836ecee05f323ebfba348c7d"}, {file = "pandas-2.3.1.tar.gz", hash = "sha256:0a95b9ac964fe83ce317827f80304d37388ea77616b1425f0ae41c9d2d0d7bb2"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [package.dependencies] numpy = [ @@ -3560,7 +3584,7 @@ files = [ {file = "protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e"}, {file = "protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [[package]] name = "pycparser" @@ -3773,6 +3797,33 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pymilvus" +version = "2.6.2" +description = "Python Sdk for Milvus" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"pref-mem\" or extra == \"all\"" +files = [ + {file = "pymilvus-2.6.2-py3-none-any.whl", hash = "sha256:933e447e09424d490dcf595053b01a7277dadea7ae3235cd704363bd6792509d"}, + {file = "pymilvus-2.6.2.tar.gz", hash = "sha256:b4802cc954de8f2d47bf8d6230e92196514dcb8a3726ba6098dc27909d4bc8e3"}, +] + +[package.dependencies] +grpcio = ">=1.66.2,<1.68.0 || >1.68.0,<1.68.1 || >1.68.1,<1.69.0 || >1.69.0,<1.70.0 || >1.70.0,<1.70.1 || >1.70.1,<1.71.0 || >1.71.0,<1.72.1 || >1.72.1,<1.73.0 || >1.73.0" +pandas = ">=1.2.4" +protobuf = ">=5.27.2" +python-dotenv = ">=1.0.1,<2.0.0" +setuptools = ">69" +ujson = ">=2.0.0" + +[package.extras] +bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "requests", "urllib3"] +dev = ["azure-storage-blob", "black", "grpcio (==1.66.2)", "grpcio-testing (==1.66.2)", "grpcio-tools (==1.66.2)", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "pytest (>=5.3.4)", "pytest-asyncio", "pytest-cov (>=5.0.0)", "pytest-timeout (>=1.3.4)", "requests", "ruff (>=0.12.9,<1)", "scipy", "urllib3"] +milvus-lite = ["milvus-lite (>=2.4.0) ; sys_platform != \"win32\""] +model = ["pymilvus.model (>=0.3.0)"] + [[package]] name = "pymysql" version = "1.1.2" @@ -3946,7 +3997,7 @@ files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\" or extra == \"mem-reader\""} +markers = {main = "extra == \"tree-mem\" or extra == \"all\" or extra == \"mem-reader\" or extra == \"pref-mem\""} [[package]] name = "pywin32" @@ -4955,7 +5006,7 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "extra == \"all\" and platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and (extra == \"all\" or extra == \"pref-mem\") or extra == \"pref-mem\" or extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -5578,7 +5629,7 @@ files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [[package]] name = "ujson" @@ -6301,13 +6352,14 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["chonkie", "markitdown", "neo4j", "pika", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["chonkie", "datasketch", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] +pref-mem = ["datasketch", "pymilvus"] tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "d85cb8a08870d67df6e462610231f1e735ba5293bd3fe5b0c4a212b3ccff7b72" \ No newline at end of file +content-hash = "3f0d0c9a996f87d945ef8bf83eed3e20f8c420b6b39e12012d0147eda2bf4d38" diff --git a/pyproject.toml b/pyproject.toml index a03b9174b..3745582f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,12 @@ mem-reader = [ "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", # Markdown parser for various file formats ] +# PreferenceTextMemory +pref-mem = [ + "pymilvus (>=2.6.1,<3.0.0)", # Milvus Vector DB + "datasketch (>=1.6.5,<2.0.0)", # MinHash library +] + # All optional dependencies # Allow users to install with `pip install MemoryOS[all]` all = [ @@ -99,6 +105,8 @@ all = [ "pymysql (>=1.1.0,<2.0.0)", "chonkie (>=1.0.7,<2.0.0)", "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", + "pymilvus (>=2.6.1,<3.0.0)", + "datasketch (>=1.6.5,<2.0.0)", # NOT exist in the above optional groups # Because they are either huge-size dependencies or infrequently used dependencies. diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d552369c5..d26672883 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -108,6 +108,25 @@ def get_activation_vllm_config() -> dict[str, Any]: }, } + @staticmethod + def get_preference_memory_config() -> dict[str, Any]: + """Get preference memory configuration.""" + return { + "backend": "pref_text", + "config": { + "extractor_llm": {"backend": "openai", "config": APIConfig.get_openai_config()}, + "vector_db": { + "backend": "milvus", + "config": APIConfig.get_milvus_config(), + }, + "embedder": APIConfig.get_embedder_config(), + "reranker": APIConfig.get_reranker_config(), + "extractor": {"backend": "naive", "config": {}}, + "adder": {"backend": "naive", "config": {}}, + "retriever": {"backend": "naive", "config": {}}, + }, + } + @staticmethod def get_reranker_config() -> dict[str, Any]: """Get embedder configuration.""" @@ -275,6 +294,20 @@ def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } + @staticmethod + def get_milvus_config(): + return { + "collection_name": [ + "explicit_preference", + "implicit_preference", + ], + "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + "distance_metric": "cosine", + "uri": os.getenv("MILVUS_URI", "http://localhost:19530"), + "user_name": os.getenv("MILVUS_USER_NAME", "root"), + "password": os.getenv("MILVUS_PASSWORD", "12345678"), + } + @staticmethod def get_mysql_config() -> dict[str, Any]: """Get MySQL configuration.""" @@ -385,6 +418,8 @@ def get_product_default_config() -> dict[str, Any]: "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": int(os.getenv("MOS_TOP_K", "50")), "max_turns_window": int(os.getenv("MOS_MAX_TURNS_WINDOW", "20")), } @@ -414,6 +449,8 @@ def get_start_default_config() -> dict[str, Any]: "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": int(os.getenv("MOS_TOP_K", "5")), "chat_model": { "backend": os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai"), @@ -478,6 +515,8 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": 30, "max_turns_window": 20, } @@ -543,6 +582,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" else APIConfig.get_activation_vllm_config(), "para_mem": {}, + "pref_mem": {} + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "false" + else APIConfig.get_preference_memory_config(), } ) else: @@ -605,6 +647,9 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" else APIConfig.get_activation_vllm_config(), "para_mem": {}, + "pref_mem": {} + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "false" + else APIConfig.get_preference_memory_config(), } ) else: diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index d14c05993..e491e9feb 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -180,6 +180,7 @@ class APISearchRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) + handle_pref_mem: bool = Field(False, description="Whether to handle preference memory") class APIADDRequest(BaseRequest): diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 9f982ddd3..d2392f927 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -2,6 +2,7 @@ import os import traceback +from concurrent.futures import ThreadPoolExecutor from typing import Any from fastapi import APIRouter, HTTPException @@ -21,6 +22,7 @@ from memos.configs.mem_reader import MemReaderConfigFactory from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.configs.reranker import RerankerConfigFactory +from memos.configs.vec_db import VectorDBConfigFactory from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory @@ -36,12 +38,24 @@ ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) +from memos.memories.textual.prefer_text_memory.factory import ( + AdderFactory, + ExtractorFactory, + RetrieverFactory, +) from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) from memos.reranker.factory import RerankerFactory +from memos.templates.instruction_completion import instruct_completion from memos.types import MOSSearchResult, UserContext +from memos.vec_dbs.factory import VecDBFactory logger = get_logger(__name__) @@ -66,6 +80,16 @@ def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: ) +def _build_vec_db_config() -> dict[str, Any]: + """Build vector database configuration.""" + return VectorDBConfigFactory.model_validate( + { + "backend": "milvus", + "config": APIConfig.get_milvus_config(), + } + ) + + def _build_llm_config() -> dict[str, Any]: """Build LLM configuration.""" return LLMConfigFactory.model_validate( @@ -98,6 +122,21 @@ def _build_internet_retriever_config() -> dict[str, Any]: return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) +def _build_pref_extractor_config() -> dict[str, Any]: + """Build extractor configuration.""" + return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def _build_pref_adder_config() -> dict[str, Any]: + """Build adder configuration.""" + return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def _build_pref_retriever_config() -> dict[str, Any]: + """Build retriever configuration.""" + return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) + + def _get_default_memory_size(cube_config) -> dict[str, int]: """Get default memory size configuration.""" return getattr(cube_config.text_mem.config, "memory_size", None) or { @@ -120,9 +159,14 @@ def init_server(): mem_reader_config = _build_mem_reader_config() reranker_config = _build_reranker_config() internet_retriever_config = _build_internet_retriever_config() + vector_db_config = _build_vec_db_config() + pref_extractor_config = _build_pref_extractor_config() + pref_adder_config = _build_pref_adder_config() + pref_retriever_config = _build_pref_retriever_config() # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) + vector_db = VecDBFactory.from_config(vector_db_config) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) @@ -130,6 +174,25 @@ def init_server(): internet_retriever = InternetRetrieverFactory.from_config( internet_retriever_config, embedder=embedder ) + pref_extractor = ExtractorFactory.from_config( + config_factory=pref_extractor_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + pref_adder = AdderFactory.from_config( + config_factory=pref_adder_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + pref_retriever = RetrieverFactory.from_config( + config_factory=pref_retriever_config, + llm_provider=llm, + embedder=embedder, + reranker=reranker, + vector_db=vector_db, + ) # Initialize memory manager memory_manager = MemoryManager( @@ -170,6 +233,10 @@ def init_server(): internet_retriever=internet_retriever, memory_manager=memory_manager, default_cube_config=default_cube_config, + vector_db=vector_db, + pref_extractor=pref_extractor, + pref_adder=pref_adder, + pref_retriever=pref_retriever, ) return ( @@ -185,6 +252,10 @@ def init_server(): mem_scheduler, naive_mem_cube, api_module, + vector_db, + pref_extractor, + pref_adder, + pref_retriever, ) @@ -202,6 +273,10 @@ def init_server(): mem_scheduler, naive_mem_cube, api_module, + vector_db, + pref_extractor, + pref_adder, + pref_retriever, ) = init_server() @@ -221,6 +296,28 @@ def _format_memory_item(memory_data: Any) -> dict[str, Any]: return memory +def _post_process_pref_mem( + memories_result: list[dict[str, Any]], + pref_formatted_mem: list[dict[str, Any]], + mem_cube_id: str, + handle_pref_mem: bool, +): + if os.getenv("RETURN_ORIGINAL_PREF_MEM", "false").lower() == "true" and pref_formatted_mem: + memories_result["prefs"] = [] + memories_result["prefs"].append( + { + "cube_id": mem_cube_id, + "memories": pref_formatted_mem, + } + ) + + if handle_pref_mem: + pref_instruction: str = instruct_completion(pref_formatted_mem) + memories_result["pref_mem"] = pref_instruction + + return memories_result + + @router.post("/search", summary="Search memories", response_model=SearchResponse) def search_memories(search_req: APISearchRequest): """Search memories for a specific user.""" @@ -239,23 +336,55 @@ def search_memories(search_req: APISearchRequest): search_mode = search_req.mode - if search_mode == SearchMode.FAST: - formatted_memories = fast_search_memories(search_req=search_req, user_context=user_context) - elif search_mode == SearchMode.FINE: - formatted_memories = fine_search_memories(search_req=search_req, user_context=user_context) - elif search_mode == SearchMode.MIXTURE: - formatted_memories = mix_search_memories(search_req=search_req, user_context=user_context) - else: - logger.error(f"Unsupported search mode: {search_mode}") - raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") + def _search_text(): + if search_mode == SearchMode.FAST: + formatted_memories = fast_search_memories( + search_req=search_req, user_context=user_context + ) + elif search_mode == SearchMode.FINE: + formatted_memories = fine_search_memories( + search_req=search_req, user_context=user_context + ) + elif search_mode == SearchMode.MIXTURE: + formatted_memories = mix_search_memories( + search_req=search_req, user_context=user_context + ) + else: + logger.error(f"Unsupported search mode: {search_mode}") + raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") + return formatted_memories + + def _search_pref(): + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + results = naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [_format_memory_item(data) for data in results] + + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(_search_text) + pref_future = executor.submit(_search_pref) + text_formatted_memories = text_future.result() + pref_formatted_memories = pref_future.result() memories_result["text_mem"].append( { "cube_id": search_req.mem_cube_id, - "memories": formatted_memories, + "memories": text_formatted_memories, } ) + memories_result = _post_process_pref_mem( + memories_result, pref_formatted_memories, search_req.mem_cube_id, search_req.handle_pref_mem + ) + return SearchResponse( message="Search completed successfully", data=memories_result, @@ -431,38 +560,69 @@ def add_memories(add_req: APIADDRequest): target_session_id = add_req.session_id if not target_session_id: target_session_id = "default_session" - memories = mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - ) - # Flatten memory list - flattened_memories = [mm for m in memories for mm in m] - logger.info(f"Memory extraction completed for user {add_req.user_id}") - mem_id_list: list[str] = naive_mem_cube.text_mem.add( - flattened_memories, - user_name=user_context.mem_cube_id, - ) + def _process_text_mem() -> list[dict[str, str]]: + memories_local = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + flattened_local = [mm for m in memories_local for mm in m] + logger.info(f"Memory extraction completed for user {add_req.user_id}") + mem_ids_local: list[str] = naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] + + def _process_pref_mem() -> list[dict[str, str]]: + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + pref_memories_local = naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) + logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(_process_text_mem) + pref_future = executor.submit(_process_pref_mem) + text_response_data = text_future.result() + pref_response_data = pref_future.result() - logger.info( - f"Added {len(mem_id_list)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_id_list}" - ) - response_data = [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) - ] return MemoryResponse( message="Memory added successfully", - data=response_data, + data=text_response_data + pref_response_data, ) diff --git a/src/memos/configs/mem_cube.py b/src/memos/configs/mem_cube.py index b9868fa99..4bd709fab 100644 --- a/src/memos/configs/mem_cube.py +++ b/src/memos/configs/mem_cube.py @@ -54,6 +54,11 @@ class GeneralMemCubeConfig(BaseMemCubeConfig): default_factory=MemoryConfigFactory, description="Configuration for the parametric memory", ) + pref_mem: MemoryConfigFactory = Field( + ..., + default_factory=MemoryConfigFactory, + description="Configuration for the preference memory", + ) @field_validator("text_mem") @classmethod @@ -87,3 +92,14 @@ def validate_para_mem(cls, para_mem: MemoryConfigFactory) -> MemoryConfigFactory f"GeneralMemCubeConfig requires para_mem backend to be one of {allowed_backends}, got '{para_mem.backend}'" ) return para_mem + + @field_validator("pref_mem") + @classmethod + def validate_pref_mem(cls, pref_mem: MemoryConfigFactory) -> MemoryConfigFactory: + """Validate the pref_mem field.""" + allowed_backends = ["pref_text", "uninitialized"] + if pref_mem.backend not in allowed_backends: + raise ConfigurationError( + f"GeneralMemCubeConfig requires pref_mem backend to be one of {allowed_backends}, got '{pref_mem.backend}'" + ) + return pref_mem diff --git a/src/memos/configs/mem_os.py b/src/memos/configs/mem_os.py index 0645fce44..549e55792 100644 --- a/src/memos/configs/mem_os.py +++ b/src/memos/configs/mem_os.py @@ -58,6 +58,10 @@ class MOSConfig(BaseConfig): default=False, description="Enable parametric memory for the MemChat", ) + enable_preference_memory: bool = Field( + default=False, + description="Enable preference memory for the MemChat", + ) enable_mem_scheduler: bool = Field( default=False, description="Enable memory scheduler for automated memory management", diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 2c3a715f7..bf2493567 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -10,6 +10,11 @@ from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory from memos.exceptions import ConfigurationError +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) # ─── 1. Global Base Memory Config ───────────────────────────────────────────── @@ -189,6 +194,45 @@ class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): """Simple tree text memory configuration class.""" +class PreferenceTextMemoryConfig(BaseTextMemoryConfig): + """Preference memory configuration class.""" + + extractor_llm: LLMConfigFactory = Field( + ..., + default_factory=LLMConfigFactory, + description="LLM configuration for the memory extractor", + ) + vector_db: VectorDBConfigFactory = Field( + ..., + default_factory=VectorDBConfigFactory, + description="Vector database configuration for the memory storage", + ) + embedder: EmbedderConfigFactory = Field( + ..., + default_factory=EmbedderConfigFactory, + description="Embedder configuration for the memory embedding", + ) + reranker: RerankerConfigFactory | None = Field( + None, + description="Reranker configuration (optional).", + ) + extractor: ExtractorConfigFactory = Field( + ..., + default_factory=ExtractorConfigFactory, + description="Extractor configuration for the memory extracting", + ) + adder: AdderConfigFactory = Field( + ..., + default_factory=AdderConfigFactory, + description="Adder configuration for the memory adding", + ) + retriever: RetrieverConfigFactory = Field( + ..., + default_factory=RetrieverConfigFactory, + description="Retriever configuration for the memory retrieving", + ) + + # ─── 3. Global Memory Config Factory ────────────────────────────────────────── @@ -203,6 +247,7 @@ class MemoryConfigFactory(BaseConfig): "general_text": GeneralTextMemoryConfig, "simple_tree_text": SimpleTreeTextMemoryConfig, "tree_text": TreeTextMemoryConfig, + "pref_text": PreferenceTextMemoryConfig, "kv_cache": KVCacheMemoryConfig, "vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache "lora": LoRAMemoryConfig, diff --git a/src/memos/mem_cube/base.py b/src/memos/mem_cube/base.py index 7d7c5e779..349d511fb 100644 --- a/src/memos/mem_cube/base.py +++ b/src/memos/mem_cube/base.py @@ -19,6 +19,7 @@ def __init__(self, config: BaseMemCubeConfig): self.text_mem: BaseTextMemory self.act_mem: BaseActMemory self.para_mem: BaseParaMemory + self.pref_mem: BaseTextMemory @abstractmethod def load(self, dir: str) -> None: diff --git a/src/memos/mem_cube/general.py b/src/memos/mem_cube/general.py index 17e45809c..1238ae050 100644 --- a/src/memos/mem_cube/general.py +++ b/src/memos/mem_cube/general.py @@ -41,16 +41,23 @@ def __init__(self, config: GeneralMemCubeConfig): if config.para_mem.backend != "uninitialized" else None ) + self._pref_mem: BaseTextMemory | None = ( + MemoryFactory.from_config(config.pref_mem) + if config.pref_mem.backend != "uninitialized" + else None + ) def load( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Load memories. Args: dir (str): The directory containing the memory files. memory_types (list[str], optional): List of memory types to load. If None, loads all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) if loaded_schema != self.config.model_schema: @@ -61,7 +68,7 @@ def load( # If no specific memory types specified, load all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Load specified memory types if "text_mem" in memory_types and self.text_mem: @@ -76,17 +83,23 @@ def load( self.para_mem.load(dir) logger.info(f"Loaded para_mem from {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.load(dir) + logger.info(f"Loaded pref_mem from {dir}") + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump memories. Args: dir (str): The directory where the memory files will be saved. memory_types (list[str], optional): List of memory types to dump. If None, dumps all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ if os.path.exists(dir) and os.listdir(dir): raise MemCubeError( @@ -98,7 +111,7 @@ def dump( # If no specific memory types specified, dump all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Dump specified memory types if "text_mem" in memory_types and self.text_mem: @@ -113,12 +126,16 @@ def dump( self.para_mem.dump(dir) logger.info(f"Dumped para_mem to {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.dump(dir) + logger.info(f"Dumped pref_mem to {dir}") + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") @staticmethod def init_from_dir( dir: str, - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, default_config: GeneralMemCubeConfig | None = None, ) -> "GeneralMemCube": """Create a MemCube instance from a MemCube directory. @@ -148,7 +165,7 @@ def init_from_dir( def init_from_remote_repo( cube_id: str, base_url: str = "https://huggingface.co/datasets", - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, default_config: GeneralMemCubeConfig | None = None, ) -> "GeneralMemCube": """Create a MemCube instance from a remote repository. @@ -207,3 +224,17 @@ def para_mem(self, value: BaseParaMemory) -> None: if not isinstance(value, BaseParaMemory): raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") self._para_mem = value + + @property + def pref_mem(self) -> "BaseTextMemory | None": + """Get the preference memory.""" + if self._pref_mem is None: + logger.warning("Preference memory is not initialized. Returning None.") + return self._pref_mem + + @pref_mem.setter + def pref_mem(self, value: BaseTextMemory) -> None: + """Set the preference memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._pref_mem = value diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py index 7ce3ca642..ba9f136b7 100644 --- a/src/memos/mem_cube/navie.py +++ b/src/memos/mem_cube/navie.py @@ -14,9 +14,14 @@ from memos.memories.activation.base import BaseActMemory from memos.memories.parametric.base import BaseParaMemory from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.prefer_text_memory.adder import BaseAdder +from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor +from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.reranker.base import BaseReranker +from memos.vec_dbs.base import BaseVecDB logger = get_logger(__name__) @@ -34,7 +39,11 @@ def __init__( reranker: BaseReranker, memory_manager: MemoryManager, default_cube_config: GeneralMemCubeConfig, + vector_db: BaseVecDB, internet_retriever: None = None, + pref_extractor: BaseExtractor | None = None, + pref_adder: BaseAdder | None = None, + pref_retriever: BaseRetriever | None = None, ): """Initialize the MemCube with a configuration.""" self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( @@ -49,6 +58,15 @@ def __init__( ) self._act_mem: BaseActMemory | None = None self._para_mem: BaseParaMemory | None = None + self._pref_mem: BaseTextMemory | None = SimplePreferenceTextMemory( + extractor_llm=llm, + vector_db=vector_db, + embedder=embedder, + reranker=reranker, + extractor=pref_extractor, + adder=pref_adder, + retriever=pref_retriever, + ) def load( self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None @@ -69,7 +87,7 @@ def load( # If no specific memory types specified, load all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Load specified memory types if "text_mem" in memory_types and self.text_mem: @@ -84,17 +102,23 @@ def load( self.para_mem.load(dir) logger.info(f"Loaded para_mem from {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.load(dir) + logger.info(f"Loaded pref_mem from {dir}") + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump memories. Args: dir (str): The directory where the memory files will be saved. memory_types (list[str], optional): List of memory types to dump. If None, dumps all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ if os.path.exists(dir) and os.listdir(dir): raise MemCubeError( @@ -106,7 +130,7 @@ def dump( # If no specific memory types specified, dump all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Dump specified memory types if "text_mem" in memory_types and self.text_mem: @@ -121,6 +145,10 @@ def dump( self.para_mem.dump(dir) logger.info(f"Dumped para_mem to {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.dump(dir) + logger.info(f"Dumped pref_mem to {dir}") + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") @property @@ -164,3 +192,17 @@ def para_mem(self, value: BaseParaMemory) -> None: if not isinstance(value, BaseParaMemory): raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") self._para_mem = value + + @property + def pref_mem(self) -> "BaseTextMemory | None": + """Get the preference memory.""" + if self._pref_mem is None: + logger.warning("Preference memory is not initialized. Returning None.") + return self._pref_mem + + @pref_mem.setter + def pref_mem(self, value: BaseTextMemory) -> None: + """Set the preference memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._pref_mem = value diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index cedffd6fb..ec8a673d7 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -2,6 +2,7 @@ import os import time +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path from threading import Lock @@ -18,6 +19,7 @@ ADD_LABEL, ANSWER_LABEL, MEM_READ_LABEL, + PREF_ADD_LABEL, QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -590,6 +592,7 @@ def search( "text_mem": [], "act_mem": [], "para_mem": [], + "pref_mem": [], } if install_cube_ids is None: install_cube_ids = user_cube_ids @@ -604,33 +607,78 @@ def search( ) for mem_cube_id, mem_cube in tmp_mem_cubes.items(): - if ( - (mem_cube_id in install_cube_ids) - and (mem_cube.text_mem is not None) - and self.config.enable_textual_memory - ): - time_start = time.time() - memories = mem_cube.text_mem.search( - query, - top_k=top_k if top_k else self.config.top_k, - mode=mode, - manual_close_internet=not internet_search, - info={ - "user_id": target_user_id, - "session_id": target_session_id, - "chat_history": chat_history.chat_history, - }, - moscube=moscube, - search_filter=search_filter, - ) - result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories}) - logger.info( - f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n" - ) - search_time_end = time.time() - logger.info( - f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" - ) + # Define internal functions for parallel search execution + def search_textual_memory(cube_id, cube): + if ( + (cube_id in install_cube_ids) + and (cube.text_mem is not None) + and self.config.enable_textual_memory + ): + time_start = time.time() + memories = cube.text_mem.search( + query, + top_k=top_k if top_k else self.config.top_k, + mode=mode, + manual_close_internet=not internet_search, + info={ + "user_id": target_user_id, + "session_id": target_session_id, + "chat_history": chat_history.chat_history, + }, + moscube=moscube, + search_filter=search_filter, + ) + search_time_end = time.time() + logger.info( + f"🧠 [Memory] Searched memories from {cube_id}:\n{self._str_memories(memories)}\n" + ) + logger.info( + f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" + ) + return {"cube_id": cube_id, "memories": memories} + return None + + def search_preference_memory(cube_id, cube): + if ( + (cube_id in install_cube_ids) + and (cube.pref_mem is not None) + and self.config.enable_preference_memory + ): + time_start = time.time() + memories = cube.pref_mem.search( + query, + top_k=top_k if top_k else self.config.top_k, + info={ + "user_id": target_user_id, + "session_id": self.session_id, + "chat_history": chat_history.chat_history, + }, + ) + search_time_end = time.time() + logger.info( + f"🧠 [Memory] Searched preferences from {cube_id}:\n{self._str_memories(memories)}\n" + ) + logger.info( + f"time search pref: search pref time user_id: {target_user_id} time is: {search_time_end - time_start}" + ) + return {"cube_id": cube_id, "memories": memories} + return None + + # Execute both search functions in parallel + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(search_textual_memory, mem_cube_id, mem_cube) + pref_future = executor.submit(search_preference_memory, mem_cube_id, mem_cube) + + # Wait for both tasks to complete and collect results + text_result = text_future.result() + pref_result = pref_future.result() + + # Add results to the main result dictionary + if text_result is not None: + result["text_mem"].append(text_result) + if pref_result is not None: + result["pref_mem"].append(pref_result) + return result def add( @@ -679,79 +727,111 @@ def add( f"time add: get mem_cube_id time user_id: {target_user_id} time is: {time.time() - time_start}" ) - time_start_0 = time.time() if mem_cube_id not in self.mem_cubes: raise ValueError(f"MemCube '{mem_cube_id}' is not loaded. Please register.") - logger.info( - f"time add: get mem_cube_id check in mem_cubes time user_id: {target_user_id} time is: {time.time() - time_start_0}" - ) + sync_mode = self.mem_cubes[mem_cube_id].text_mem.mode if sync_mode == "async": assert self.mem_scheduler is not None, ( "Mem-Scheduler must be working when use asynchronous memory adding." ) logger.debug(f"Mem-reader mode is: {sync_mode}") - time_start_1 = time.time() - if ( - (messages is not None) - and self.config.enable_textual_memory - and self.mem_cubes[mem_cube_id].text_mem - ): - logger.info( - f"time add: messages is not None and enable_textual_memory and text_mem is not None time user_id: {target_user_id} time is: {time.time() - time_start_1}" - ) - if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": - add_memory = [] - metadata = TextualMemoryMetadata( - user_id=target_user_id, session_id=target_session_id, source="conversation" - ) - for message in messages: - add_memory.append( - TextualMemoryItem(memory=message["content"], metadata=metadata) + def process_textual_memory(): + if ( + (messages is not None) + and self.config.enable_textual_memory + and self.mem_cubes[mem_cube_id].text_mem + ): + if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": + add_memory = [] + metadata = TextualMemoryMetadata( + user_id=target_user_id, session_id=target_session_id, source="conversation" ) - self.mem_cubes[mem_cube_id].text_mem.add(add_memory) - else: - messages_list = [messages] - time_start_2 = time.time() - memories = self.mem_reader.get_memory( - messages_list, - type="chat", - info={"user_id": target_user_id, "session_id": target_session_id}, - mode="fast" if sync_mode == "async" else "fine", - ) - logger.info( - f"time add: get mem_reader time user_id: {target_user_id} time is: {time.time() - time_start_2}" - ) - memories_flatten = [m for m_list in memories for m in m_list] - mem_ids: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(memories_flatten) - logger.info( - f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_ids}" - ) - # submit messages for scheduler - if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] - if sync_mode == "async": + for message in messages: + add_memory.append( + TextualMemoryItem(memory=message["content"], metadata=metadata) + ) + self.mem_cubes[mem_cube_id].text_mem.add(add_memory) + else: + messages_list = [messages] + memories = self.mem_reader.get_memory( + messages_list, + type="chat", + info={"user_id": target_user_id, "session_id": target_session_id}, + mode="fast" if sync_mode == "async" else "fine", + ) + memories_flatten = [m for m_list in memories for m in m_list] + mem_ids: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(memories_flatten) + logger.info( + f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_ids}" + ) + # submit messages for scheduler + if self.enable_mem_scheduler and self.mem_scheduler is not None: + mem_cube = self.mem_cubes[mem_cube_id] + if sync_mode == "async": + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, - label=MEM_READ_LABEL, + label=ADD_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) + def process_preference_memory(): + if ( + (messages is not None) + and self.config.enable_preference_memory + and self.mem_cubes[mem_cube_id].pref_mem + ): + messages_list = [messages] + mem_cube = self.mem_cubes[mem_cube_id] + if sync_mode == "sync": + pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( + messages_list, + type="chat", + info={"user_id": target_user_id, "session_id": self.session_id}, + ) + pref_ids = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) + logger.info( + f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {pref_ids}" + ) + elif sync_mode == "async": + assert self.mem_scheduler is not None, ( + "Mem-Scheduler must be working when use asynchronous memory adding." + ) message_item = ScheduleMessageItem( user_id=target_user_id, + session_id=target_session_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) + # Execute both memory processing functions in parallel + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(process_textual_memory) + pref_future = executor.submit(process_preference_memory) + + # Wait for both tasks to complete + text_future.result() + pref_future.result() + # user profile if ( (memory_content is not None) @@ -1030,7 +1110,7 @@ def load( load_dir: str, user_id: str | None = None, mem_cube_id: str | None = None, - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump the MemCube to a dictionary. Args: diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 7e0ed9aef..fed8f7278 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -1443,6 +1443,24 @@ def search( reformat_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list}) logger.info(f"search memory list is : {reformat_memory_list}") search_result["text_mem"] = reformat_memory_list + + pref_memory_list = search_result["pref_mem"] + reformat_pref_memory_list = [] + for memory in pref_memory_list: + memories_list = [] + for data in memory["memories"]: + memories = data.model_dump() + memories["ref_id"] = f"[{memories['id'].split('-')[0]}]" + memories["metadata"]["embedding"] = [] + memories["metadata"]["sources"] = [] + memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]" + memories["metadata"]["id"] = memories["id"] + memories["metadata"]["memory"] = memories["memory"] + memories_list.append(memories) + reformat_pref_memory_list.append( + {"cube_id": memory["cube_id"], "memories": memories_list} + ) + search_result["pref_mem"] = reformat_pref_memory_list time_end = time.time() logger.info( f"time search: total time for user_id: {user_id} time is: {time_end - time_start}" diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 31bb9b3da..d84ebb242 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -12,6 +12,7 @@ DEFAULT_MAX_QUERY_KEY_WORDS, MEM_ORGANIZE_LABEL, MEM_READ_LABEL, + PREF_ADD_LABEL, QUERY_LABEL, WORKING_MEMORY_TYPE, MemCubeID, @@ -20,7 +21,9 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.preference import PreferenceTextMemory +from memos.memories.textual.tree import TreeTextMemory logger = get_logger(__name__) @@ -40,6 +43,7 @@ def __init__(self, config: GeneralSchedulerConfig): ADD_LABEL: self._add_message_consumer, MEM_READ_LABEL: self._mem_read_message_consumer, MEM_ORGANIZE_LABEL: self._mem_reorganize_message_consumer, + PREF_ADD_LABEL: self._pref_add_message_consumer, } self.dispatcher.register_handlers(handlers) @@ -468,6 +472,48 @@ def _process_memories_with_reorganize( f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True ) + def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {PREF_ADD_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + session_id = message.session_id + mem_cube_id = message.mem_cube_id + mem_cube = message.mem_cube + content = message.content + messages_list = json.loads(content) + + logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}") + + # Get the preference memory from the mem_cube + pref_mem = mem_cube.pref_mem + if not isinstance(pref_mem, PreferenceTextMemory): + logger.error(f"Expected PreferenceTextMemory but got {type(pref_mem).__name__}") + return + + # Use pref_mem.get_memory to process the memories + pref_memories = pref_mem.get_memory( + messages_list, type="chat", info={"user_id": user_id, "session_id": session_id} + ) + # Add pref_mem to vector db + pref_ids = pref_mem.add(pref_memories) + + logger.info( + f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}" + ) + + except Exception as e: + logger.error(f"Error processing pref_add message: {e}", exc_info=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Thread task failed: {e}", exc_info=True) + def process_session_turn( self, queries: str | list[str], diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index f0868e8df..2bc7a3b98 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -20,7 +20,7 @@ class SearchMode(str, Enum): MEM_READ_LABEL = "mem_read" MEM_ORGANIZE_LABEL = "mem_organize" API_MIX_SEARCH_LABEL = "api_mix_search" - +PREF_ADD_LABEL = "pref_add" TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index efdaa44ef..541d2486d 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -35,6 +35,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) user_id: str = Field(..., description="user id") + session_id: str | None = Field(default=None, description="session id") mem_cube_id: str = Field(..., description="memcube id") label: str = Field(..., description="Label of the schedule message") mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") @@ -55,6 +56,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "example": { "item_id": "123e4567-e89b-12d3-a456-426614174000", # Sample UUID "user_id": "user123", # Example user identifier + "session_id": "session123", # Example session identifier "mem_cube_id": "cube456", # Sample memory cube ID "label": "sample_label", # Demonstration label value "mem_cube": "obj of GeneralMemCube", # Added mem_cube example @@ -76,6 +78,7 @@ def to_dict(self) -> dict: return { "item_id": self.item_id, "user_id": self.user_id, + "session_id": self.session_id, "cube_id": self.mem_cube_id, "label": self.label, "cube": "Not Applicable", # Custom cube serialization @@ -90,6 +93,8 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": item_id=data.get("item_id", str(uuid4())), user_id=data["user_id"], mem_cube_id=data["cube_id"], + session_id=data["session_id"], + cube_id=data["cube_id"], label=data["label"], mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index bcf7fdd9b..5ba1c6726 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -10,6 +10,8 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.general import GeneralTextMemory from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.preference import PreferenceTextMemory +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree import TreeTextMemory @@ -22,6 +24,8 @@ class MemoryFactory(BaseMemory): "general_text": GeneralTextMemory, "tree_text": TreeTextMemory, "simple_tree_text": SimpleTreeTextMemory, + "pref_text": PreferenceTextMemory, + "simple_pref_text": SimplePreferenceTextMemory, "kv_cache": KVCacheMemory, "vllm_kv_cache": VLLMKVCacheMemory, "lora": LoRAMemory, diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 2da283d47..6d975cfd7 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -167,6 +167,20 @@ class SearchedTreeNodeTextualMemoryMetadata(TreeNodeTextualMemoryMetadata): ) +class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): + """Metadata for preference memory item.""" + + preference_type: Literal["explicit_preference", "implicit_preference"] = Field( + default="explicit_preference", description="Type of preference." + ) + dialog_id: str | None = Field(default=None, description="ID of the dialog.") + dialog_str: str | None = Field(default=None, description="String of the dialog.") + embedding: list[float] | None = Field(default=None, description="Vector of the dialog.") + explicit_preference: str | None = Field(default=None, description="Explicit preference.") + created_at: str | None = Field(default=None, description="Timestamp of the dialog.") + implicit_preference: str | None = Field(default=None, description="Implicit preference.") + + class TextualMemoryItem(BaseModel): """Represents a single memory item in the textual memory. @@ -180,6 +194,7 @@ class TextualMemoryItem(BaseModel): SearchedTreeNodeTextualMemoryMetadata | TreeNodeTextualMemoryMetadata | TextualMemoryMetadata + | PreferenceTextualMemoryMetadata ) = Field(default_factory=TextualMemoryMetadata) model_config = ConfigDict(extra="forbid") @@ -204,12 +219,15 @@ def _coerce_metadata(cls, v: Any): v, SearchedTreeNodeTextualMemoryMetadata | TreeNodeTextualMemoryMetadata - | TextualMemoryMetadata, + | TextualMemoryMetadata + | PreferenceTextualMemoryMetadata, ): return v if isinstance(v, dict): if v.get("relativity") is not None: return SearchedTreeNodeTextualMemoryMetadata(**v) + if v.get("preference_type") is not None: + return PreferenceTextualMemoryMetadata(**v) if any(k in v for k in ("sources", "memory_type", "embedding", "background", "usage")): return TreeNodeTextualMemoryMetadata(**v) return TextualMemoryMetadata(**v) diff --git a/src/memos/memories/textual/prefer_text_memory/__init__.py b/src/memos/memories/textual/prefer_text_memory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py new file mode 100644 index 000000000..390f048ef --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -0,0 +1,284 @@ +import json + +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any + +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.templates.prefer_complete_prompt import ( + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE, +) +from memos.vec_dbs.item import MilvusVecDBItem + + +logger = get_logger(__name__) + + +class BaseAdder(ABC): + """Abstract base class for adders.""" + + @abstractmethod + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the adder.""" + + @abstractmethod + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], *args, **kwargs) -> list[str]: + """Add the instruct preference memories. + Args: + memories (list[TextualMemoryItem | dict[str, Any]]): The memories to add. + **kwargs: Additional keyword arguments. + Returns: + list[str]: List of added memory IDs. + """ + + +class NaiveAdder(BaseAdder): + """Naive adder.""" + + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the naive adder.""" + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db + + def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool: + """Judge if the new message expresses the same core content as the old message.""" + # Use the template prompt with placeholders + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_msg).replace( + "{new_information}", new_msg + ) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + response = result.get("is_same", False) + return response if isinstance(response, bool) else response == "true" + except Exception as e: + logger.error(f"Error in judge_update_or_add: {e}") + # Fallback to simple string comparison + return old_msg == new_msg + + def _judge_update_or_add_trace_op( + self, new_mem: str, retrieved_mems: str + ) -> dict[str, Any] | None: + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE.replace("{new_memory}", new_mem).replace( + "{retrieved_memories}", retrieved_mems + ) + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + return result + except Exception as e: + logger.error(f"Error in judge_update_or_add_trace_op: {e}") + return None + + def _update_memory_op_trace( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str, + ) -> list[str] | str: + if not retrieved_memories: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, + ) + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + + new_mem_input = { + "context_summary": new_memory.memory, + "preference": new_memory.metadata.explicit_preference + if preference_type == "explicit_preference" + else new_memory.metadata.implicit_preference, + } + retrieved_mem_inputs = [ + { + "id": mem.id, + "context_summary": mem.memory, + "preference": mem.payload[preference_type], + } + for mem in retrieved_memories + ] + + rsp = self._judge_update_or_add_trace_op( + new_mem=json.dumps(new_mem_input), retrieved_mems=json.dumps(retrieved_mem_inputs) + ) + if not rsp: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, + ) + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + + def execute_op(op): + op_type = op["type"].lower() + if op_type == "add": + payload = new_memory.to_dict()["metadata"] + payload = { + k: v + for k, v in payload.items() + if k not in {"dialog_id", "dialog_str", "embedding"} + } + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, + ) + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + elif op_type == "update": + payload = { + "preference_type": preference_type, + preference_type: op["new_preference"], + } + vec_db_item = MilvusVecDBItem( + id=op["target_id"], + memory=op["new_context_summary"], + vector=self.embedder.embed([op["new_context_summary"]])[0], + payload=payload, + ) + self.vector_db.update(collection_name, op["target_id"], vec_db_item) + return op["target_id"] + elif op_type == "delete": + self.vector_db.delete(collection_name, [op["target_id"]]) + return None + + with ThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: + future_to_op = {executor.submit(execute_op, op): op for op in rsp["trace"]} + added_ids = [] + for future in as_completed(future_to_op): + result = future.result() + if result is not None: + added_ids.append(result) + + return added_ids + + def _update_memory_fast( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + ) -> str: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, + ) + recall = retrieved_memories[0] if retrieved_memories else None + if not recall or (recall.score is not None and recall.score < 0.5): + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + + old_msg_str = recall.memory + new_msg_str = new_memory.memory + is_same = self._judge_update_or_add_fast(old_msg=old_msg_str, new_msg=new_msg_str) + if is_same: + self.vector_db.delete(collection_name, [recall.id]) + self.vector_db.update(collection_name, new_memory.id, vec_db_item) + return new_memory.id + + def _update_memory( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str, + update_mode: str = "op_trace", + ) -> list[str] | str | None: + """Update the memory. + Args: + new_memory: TextualMemoryItem + retrieved_memories: list[MilvusVecDBItem] + collection_name: str + preference_type: str + update_mode: str, "op_trace" or "fast" + """ + if update_mode == "op_trace": + return self._update_memory_op_trace( + new_memory, retrieved_memories, collection_name, preference_type + ) + elif update_mode == "fast": + return self._update_memory_fast(new_memory, retrieved_memories, collection_name) + else: + raise ValueError(f"Invalid update mode: {update_mode}") + + def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | None: + """Process a single memory and return its ID if added successfully.""" + try: + pref_type_collection_map = { + "explicit_preference": "explicit_preference", + "implicit_preference": "implicit_preference", + } + preference_type = memory.metadata.preference_type + collection_name = pref_type_collection_map[preference_type] + + search_results = self.vector_db.search( + memory.metadata.embedding, + collection_name, + top_k=5, + filter={"user_id": memory.metadata.user_id}, + ) + search_results.sort(key=lambda x: x.score, reverse=True) + + return self._update_memory( + memory, search_results, collection_name, preference_type, update_mode="fast" + ) + + except Exception as e: + logger.error(f"Error processing memory {memory.id}: {e}") + return None + + def add( + self, + memories: list[TextualMemoryItem | dict[str, Any]], + max_workers: int = 8, + *args, + **kwargs, + ) -> list[str]: + """Add the instruct preference memories using thread pool for acceleration.""" + if not memories: + return [] + + added_ids = [] + with ThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: + future_to_memory = { + executor.submit(self._process_single_memory, memory): memory for memory in memories + } + + for future in as_completed(future_to_memory): + try: + memory_id = future.result() + if memory_id: + if isinstance(memory_id, list): + added_ids.extend(memory_id) + else: + added_ids.append(memory_id) + except Exception as e: + memory = future_to_memory[future] + logger.error(f"Error processing memory {memory.id}: {e}") + continue + + return added_ids diff --git a/src/memos/memories/textual/prefer_text_memory/config.py b/src/memos/memories/textual/prefer_text_memory/config.py new file mode 100644 index 000000000..7e8354747 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/config.py @@ -0,0 +1,106 @@ +from typing import Any, ClassVar + +from pydantic import Field, field_validator, model_validator + +from memos.configs.base import BaseConfig + + +class BaseAdderConfig(BaseConfig): + """Base configuration class for Adder.""" + + +class NaiveAdderConfig(BaseAdderConfig): + """Configuration for Naive Adder.""" + + # No additional config needed since components are passed from parent + + +class AdderConfigFactory(BaseConfig): + """Factory class for creating Adder configurations.""" + + backend: str = Field(..., description="Backend for Adder") + config: dict[str, Any] = Field(..., description="Configuration for the Adder backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveAdderConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "AdderConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + + +class BaseExtractorConfig(BaseConfig): + """Base configuration class for Extractor.""" + + +class NaiveExtractorConfig(BaseExtractorConfig): + """Configuration for Naive Extractor.""" + + +class ExtractorConfigFactory(BaseConfig): + """Factory class for creating Extractor configurations.""" + + backend: str = Field(..., description="Backend for Extractor") + config: dict[str, Any] = Field(..., description="Configuration for the Extractor backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveExtractorConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "ExtractorConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + + +class BaseRetrieverConfig(BaseConfig): + """Base configuration class for Retrievers.""" + + +class NaiveRetrieverConfig(BaseRetrieverConfig): + """Configuration for Naive Retriever.""" + + +class RetrieverConfigFactory(BaseConfig): + """Factory class for creating Retriever configurations.""" + + backend: str = Field(..., description="Backend for Retriever") + config: dict[str, Any] = Field(..., description="Configuration for the Retriever backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveRetrieverConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "RetrieverConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py new file mode 100644 index 000000000..460b31f4f --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -0,0 +1,184 @@ +import json +import uuid + +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from typing import Any + +from memos.log import get_logger +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.prefer_text_memory.spliter import Splitter +from memos.memories.textual.prefer_text_memory.utils import convert_messages_to_string +from memos.templates.prefer_complete_prompt import ( + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, +) +from memos.types import MessageList + + +logger = get_logger(__name__) + + +class BaseExtractor(ABC): + """Abstract base class for extractors.""" + + @abstractmethod + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the extractor.""" + + +class NaiveExtractor(BaseExtractor): + """Extractor.""" + + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the extractor.""" + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db + self.splitter = Splitter() + + def extract_basic_info(self, qa_pair: MessageList) -> dict[str, Any]: + """Extract basic information from a QA pair (no LLM needed).""" + basic_info = { + "dialog_id": str(uuid.uuid4()), + "dialog_str": convert_messages_to_string(qa_pair), + "created_at": datetime.now().isoformat(), + } + + return basic_info + + def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, Any] | None: + """Extract explicit preference from a QA pair.""" + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair + prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + return result + except Exception as e: + logger.error(f"Error extracting explicit preference: {e}, return None") + return None + + def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, Any] | None: + """Extract implicit preferences from cluster qa pairs.""" + if not qa_pair: + return None + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair + prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + return result + except Exception as e: + logger.error(f"Error extracting implicit preferences: {e}, return None") + return None + + def _process_single_chunk_explicit( + self, chunk: MessageList, msg_type: str, info: dict[str, Any] + ) -> TextualMemoryItem | None: + """Process a single chunk and return a TextualMemoryItem.""" + basic_info = self.extract_basic_info(chunk) + if not basic_info["dialog_str"]: + return None + + explicit_pref = self.extract_explicit_preference(basic_info["dialog_str"]) + if not explicit_pref: + return None + + memories = [] + for pref in explicit_pref: + vector_info = { + "embedding": self.embedder.embed([pref["context_summary"]])[0], + } + extract_info = {**basic_info, **pref, **vector_info, **info} + + metadata = PreferenceTextualMemoryMetadata( + type=msg_type, preference_type="explicit_preference", **extract_info + ) + memory = TextualMemoryItem( + id=str(uuid.uuid4()), memory=pref["context_summary"], metadata=metadata + ) + + memories.append(memory) + + return memories + + def _process_single_chunk_implicit( + self, chunk: MessageList, msg_type: str, info: dict[str, Any] + ) -> TextualMemoryItem | None: + basic_info = self.extract_basic_info(chunk) + if not basic_info["dialog_str"]: + return None + implicit_pref = self.extract_implicit_preference(basic_info["dialog_str"]) + if not implicit_pref: + return None + + vector_info = { + "embedding": self.embedder.embed([implicit_pref["context_summary"]])[0], + } + + extract_info = {**basic_info, **implicit_pref, **vector_info, **info} + + metadata = PreferenceTextualMemoryMetadata( + type=msg_type, preference_type="implicit_preference", **extract_info + ) + memory = TextualMemoryItem( + id=extract_info["dialog_id"], memory=implicit_pref["context_summary"], metadata=metadata + ) + + return memory + + def extract( + self, + messages: list[MessageList], + msg_type: str, + info: dict[str, Any], + max_workers: int = 10, + ) -> list[TextualMemoryItem]: + """Extract preference memories based on the messages using thread pool for acceleration.""" + chunks: list[MessageList] = [] + for message in messages: + chunk = self.splitter.split_chunks(message, split_type="overlap") + chunks.extend(chunk) + if not chunks: + return [] + + memories = [] + with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: + futures = { + executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ( + "explicit", + chunk, + ) + for chunk in chunks + } + futures.update( + { + executor.submit(self._process_single_chunk_implicit, chunk, msg_type, info): ( + "implicit", + chunk, + ) + for chunk in chunks + } + ) + + for future in as_completed(futures): + try: + memory = future.result() + if memory: + if isinstance(memory, list): + memories.extend(memory) + else: + memories.append(memory) + except Exception as e: + task_type, chunk = futures[future] + logger.error(f"Error processing {task_type} chunk: {chunk}\n{e}") + continue + + return memories diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py new file mode 100644 index 000000000..22182261a --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -0,0 +1,78 @@ +from typing import Any, ClassVar + +from memos.memories.textual.prefer_text_memory.adder import BaseAdder, NaiveAdder +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) +from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor, NaiveExtractor +from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever, NaiveRetriever + + +class AdderFactory(BaseAdder): + """Factory class for creating Adder instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveAdder, + } + + @classmethod + def from_config( + cls, config_factory: AdderConfigFactory, llm_provider=None, embedder=None, vector_db=None + ) -> BaseAdder: + """Create a Adder instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + adder_class = cls.backend_to_class[backend] + return adder_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + + +class ExtractorFactory(BaseExtractor): + """Factory class for creating Extractor instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveExtractor, + } + + @classmethod + def from_config( + cls, + config_factory: ExtractorConfigFactory, + llm_provider=None, + embedder=None, + vector_db=None, + ) -> BaseExtractor: + """Create a Extractor instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + extractor_class = cls.backend_to_class[backend] + return extractor_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + + +class RetrieverFactory(BaseRetriever): + """Factory class for creating Retriever instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveRetriever, + } + + @classmethod + def from_config( + cls, + config_factory: RetrieverConfigFactory, + llm_provider=None, + embedder=None, + reranker=None, + vector_db=None, + ) -> BaseRetriever: + """Create a Retriever instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + retriever_class = cls.backend_to_class[backend] + return retriever_class( + llm_provider=llm_provider, embedder=embedder, reranker=reranker, vector_db=vector_db + ) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py new file mode 100644 index 000000000..7f70bac3b --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -0,0 +1,88 @@ +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem + + +class BaseRetriever(ABC): + """Abstract base class for retrievers.""" + + @abstractmethod + def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=None): + """Initialize the retriever.""" + + @abstractmethod + def retrieve( + self, query: str, top_k: int, info: dict[str, Any] | None = None + ) -> list[TextualMemoryItem]: + """Retrieve memories from the retriever.""" + + +class NaiveRetriever(BaseRetriever): + """Naive retriever.""" + + def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=None): + """Initialize the naive retriever.""" + super().__init__(llm_provider, embedder, reranker, vector_db) + self.reranker = reranker + self.vector_db = vector_db + self.embedder = embedder + + def retrieve( + self, query: str, top_k: int, info: dict[str, Any] | None = None + ) -> list[TextualMemoryItem]: + """Retrieve memories from the naive retriever.""" + # TODO: un-support rewrite query and session filter now + if info: + info = info.copy() # Create a copy to avoid modifying the original + info.pop("chat_history", None) + info.pop("session_id", None) + query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings + query_embedding = query_embeddings[0] # Get the first (and only) embedding + + # Use thread pool to parallelize the searches + with ThreadPoolExecutor(max_workers=2) as executor: + # Submit all search tasks + future_explicit = executor.submit( + self.vector_db.search, query_embedding, "explicit_preference", top_k * 2, info + ) + future_implicit = executor.submit( + self.vector_db.search, query_embedding, "implicit_preference", top_k * 2, info + ) + + # Wait for all results + explicit_prefs = future_explicit.result() + implicit_prefs = future_implicit.result() + + # sort by score + explicit_prefs.sort(key=lambda x: x.score, reverse=True) + implicit_prefs.sort(key=lambda x: x.score, reverse=True) + + explicit_prefs = [ + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**pref.payload), + ) + for pref in explicit_prefs + if pref.payload["explicit_preference"] + ] + + implicit_prefs = [ + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**pref.payload), + ) + for pref in implicit_prefs + if pref.payload["implicit_preference"] + ] + + if self.reranker: + explicit_prefs = self.reranker.rerank(query, explicit_prefs, top_k) + implicit_prefs = self.reranker.rerank(query, implicit_prefs, top_k) + explicit_prefs = [item for item, _ in explicit_prefs] + implicit_prefs = [item for item, _ in implicit_prefs] + + return explicit_prefs + implicit_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py new file mode 100644 index 000000000..59a6b0052 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -0,0 +1,132 @@ +import copy + +from memos.chunkers import ChunkerFactory +from memos.configs.chunker import ChunkerConfigFactory +from memos.configs.parser import ParserConfigFactory +from memos.parsers.factory import ParserFactory +from memos.types import MessageList + + +class Splitter: + """Splitter.""" + + def __init__( + self, + lookback_turns: int = 1, + chunk_size: int = 256, + chunk_overlap: int = 128, + min_sentences_per_chunk: int = 1, + tokenizer: str = "gpt2", + parser_backend: str = "markitdown", + chunker_backend: str = "sentence", + ): + """Initialize the splitter.""" + self.lookback_turns = lookback_turns + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.min_sentences_per_chunk = min_sentences_per_chunk + self.tokenizer = tokenizer + self.chunker_backend = chunker_backend + self.parser_backend = parser_backend + # Initialize parser + parser_config = ParserConfigFactory.model_validate( + { + "backend": self.parser_backend, + "config": {}, + } + ) + self.parser = ParserFactory.from_config(parser_config) + + # Initialize chunker + chunker_config = ChunkerConfigFactory.model_validate( + { + "backend": self.chunker_backend, + "config": { + "tokenizer_or_token_counter": self.tokenizer, + "chunk_size": self.chunk_size, + "chunk_overlap": self.chunk_overlap, + "min_sentences_per_chunk": self.min_sentences_per_chunk, + }, + } + ) + self.chunker = ChunkerFactory.from_config(chunker_config) + + def _split_with_lookback(self, data: MessageList) -> list[MessageList]: + """Split the messages or files into chunks by looking back fixed number of turns. + adjacent chunk with high duplicate rate, + default lookback turns is 1, only current turn in chunk""" + # Build QA pairs from chat history + pairs = self.build_qa_pairs(data) + chunks = [] + + # Create chunks by looking back fixed number of turns + for i in range(len(pairs)): + # Calculate the start index for lookback + start_idx = max(0, i + 1 - self.lookback_turns) + # Get the chunk of pairs (as many as available, up to lookback_turns) + chunk_pairs = pairs[start_idx : i + 1] + + # Flatten chunk_pairs (list[list[dict]]) to MessageList (list[dict]) + chunk_messages = [] + for pair in chunk_pairs: + chunk_messages.extend(pair) + + chunks.append(chunk_messages) + return chunks + + def _split_with_overlap(self, data: MessageList) -> list[MessageList]: + """split the messages or files into chunks with overlap. + adjacent chunk with low duplicate rate""" + chunks = [] + chunk = [] + for item in data: + chunk.append(item) + # 5 turns (Q + A = 10) each chunk + if len(chunk) >= 10: + chunks.append(chunk) + # overlap 1 turns (Q + A = 2) + context = copy.deepcopy(chunk[-2:]) + chunk = context + if chunk: + chunks.append(chunk) + + return chunks + + def split_chunks(self, data: MessageList | str, **kwargs) -> list[MessageList] | list[str]: + """Split the messages or files into chunks. + + Args: + data: MessageList or string to split + + Returns: + List of MessageList chunks or list of string chunks + """ + if isinstance(data, list): + if kwargs.get("split_type") == "lookback": + chunks = self._split_with_lookback(data) + elif kwargs.get("split_type") == "overlap": + chunks = self._split_with_overlap(data) + return chunks + else: + # Parse and chunk the string data using pre-initialized components + text = self.parser.parse(data) + chunks = self.chunker.chunk(text) + + return [chunk.text for chunk in chunks] + + def build_qa_pairs(self, chat_history: MessageList) -> list[MessageList]: + """Build QA pairs from chat history.""" + qa_pairs = [] + current_qa_pair = [] + + for message in chat_history: + if message["role"] == "user": + current_qa_pair.append(message) + elif message["role"] == "assistant": + if not current_qa_pair: + continue + current_qa_pair.append(message) + qa_pairs.append(current_qa_pair.copy()) + current_qa_pair = [] # reset + + return qa_pairs diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py new file mode 100644 index 000000000..85adc9304 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -0,0 +1,70 @@ +import re + +from memos.dependency import require_python_package +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessageList + + +def convert_messages_to_string(messages: MessageList) -> str: + """Convert a list of messages to a string.""" + message_text = "" + for message in messages: + if message["role"] == "user": + message_text += f"Query: {message['content']}\n" if message["content"].strip() else "" + elif message["role"] == "assistant": + message_text += f"Answer: {message['content']}\n" if message["content"].strip() else "" + message_text = message_text.strip() + return message_text + + +@require_python_package( + import_name="datasketch", + install_command="pip install datasketch", + install_link="https://github.com/ekzhu/datasketch", +) +def deduplicate_preferences( + prefs: list[TextualMemoryItem], similarity_threshold: float = 0.6, num_perm: int = 256 +) -> list[TextualMemoryItem]: + """ + Deduplicate preference texts using MinHash algorithm. + + Args: + prefs: List of preference memory items to deduplicate + similarity_threshold: Jaccard similarity threshold (0.0-1.0), default 0.8 + + Returns: + Deduplicated list of preference items + """ + from datasketch import MinHash, MinHashLSH + + if not prefs: + return prefs + + # Use MinHashLSH for efficient similarity search + lsh = MinHashLSH(threshold=similarity_threshold, num_perm=num_perm) + unique_prefs = [] + + for i, pref in enumerate(prefs): + # Extract preference text + if hasattr(pref.metadata, "implicit_preference") and pref.metadata.implicit_preference: + text = pref.metadata.implicit_preference + elif hasattr(pref.metadata, "explicit_preference") and pref.metadata.explicit_preference: + text = pref.metadata.explicit_preference + else: + text = pref.memory + + # Create MinHash from text tokens + minhash = MinHash(num_perm=num_perm) + # Simple tokenization: split by whitespace and clean + tokens = re.findall(r"\w+", text.lower()) + for token in tokens: + minhash.update(token.encode("utf8")) + + # Check for duplicates using LSH + similar_items = lsh.query(minhash) + + if not similar_items: # No similar items found + lsh.insert(i, minhash) + unique_prefs.append(pref) + + return unique_prefs diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py new file mode 100644 index 000000000..5f85aa907 --- /dev/null +++ b/src/memos/memories/textual/preference.py @@ -0,0 +1,283 @@ +import json +import os + +from typing import Any + +from memos.configs.memory import PreferenceTextMemoryConfig +from memos.embedders.factory import ( + ArkEmbedder, + EmbedderFactory, + OllamaEmbedder, + SenTranEmbedder, + UniversalAPIEmbedder, +) +from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM +from memos.log import get_logger +from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.prefer_text_memory.factory import ( + AdderFactory, + ExtractorFactory, + RetrieverFactory, +) +from memos.reranker.factory import RerankerFactory +from memos.types import MessageList +from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB, VecDBFactory +from memos.vec_dbs.item import VecDBItem + + +logger = get_logger(__name__) + + +class PreferenceTextMemory(BaseTextMemory): + """Preference textual memory implementation for storing and retrieving memories.""" + + def __init__(self, config: PreferenceTextMemoryConfig): + """Initialize memory with the given configuration.""" + self.config: PreferenceTextMemoryConfig = config + self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( + config.extractor_llm + ) + self.vector_db: MilvusVecDB | QdrantVecDB = VecDBFactory.from_config(config.vector_db) + self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = ( + EmbedderFactory.from_config(config.embedder) + ) + self.reranker = RerankerFactory.from_config(config.reranker) + + self.extractor = ExtractorFactory.from_config( + config.extractor, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db, + ) + + self.adder = AdderFactory.from_config( + config.adder, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db, + ) + self.retriever = RetrieverFactory.from_config( + config.retriever, + llm_provider=self.extractor_llm, + embedder=self.embedder, + reranker=self.reranker, + vector_db=self.vector_db, + ) + + def get_memory( + self, messages: list[MessageList], type: str, info: dict[str, Any] + ) -> list[TextualMemoryItem]: + """Get memory based on the messages. + Args: + messages (list[MessageList]): The messages to get memory from. + type (str): The type of memory to get. + info (dict[str, Any]): The info to get memory. + """ + return self.extractor.extract(messages, type, info) + + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + """Search for memories based on a query. + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + return self.retriever.retrieve(query, top_k, info) + + def load(self, dir: str) -> None: + """Load memories from the specified directory. + Args: + dir (str): The directory containing the memory files. + """ + # For preference memory, we don't need to load from files + # as the data is stored in the vector database + try: + memory_file = os.path.join(dir, self.config.memory_filename) + + if not os.path.exists(memory_file): + logger.warning(f"Memory file not found: {memory_file}") + return + + with open(memory_file, encoding="utf-8") as f: + memories = json.load(f) + for collection_name, items in memories.items(): + vec_db_items = [VecDBItem.from_dict(m) for m in items] + self.vector_db.add(collection_name, vec_db_items) + logger.info(f"Loaded {len(items)} memories from {collection_name} in {memory_file}") + + except FileNotFoundError: + logger.error(f"Memory file not found in directory: {dir}") + except json.JSONDecodeError as e: + if e.pos == 0 and "Expecting value" in str(e): + logger.warning(f"Memory file is empty or contains only whitespace: {memory_file}") + else: + logger.error(f"Error decoding JSON from memory file: {e}") + except Exception as e: + logger.error(f"An error occurred while loading memories: {e}") + + def dump(self, dir: str) -> None: + """Dump memories to the specified directory. + Args: + dir (str): The directory where the memory files will be saved. + """ + # For preference memory, we don't need to dump to files + # as the data is stored in the vector database + try: + json_memories = {} + for collection_name in self.vector_db.config.collection_name: + items = self.vector_db.get_all(collection_name) + json_memories[collection_name] = [memory.to_dict() for memory in items] + + os.makedirs(dir, exist_ok=True) + memory_file = os.path.join(dir, self.config.memory_filename) + with open(memory_file, "w", encoding="utf-8") as f: + json.dump(json_memories, f, indent=4, ensure_ascii=False) + + logger.info( + f"Dumped {len(json_memories)} collections, {sum(len(items) for items in json_memories.values())} memories to {memory_file}" + ) + + except Exception as e: + logger.error(f"An error occurred while dumping memories: {e}") + raise + + def extract(self, messages: MessageList) -> list[TextualMemoryItem]: + """Extract memories based on the messages. + Args: + messages (MessageList): The messages to extract memories from. + Returns: + list[TextualMemoryItem]: List of extracted memory items. + """ + raise NotImplementedError + + def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + """Add memories. + + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + """ + return self.adder.add(memories) + + def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: + """Update a memory by memory_id.""" + raise NotImplementedError + + def get(self, memory_id: str) -> TextualMemoryItem: + """Get a memory by its ID. + Args: + memory_id (str): The ID of the memory to retrieve. + Returns: + TextualMemoryItem: The memory with the given ID. + """ + raise NotImplementedError + + def get_with_collection_name( + self, collection_name: str, memory_id: str + ) -> TextualMemoryItem | None: + """Get a memory by its ID and collection name. + Args: + memory_id (str): The ID of the memory to retrieve. + collection_name (str): The name of the collection to retrieve the memory from. + Returns: + TextualMemoryItem: The memory with the given ID and collection name. + """ + try: + res = self.vector_db.get_by_id(collection_name, memory_id) + if res is None: + return None + return TextualMemoryItem( + id=res.id, + memory=res.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**res.payload), + ) + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with ID {memory_id} not found in collection {collection_name}: {e}" + ) from e + + def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: + """Get memories by their IDs. + Args: + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs. + """ + raise NotImplementedError + + def get_by_ids_with_collection_name( + self, collection_name: str, memory_ids: list[str] + ) -> list[TextualMemoryItem]: + """Get memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to retrieve the memory from. + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs and collection name. + """ + try: + res = self.vector_db.get_by_ids(collection_name, memory_ids) + if not res: + return [] + return [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in res + ] + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}" + ) from e + + def get_all(self) -> list[TextualMemoryItem]: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + all_collections = self.vector_db.list_collections() + all_memories = {} + for collection_name in all_collections: + items = self.vector_db.get_all(collection_name) + all_memories[collection_name] = [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in items + ] + return all_memories + + def delete(self, memory_ids: list[str]) -> None: + """Delete memories. + Args: + memory_ids (list[str]): List of memory IDs to delete. + """ + raise NotImplementedError + + def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: + """Delete memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to delete the memory from. + memory_ids (list[str]): List of memory IDs to delete. + """ + self.vector_db.delete(collection_name, memory_ids) + + def delete_all(self) -> None: + """Delete all memories.""" + for collection_name in self.vector_db.config.collection_name: + self.vector_db.delete_collection(collection_name) + self.vector_db.create_collection() + + def drop( + self, + ) -> None: + """Drop all databases.""" + raise NotImplementedError diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py new file mode 100644 index 000000000..29f30d384 --- /dev/null +++ b/src/memos/memories/textual/simple_preference.py @@ -0,0 +1,156 @@ +from typing import Any + +from memos.embedders.factory import ( + ArkEmbedder, + OllamaEmbedder, + SenTranEmbedder, + UniversalAPIEmbedder, +) +from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM +from memos.log import get_logger +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.preference import PreferenceTextMemory +from memos.types import MessageList +from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB + + +logger = get_logger(__name__) + + +class SimplePreferenceTextMemory(PreferenceTextMemory): + """Preference textual memory implementation for storing and retrieving memories.""" + + def __init__( + self, + extractor_llm: OpenAILLM | OllamaLLM | AzureLLM, + vector_db: MilvusVecDB | QdrantVecDB, + embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder, + reranker, + extractor, + adder, + retriever, + ): + """Initialize memory with the given configuration.""" + self.extractor_llm = extractor_llm + self.vector_db = vector_db + self.embedder = embedder + self.reranker = reranker + self.extractor = extractor + self.adder = adder + self.retriever = retriever + + def get_memory( + self, messages: list[MessageList], type: str, info: dict[str, Any] + ) -> list[TextualMemoryItem]: + """Get memory based on the messages. + Args: + messages (MessageList): The messages to get memory from. + type (str): The type of memory to get. + info (dict[str, Any]): The info to get memory. + """ + return self.extractor.extract(messages, type, info) + + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + """Search for memories based on a query. + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + return self.retriever.retrieve(query, top_k, info) + + def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + """Add memories. + + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + """ + return self.adder.add(memories) + + def get_with_collection_name( + self, collection_name: str, memory_id: str + ) -> TextualMemoryItem | None: + """Get a memory by its ID and collection name. + Args: + memory_id (str): The ID of the memory to retrieve. + collection_name (str): The name of the collection to retrieve the memory from. + Returns: + TextualMemoryItem: The memory with the given ID and collection name. + """ + try: + res = self.vector_db.get_by_id(collection_name, memory_id) + if res is None: + return None + return TextualMemoryItem( + id=res.id, + memory=res.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**res.payload), + ) + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with ID {memory_id} not found in collection {collection_name}: {e}" + ) from e + + def get_by_ids_with_collection_name( + self, collection_name: str, memory_ids: list[str] + ) -> list[TextualMemoryItem]: + """Get memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to retrieve the memory from. + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs and collection name. + """ + try: + res = self.vector_db.get_by_ids(collection_name, memory_ids) + if not res: + return [] + return [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in res + ] + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}" + ) from e + + def get_all(self) -> list[TextualMemoryItem]: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + all_collections = self.vector_db.list_collections() + all_memories = {} + for collection_name in all_collections: + items = self.vector_db.get_all(collection_name) + all_memories[collection_name] = [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in items + ] + return all_memories + + def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: + """Delete memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to delete the memory from. + memory_ids (list[str]): List of memory IDs to delete. + """ + self.vector_db.delete(collection_name, memory_ids) + + def delete_all(self) -> None: + """Delete all memories.""" + for collection_name in self.vector_db.config.collection_name: + self.vector_db.delete_collection(collection_name) + self.vector_db.create_collection() diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py new file mode 100644 index 000000000..7ad0fe190 --- /dev/null +++ b/src/memos/templates/instruction_completion.py @@ -0,0 +1,43 @@ +from typing import Any + +from memos.templates.prefer_complete_prompt import PREF_INSTRUCTIONS + + +def instruct_completion( + memories: list[dict[str, Any]] | None = None, +) -> str: + """Create instruction following the preferences.""" + explicit_pref = [] + implicit_pref = [] + for memory in memories: + pref_type = memory.get("metadata", {}).get("preference_type") + if pref_type == "explicit_preference": + pref = memory.get("metadata", {}).get("explicit_preference", None) + if pref: + explicit_pref.append(pref) + elif pref_type == "implicit_preference": + pref = memory.get("metadata", {}).get("implicit_preference", None) + if pref: + implicit_pref.append(pref) + + explicit_pref_str = ( + "Explicit Preference:\n" + + "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(explicit_pref)) + if explicit_pref + else "" + ) + implicit_pref_str = ( + "Implicit Preference:\n" + + "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(implicit_pref)) + if implicit_pref + else "" + ) + + if not explicit_pref_str and not implicit_pref_str: + return "" + if not explicit_pref_str: + return implicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("explicit preferences > ", "") + if not implicit_pref_str: + return explicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("implicit preferences > ", "") + + return explicit_pref_str + "\n" + implicit_pref_str + "\n" + PREF_INSTRUCTIONS diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py new file mode 100644 index 000000000..d40b7b778 --- /dev/null +++ b/src/memos/templates/prefer_complete_prompt.py @@ -0,0 +1,250 @@ +NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT = """ +You are a preference extraction assistant. +Please extract the user's explicitly mentioned preferences from the following conversation. + +Notes: +- A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". +- This includes, but is not limited to, any user's explicitly expressed inclination, desire, rejection, or priority that counts as an explicit preference. +- Focus on extracting the user's preferences in query. Do not extract preferences from the assistant's responses unless the user explicitly agrees with or endorses the assistant's suggestions. +- When the user modifies or updates their preferences for the same topic or event, extract the complete evolution process of their preference changes, including both the original and updated preferences. + +Requirements: +1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. +2. Output should be a list of concise natural language summaries and the corresponding context summary, context summary must contain complete information of the conversation fragment that the preference is mentioned. +3. If multiple preferences are mentioned within the same topic, you need to merge the preferences and context summary. + +Conversation: +{qa_pair} + +Find ALL explicit preferences. If no explicit preferences found, return []. Output JSON only: +```json +[ + { + "explicit_preference": "A short natural language summary of the preferences", + "context_summary": "The corresponding context summary, which is a summary of the corresponding conversation, do not lack any scenario information", + "reasoning": "reasoning process to find the explicit preferences" + }, +] +``` +""" + + +NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ +You are a preference inference assistant. Please extract **implicit preferences** from the following conversation +(preferences that the user did not explicitly state but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices). + +Notes: +- Implicit preferences refer to user inclinations or choices that are not directly expressed, but can be reasonably inferred from factual cues in the conversation. +- Do not treat explicitly stated preferences as implicit preferences; this prompt is only for inferring preferences that are not directly mentioned. + +Requirements: +1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. +2. Output a concise natural language statement; do not use lists, categories, or include the reasoning process. +3. Inferred implicit preferences must not conflict with explicit preferences. +4. For implicit_preference: only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. Put all reasoning and explanation in the reasoning field. +5. If no implicit preference can be reasonably inferred, leave the implicit_preference field empty (do not output anything else). + +Conversation: +{qa_pair} + +Output format: +```json +{ + "implicit_preference": "A concise natural language statement of the implicit preferences reasonably inferred from the conversation, or an empty string", + "context_summary": "The corresponding context summary, which is a summary of the corresponding conversation, do not lack any scenario information", + "reasoning": "Briefly explain the reasoning process for the implicit preference" +} +``` +Don't output anything except the JSON. +""" + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT = """ +You are a content comparison expert. Now you are given old and new information, each containing a question, answer topic name and topic description. +Please judge whether these two information express the **same question or core content**, regardless of expression differences, details or example differences. The judgment criteria are as follows: + +- Core content is consistent, that is, the essence of the question, goal or core concept to be solved is the same, it counts as "same". +- Different expressions, different examples, but the core meaning is consistent, also counts as "same". +- If the question goals, concepts involved or solution ideas are different, it counts as "different". + +Please output JSON format: +{ + "is_same": true/false, + "reasoning": "Briefly explain the judgment basis, highlighting whether the core content is consistent" +} + +**Old Information:** +{old_information} + +**New Information:** +{new_information} +""" + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE = """ +# User Preference Memory Management Agent + +You are a **User Preference Memory Management Agent**. +Your goal is to maintain a user's long-term **preference memory base** by analyzing new preference information and determining how it should update existing memories. + +Each memory entry contains three fields: +- **id**: a unique identifier for the memory. +- **context_summary**: a factual summary of the dialogue or situation from which the preference was extracted. +- **preference**: the extracted statement describing the user's preference or tendency. + +When updating a preference, you should also integrate and update the corresponding `context_summary` to ensure both fields stay semantically consistent. + +You must produce a complete **operation trace**, showing which memory entries (identified by unique IDs) should be **added**, **updated**, or **deleted**, and then output the **final memory state** after all operations. + +## Input Format + +New preference memory (new_memory): +{new_memory} + +Retrieved preference memories (retrieved_memories): +{retrieved_memories} + +## Task Instructions + +1. Analyze each retrieved memory and determine its relationship to the new memory: + - **Unrelated** → perform `"ADD"` (insert as a new independent memory); + - **Related** → perform `"UPDATE"` (refine, supplement, or merge both the `preference` and the `context_summary`); + - **Conflicting or outdated** → perform `"DELETE"` (remove obsolete or contradictory memory). + +2. If multiple retrieved memories describe the same preference theme, merge them into one updated memory entry, combining both their `preference` information and their `context_summary` in a coherent and concise way. + +3. Output a structured list of **operation traces**, each explicitly stating: + - which memory (by ID) is affected, + - what operation is performed, + - the before/after `preference` and `context_summary`, + - and the reasoning behind it. + +4. Output the **final memory state (after_update_state)**, representing the complete preference memory base after applying all operations. + +## Output Format (JSON) + +{ + "trace": [ + { + "op_id": "op_1", + "type": "ADD" | "UPDATE" | "DELETE", + "target_id": "(the old memory ID; null if ADD)", + "old_preference": "(the old preference text; null if ADD)", + "old_context_summary": "(the old context summary; null if ADD)", + "new_preference": "(the updated or newly created preference, if applicable)", + "new_context_summary": "(the updated or newly created context summary, if applicable)", + "reason": "(brief natural-language explanation for the decision)" + } + ], + "after_update_state": [ + { + "id": "id1", + "context_summary": "updated factual summary of the context", + "preference": "updated or final preference text" + } + ] +} + +## Example + +**Input:** +new_memory: +{ + "context_summary": "During a recent chat about study habits, the user mentioned that he often studies in quiet coffee shops and has started preferring lattes over Americanos, which he only drinks occasionally.", + "preference": "User now prefers lattes but occasionally drinks Americanos; he also enjoys studying in quiet coffee shops." +} + +retrieved_memories: +[ + { + "id": "id1", + "context_summary": "The user previously said he likes coffee in general.", + "preference": "User likes coffee." + }, + { + "id": "id2", + "context_summary": "The user once mentioned preferring Americanos during work breaks.", + "preference": "User prefers Americanos." + }, + { + "id": "id3", + "context_summary": "The user said he often works from home", + "preference": "User likes working from home." + }, + { + "id": "id4", + "context_summary": "The user noted he doesn't drink tea very often.", + "preference": "User has no particular interest in tea." + } +] + +**Output:** +{ + "trace": [ + { + "op_id": "op_1", + "type": "UPDATE", + "target_id": "id1", + "old_preference": "User likes coffee.", + "old_context_summary": "The user previously said he likes coffee in general.", + "new_preference": "User likes coffee, especially lattes, but occasionally drinks Americanos.", + "new_context_summary": "The user discussed his coffee habits, stating he now prefers lattes but only occasionally drinks Americanos", + "reason": "The new memory refines and expands the coffee preference and context while preserving frequency semantics ('occasionally')." + }, + { + "op_id": "op_2", + "type": "DELETE", + "target_id": "id2", + "old_preference": "User prefers Americanos.", + "old_context_summary": "The user once mentioned preferring Americanos during work breaks.", + "new_preference": null, + "new_context_summary": null, + "reason": "This old memory is now merged into the updated coffee preference (id1)." + }, + { + "op_id": "op_3", + "type": "UPDATE", + "target_id": "id3", + "old_preference": "User likes working from home.", + "old_context_summary": "The user said he often works from home.", + "new_preference": "User now prefers studying in quiet coffee shops instead of working from home.", + "new_context_summary": "The user mentioned shifting from working at home to studying in quiet cafes, reflecting a new preferred environment.", + "reason": "The preference has changed for the working environment." + } + ], + "after_update_state": [ + { + "id": "id1", + "context_summary": "The user discussed his coffee habits, saying he now prefers lattes but only occasionally drinks Americanos.", + "preference": "User likes coffee, especially lattes, but occasionally drinks Americanos." + }, + { + "id": "id3", + "context_summary": "The user mentioned shifting from working at home to studying in quiet cafes, reflecting a new preferred environment.", + "preference": "User now prefers studying in quiet coffee shops instead of working from home." + }, + { + "id": "id4", + "context_summary": "The user noted he doesn't drink tea very often.", + "preference": "User has no particular interest in tea." + } + ] +} + +## Output Requirements + +- The output **must** be valid JSON. +- Each operation must include both `preference` and `context_summary` updates where applicable. +- Each operation must include a clear `reason`. +- Multiple retrieved memories may be merged into one unified updated memory. +- `after_update_state` must reflect the final, post-update state of the preference memory base. +- Do **not** include any explanatory text outside the JSON. +""" + + +PREF_INSTRUCTIONS = """ +# Note: +Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. +Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. +When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" diff --git a/src/memos/vec_dbs/factory.py b/src/memos/vec_dbs/factory.py index 8df22d14d..f2950b4ea 100644 --- a/src/memos/vec_dbs/factory.py +++ b/src/memos/vec_dbs/factory.py @@ -2,6 +2,7 @@ from memos.configs.vec_db import VectorDBConfigFactory from memos.vec_dbs.base import BaseVecDB +from memos.vec_dbs.milvus import MilvusVecDB from memos.vec_dbs.qdrant import QdrantVecDB @@ -10,6 +11,7 @@ class VecDBFactory(BaseVecDB): backend_to_class: ClassVar[dict[str, Any]] = { "qdrant": QdrantVecDB, + "milvus": MilvusVecDB, } @classmethod diff --git a/src/memos/vec_dbs/item.py b/src/memos/vec_dbs/item.py index 6f74879ac..081400f15 100644 --- a/src/memos/vec_dbs/item.py +++ b/src/memos/vec_dbs/item.py @@ -41,3 +41,9 @@ def from_dict(cls, data: dict[str, Any]) -> "VecDBItem": def to_dict(self) -> dict[str, Any]: """Convert to dictionary format.""" return self.model_dump(exclude_none=True) + + +class MilvusVecDBItem(VecDBItem): + """Represents a single item in the Milvus vector database.""" + + memory: str | None = Field(default=None, description="Memory string") diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 7bb1ceeba..fb19fd6ff 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -4,7 +4,7 @@ from memos.dependency import require_python_package from memos.log import get_logger from memos.vec_dbs.base import BaseVecDB -from memos.vec_dbs.item import VecDBItem +from memos.vec_dbs.item import MilvusVecDBItem logger = get_logger(__name__) @@ -40,6 +40,7 @@ def create_schema(self): schema.add_field( field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True ) + schema.add_field(field_name="memory", datatype=DataType.VARCHAR, max_length=65535) schema.add_field( field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension ) @@ -107,7 +108,7 @@ def search( collection_name: str, top_k: int, filter: dict[str, Any] | None = None, - ) -> list[VecDBItem]: + ) -> list[MilvusVecDBItem]: """ Search for similar items in the database. @@ -136,8 +137,9 @@ def search( entity = hit.get("entity", {}) items.append( - VecDBItem( + MilvusVecDBItem( id=str(hit["id"]), + memory=entity.get("memory"), vector=entity.get("vector"), payload=entity.get("payload", {}), score=1 - float(hit["distance"]), @@ -178,7 +180,7 @@ def _get_metric_type(self) -> str: } return metric_map.get(self.config.distance_metric, "L2") - def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: + def get_by_id(self, collection_name: str, id: str) -> MilvusVecDBItem | None: """Get a single item by ID.""" results = self.client.get( collection_name=collection_name, @@ -191,13 +193,14 @@ def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: entity = results[0] payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} - return VecDBItem( + return MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) - def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: + def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBItem]: """Get multiple items by their IDs.""" results = self.client.get( collection_name=collection_name, @@ -211,8 +214,9 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: for entity in results: payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} items.append( - VecDBItem( + MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) @@ -222,7 +226,7 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: def get_by_filter( self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 - ) -> list[VecDBItem]: + ) -> list[MilvusVecDBItem]: """ Retrieve all items that match the given filter criteria using query_iterator. @@ -252,13 +256,14 @@ def get_by_filter( if not batch_results: break - # Convert batch results to VecDBItem objects + # Convert batch results to MilvusVecDBItem objects for entity in batch_results: # Extract the actual payload from Milvus entity payload = entity.get("payload", {}) all_items.append( - VecDBItem( + MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) @@ -274,7 +279,7 @@ def get_by_filter( logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") return all_items - def get_all(self, collection_name: str, scroll_limit=100) -> list[VecDBItem]: + def get_all(self, collection_name: str, scroll_limit=100) -> list[MilvusVecDBItem]: """Retrieve all items in the vector database.""" return self.get_by_filter(collection_name, {}, scroll_limit=scroll_limit) @@ -295,13 +300,14 @@ def count(self, collection_name: str, filter: dict[str, Any] | None = None) -> i # Extract row count from stats - stats is a dict, not a list return int(stats.get("row_count", 0)) - def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + def add(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]]) -> None: """ Add data to the vector database. Args: - data: List of VecDBItem objects or dictionaries containing: + data: List of MilvusVecDBItem objects or dictionaries containing: - 'id': unique identifier + - 'memory': memory string - 'vector': embedding vector - 'payload': additional fields for filtering/retrieval """ @@ -309,11 +315,12 @@ def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> N for item in data: if isinstance(item, dict): item = item.copy() - item = VecDBItem.from_dict(item) + item = MilvusVecDBItem.from_dict(item) # Prepare entity data entity = { "id": item.id, + "memory": item.memory, "vector": item.vector, "payload": item.payload if item.payload else {}, } @@ -326,11 +333,15 @@ def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> N data=entities, ) - def update(self, collection_name: str, id: str, data: VecDBItem | dict[str, Any]) -> None: + def update(self, collection_name: str, id: str, data: MilvusVecDBItem | dict[str, Any]) -> None: """Update an item in the vector database.""" + if id != data.id: + raise ValueError( + f"The id of the data to update must be the same as the id of the item to update, ID mismatch: expected {id}, got {data.id}" + ) if isinstance(data, dict): data = data.copy() - data = VecDBItem.from_dict(data) + data = MilvusVecDBItem.from_dict(data) # Use upsert for updates self.upsert(collection_name, [data]) @@ -347,7 +358,7 @@ def ensure_payload_indexes(self, fields: list[str]) -> None: # Field indexes are created automatically for scalar fields logger.info(f"Milvus automatically indexes scalar fields: {fields}") - def upsert(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + def upsert(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]]) -> None: """ Add or update data in the vector database. diff --git a/tests/configs/test_mem_cube.py b/tests/configs/test_mem_cube.py index 6c962dd01..c50195558 100644 --- a/tests/configs/test_mem_cube.py +++ b/tests/configs/test_mem_cube.py @@ -28,7 +28,7 @@ def test_base_mem_cube_config(): def test_general_mem_cube_config(): check_config_base_class( GeneralMemCubeConfig, - factory_fields=["text_mem", "act_mem", "para_mem"], + factory_fields=["text_mem", "act_mem", "para_mem", "pref_mem"], required_fields=[], optional_fields=["config_filename", "user_id", "cube_id"], reserved_fields=["model_schema"], From f6e96d57cb65e8ebebb5842b501982c3c0a76ce0 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Mon, 27 Oct 2025 11:54:52 +0800 Subject: [PATCH 06/64] Feat: add reranker strategies and update configs (#390) * feat:change reranking source filed * fix: code ci * feat: add reranker strategy * fix: code suffix * fix: code suffix * fix:change strategy name * fix: code format * feat: update memory strategies * fix: code ci --------- Co-authored-by: CaralHsi --- src/memos/api/config.py | 5 +- src/memos/graph_dbs/neo4j.py | 2 +- src/memos/reranker/base.py | 3 +- src/memos/reranker/concat.py | 60 +++- src/memos/reranker/cosine_local.py | 6 +- src/memos/reranker/factory.py | 11 + src/memos/reranker/http_bge.py | 8 +- src/memos/reranker/http_bge_strategy.py | 317 ++++++++++++++++++ src/memos/reranker/strategies/__init__.py | 4 + src/memos/reranker/strategies/base.py | 61 ++++ .../reranker/strategies/concat_background.py | 94 ++++++ .../reranker/strategies/dialogue_common.py | 109 ++++++ src/memos/reranker/strategies/factory.py | 29 ++ src/memos/reranker/strategies/single_turn.py | 107 ++++++ .../reranker/strategies/singleturn_outmem.py | 98 ++++++ 15 files changed, 893 insertions(+), 21 deletions(-) create mode 100644 src/memos/reranker/http_bge_strategy.py create mode 100644 src/memos/reranker/strategies/__init__.py create mode 100644 src/memos/reranker/strategies/base.py create mode 100644 src/memos/reranker/strategies/concat_background.py create mode 100644 src/memos/reranker/strategies/dialogue_common.py create mode 100644 src/memos/reranker/strategies/factory.py create mode 100644 src/memos/reranker/strategies/single_turn.py create mode 100644 src/memos/reranker/strategies/singleturn_outmem.py diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d26672883..4805d1062 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -132,15 +132,16 @@ def get_reranker_config() -> dict[str, Any]: """Get embedder configuration.""" embedder_backend = os.getenv("MOS_RERANKER_BACKEND", "http_bge") - if embedder_backend == "http_bge": + if embedder_backend in ["http_bge", "http_bge_strategy"]: return { - "backend": "http_bge", + "backend": embedder_backend, "config": { "url": os.getenv("MOS_RERANKER_URL"), "model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"), "timeout": 10, "headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"), "rerank_source": os.getenv("MOS_RERANK_SOURCE"), + "reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"), }, } else: diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index f51b3465d..fd3a1ba22 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -157,7 +157,7 @@ def remove_oldest_memory( """ if not self.config.use_multi_db and (self.config.user_name or user_name): query += f"\nAND n.user_name = '{user_name}'" - + keep_latest = int(keep_latest) query += f""" WITH n ORDER BY n.updated_at DESC SKIP {keep_latest} diff --git a/src/memos/reranker/base.py b/src/memos/reranker/base.py index 77a24c164..1c2f86ac5 100644 --- a/src/memos/reranker/base.py +++ b/src/memos/reranker/base.py @@ -16,8 +16,9 @@ class BaseReranker(ABC): def rerank( self, query: str, - graph_results: list, + graph_results: list[TextualMemoryItem], top_k: int, + search_filter: dict | None = None, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: """Return top_k (item, score) sorted by score desc.""" diff --git a/src/memos/reranker/concat.py b/src/memos/reranker/concat.py index 5ad339529..502af18b6 100644 --- a/src/memos/reranker/concat.py +++ b/src/memos/reranker/concat.py @@ -2,12 +2,49 @@ from typing import Any +from memos.memories.textual.item import SourceMessage + _TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") +def get_encoded_tokens(content: str) -> int: + """ + Get encoded tokens. + Args: + content: str + Returns: + int: Encoded tokens. + """ + return len(content) + + +def truncate_data(data: list[str | dict[str, Any] | Any], max_tokens: int) -> list[str]: + """ + Truncate data to max tokens. + Args: + data: List of strings or dictionaries. + max_tokens: Maximum number of tokens. + Returns: + str: Truncated string. + """ + truncated_string = "" + for item in data: + if isinstance(item, SourceMessage): + content = getattr(item, "content", "") + chat_time = getattr(item, "chat_time", "") + if not content: + continue + truncated_string += f"[{chat_time}]: {content}\n" + if get_encoded_tokens(truncated_string) > max_tokens: + break + return truncated_string + + def process_source( - items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, recent_num: int = 3 + items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, + recent_num: int = 10, + max_tokens: int = 2048, ) -> str: """ Args: @@ -23,19 +60,16 @@ def process_source( memory = None for item in items: memory, source = item - for content in source: - if isinstance(content, str): - if "assistant:" in content: - continue - concat_data.append(content) + concat_data.extend(source[-recent_num:]) + truncated_string = truncate_data(concat_data, max_tokens) if memory is not None: - concat_data = [memory, *concat_data] - return "\n".join(concat_data) + truncated_string = f"{memory}\n{truncated_string}" + return truncated_string def concat_original_source( graph_results: list, - merge_field: list[str] | None = None, + rerank_source: str | None = None, ) -> list[str]: """ Merge memory items with original dialogue. @@ -45,14 +79,16 @@ def concat_original_source( Returns: list[str]: List of memory and concat orginal memory. """ - if merge_field is None: - merge_field = ["sources"] + merge_field = [] + merge_field = ["sources"] if rerank_source is None else rerank_source.split(",") documents = [] for item in graph_results: memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m sources = [] for field in merge_field: - source = getattr(item.metadata, field, "") + source = getattr(item.metadata, field, None) + if source is None: + continue sources.append((memory, source)) concat_string = process_source(sources) documents.append(concat_string) diff --git a/src/memos/reranker/cosine_local.py b/src/memos/reranker/cosine_local.py index 000b64cf4..fc1dada2b 100644 --- a/src/memos/reranker/cosine_local.py +++ b/src/memos/reranker/cosine_local.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING +from memos.log import get_logger + from .base import BaseReranker @@ -16,6 +18,8 @@ except Exception: _HAS_NUMPY = False +logger = get_logger(__name__) + def _cosine_one_to_many(q: list[float], m: list[list[float]]) -> list[float]: """ @@ -92,5 +96,5 @@ def get_weight(it: TextualMemoryItem) -> float: chosen = {it.id for it, _ in top_items} remain = [(it, -1.0) for it in graph_results if it.id not in chosen] top_items.extend(remain[: top_k - len(top_items)]) - + logger.info(f"CosineLocalReranker rerank result: {top_items[:1]}") return top_items diff --git a/src/memos/reranker/factory.py b/src/memos/reranker/factory.py index 134e29eb9..57460a4af 100644 --- a/src/memos/reranker/factory.py +++ b/src/memos/reranker/factory.py @@ -8,6 +8,7 @@ from .cosine_local import CosineLocalReranker from .http_bge import HTTPBGEReranker +from .http_bge_strategy import HTTPBGERerankerStrategy from .noop import NoopReranker @@ -45,4 +46,14 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None: if backend in {"noop", "none", "disabled"}: return NoopReranker() + if backend in {"http_bge_strategy", "bge_strategy"}: + return HTTPBGERerankerStrategy( + reranker_url=c.get("url") or c.get("endpoint") or c.get("reranker_url"), + model=c.get("model", "bge-reranker-v2-m3"), + timeout=int(c.get("timeout", 10)), + headers_extra=c.get("headers_extra"), + rerank_source=c.get("rerank_source"), + reranker_strategy=c.get("reranker_strategy"), + ) + raise ValueError(f"Unknown reranker backend: {cfg.backend}") diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index f0f5d17a0..9cce12786 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -80,7 +80,7 @@ def __init__( model: str = "bge-reranker-v2-m3", timeout: int = 10, headers_extra: dict | None = None, - rerank_source: list[str] | None = None, + rerank_source: str | None = None, boost_weights: dict[str, float] | None = None, boost_default: float = 0.0, warn_unknown_filter_keys: bool = True, @@ -107,7 +107,7 @@ def __init__( self.model = model self.timeout = timeout self.headers_extra = headers_extra or {} - self.concat_source = rerank_source + self.rerank_source = rerank_source self.boost_weights = ( DEFAULT_BOOST_WEIGHTS.copy() @@ -152,8 +152,8 @@ def rerank( # Build a mapping from "payload docs index" -> "original graph_results index" # Only include items that have a non-empty string memory. This ensures that # any index returned by the server can be mapped back correctly. - if self.concat_source: - documents = concat_original_source(graph_results, self.concat_source) + if self.rerank_source: + documents = concat_original_source(graph_results, self.rerank_source) else: documents = [ (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m) diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py new file mode 100644 index 000000000..8cbf633a6 --- /dev/null +++ b/src/memos/reranker/http_bge_strategy.py @@ -0,0 +1,317 @@ +# memos/reranker/http_bge.py +from __future__ import annotations + +import re + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +import requests + +from memos.log import get_logger +from memos.reranker.strategies import RerankerStrategyFactory + +from .base import BaseReranker + + +logger = get_logger(__name__) + + +if TYPE_CHECKING: + from memos.memories.textual.item import TextualMemoryItem + +# Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...") +# before sending text to the reranker. This keeps inputs clean and +# avoids misleading the model with bracketed prefixes. +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") +DEFAULT_BOOST_WEIGHTS = {"user_id": 0.5, "tags": 0.2, "session_id": 0.3} + + +def _value_matches(item_value: Any, wanted: Any) -> bool: + """ + Generic matching: + - if item_value is list/tuple/set: check membership (any match if wanted is iterable) + - else: equality (any match if wanted is iterable) + """ + + def _iterable(x): + # exclude strings from "iterable" + return isinstance(x, Iterable) and not isinstance(x, str | bytes) + + if _iterable(item_value): + if _iterable(wanted): + return any(w in item_value for w in wanted) + return wanted in item_value + else: + if _iterable(wanted): + return any(item_value == w for w in wanted) + return item_value == wanted + + +class HTTPBGERerankerStrategy(BaseReranker): + """ + HTTP-based BGE reranker. + + This class sends (query, documents[]) to a remote HTTP endpoint that + performs cross-encoder-style re-ranking (e.g., BGE reranker) and returns + relevance scores. It then maps those scores back onto the original + TextualMemoryItem list and returns (item, score) pairs sorted by score. + + Notes + ----- + - The endpoint is expected to accept JSON: + { + "model": "", + "query": "", + "documents": ["doc1", "doc2", ...] + } + - Two response shapes are supported: + 1) {"results": [{"index": , "relevance_score": }, ...]} + where "index" refers to the *position in the documents array*. + 2) {"data": [{"score": }, ...]} (aligned by list order) + - If the service fails or responds unexpectedly, this falls back to + returning the original items with 0.0 scores (best-effort). + """ + + def __init__( + self, + reranker_url: str, + token: str = "", + model: str = "bge-reranker-v2-m3", + timeout: int = 10, + headers_extra: dict | None = None, + rerank_source: str | None = None, + boost_weights: dict[str, float] | None = None, + boost_default: float = 0.0, + warn_unknown_filter_keys: bool = True, + reranker_strategy: str = "single_turn", + **kwargs, + ): + """ + Parameters + ---------- + reranker_url : str + HTTP endpoint for the reranker service. + token : str, optional + Bearer token for auth. If non-empty, added to the Authorization header. + model : str, optional + Model identifier understood by the server. + timeout : int, optional + Request timeout (seconds). + headers_extra : dict | None, optional + Additional headers to merge into the request headers. + """ + if not reranker_url: + raise ValueError("reranker_url must not be empty") + self.reranker_url = reranker_url + self.token = token or "" + self.model = model + self.timeout = timeout + self.headers_extra = headers_extra or {} + + self.boost_weights = ( + DEFAULT_BOOST_WEIGHTS.copy() + if boost_weights is None + else {k: float(v) for k, v in boost_weights.items()} + ) + self.boost_default = float(boost_default) + self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys) + self._warned_missing_keys: set[str] = set() + self.reranker_strategy = RerankerStrategyFactory.from_config(reranker_strategy) + + def rerank( + self, + query: str, + graph_results: list[TextualMemoryItem], + top_k: int, + search_filter: dict | None = None, + **kwargs, + ) -> list[tuple[TextualMemoryItem, float]]: + """ + Rank candidate memories by relevance to the query. + + Parameters + ---------- + query : str + The search query. + graph_results : list[TextualMemoryItem] + Candidate items to re-rank. Each item is expected to have a + `.memory` str field; non-strings are ignored. + top_k : int + Return at most this many items. + search_filter : dict | None + Currently unused. Present to keep signature compatible. + + Returns + ------- + list[tuple[TextualMemoryItem, float]] + Re-ranked items with scores, sorted descending by score. + """ + if not graph_results: + return [] + + tracker, original_items, documents = self.reranker_strategy.prepare_documents( + query, graph_results, top_k + ) + + logger.info( + f"[HTTPBGEWithSourceReranker] strategy: {self.reranker_strategy}, " + f"query: {query}, documents count: {len(documents)}" + ) + logger.info(f"[HTTPBGEWithSourceReranker] sample documents: {documents[:3]}...") + + if not documents: + return [] + + headers = {"Content-Type": "application/json", **self.headers_extra} + payload = {"model": self.model, "query": query, "documents": documents} + + try: + # Make the HTTP request to the reranker service + resp = requests.post( + self.reranker_url, headers=headers, json=payload, timeout=self.timeout + ) + resp.raise_for_status() + data = resp.json() + + scored_items: list[tuple[TextualMemoryItem, float]] = [] + + if "results" in data: + # Format: + # dict("results": [{"index": int, "relevance_score": float}, + # ...]) + rows = data.get("results", []) + + ranked_indices = [] + scores = [] + for r in rows: + idx = r.get("index") + # The returned index refers to 'documents' (i.e., our 'pairs' order), + # so we must map it back to the original graph_results index. + if isinstance(idx, int) and 0 <= idx < len(graph_results): + raw_score = float(r.get("relevance_score", r.get("score", 0.0))) + ranked_indices.append(idx) + scores.append(raw_score) + reconstructed_items = self.reranker_strategy.reconstruct_items( + ranked_indices=ranked_indices, + scores=scores, + tracker=tracker, + original_items=original_items, + top_k=top_k, + graph_results=graph_results, + documents=documents, + ) + return reconstructed_items + + elif "data" in data: + # Format: {"data": [{"score": float}, ...]} aligned by list order + rows = data.get("data", []) + # Build a list of scores aligned with our 'documents' (pairs) + score_list = [float(r.get("score", 0.0)) for r in rows] + + if len(score_list) < len(graph_results): + score_list += [0.0] * (len(graph_results) - len(score_list)) + elif len(score_list) > len(graph_results): + score_list = score_list[: len(graph_results)] + + scored_items = [] + for item, raw_score in zip(graph_results, score_list, strict=False): + score = self._apply_boost_generic(item, raw_score, search_filter) + scored_items.append((item, score)) + + scored_items.sort(key=lambda x: x[1], reverse=True) + return scored_items[: min(top_k, len(scored_items))] + + else: + # Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs + # Note: we use 'pairs' to keep alignment with valid (string) docs. + return [(item, 0.0) for item in graph_results[:top_k]] + + except Exception as e: + # Network error, timeout, JSON decode error, etc. + # Degrade gracefully by returning first top_k valid docs with 0.0 score. + logger.error(f"[HTTPBGEReranker] request failed: {e}") + return [(item, 0.0) for item in graph_results[:top_k]] + + def _get_attr_or_key(self, obj: Any, key: str) -> Any: + """ + Resolve `key` on `obj` with one-level fallback into `obj.metadata`. + + Priority: + 1) obj. + 2) obj[key] + 3) obj.metadata. + 4) obj.metadata[key] + """ + if obj is None: + return None + + # support input like "metadata.user_id" + if "." in key: + head, tail = key.split(".", 1) + base = self._get_attr_or_key(obj, head) + return self._get_attr_or_key(base, tail) + + def _resolve(o: Any, k: str): + if o is None: + return None + v = getattr(o, k, None) + if v is not None: + return v + if hasattr(o, "get"): + try: + return o.get(k) + except Exception: + return None + return None + + # 1) find in obj + v = _resolve(obj, key) + if v is not None: + return v + + # 2) find in obj.metadata + meta = _resolve(obj, "metadata") + if meta is not None: + return _resolve(meta, key) + + return None + + def _apply_boost_generic( + self, + item: TextualMemoryItem, + base_score: float, + search_filter: dict | None, + ) -> float: + """ + Multiply base_score by (1 + weight) for each matching key in search_filter. + - key resolution: self._get_attr_or_key(item, key) + - weight = boost_weights.get(key, self.boost_default) + - unknown key -> one-time warning + """ + if not search_filter: + return base_score + + score = float(base_score) + + for key, wanted in search_filter.items(): + # _get_attr_or_key automatically find key in item and + # item.metadata ("metadata.user_id" supported) + resolved = self._get_attr_or_key(item, key) + + if resolved is None: + if self.warn_unknown_filter_keys and key not in self._warned_missing_keys: + logger.warning( + "[HTTPBGEReranker] search_filter key '%s' not found on TextualMemoryItem or metadata", + key, + ) + self._warned_missing_keys.add(key) + continue + + if _value_matches(resolved, wanted): + w = float(self.boost_weights.get(key, self.boost_default)) + if w != 0.0: + score *= 1.0 + w + score = min(max(0.0, score), 1.0) + + return score diff --git a/src/memos/reranker/strategies/__init__.py b/src/memos/reranker/strategies/__init__.py new file mode 100644 index 000000000..cee60f2be --- /dev/null +++ b/src/memos/reranker/strategies/__init__.py @@ -0,0 +1,4 @@ +from .factory import RerankerStrategyFactory + + +__all__ = ["RerankerStrategyFactory"] diff --git a/src/memos/reranker/strategies/base.py b/src/memos/reranker/strategies/base.py new file mode 100644 index 000000000..43166dd92 --- /dev/null +++ b/src/memos/reranker/strategies/base.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from typing import Any + +from memos.memories.textual.item import TextualMemoryItem + +from .dialogue_common import DialogueRankingTracker + + +class BaseRerankerStrategy(ABC): + """Abstract interface for memory rerankers with concatenation strategy.""" + + @abstractmethod + def prepare_documents( + self, + query: str, + graph_results: list[TextualMemoryItem], + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents for ranking based on the strategy. + + Args: + query: The search query + graph_results: List of TextualMemoryItem objects to process + top_k: Maximum number of items to return + **kwargs: Additional strategy-specific parameters + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + raise NotImplementedError + + @abstractmethod + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[TextualMemoryItem, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked results. + + Args: + ranked_indices: List of indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + **kwargs: Additional strategy-specific parameters + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + raise NotImplementedError diff --git a/src/memos/reranker/strategies/concat_background.py b/src/memos/reranker/strategies/concat_background.py new file mode 100644 index 000000000..a52313548 --- /dev/null +++ b/src/memos/reranker/strategies/concat_background.py @@ -0,0 +1,94 @@ +# memos/reranker/strategies/single_turn.py +from __future__ import annotations + +import re + +from typing import Any + +from .base import BaseRerankerStrategy +from .dialogue_common import DialogueRankingTracker + + +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") + + +class ConcatBackgroundStrategy(BaseRerankerStrategy): + """ + Concat background strategy. + + This strategy processes dialogue pairs by concatenating background and + user and assistant messages into single strings for ranking. Each dialogue pair becomes a + separate document for ranking. + """ + + def prepare_documents( + self, + query: str, + graph_results: list, + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents based on single turn concatenation strategy. + + Args: + query: The search query + graph_results: List of graph results + top_k: Maximum number of items to return + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + + original_items = {} + tracker = DialogueRankingTracker() + documents = [] + for item in graph_results: + memory = getattr(item, "memory", None) + if isinstance(memory, str): + memory = _TAG1.sub("", memory) + + background = "" + if hasattr(item, "metadata") and hasattr(item.metadata, "background"): + background = getattr(item.metadata, "background", "") + if not isinstance(background, str): + background = "" + + documents.append(f"{memory}\n{background}") + return tracker, original_items, documents + + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[Any, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked dialogue pairs. + + Args: + ranked_indices: List of dialogue pair indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + graph_results = kwargs.get("graph_results") + documents = kwargs.get("documents") + reconstructed_items = [] + for idx in ranked_indices: + item = graph_results[idx] + item.memory = f"{item.memory}\n{documents[idx]}" + reconstructed_items.append((item, scores[idx])) + + reconstructed_items.sort(key=lambda x: x[1], reverse=True) + return reconstructed_items[:top_k] diff --git a/src/memos/reranker/strategies/dialogue_common.py b/src/memos/reranker/strategies/dialogue_common.py new file mode 100644 index 000000000..ce0138284 --- /dev/null +++ b/src/memos/reranker/strategies/dialogue_common.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import re + +from typing import Any, Literal + +from pydantic import BaseModel + +from memos.memories.textual.item import SourceMessage, TextualMemoryItem + + +# Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...") +# before sending text to the reranker. This keeps inputs clean and +# avoids misleading the model with bracketed prefixes. +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") + + +def strip_memory_tags(item: TextualMemoryItem) -> str: + """Strip leading tags from memory text.""" + memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m + return memory + + +def extract_content(msg: dict[str, Any] | str) -> str: + """Extract content from message, handling both string and dict formats.""" + if isinstance(msg, dict): + return msg.get("content", str(msg)) + if isinstance(msg, SourceMessage): + return msg.content + return str(msg) + + +class DialoguePair(BaseModel): + """Represents a single dialogue pair extracted from sources.""" + + pair_id: str # Unique identifier for this dialogue pair + memory_id: str # ID of the source TextualMemoryItem + memory: str + pair_index: int # Index of this pair within the source memory's dialogue + user_msg: str | dict[str, Any] | SourceMessage # User message content + assistant_msg: str | dict[str, Any] | SourceMessage # Assistant message content + combined_text: str # The concatenated text used for ranking + chat_time: str | None = None + + @property + def user_content(self) -> str: + """Get user message content as string.""" + return extract_content(self.user_msg) + + @property + def assistant_content(self) -> str: + """Get assistant message content as string.""" + return extract_content(self.assistant_msg) + + +class DialogueRankingTracker: + """Tracks dialogue pairs and their rankings for memory reconstruction.""" + + def __init__(self): + self.dialogue_pairs: list[DialoguePair] = [] + + def add_dialogue_pair( + self, + memory_id: str, + pair_index: int, + user_msg: str | dict[str, Any], + assistant_msg: str | dict[str, Any], + memory: str, + chat_time: str | None = None, + concat_format: Literal["user_assistant", "user_only"] = "user_assistant", + ) -> str: + """Add a dialogue pair and return its unique ID.""" + user_content = extract_content(user_msg) + assistant_content = extract_content(assistant_msg) + if concat_format == "user_assistant": + combined_text = f"[{chat_time}]: \nuser: {user_content}\nassistant: {assistant_content}" + elif concat_format == "user_only": + combined_text = f"[{chat_time}]: \nuser: {user_content}" + else: + raise ValueError(f"Invalid concat format: {concat_format}") + + pair_id = f"{memory_id}_{pair_index}" + + dialogue_pair = DialoguePair( + pair_id=pair_id, + memory_id=memory_id, + pair_index=pair_index, + user_msg=user_msg, + assistant_msg=assistant_msg, + combined_text=combined_text, + memory=memory, + chat_time=chat_time, + ) + + self.dialogue_pairs.append(dialogue_pair) + return pair_id + + def get_documents_for_ranking(self, concat_memory: bool = True) -> list[str]: + """Get the combined text documents for ranking.""" + if concat_memory: + return [(pair.memory + "\n\n" + pair.combined_text) for pair in self.dialogue_pairs] + else: + return [pair.combined_text for pair in self.dialogue_pairs] + + def get_dialogue_pair_by_index(self, index: int) -> DialoguePair | None: + """Get dialogue pair by its index in the ranking results.""" + if 0 <= index < len(self.dialogue_pairs): + return self.dialogue_pairs[index] + return None diff --git a/src/memos/reranker/strategies/factory.py b/src/memos/reranker/strategies/factory.py new file mode 100644 index 000000000..d93cbd65a --- /dev/null +++ b/src/memos/reranker/strategies/factory.py @@ -0,0 +1,29 @@ +# memos/reranker/factory.py +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +from .concat_background import ConcatBackgroundStrategy +from .single_turn import SingleTurnStrategy +from .singleturn_outmem import SingleTurnOutMemStrategy + + +if TYPE_CHECKING: + from .base import BaseRerankerStrategy + + +class RerankerStrategyFactory: + """Factory class for creating reranker strategy instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "single_turn": SingleTurnStrategy, + "concat_background": ConcatBackgroundStrategy, + "singleturn_outmem": SingleTurnOutMemStrategy, + } + + @classmethod + def from_config(cls, config_factory: str = "single_turn") -> BaseRerankerStrategy: + if config_factory not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {config_factory}") + strategy_class = cls.backend_to_class[config_factory] + return strategy_class() diff --git a/src/memos/reranker/strategies/single_turn.py b/src/memos/reranker/strategies/single_turn.py new file mode 100644 index 000000000..d86744811 --- /dev/null +++ b/src/memos/reranker/strategies/single_turn.py @@ -0,0 +1,107 @@ +# memos/reranker/strategies/single_turn.py +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +from .base import BaseRerankerStrategy +from .dialogue_common import DialogueRankingTracker, extract_content, strip_memory_tags + + +class SingleTurnStrategy(BaseRerankerStrategy): + """ + Single turn dialogue strategy. + + This strategy processes dialogue pairs by concatenating user and assistant + messages into single strings for ranking. Each dialogue pair becomes a + separate document for ranking. + example: + >>> documents = ["chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] + >>> output memory item: ["Memory:xxx \n\n chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] + """ + + def prepare_documents( + self, + query: str, + graph_results: list, + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents based on single turn concatenation strategy. + + Args: + query: The search query + graph_results: List of graph results + top_k: Maximum number of items to return + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + + original_items = {} + tracker = DialogueRankingTracker() + for item in graph_results: + memory = strip_memory_tags(item) + sources = getattr(item.metadata, "sources", []) + original_items[item.id] = item + + # Group messages into pairs and concatenate + for i in range(0, len(sources), 2): + user_msg = sources[i] if i < len(sources) else {} + assistant_msg = sources[i + 1] if i + 1 < len(sources) else {} + + user_content = extract_content(user_msg) + assistant_content = extract_content(assistant_msg) + chat_time = getattr(user_msg, "chat_time", "") + + if user_content or assistant_content: # Only add non-empty pairs + pair_index = i // 2 + tracker.add_dialogue_pair( + item.id, pair_index, user_msg, assistant_msg, memory or "", chat_time + ) + + documents = tracker.get_documents_for_ranking() + return tracker, original_items, documents + + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[Any, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked dialogue pairs. + + Args: + ranked_indices: List of dialogue pair indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + reconstructed_items = [] + for idx, score in zip(ranked_indices, scores, strict=False): + dialogue_pair = tracker.get_dialogue_pair_by_index(idx) + if dialogue_pair and (dialogue_pair.memory_id in original_items): + original_item = original_items[dialogue_pair.memory_id] + reconstructed_item = deepcopy(original_item) + reconstructed_item.memory = ( + dialogue_pair.memory + + "\n\nsources-dialogue-pairs" + + dialogue_pair.combined_text + ) + reconstructed_items.append((reconstructed_item, score)) + + # Sort by aggregated score and return top_k + reconstructed_items.sort(key=lambda x: x[1], reverse=True) + return reconstructed_items[:top_k] diff --git a/src/memos/reranker/strategies/singleturn_outmem.py b/src/memos/reranker/strategies/singleturn_outmem.py new file mode 100644 index 000000000..de59fec97 --- /dev/null +++ b/src/memos/reranker/strategies/singleturn_outmem.py @@ -0,0 +1,98 @@ +# memos/reranker/strategies/single_turn.py +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING, Any + +from .dialogue_common import DialogueRankingTracker +from .single_turn import SingleTurnStrategy + + +if TYPE_CHECKING: + from .dialogue_common import DialogueRankingTracker + + +class SingleTurnOutMemStrategy(SingleTurnStrategy): + """ + Single turn dialogue strategy. + + This strategy processes dialogue pairs by concatenating user and assistant + messages into single strings for ranking. Each dialogue pair becomes a + separate document for ranking. + example: + >>> documents = ["chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] + >>> output memory item: ["Memory:xxx \n\n chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] + """ + + def prepare_documents( + self, + query: str, + graph_results: list, + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents based on single turn concatenation strategy. + + Args: + query: The search query + graph_results: List of graph results + top_k: Maximum number of items to return + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + return super().prepare_documents(query, graph_results, top_k, **kwargs) + + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[Any, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked dialogue pairs. + + Args: + ranked_indices: List of dialogue pair indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + # Group ranked pairs by memory_id + memory_groups = defaultdict(list) + memory_scores = defaultdict(list) + + for idx, score in zip(ranked_indices, scores, strict=False): + dialogue_pair = tracker.get_dialogue_pair_by_index(idx) + if dialogue_pair: + memory_groups[dialogue_pair.memory_id].append(dialogue_pair) + memory_scores[dialogue_pair.memory_id].append(score) + + reconstructed_items = [] + + for memory_id, _pairs in memory_groups.items(): + if memory_id not in original_items: + continue + original_item = original_items[memory_id] + + # Calculate aggregated score (e.g., max, mean, or weighted average) + pair_scores = memory_scores[memory_id] + + aggregated_score = max(pair_scores) if pair_scores else 0.0 + + reconstructed_items.append((original_item, aggregated_score)) + + # Sort by aggregated score and return top_k + reconstructed_items.sort(key=lambda x: x[1], reverse=True) + return reconstructed_items[:top_k] From e069928df30616dd2bf23665e4d8599ac6342d2d Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Mon, 27 Oct 2025 14:32:20 +0800 Subject: [PATCH 07/64] modify code in evaluation (#392) * modify code in evaluation * modify code in evaluation --------- Co-authored-by: yuan.wang Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- evaluation/.env-example | 9 +--- evaluation/scripts/PrefEval/pref_memos.py | 15 ++++--- evaluation/scripts/locomo/locomo_responses.py | 8 +--- evaluation/scripts/locomo/locomo_search.py | 11 +++-- evaluation/scripts/locomo/prompts.py | 17 +------- .../scripts/longmemeval/lme_responses.py | 13 +++--- evaluation/scripts/longmemeval/lme_search.py | 6 ++- evaluation/scripts/personamem/pm_responses.py | 16 +++---- evaluation/scripts/personamem/pm_search.py | 6 ++- evaluation/scripts/utils/pref_mem_utils.py | 43 ------------------- evaluation/scripts/utils/prompts.py | 13 +----- src/memos/vec_dbs/milvus.py | 2 +- 12 files changed, 41 insertions(+), 118 deletions(-) delete mode 100644 evaluation/scripts/utils/pref_mem_utils.py diff --git a/evaluation/.env-example b/evaluation/.env-example index bda935442..0e94e9caa 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -22,13 +22,8 @@ SUPERMEMORY_API_KEY="sm_xxx" MEMOBASE_API_KEY="xxx" MEMOBASE_PROJECT_URL="http://***.***.***.***:8019" -# pref -PRE_SPLIT_CHUNK=false # pre split chunk in client end, for personamem and prefeval -# 1. text_mem + pref_mem + instruction_completion: set INSTRUCT_COMPLETE=true, ABLATION_PREF=false -# 2. text_mem + pref_mem: set INSTRUCT_COMPLETE=false, ABLATION_PREF=false -# 3. text_mem: set INSTRUCT_COMPLETE=false, ABLATION_PREF=true -INSTRUCT_COMPLETE=true # use instruct complete format or not -ABLATION_PREF=false # remove pref mem, only text mem +# eval settings +PRE_SPLIT_CHUNK=false # Configuration Only For Scheduler # RabbitMQ Configuration diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 753a77d99..7336d4612 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -72,7 +72,6 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di """ Processes a single line of data, searching memory based on the question. """ - from utils.pref_mem_utils import create_mem_string i, line = line_data try: @@ -94,7 +93,13 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di start_time_search = time.monotonic() relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) search_memories_duration = time.monotonic() - start_time_search - memories_str = create_mem_string(relevant_memories) + memories_str = ( + "\n".join( + f"- {entry.get('memory', '')}" + for entry in relevant_memories["text_mem"][0]["memories"] + ) + + f"\n{relevant_memories['pref_mem']}" + ) memory_tokens_used = len(tokenizer.encode(memories_str)) @@ -119,7 +124,6 @@ def generate_response_for_line(line_data: tuple, openai_client: OpenAI, lib: str """ Generates a response for a single line of data using pre-fetched memories. """ - from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string from utils.prompts import PREFEVAL_ANSWER_PROMPT i, line = line_data @@ -146,10 +150,7 @@ def generate_response_for_line(line_data: tuple, openai_client: OpenAI, lib: str ) return original_data - memories_str = remove_pref_mem_from_mem_string(memories_str, frame=lib) - - template = add_pref_instruction(PREFEVAL_ANSWER_PROMPT, frame=lib) - system_prompt = template.format(context=memories_str) + system_prompt = PREFEVAL_ANSWER_PROMPT.format(context=memories_str) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index 2ae4dcb6e..35a444b7d 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -35,10 +35,7 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str question=question, ) else: - from utils.pref_mem_utils import add_pref_instruction - - template = add_pref_instruction(ANSWER_PROMPT_MEMOS, frame=frame) - prompt = template.format( + prompt = ANSWER_PROMPT_MEMOS.format( context=context, question=question, ) @@ -55,8 +52,6 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str async def process_qa(frame, qa, search_result, oai_client): - from utils.pref_mem_utils import remove_pref_mem_from_mem_string - start = time() query = qa.get("question") gold_answer = qa.get("answer") @@ -64,7 +59,6 @@ async def process_qa(frame, qa, search_result, oai_client): context = search_result.get("context") - context = remove_pref_mem_from_mem_string(context, frame) answer = await locomo_response(frame, oai_client, context, query) response_duration_ms = (time() - start) * 1000 diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index 19efb5b92..c629124dd 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -100,14 +100,19 @@ def memos_api_search( client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b ): from prompts import TEMPLATE_MEMOS - from utils.pref_mem_utils import create_mem_string start = time() search_a_results = client.search(query=query, user_id=speaker_a_user_id, top_k=top_k) search_b_results = client.search(query=query, user_id=speaker_b_user_id, top_k=top_k) - speaker_a_context = create_mem_string(search_a_results) - speaker_b_context = create_mem_string(search_b_results) + speaker_a_context = ( + "\n".join([i["memory"] for i in search_a_results["text_mem"][0]["memories"]]) + + f"\n{search_a_results['pref_mem']}" + ) + speaker_b_context = ( + "\n".join([i["memory"] for i in search_b_results["text_mem"][0]["memories"]]) + + f"\n{search_b_results['pref_mem']}" + ) context = TEMPLATE_MEMOS.format( speaker_1=speaker_a, diff --git a/evaluation/scripts/locomo/prompts.py b/evaluation/scripts/locomo/prompts.py index caf462f6a..152e5b87f 100644 --- a/evaluation/scripts/locomo/prompts.py +++ b/evaluation/scripts/locomo/prompts.py @@ -1,14 +1,3 @@ -import os - - -PREF_INSTRUCTIONS = """ - # Note: - Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. - Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. - When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. -""" - - ANSWER_PROMPT_MEM0 = """ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. @@ -114,7 +103,7 @@ 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. 7. Ensure your final answer is specific and avoids vague time references. - {pref_instructions} + {context} Question: {question} @@ -122,10 +111,6 @@ Answer: """ -if os.getenv("INSTRUCT_COMPLETE") == "true": - ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", PREF_INSTRUCTIONS) -else: - ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", "") custom_instructions = """ Generate personal memories that follow these guidelines: diff --git a/evaluation/scripts/longmemeval/lme_responses.py b/evaluation/scripts/longmemeval/lme_responses.py index 22f17c304..a4adf90b5 100644 --- a/evaluation/scripts/longmemeval/lme_responses.py +++ b/evaluation/scripts/longmemeval/lme_responses.py @@ -12,13 +12,11 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string from utils.prompts import LME_ANSWER_PROMPT -def lme_response(llm_client, context, question, question_date, frame): - template = add_pref_instruction(LME_ANSWER_PROMPT, frame=frame) - prompt = template.format( +def lme_response(llm_client, context, question, question_date): + prompt = LME_ANSWER_PROMPT.format( question=question, question_date=question_date, context=context, @@ -35,14 +33,13 @@ def lme_response(llm_client, context, question, question_date, frame): return result -def process_qa(user_id, search_result, llm_client, frame): +def process_qa(user_id, search_result, llm_client): start = time() search_result = search_result[0] question = search_result.get("question") question_date = search_result.get("date") context = search_result.get("search_context", "") - context = remove_pref_mem_from_mem_string(context, frame=frame) - anwer = lme_response(llm_client, context, question, question_date, frame) + anwer = lme_response(llm_client, context, question, question_date) response_duration_ms = (time() - start) * 1000 @@ -97,7 +94,7 @@ def main(frame, version, num_workers=4): future_to_user_id = {} for user_id, search_results in lme_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, oai_client, frame) + future = executor.submit(process_qa, user_id, search_results, oai_client) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index d21795eef..c02518083 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -13,7 +13,6 @@ import pandas as pd from tqdm import tqdm -from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, @@ -45,7 +44,10 @@ def mem0_search(client, query, user_id, top_k): def memos_search(client, query, user_id, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) - context = create_mem_string(results) + context = ( + "\n".join([i["memory"] for i in results["text_mem"][0]["memories"]]) + + f"\n{results['pref_mem']}" + ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context) duration_ms = (time() - start) * 1000 return context, duration_ms diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index 5b54f9bb8..ff561f8d8 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -14,7 +14,6 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re -from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string from utils.prompts import PM_ANSWER_PROMPT @@ -49,9 +48,8 @@ def _extract_only_options(text): return False, predicted_answer -def pm_response(llm_client, context, question, options, frame): - template = add_pref_instruction(PM_ANSWER_PROMPT, frame=frame) - prompt = template.format( +def pm_response(llm_client, context, question, options): + prompt = PM_ANSWER_PROMPT.format( question=question, context=context, options=options, @@ -68,19 +66,17 @@ def pm_response(llm_client, context, question, options, frame): return result -def process_qa(user_id, search_result, num_runs, llm_client, frame): +def process_qa(user_id, search_result, num_runs, llm_client): search_result = search_result[0] question = search_result.get("question") context = search_result.get("search_context", "") options = search_result.get("all_options", []) - context = remove_pref_mem_from_mem_string(context, frame=frame) - run_results = [] for idx in range(num_runs): start = time() - answer = pm_response(llm_client, context, question, options, frame) + answer = pm_response(llm_client, context, question, options) is_correct, answer = extract_choice_answer(answer, search_result.get("golden_answer", "")) response_duration_ms = (time() - start) * 1000 @@ -154,9 +150,7 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id = {} for user_id, search_results in pm_search_results.items(): - future = executor.submit( - process_qa, user_id, search_results, num_runs, oai_client, frame - ) + future = executor.submit(process_qa, user_id, search_results, num_runs, oai_client) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index 243c64589..c18e05623 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -14,7 +14,6 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, @@ -83,7 +82,10 @@ def memobase_search(client, query, user_id, top_k): def memos_search(client, user_id, query, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) - search_memories = create_mem_string(results) + search_memories = ( + "\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"]) + + f"\n{results['pref_mem']}" + ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) duration_ms = (time() - start) * 1000 diff --git a/evaluation/scripts/utils/pref_mem_utils.py b/evaluation/scripts/utils/pref_mem_utils.py deleted file mode 100644 index 22a5bb86c..000000000 --- a/evaluation/scripts/utils/pref_mem_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -import os -import sys - - -sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from prompts import PREF_INSTRUCTIONS - - -def create_mem_string(relevant_memories) -> str: - text_memories = [] - explicit = [] - implicit = [] - for item in relevant_memories["text_mem"]: - for mem in item["memories"]: - text_memories.append(mem["memory"]) - text_memories_text = "\n".join(f"{i + 1}. {mem}" for i, mem in enumerate(text_memories)).strip() - text_context = f"Plaintext Memory:\n{text_memories_text}\n" if text_memories_text else "" - - for item in relevant_memories.get("prefs", []): - for mem in item["memories"]: - if mem["metadata"]["preference_type"] == "explicit_preference": - explicit.append(mem["metadata"]["explicit_preference"]) - elif mem["metadata"]["preference_type"] == "implicit_preference": - implicit.append(mem["metadata"]["implicit_preference"]) - explicit_text = "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(explicit)).strip() - explicit_context = f"Explicit Preference:\n{explicit_text}\n" if explicit_text else "" - implicit_text = "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(implicit)).strip() - implicit_context = f"Implicit Preference:\n{implicit_text}\n" if implicit_text else "" - return text_context + explicit_context + implicit_context - - -def remove_pref_mem_from_mem_string(mem_string: str, frame: str) -> str: - if os.getenv("ABLATION_PREF", "false").lower() == "true" and frame == "memos-api": - tmp_list = mem_string.split("Plaintext Memory:") - if len(tmp_list) > 1: - return tmp_list[1].split("Explicit Preference:")[0] - return mem_string - - -def add_pref_instruction(template: str, frame: str): - if os.getenv("INSTRUCT_COMPLETE", "false").lower() == "true" and frame == "memos-api": - return template.replace("{pref_instructions}", PREF_INSTRUCTIONS) - return template.replace("{pref_instructions}", "") diff --git a/evaluation/scripts/utils/prompts.py b/evaluation/scripts/utils/prompts.py index 902bbb1be..32e6d6729 100644 --- a/evaluation/scripts/utils/prompts.py +++ b/evaluation/scripts/utils/prompts.py @@ -1,11 +1,3 @@ -PREF_INSTRUCTIONS = """ - # Note: - Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. - Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. - When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. -""" - - LME_ANSWER_PROMPT = """ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. @@ -25,7 +17,7 @@ 5. Formulate a precise, concise answer based solely on the evidence in the memories. 6. Double-check that your answer directly addresses the question asked. 7. Ensure your final answer is specific and avoids vague time references. - {pref_instructions} + {context} Current Date: {question_date} @@ -55,7 +47,7 @@ - Your final answer **must use parentheses**, like (a) or (b). - Do NOT list multiple choices. Choose only one. - Do NOT include extra text after . Just output the answer. - {pref_instructions} + # QUESTION: {question} @@ -71,7 +63,6 @@ You are a helpful AI. Answer the question based on the query and the following memories: User Memories: {context} - {pref_instructions} """ diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index fb19fd6ff..c1cb26362 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -138,7 +138,7 @@ def search( items.append( MilvusVecDBItem( - id=str(hit["id"]), + id=str(entity.get("id")), memory=entity.get("memory"), vector=entity.get("vector"), payload=entity.get("payload", {}), From 84adda675d2906ad28b229a572d06c6b6d2afacc Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Mon, 27 Oct 2025 20:20:37 +0800 Subject: [PATCH 08/64] fix bug in pref_mem return (#399) Co-authored-by: yuan.wang --- src/memos/api/routers/server_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index d2392f927..86e7a300a 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -332,6 +332,7 @@ def search_memories(search_req: APISearchRequest): "text_mem": [], "act_mem": [], "para_mem": [], + "pref_mem": "", } search_mode = search_req.mode From ce34bd185d2e3b1df34b6b5a3833f5d8520cbadf Mon Sep 17 00:00:00 2001 From: Wustzdy <67457465+wustzdy@users.noreply.github.com> Date: Mon, 27 Oct 2025 22:07:01 +0800 Subject: [PATCH 09/64] add polardb (#395) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add polardb.py * add polardb.py * add polar factory * delete * update get_memory_count * update get_memory_count * update node_not_exist * update remove_oldest_memory * fix * update get_node * update get_node * update update_node * update delete_node * add edge * add create_extension,create_graph,create_edge * add add_edge * edge_exist * edge_exist * edge_exist * update edge_exists * update polardb.py * update get_children_with_embeddings * update get_children_with_embeddings * update get_subgraph * update get_grouped_counts * update get_all_memory_items * update export_graph * remove * insert Memory * fix add_node * fix polardb.py * fix * fix get_subgraph * fix * get_grouped_counts * update get_by_metadata * get_grouped_counts * update get_grouped_counts * update get_grouped_counts * get_grouped_counts * update get_nodes * update search_by_embedding filter user_name * update search_by_embedding filter user_name * add filter user_name for update_node * get_structure_optimization_candidates * add filter user_name for update_node * fix * fix * fix * feat: 增加polardb的启动配置 * fix * fix * fix get_structure_optimization_candidates * fix get_all_memory_items * fix get_all_memory_items * remove embedding for get_nodes * fix get_structure_optimization_candidates * add _parse_node_new * update get_all_memory_items * update get_all_memory_items * update get_all_memory_items for include_embedding * feat: server router add polardb config * feat: server router add polardb config * update get_all_memory_items for include_embedding False * update get_all_memory_items for include_embedding False * fix * fix get_all_memory_items * update get_all_memory_items for include_embedding False * fix get_all_memory_items * update get_all_memory_items for include_embedding False * update get_grouped_counts * update get_grouped_counts * add_node and graph_id * fix * fix get_all_memory_items false * fix * fix get_all_memory_items true * fix * fix * fix * fix export_graph * fix export_graph * fix get_by_metadata * update get_neighbors_by_tag * update get_neighbors_by_tag * update get_neighbors_by_tag * fix * fix * add import_graph * fix * add get_edges * add clear * get_neighbors_by_tag * get_neighbors_by_tag * update get_by_metadata * search_by_emdedding remove embedding * fix:parseJson.py * fix:get_my_metadata * fix * fix get_by_metadata result * update polardb.py * fix _coerce_metadata * feat: add rerank time * feat: add rerank time * fix:node_not_exist * import node * import node * feat: fix merge_config_with_default * import node * fix * fix * feat: fix polardb * feat: fix scheduler method name * fix get_by_metadata for "query": "How long ago was Caroline's 18th birthday?" * fix get_by_metadata for "query": "How long ago was Caroline's 18th birthday?" * fix get_node format_param_value * feat: fix CONFIG * fix * feat: fix import * feat: delete test file * feat: fix polardb * feat: fix recall * Comment out unused configuration handling code Commented out code related to auto_create and embedding_dimension handling. * fix * feat: fix polardb * import polardb * feat: fix polardb * fix * feat: fix polardb * fix * fix * feat: fix polardb * feat: delete polardb * feat: fix utils * feat: fix polardb * feat: format polardb * feat: format utils --------- Co-authored-by: ccl <13282138256@163.com> Co-authored-by: liji <532311301@qq.com> Co-authored-by: CaralHsi --- docker/requirements.txt | 2 +- src/memos/api/config.py | 30 + src/memos/api/routers/server_router.py | 1 + src/memos/configs/graph_db.py | 54 + src/memos/graph_dbs/factory.py | 2 + src/memos/graph_dbs/polardb.py | 2777 ++++++++++++++++++++++++ src/memos/mem_cube/utils.py | 106 +- src/memos/memories/textual/item.py | 11 + src/memos/reranker/cosine_local.py | 2 + src/memos/reranker/http_bge.py | 2 + src/memos/reranker/noop.py | 3 +- 11 files changed, 2953 insertions(+), 37 deletions(-) create mode 100644 src/memos/graph_dbs/polardb.py diff --git a/docker/requirements.txt b/docker/requirements.txt index 4846f1832..d20c0b36e 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -157,4 +157,4 @@ volcengine-python-sdk==4.0.6 watchfiles==1.1.0 websockets==15.0.1 xlrd==2.0.2 -xlsxwriter==3.2.5 +xlsxwriter==3.2.5 \ No newline at end of file diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 4805d1062..d1bc6efff 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -309,6 +309,32 @@ def get_milvus_config(): "password": os.getenv("MILVUS_PASSWORD", "12345678"), } + @staticmethod + def get_polardb_config(user_id: str | None = None) -> dict[str, Any]: + """Get PolarDB configuration.""" + use_multi_db = os.getenv("POLAR_DB_USE_MULTI_DB", "false").lower() == "true" + + if use_multi_db: + # Multi-DB mode: each user gets their own database (physical isolation) + db_name = f"memos{user_id.replace('-', '')}" if user_id else "memos_default" + user_name = None + else: + # Shared-DB mode: all users share one database with user_name tag (logical isolation) + db_name = os.getenv("POLAR_DB_DB_NAME", "shared_memos_db") + user_name = f"memos{user_id.replace('-', '')}" if user_id else "memos_default" + + return { + "host": os.getenv("POLAR_DB_HOST", "localhost"), + "port": int(os.getenv("POLAR_DB_PORT", "5432")), + "user": os.getenv("POLAR_DB_USER", "root"), + "password": os.getenv("POLAR_DB_PASSWORD", "123456"), + "db_name": db_name, + "user_name": user_name, + "use_multi_db": use_multi_db, + "auto_create": True, + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + } + @staticmethod def get_mysql_config() -> dict[str, Any]: """Get MySQL configuration.""" @@ -540,6 +566,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General neo4j_community_config = APIConfig.get_neo4j_community_config(user_id) neo4j_config = APIConfig.get_neo4j_config(user_id) nebular_config = APIConfig.get_nebular_config(user_id) + polardb_config = APIConfig.get_polardb_config(user_id) internet_config = ( APIConfig.get_internet_config() if os.getenv("ENABLE_INTERNET", "false").lower() == "true" @@ -549,6 +576,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "nebular": nebular_config, + "polardb": polardb_config, } graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower() if graph_db_backend in graph_db_backend_map: @@ -607,10 +635,12 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default") neo4j_config = APIConfig.get_neo4j_config(user_id="default") nebular_config = APIConfig.get_nebular_config(user_id="default") + polardb_config = APIConfig.get_polardb_config(user_id="default") graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "nebular": nebular_config, + "polardb": polardb_config, } internet_config = ( APIConfig.get_internet_config() diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 86e7a300a..1331094a8 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -69,6 +69,7 @@ def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), "neo4j": APIConfig.get_neo4j_config(user_id=user_id), "nebular": APIConfig.get_nebular_config(user_id=user_id), + "polardb": APIConfig.get_polardb_config(user_id=user_id), } graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 2df917166..ce180606b 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -154,6 +154,59 @@ def validate_config(self): return self +class PolarDBGraphDBConfig(BaseConfig): + """ + PolarDB-specific configuration. + + Key concepts: + - `db_name`: The name of the target PolarDB database + - `user_name`: Used for logical tenant isolation if needed + - `auto_create`: Whether to automatically create the target database if it does not exist + - `use_multi_db`: Whether to use multi-database mode for physical isolation + + Example: + --- + host = "localhost" + port = 5432 + user = "postgres" + password = "password" + db_name = "memos_db" + user_name = "alice" + use_multi_db = True + auto_create = True + """ + + host: str = Field(..., description="Database host") + port: int = Field(default=5432, description="Database port") + user: str = Field(..., description="Database user") + password: str = Field(..., description="Database password") + db_name: str = Field(..., description="The name of the target PolarDB database") + user_name: str | None = Field( + default=None, + description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)", + ) + auto_create: bool = Field( + default=False, + description="Whether to auto-create the database if it does not exist", + ) + use_multi_db: bool = Field( + default=True, + description=( + "If True: use multi-database mode for physical isolation; " + "each tenant typically gets a separate database. " + "If False: use a single shared database with logical isolation by user_name." + ), + ) + embedding_dimension: int = Field(default=1024, description="Dimension of vector embedding") + + @model_validator(mode="after") + def validate_config(self): + """Validate config.""" + if not self.db_name: + raise ValueError("`db_name` must be provided") + return self + + class GraphDBConfigFactory(BaseModel): backend: str = Field(..., description="Backend for graph database") config: dict[str, Any] = Field(..., description="Configuration for the graph database backend") @@ -162,6 +215,7 @@ class GraphDBConfigFactory(BaseModel): "neo4j": Neo4jGraphDBConfig, "neo4j-community": Neo4jCommunityGraphDBConfig, "nebular": NebulaGraphDBConfig, + "polardb": PolarDBGraphDBConfig, } @field_validator("backend") diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index 0b38287eb..ec9cbcda0 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -5,6 +5,7 @@ from memos.graph_dbs.nebular import NebulaGraphDB from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB +from memos.graph_dbs.polardb import PolarDBGraphDB class GraphStoreFactory(BaseGraphDB): @@ -14,6 +15,7 @@ class GraphStoreFactory(BaseGraphDB): "neo4j": Neo4jGraphDB, "neo4j-community": Neo4jCommunityGraphDB, "nebular": NebulaGraphDB, + "polardb": PolarDBGraphDB, } @classmethod diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py new file mode 100644 index 000000000..38e71298f --- /dev/null +++ b/src/memos/graph_dbs/polardb.py @@ -0,0 +1,2777 @@ +import json +import time +import random +from datetime import datetime +from typing import Any, Literal + +import numpy as np + + +from memos.configs.graph_db import PolarDBGraphDBConfig +from memos.dependency import require_python_package +from memos.graph_dbs.base import BaseGraphDB +from memos.log import get_logger +from memos.utils import timed + +logger = get_logger(__name__) + +# Graph database configuration +GRAPH_NAME = "test_memos_graph" + + +def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: + node_id = item["id"] + memory = item["memory"] + metadata = item.get("metadata", {}) + return node_id, memory, metadata + + +def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """ + Ensure metadata has proper datetime fields and normalized types. + + - Fill `created_at` and `updated_at` if missing (in ISO 8601 format). + - Convert embedding to list of float if present. + """ + now = datetime.utcnow().isoformat() + + # Fill timestamps if missing + metadata.setdefault("created_at", now) + metadata.setdefault("updated_at", now) + + # Normalize embedding type + embedding = metadata.get("embedding") + if embedding and isinstance(embedding, list): + metadata["embedding"] = [float(x) for x in embedding] + + return metadata + + +def generate_vector(dim=1024, low=-0.2, high=0.2): + """Generate a random vector for testing purposes.""" + return [round(random.uniform(low, high), 6) for _ in range(dim)] + + +def find_embedding(metadata): + def find_embedding(item): + """Find an embedding vector within nested structures""" + for key in ["embedding", "embedding_1024", "embedding_3072", "embedding_768"]: + if key in item and isinstance(item[key], list): + return item[key] + if "metadata" in item and key in item["metadata"]: + return item["metadata"][key] + if "properties" in item and key in item["properties"]: + return item["properties"][key] + return None + + +def detect_embedding_field(embedding_list): + if not embedding_list: + return None + dim = len(embedding_list) + if dim == 1024: + return "embedding" + else: + print(f"⚠️ Unknown embedding dimension {dim}, skipping this vector") + return None + + +def convert_to_vector(embedding_list): + if not embedding_list: + return None + if isinstance(embedding_list, np.ndarray): + embedding_list = embedding_list.tolist() + return "[" + ",".join(str(float(x)) for x in embedding_list) + "]" + + +def clean_properties(props): + """Remove vector fields""" + vector_keys = {"embedding", "embedding_1024", "embedding_3072", "embedding_768"} + if not isinstance(props, dict): + return {} + return {k: v for k, v in props.items() if k not in vector_keys} + + +class PolarDBGraphDB(BaseGraphDB): + """PolarDB-based implementation using Apache AGE graph database extension.""" + + @require_python_package( + import_name="psycopg2", + install_command="pip install psycopg2-binary", + install_link="https://pypi.org/project/psycopg2-binary/", + ) + def __init__(self, config: PolarDBGraphDBConfig): + """PolarDB-based implementation using Apache AGE. + + Tenant Modes: + - use_multi_db = True: + Dedicated Database Mode (Multi-Database Multi-Tenant). + Each tenant or logical scope uses a separate PolarDB database. + `db_name` is the specific tenant database. + `user_name` can be None (optional). + + - use_multi_db = False: + Shared Database Multi-Tenant Mode. + All tenants share a single PolarDB database. + `db_name` is the shared database. + `user_name` is required to isolate each tenant's data at the node level. + All node queries will enforce `user_name` in WHERE conditions and store it in metadata, + but it will be removed automatically before returning to external consumers. + """ + import psycopg2 + + self.config = config + + # Handle both dict and object config + if isinstance(config, dict): + self.db_name = config.get("db_name") + self.user_name = config.get("user_name") + host = config.get("host") + port = config.get("port") + user = config.get("user") + password = config.get("password") + else: + self.db_name = config.db_name + self.user_name = config.user_name + host = config.host + port = config.port + user = config.user + password = config.password + + # Create connection + self.connection = psycopg2.connect( + host=host, port=port, user=user, password=password, dbname=self.db_name + ) + self.connection.autocommit = True + + """ + # Handle auto_create + # auto_create = config.get("auto_create", False) if isinstance(config, dict) else config.auto_create + # if auto_create: + # self._ensure_database_exists() + + # Create graph and tables + # self.create_graph() + # self.create_edge() + # self._create_graph() + + # Handle embedding_dimension + # embedding_dim = config.get("embedding_dimension", 1024) if isinstance(config,dict) else config.embedding_dimension + # self.create_index(dimensions=embedding_dim) + """ + + def _get_config_value(self, key: str, default=None): + """Safely get config value from either dict or object.""" + if isinstance(self.config, dict): + return self.config.get(key, default) + else: + return getattr(self.config, key, default) + + def _ensure_database_exists(self): + """Create database if it doesn't exist.""" + try: + # For PostgreSQL/PolarDB, we need to connect to a default database first + # This is a simplified implementation - in production you might want to handle this differently + logger.info(f"Using database '{self.db_name}'") + except Exception as e: + logger.error(f"Failed to access database '{self.db_name}': {e}") + raise + + @timed + def _create_graph(self): + """Create PostgreSQL schema and table for graph storage.""" + try: + with self.connection.cursor() as cursor: + # Create schema if it doesn't exist + cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') + logger.info(f"Schema '{self.db_name}_graph' ensured.") + + # Create Memory table if it doesn't exist + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Memory" ( + id TEXT PRIMARY KEY, + properties JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """) + logger.info(f"Memory table created in schema '{self.db_name}_graph'.") + + # Add embedding column if it doesn't exist (using JSONB for compatibility) + try: + cursor.execute(f""" + ALTER TABLE "{self.db_name}_graph"."Memory" + ADD COLUMN IF NOT EXISTS embedding JSONB; + """) + logger.info(f"Embedding column added to Memory table.") + except Exception as e: + logger.warning(f"Failed to add embedding column: {e}") + + # Create indexes + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); + """) + + # Create vector index for embedding field + try: + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + """) + logger.info(f"Vector index created for Memory table.") + except Exception as e: + logger.warning(f"Vector index creation failed (might not be supported): {e}") + + logger.info(f"Indexes created for Memory table.") + + except Exception as e: + logger.error(f"Failed to create graph schema: {e}") + raise e + + def create_index( + self, + label: str = "Memory", + vector_property: str = "embedding", + dimensions: int = 1024, + index_name: str = "memory_vector_index", + ) -> None: + """ + Create indexes for embedding and other fields. + Note: This creates PostgreSQL indexes on the underlying tables. + """ + try: + with self.connection.cursor() as cursor: + # Create indexes on the underlying PostgreSQL tables + # Apache AGE stores data in regular PostgreSQL tables + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); + """) + + # Try to create vector index, but don't fail if it doesn't work + try: + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); + """) + except Exception as ve: + logger.warning(f"Vector index creation failed (might not be supported): {ve}") + + logger.debug(f"Indexes created successfully.") + except Exception as e: + logger.warning(f"Failed to create indexes: {e}") + + def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: + """Get count of memory nodes by type.""" + user_name = user_name if user_name else self._get_config_value("user_name") + query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype + """ + query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + params = [f'"{memory_type}"', f'"{user_name}"'] + + print(f"[get_memory_count] Query: {query}, Params: {params}") + + try: + with self.connection.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result[0] if result else 0 + except Exception as e: + logger.error(f"[get_memory_count] Failed: {e}") + return -1 + + @timed + def node_not_exist(self, scope: str, user_name: str | None = None) -> int: + """Check if a node with given scope exists.""" + user_name = user_name if user_name else self._get_config_value("user_name") + query = f""" + SELECT id + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype + """ + query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + query += "\nLIMIT 1" + params = [f'"{scope}"', f'"{user_name}"'] + + print(f"[node_not_exist] Query: {query}, Params: {params}") + + try: + with self.connection.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + print(f"[node_not_exist] Query result: {result}") + return 1 if result else 0 + except Exception as e: + logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) + raise + + @timed + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: + """ + Remove all WorkingMemory nodes except the latest `keep_latest` entries. + + Args: + memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). + keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Use actual OFFSET logic, consistent with nebular.py + # First find IDs to delete, then delete them + select_query = f""" + SELECT id FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype + AND ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = %s::agtype + ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC + OFFSET %s + """ + select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest] + print(f"[remove_oldest_memory] Select query: {select_query}") + print(f"[remove_oldest_memory] Select params: {select_params}") + + try: + with self.connection.cursor() as cursor: + # Execute query to get IDs to delete + cursor.execute(select_query, select_params) + ids_to_delete = [row[0] for row in cursor.fetchall()] + + if not ids_to_delete: + logger.info(f"No {memory_type} memories to remove for user {user_name}") + return + + # Build delete query + placeholders = ",".join(["%s"] * len(ids_to_delete)) + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE id IN ({placeholders}) + """ + delete_params = ids_to_delete + + # Execute deletion + cursor.execute(delete_query, delete_params) + deleted_count = cursor.rowcount + logger.info( + f"Removed {deleted_count} oldest {memory_type} memories, keeping {keep_latest} latest for user {user_name}" + ) + except Exception as e: + logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) + raise + + @timed + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: + """ + Update node fields in PolarDB, auto-converting `created_at` and `updated_at` to datetime type if present. + """ + if not fields: + return + + user_name = user_name if user_name else self.config.user_name + + # Get the current node + current_node = self.get_node(id, user_name=user_name) + if not current_node: + return + + # Update properties but keep original id and memory fields + properties = current_node["metadata"].copy() + original_id = properties.get("id", id) # Preserve original ID + original_memory = current_node.get("memory", "") # Preserve original memory + + # If fields include memory, use it; otherwise keep original memory + if "memory" in fields: + original_memory = fields.pop("memory") + + properties.update(fields) + properties["id"] = original_id # Ensure ID is not overwritten + properties["memory"] = original_memory # Ensure memory is not overwritten + + # Handle embedding field + embedding_vector = None + if "embedding" in fields: + embedding_vector = fields.pop("embedding") + if not isinstance(embedding_vector, list): + embedding_vector = None + + # Build update query + if embedding_vector is not None: + query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = %s, embedding = %s + WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype + """ + params = [json.dumps(properties), json.dumps(embedding_vector), f'"{id}"'] + else: + query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = %s + WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype + """ + params = [json.dumps(properties), f'"{id}"'] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + params.append(f'"{user_name}"') + + print(f"[update_node] query: {query}, params: {params}") + try: + with self.connection.cursor() as cursor: + cursor.execute(query, params) + except Exception as e: + logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) + raise + + @timed + def delete_node(self, id: str, user_name: str | None = None) -> None: + """ + Delete a node from the graph. + Args: + id: Node identifier to delete. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype + """ + params = [f'"{id}"'] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + params.append(f'"{user_name}"') + + print(f"[delete_node] query: {query}, params: {params}") + try: + with self.connection.cursor() as cursor: + cursor.execute(query, params) + except Exception as e: + logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) + raise + + @timed + def create_extension(self): + extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] + try: + with self.connection.cursor() as cursor: + # Ensure in the correct database context + cursor.execute(f"SELECT current_database();") + current_db = cursor.fetchone()[0] + print(f"Current database context: {current_db}") + + for ext_name, ext_desc in extensions: + try: + cursor.execute(f"create extension if not exists {ext_name};") + print(f"✅ Extension '{ext_name}' ({ext_desc}) ensured.") + except Exception as e: + if "already exists" in str(e): + print(f"ℹ️ Extension '{ext_name}' ({ext_desc}) already exists.") + else: + print(f"⚠️ Failed to create extension '{ext_name}' ({ext_desc}): {e}") + logger.error( + f"Failed to create extension '{ext_name}': {e}", exc_info=True + ) + except Exception as e: + print(f"⚠️ Failed to access database context: {e}") + logger.error(f"Failed to access database context: {e}", exc_info=True) + + @timed + def create_graph(self): + try: + with self.connection.cursor() as cursor: + cursor.execute(f""" + SELECT COUNT(*) FROM ag_catalog.ag_graph + WHERE name = '{self.db_name}_graph'; + """) + graph_exists = cursor.fetchone()[0] > 0 + + if graph_exists: + print(f"ℹ️ Graph '{self.db_name}_graph' already exists.") + else: + cursor.execute(f"select create_graph('{self.db_name}_graph');") + print(f"✅ Graph database '{self.db_name}_graph' created.") + except Exception as e: + print(f"⚠️ Failed to create graph '{self.db_name}_graph': {e}") + logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) + + @timed + def create_edge(self): + """Create all valid edge types if they do not exist""" + + valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} + + for label_name in valid_rel_types: + print(f"🪶 Creating elabel: {label_name}") + try: + with self.connection.cursor() as cursor: + cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") + print(f"✅ Successfully created elabel: {label_name}") + except Exception as e: + if "already exists" in str(e): + print(f"ℹ️ Label '{label_name}' already exists, skipping.") + else: + print(f"⚠️ Failed to create label {label_name}: {e}") + logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) + + @timed + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + if not source_id or not target_id: + raise ValueError("[add_edge] source_id and target_id must be provided") + + source_exists = self.get_node(source_id) is not None + target_exists = self.get_node(target_id) is not None + + if not source_exists or not target_exists: + raise ValueError("[add_edge] source_id and target_id must be provided") + + properties = {} + if user_name is not None: + properties["user_name"] = user_name + query = f""" + INSERT INTO {self.db_name}_graph."{type}"(id, start_id, end_id, properties) + SELECT + ag_catalog._next_graph_id('{self.db_name}_graph'::name, '{type}'), + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{source_id}'::text::cstring), + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{target_id}'::text::cstring), + jsonb_build_object('user_name', '{user_name}')::text::agtype + WHERE NOT EXISTS ( + SELECT 1 FROM {self.db_name}_graph."{type}" + WHERE start_id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{source_id}'::text::cstring) + AND end_id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{target_id}'::text::cstring) + ); + """ + print(f"Executing add_edge: {query}") + + try: + with self.connection.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) + logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") + except Exception as e: + logger.error(f"Failed to insert edge: {e}", exc_info=True) + raise + + @timed + def delete_edge(self, source_id: str, target_id: str, type: str) -> None: + """ + Delete a specific edge between two nodes. + Args: + source_id: ID of the source node. + target_id: ID of the target node. + type: Relationship type to remove. + """ + query = f""" + DELETE FROM "{self.db_name}_graph"."Edges" + WHERE source_id = %s AND target_id = %s AND edge_type = %s + """ + + with self.connection.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + + @timed + def edge_exists_old( + self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING" + ) -> bool: + """ + Check if an edge exists between two nodes. + Args: + source_id: ID of the source node. + target_id: ID of the target node. + type: Relationship type. Use "ANY" to match any relationship type. + direction: Direction of the edge. + Use "OUTGOING" (default), "INCOMING", or "ANY". + Returns: + True if the edge exists, otherwise False. + """ + where_clauses = [] + params = [] + # SELECT * FROM + # cypher('memtensor_memos_graph', $$ + # MATCH(a: Memory + # {id: "13bb9df6-0609-4442-8bed-bba77dadac92"})-[r] - (b:Memory {id: "2dd03a5b-5d5f-49c9-9e0a-9a2a2899b98d"}) + # RETURN + # r + # $$) AS(r + # agtype); + + if direction == "OUTGOING": + where_clauses.append("source_id = %s AND target_id = %s") + params.extend([source_id, target_id]) + elif direction == "INCOMING": + where_clauses.append("source_id = %s AND target_id = %s") + params.extend([target_id, source_id]) + elif direction == "ANY": + where_clauses.append( + "((source_id = %s AND target_id = %s) OR (source_id = %s AND target_id = %s))" + ) + params.extend([source_id, target_id, target_id, source_id]) + else: + raise ValueError( + f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." + ) + + if type != "ANY": + where_clauses.append("edge_type = %s") + params.append(type) + + where_clause = " AND ".join(where_clauses) + + query = f""" + SELECT 1 FROM "{self.db_name}_graph"."Edges" + WHERE {where_clause} + LIMIT 1 + """ + + with self.connection.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result is not None + + @timed + def edge_exists( + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str | None = None, + ) -> bool: + """ + Check if an edge exists between two nodes. + Args: + source_id: ID of the source node. + target_id: ID of the target node. + type: Relationship type. Use "ANY" to match any relationship type. + direction: Direction of the edge. + Use "OUTGOING" (default), "INCOMING", or "ANY". + user_name (str, optional): User name for filtering in non-multi-db mode + Returns: + True if the edge exists, otherwise False. + """ + + # Prepare the relationship pattern + user_name = user_name if user_name else self.config.user_name + print(f"edge_exists direction: {direction}") + + # Prepare the match pattern with direction + if direction == "OUTGOING": + pattern = f"(a:Memory)-[r]->(b:Memory)" + elif direction == "INCOMING": + pattern = f"(a:Memory)<-[r]-(b:Memory)" + elif direction == "ANY": + pattern = f"(a:Memory)-[r]-(b:Memory)" + else: + raise ValueError( + f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." + ) + query = f"SELECT * FROM cypher('{self.db_name}_graph', $$" + query += f"\nMATCH {pattern}" + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + query += f"\nAND a.id = '{source_id}' AND b.id = '{target_id}'" + if type != "ANY": + query += f"\n AND type(r) = '{type}'" + + query += "\nRETURN r" + query += "\n$$) AS (r agtype)" + + print(f"edge_exists query: {query}") + with self.connection.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None + + @timed + def get_node( + self, id: str, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any] | None: + """ + Retrieve a Memory node by its unique ID. + + Args: + id (str): Node ID (Memory.id) + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + dict: Node properties as key-value pairs, or None if not found. + """ + + select_fields = "id, properties, embedding" if include_embedding else "id, properties" + + # Helper function to format parameter value + def format_param_value(value: str) -> str: + """Format parameter value to handle both quoted and unquoted formats""" + # Remove outer quotes if they exist + if value.startswith('"') and value.endswith('"'): + # Already has double quotes, return as is + return value + else: + # Add double quotes + return f'"{value}"' + + query = f""" + SELECT {select_fields} + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype + """ + params = [format_param_value(id)] + + # Only add user filter when user_name is provided + if user_name is not None: + query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + params.append(format_param_value(user_name)) + + print(f"[get_node] query: {query}, params: {params}") + try: + with self.connection.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + + if result: + if include_embedding: + node_id, properties_json, embedding_json = result + else: + node_id, properties_json = result + embedding_json = None + + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {id}") + properties = {} + else: + properties = properties_json if properties_json else {} + + # Parse embedding from JSONB if it exists and include_embedding is True + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {id}") + + return self._parse_node( + {"id": id, "memory": properties.get("memory", ""), **properties} + ) + return None + + except Exception as e: + logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) + return None + + @timed + def get_nodes( + self, ids: list[str], user_name: str | None = None, **kwargs + ) -> list[dict[str, Any]]: + """ + Retrieve the metadata and memory of a list of nodes. + Args: + ids: List of Node identifier. + Returns: + list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. + + Notes: + - Assumes all provided IDs are valid and exist. + - Returns empty list if input is empty. + """ + if not ids: + return [] + + # Build WHERE clause using agtype_access_operator like get_node method + where_conditions = [] + params = [] + + for id_val in ids: + where_conditions.append( + "ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = %s::agtype" + ) + params.append(f"{id_val}") + + where_clause = " OR ".join(where_conditions) + + query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + WHERE ({where_clause}) + """ + + user_name = user_name if user_name else self.config.user_name + query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + params.append(f'"{user_name}"') + + print(f"[get_nodes] query: {query}, params: {params}") + with self.connection.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} + + # Parse embedding from JSONB if it exists + if embedding_json is not None: + try: + print("embedding_json:", embedding_json) + # remove embedding + """ + embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json + # properties["embedding"] = embedding + """ + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + ) + return nodes + + @timed + def get_edges_old( + self, id: str, type: str = "ANY", direction: str = "ANY" + ) -> list[dict[str, str]]: + """ + Get edges connected to a node, with optional type and direction filter. + + Args: + id: Node ID to retrieve edges for. + type: Relationship type to match, or 'ANY' to match all. + direction: 'OUTGOING', 'INCOMING', or 'ANY'. + + Returns: + List of edges: + [ + {"from": "source_id", "to": "target_id", "type": "RELATE"}, + ... + ] + """ + + # Create a simple edge table to store relationships (if not exists) + try: + with self.connection.cursor() as cursor: + # Create edge table + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Edges" ( + id SERIAL PRIMARY KEY, + source_id TEXT NOT NULL, + target_id TEXT NOT NULL, + edge_type TEXT NOT NULL, + properties JSONB, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (source_id) REFERENCES "{self.db_name}_graph"."Memory"(id), + FOREIGN KEY (target_id) REFERENCES "{self.db_name}_graph"."Memory"(id) + ); + """) + + # Create indexes + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_source + ON "{self.db_name}_graph"."Edges" (source_id); + """) + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_target + ON "{self.db_name}_graph"."Edges" (target_id); + """) + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_type + ON "{self.db_name}_graph"."Edges" (edge_type); + """) + except Exception as e: + logger.warning(f"Failed to create edges table: {e}") + + # Query edges + where_clauses = [] + params = [id] + + if type != "ANY": + where_clauses.append("edge_type = %s") + params.append(type) + + if direction == "OUTGOING": + where_clauses.append("source_id = %s") + elif direction == "INCOMING": + where_clauses.append("target_id = %s") + else: # ANY + where_clauses.append("(source_id = %s OR target_id = %s)") + params.append(id) # Add second parameter for ANY direction + + where_clause = " AND ".join(where_clauses) + + query = f""" + SELECT source_id, target_id, edge_type + FROM "{self.db_name}_graph"."Edges" + WHERE {where_clause} + """ + + with self.connection.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + edges = [] + for row in results: + source_id, target_id, edge_type = row + edges.append({"from": source_id, "to": target_id, "type": edge_type}) + return edges + + def get_neighbors( + self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" + ) -> list[str]: + """Get connected node IDs in a specific direction and relationship type.""" + raise NotImplementedError + + @timed + def get_neighbors_by_tag_old( + self, + tags: list[str], + exclude_ids: list[str], + top_k: int = 5, + min_overlap: int = 1, + ) -> list[dict[str, Any]]: + """ + Find top-K neighbor nodes with maximum tag overlap. + + Args: + tags: The list of tags to match. + exclude_ids: Node IDs to exclude (e.g., local cluster). + top_k: Max number of neighbors to return. + min_overlap: Minimum number of overlapping tags required. + + Returns: + List of dicts with node details and overlap count. + """ + # Build query conditions + where_clauses = [] + params = [] + + # Exclude specified IDs + if exclude_ids: + placeholders = ",".join(["%s"] * len(exclude_ids)) + where_clauses.append(f"id NOT IN ({placeholders})") + params.extend(exclude_ids) + + # Status filter + where_clauses.append("properties->>'status' = %s") + params.append("activated") + + # Type filter + where_clauses.append("properties->>'type' != %s") + params.append("reasoning") + + where_clauses.append("properties->>'memory_type' != %s") + params.append("WorkingMemory") + + # User filter + if not self._get_config_value("use_multi_db", True) and self._get_config_value("user_name"): + where_clauses.append("properties->>'user_name' = %s") + params.append(self._get_config_value("user_name")) + + where_clause = " AND ".join(where_clauses) + + # Get all candidate nodes + query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + with self.connection.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + nodes_with_overlap = [] + for row in results: + node_id, properties_json, embedding_json = row + properties = properties_json if properties_json else {} + + # Parse embedding + if embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + + # Compute tag overlap + node_tags = properties.get("tags", []) + if isinstance(node_tags, str): + try: + node_tags = json.loads(node_tags) + except (json.JSONDecodeError, TypeError): + node_tags = [] + + overlap_tags = [tag for tag in tags if tag in node_tags] + overlap_count = len(overlap_tags) + + if overlap_count >= min_overlap: + node_data = self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + nodes_with_overlap.append((node_data, overlap_count)) + + # Sort by overlap count and return top_k + nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) + return [node for node, _ in nodes_with_overlap[:top_k]] + + @timed + def get_children_with_embeddings( + self, id: str, user_name: str | None = None + ) -> list[dict[str, Any]]: + """Get children nodes with their embeddings.""" + user_name = user_name if user_name else self._get_config_value("user_name") + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + + query = f""" + WITH t as ( + SELECT * + FROM cypher('{self.db_name}_graph', $$ + MATCH (p:Memory)-[r:PARENT]->(c:Memory) + WHERE p.id = '{id}' {where_user} + RETURN id(c) as cid, c.id AS id, c.memory AS memory + $$) as (cid agtype, id agtype, memory agtype) + ) + SELECT t.id, m.embedding, t.memory FROM t, + "{self.db_name}_graph"."Memory" m + WHERE t.cid::graphid = m.id; + """ + + print("[get_children_with_embeddings] query:", query) + + try: + with self.connection.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + children = [] + for row in results: + # Handle child_id - remove possible quotes + child_id_raw = row[0].value if hasattr(row[0], "value") else str(row[0]) + if isinstance(child_id_raw, str): + # If string starts and ends with quotes, remove quotes + if child_id_raw.startswith('"') and child_id_raw.endswith('"'): + child_id = child_id_raw[1:-1] + else: + child_id = child_id_raw + else: + child_id = str(child_id_raw) + + # Handle embedding - get from database embedding column + embedding_raw = row[1] + embedding = [] + if embedding_raw is not None: + try: + if isinstance(embedding_raw, str): + # If it is a JSON string, parse it + embedding = json.loads(embedding_raw) + elif isinstance(embedding_raw, list): + # If already a list, use directly + embedding = embedding_raw + else: + # Try converting to list + embedding = list(embedding_raw) + except (json.JSONDecodeError, TypeError, ValueError) as e: + logger.warning( + f"Failed to parse embedding for child node {child_id}: {e}" + ) + embedding = [] + + # Handle memory - remove possible quotes + memory_raw = row[2].value if hasattr(row[2], "value") else str(row[2]) + if isinstance(memory_raw, str): + # If string starts and ends with quotes, remove quotes + if memory_raw.startswith('"') and memory_raw.endswith('"'): + memory = memory_raw[1:-1] + else: + memory = memory_raw + else: + memory = str(memory_raw) + + children.append({"id": child_id, "embedding": embedding, "memory": memory}) + + return children + + except Exception as e: + logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) + return [] + + def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: + """Get the path of nodes from source to target within a limited depth.""" + raise NotImplementedError + + @timed + def get_subgraph( + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, + ) -> dict[str, Any]: + """ + Retrieve a local subgraph centered at a given node. + Args: + center_id: The ID of the center node. + depth: The hop distance for neighbors. + center_status: Required status for center node. + user_name (str, optional): User name for filtering in non-multi-db mode + Returns: + { + "core_node": {...}, + "neighbors": [...], + "edges": [...] + } + """ + if not 1 <= depth <= 5: + raise ValueError("depth must be 1-5") + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Use a simplified query to get the subgraph (temporarily only direct neighbors) + """ + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH(center: Memory)-[r * 1..{depth}]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN + collect(DISTINCT + center), collect(DISTINCT + neighbor), collect(DISTINCT + r) + $$ ) as (centers agtype, neighbors agtype, rels agtype); + """ + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH(center: Memory)-[r * 1..{depth}]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + RETURN + collect(DISTINCT + center), collect(DISTINCT + neighbor), collect(DISTINCT + r) + $$ ) as (centers agtype, neighbors agtype, rels agtype); + """ + + try: + with self.connection.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + print("[get_subgraph] result:", result) + + if not result or not result[0]: + return {"core_node": None, "neighbors": [], "edges": []} + + # Parse center node + centers_data = result[0] if result[0] else "[]" + neighbors_data = result[1] if result[1] else "[]" + edges_data = result[2] if result[2] else "[]" + + # Parse JSON data + try: + # Clean ::vertex and ::edge suffixes in data + if isinstance(centers_data, str): + centers_data = centers_data.replace("::vertex", "") + if isinstance(neighbors_data, str): + neighbors_data = neighbors_data.replace("::vertex", "") + if isinstance(edges_data, str): + edges_data = edges_data.replace("::edge", "") + + centers_list = ( + json.loads(centers_data) if isinstance(centers_data, str) else centers_data + ) + neighbors_list = ( + json.loads(neighbors_data) + if isinstance(neighbors_data, str) + else neighbors_data + ) + edges_list = ( + json.loads(edges_data) if isinstance(edges_data, str) else edges_data + ) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON data: {e}") + return {"core_node": None, "neighbors": [], "edges": []} + + # Parse center node + core_node = None + if centers_list and len(centers_list) > 0: + center_data = centers_list[0] + if isinstance(center_data, dict) and "properties" in center_data: + core_node = self._parse_node(center_data["properties"]) + + # Parse neighbor nodes + neighbors = [] + if isinstance(neighbors_list, list): + for neighbor_data in neighbors_list: + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_parsed = self._parse_node(neighbor_data["properties"]) + neighbors.append(neighbor_parsed) + + # Parse edges + edges = [] + if isinstance(edges_list, list): + for edge_group in edges_list: + if isinstance(edge_group, list): + for edge_data in edge_group: + if isinstance(edge_data, dict): + edges.append( + { + "type": edge_data.get("label", ""), + "source": edge_data.get("start_id", ""), + "target": edge_data.get("end_id", ""), + } + ) + + return {"core_node": core_node, "neighbors": neighbors, "edges": edges} + + except Exception as e: + logger.error(f"Failed to get subgraph: {e}", exc_info=True) + return {"core_node": None, "neighbors": [], "edges": []} + + def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: + """Get the ordered context chain starting from a node.""" + raise NotImplementedError + + @timed + def search_by_embedding( + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ) -> list[dict]: + """ + Retrieve node IDs based on vector similarity using PostgreSQL vector operations. + """ + # Build WHERE clause dynamically like nebular.py + where_clauses = [] + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + where_clauses.append("embedding is not null") + # Add user_name filter like nebular.py + + """ + # user_name = self._get_config_value("user_name") + # if not self.config.use_multi_db and user_name: + # if kwargs.get("cube_name"): + # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{kwargs['cube_name']}\"'::agtype") + # else: + # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype") + """ + user_name = user_name if user_name else self.config.user_name + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + + # Add search_filter conditions like nebular.py + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + ) + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + # Keep original simple query structure but add dynamic WHERE clause + query = f""" + WITH t AS ( + SELECT id, + properties, + timeline, + ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, + (1 - (embedding <=> %s::vector(1024))) AS scope + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY scope DESC + LIMIT {top_k} + ) + SELECT * + FROM t + WHERE scope > 0.1; + """ + params = [vector] + + print( + f"[search_by_embedding] query: {query}, params: {params}, where_clause: {where_clause}" + ) + with self.connection.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + """ + polarId = row[0] # id + properties = row[1] # properties + # embedding = row[3] # embedding + """ + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid) + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + return output[:top_k] + + @timed + def get_by_metadata( + self, filters: list[dict[str, Any]], user_name: str | None = None + ) -> list[str]: + """ + Retrieve node IDs that match given metadata filters. + Supports exact match. + + Args: + filters: List of filter dicts like: + [ + {"field": "key", "op": "in", "value": ["A", "B"]}, + {"field": "confidence", "op": ">=", "value": 80}, + {"field": "tags", "op": "contains", "value": "AI"}, + ... + ] + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + list[str]: Node IDs whose metadata match the filter conditions. (AND logic). + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build WHERE conditions for cypher query + where_conditions = [] + + for f in filters: + field = f["field"] + op = f.get("op", "=") + value = f["value"] + + # Format value + if isinstance(value, str): + # Escape single quotes in string values + escaped_str = value.replace("'", "''") + escaped_value = f"'{escaped_str}'" + elif isinstance(value, list): + # Handle list values - use double quotes for Cypher arrays + list_items = [] + for v in value: + if isinstance(v, str): + # Escape double quotes in string values for Cypher + escaped_str = v.replace('"', '\\"') + list_items.append(f'"{escaped_str}"') + else: + list_items.append(str(v)) + escaped_value = f"[{', '.join(list_items)}]" + else: + escaped_value = f"'{value}'" if isinstance(value, str) else str(value) + print("op=============:", op) + # Build WHERE conditions + if op == "=": + where_conditions.append(f"n.{field} = {escaped_value}") + elif op == "in": + where_conditions.append(f"n.{field} IN {escaped_value}") + """ + # where_conditions.append(f"{escaped_value} IN n.{field}") + """ + elif op == "contains": + where_conditions.append(f"{escaped_value} IN n.{field}") + """ + # where_conditions.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0") + """ + elif op == "starts_with": + where_conditions.append(f"n.{field} STARTS WITH {escaped_value}") + elif op == "ends_with": + where_conditions.append(f"n.{field} ENDS WITH {escaped_value}") + elif op in [">", ">=", "<", "<="]: + where_conditions.append(f"n.{field} {op} {escaped_value}") + else: + raise ValueError(f"Unsupported operator: {op}") + + # Add user_name filter + escaped_user_name = user_name.replace("'", "''") + where_conditions.append(f"n.user_name = '{escaped_user_name}'") + + where_str = " AND ".join(where_conditions) + + # Use cypher query + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {where_str} + RETURN n.id AS id + $$) AS (id agtype) + """ + + print(f"[get_by_metadata] query: {cypher_query}, where_str: {where_str}") + ids = [] + try: + with self.connection.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + print("[get_by_metadata] result:", results) + ids = [str(item[0]).strip('"') for item in results] + except Exception as e: + print("Failed to get metadata:", {e}) + logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + + return ids + + @timed + def get_grouped_counts1( + self, + group_fields: list[str], + where_clause: str = "", + params: dict[str, Any] | None = None, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Count nodes grouped by any fields. + + Args: + group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] + where_clause (str, optional): Extra WHERE condition. E.g., + "WHERE n.status = 'activated'" + params (dict, optional): Parameters for WHERE clause. + + Returns: + list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] + """ + user_name = user_name if user_name else self.config.user_name + if not group_fields: + raise ValueError("group_fields cannot be empty") + + final_params = params.copy() if params else {} + print("username:" + user_name) + if not self.config.use_multi_db and (self.config.user_name or user_name): + user_clause = "n.user_name = $user_name" + final_params["user_name"] = user_name + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" + else: + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" + print("where_clause:" + where_clause) + # Force RETURN field AS field to guarantee key match + group_fields_cypher = ", ".join([f"n.{field} AS {field}" for field in group_fields]) + """ + # group_fields_cypher_polardb = "agtype, ".join([f"{field}" for field in group_fields]) + """ + group_fields_cypher_polardb = ", ".join([f"{field} agtype" for field in group_fields]) + print("group_fields_cypher_polardb:" + group_fields_cypher_polardb) + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + {where_clause} + RETURN {group_fields_cypher}, COUNT(n) AS count1 + $$ ) as ({group_fields_cypher_polardb}, count1 agtype); + """ + print("get_grouped_counts:" + query) + try: + with self.connection.cursor() as cursor: + # Handle parameterized query + if params and isinstance(params, list): + cursor.execute(query, final_params) + else: + cursor.execute(query) + results = cursor.fetchall() + + output = [] + for row in results: + group_values = {} + for i, field in enumerate(group_fields): + value = row[i] + if hasattr(value, "value"): + group_values[field] = value.value + else: + group_values[field] = str(value) + count_value = row[-1] # Last column is count + output.append({**group_values, "count": count_value}) + + return output + + except Exception as e: + logger.error(f"Failed to get grouped counts: {e}", exc_info=True) + return [] + + @timed + def get_grouped_counts( + self, + group_fields: list[str], + where_clause: str = "", + params: dict[str, Any] | None = None, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Count nodes grouped by any fields. + + Args: + group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] + where_clause (str, optional): Extra WHERE condition. E.g., + "WHERE n.status = 'activated'" + params (dict, optional): Parameters for WHERE clause. + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] + """ + if not group_fields: + raise ValueError("group_fields cannot be empty") + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build user clause + user_clause = f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" + else: + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" + + # Inline parameters if provided + if params and isinstance(params, dict): + for key, value in params.items(): + # Handle different value types appropriately + if isinstance(value, str): + value = f"'{value}'" + where_clause = where_clause.replace(f"${key}", str(value)) + + # Handle user_name parameter in where_clause + if "user_name = %s" in where_clause: + where_clause = where_clause.replace( + "user_name = %s", + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype", + ) + + # Build return fields and group by fields + return_fields = [] + group_by_fields = [] + + for field in group_fields: + alias = field.replace(".", "_") + return_fields.append( + f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype) AS {alias}" + ) + group_by_fields.append(alias) + + # Full SQL query construction + query = f""" + SELECT {", ".join(return_fields)}, COUNT(*) AS count + FROM "{self.db_name}_graph"."Memory" + {where_clause} + GROUP BY {", ".join(group_by_fields)} + """ + + print("[get_grouped_counts] query:", query) + + try: + with self.connection.cursor() as cursor: + # Handle parameterized query + if params and isinstance(params, list): + cursor.execute(query, params) + else: + cursor.execute(query) + results = cursor.fetchall() + + output = [] + for row in results: + group_values = {} + for i, field in enumerate(group_fields): + value = row[i] + if hasattr(value, "value"): + group_values[field] = value.value + else: + group_values[field] = str(value) + count_value = row[-1] # Last column is count + output.append({**group_values, "count": count_value}) + + return output + + except Exception as e: + logger.error(f"Failed to get grouped counts: {e}", exc_info=True) + return [] + + def deduplicate_nodes(self) -> None: + """Deduplicate redundant or semantically similar nodes.""" + raise NotImplementedError + + def detect_conflicts(self) -> list[tuple[str, str]]: + """Detect conflicting nodes based on logical or semantic inconsistency.""" + raise NotImplementedError + + def merge_nodes(self, id1: str, id2: str) -> str: + """Merge two similar or duplicate nodes into one.""" + raise NotImplementedError + + @timed + def clear(self, user_name: str | None = None) -> None: + """ + Clear the entire graph if the target database exists. + + Args: + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + try: + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.user_name = '{user_name}' + DETACH DELETE n + $$) AS (result agtype) + """ + + with self.connection.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") + + except Exception as e: + logger.error(f"[ERROR] Failed to clear database: {e}") + + @timed + def export_graph( + self, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any]: + """ + Export all graph nodes and edges in a structured form. + Args: + include_embedding (bool): Whether to include the large embedding field. + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + { + "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], + "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] + } + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + try: + # Export nodes + if include_embedding: + node_query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype + """ + else: + node_query = f""" + SELECT id, properties + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype + """ + + with self.connection.cursor() as cursor: + cursor.execute(node_query) + node_results = cursor.fetchall() + nodes = [] + + for row in node_results: + if include_embedding: + node_id, properties_json, embedding_json = row + else: + node_id, properties_json = row + embedding_json = None + + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except json.JSONDecodeError: + properties = {} + else: + properties = properties_json if properties_json else {} + + # # Build node data + + """ + # node_data = { + # "id": properties.get("id", node_id), + # "memory": properties.get("memory", ""), + # "metadata": properties + # } + """ + + if include_embedding and embedding_json is not None: + properties["embedding"] = embedding_json + + nodes.append(self._parse_node(properties)) + + except Exception as e: + logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) + raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e + + try: + # Export edges using cypher query + edge_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (a:Memory)-[r]->(b:Memory) + WHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}' + RETURN a.id AS source, b.id AS target, type(r) as edge + $$) AS (source agtype, target agtype, edge agtype) + """ + + with self.connection.cursor() as cursor: + cursor.execute(edge_query) + edge_results = cursor.fetchall() + edges = [] + + for row in edge_results: + source_agtype, target_agtype, edge_agtype = row + edges.append( + { + "source": source_agtype.value + if hasattr(source_agtype, "value") + else str(source_agtype), + "target": target_agtype.value + if hasattr(target_agtype, "value") + else str(target_agtype), + "type": edge_agtype.value + if hasattr(edge_agtype, "value") + else str(edge_agtype), + } + ) + + except Exception as e: + logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) + raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e + + return {"nodes": nodes, "edges": edges} + + @timed + def count_nodes(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name + + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' + AND n.user_name = '{user_name}' + RETURN count(n) + $$) AS (count agtype) + """ + + result = self.execute_query(query) + return int(result.one_or_none()["count"].value) + + @timed + def get_all_memory_items( + self, scope: str, include_embedding: bool = False, user_name: str | None = None + ) -> list[dict]: + """ + Retrieve all memory items of a specific memory_type. + + Args: + scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + list[dict]: Full list of memory items under this scope. + """ + user_name = user_name if user_name else self._get_config_value("user_name") + if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: + raise ValueError(f"Unsupported memory type scope: {scope}") + + # Use cypher query to retrieve memory items + if include_embedding: + cypher_query = f""" + WITH t as ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' + RETURN id(n) as id1,n + LIMIT 100 + $$) AS (id1 agtype,n agtype) + ) + SELECT + m.embedding, + t.n + FROM t, + {self.db_name}_graph."Memory" m + WHERE t.id1 = m.id; + """ + nodes = [] + node_ids = set() + print("[get_all_memory_items embedding true ] cypher_query:", cypher_query) + try: + with self.connection.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + + for row in results: + """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] + + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + nodes.append(node) + node_ids.add(node_id) + + except Exception as e: + logger.error(f"Failed to get memories: {e}", exc_info=True) + + return nodes + else: + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' + RETURN properties(n) as props + LIMIT 100 + $$) AS (nprops agtype) + """ + print("[get_all_memory_items embedding false ] cypher_query:", cypher_query) + + nodes = [] + try: + with self.connection.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + + for row in results: + """ + if isinstance(row[0], str): + memory_data = json.loads(row[0]) + else: + memory_data = row[0] # 如果已经是字典,直接使用 + nodes.append(self._parse_node(memory_data)) + """ + memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] + nodes.append(self._parse_node(memory_data)) + + except Exception as e: + logger.error(f"Failed to get memories: {e}", exc_info=True) + + return nodes + + def get_all_memory_items_old( + self, scope: str, include_embedding: bool = False, user_name: str | None = None + ) -> list[dict]: + """ + Retrieve all memory items of a specific memory_type. + + Args: + scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + list[dict]: Full list of memory items under this scope. + """ + user_name = user_name if user_name else self._get_config_value("user_name") + if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: + raise ValueError(f"Unsupported memory type scope: {scope}") + + # Use cypher query to retrieve memory items + if include_embedding: + cypher_query = f""" + WITH t as ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' + RETURN id(n) as id1,n + LIMIT 100 + $$) AS (id1 agtype,n agtype) + ) + SELECT + m.embedding, + t.n + FROM t, + {self.db_name}_graph."Memory" m + WHERE t.id1 = m.id; + """ + else: + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' + RETURN properties(n) as props + LIMIT 100 + $$) AS (nprops agtype) + """ + print("[get_all_memory_items] cypher_query:", cypher_query) + + nodes = [] + try: + with self.connection.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + print("[get_all_memory_items] results:", results) + + for row in results: + node_agtype = row[0] + + # Handle string-formatted data + if isinstance(node_agtype, str): + try: + # Remove ::vertex suffix + json_str = node_agtype.replace("::vertex", "") + node_data = json.loads(json_str) + + if isinstance(node_data, dict) and "properties" in node_data: + properties = node_data["properties"] + # Build node data + parsed_node_data = { + "id": properties.get("id", ""), + "memory": properties.get("memory", ""), + "metadata": properties, + } + + if include_embedding and "embedding" in properties: + parsed_node_data["embedding"] = properties["embedding"] + + nodes.append(self._parse_node(parsed_node_data)) + print( + f"[get_all_memory_items] ✅ Parsed node successfully: {properties.get('id', '')}" + ) + else: + print( + f"[get_all_memory_items] ❌ Invalid node data format: {node_data}" + ) + + except (json.JSONDecodeError, TypeError) as e: + print(f"[get_all_memory_items] ❌ JSON parsing failed: {e}") + elif node_agtype and hasattr(node_agtype, "value"): + # Handle agtype object + node_props = node_agtype.value + if isinstance(node_props, dict): + # Parse node properties + node_data = { + "id": node_props.get("id", ""), + "memory": node_props.get("memory", ""), + "metadata": node_props, + } + + if include_embedding and "embedding" in node_props: + node_data["embedding"] = node_props["embedding"] + + nodes.append(self._parse_node(node_data)) + print( + f"[get_all_memory_items] ✅ Parsed agtype node successfully: {node_props.get('id', '')}" + ) + else: + print( + f"[get_all_memory_items] ❌ Unknown data format: {type(node_agtype)}" + ) + + except Exception as e: + logger.error(f"Failed to get memories: {e}", exc_info=True) + + return nodes + + @timed + def get_structure_optimization_candidates( + self, scope: str, include_embedding: bool = False, user_name: str | None = None + ) -> list[dict]: + """ + Find nodes that are likely candidates for structure optimization: + - Isolated nodes, nodes with empty background, or nodes with exactly one child. + - Plus: the child of any parent node that has exactly one child. + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build return fields based on include_embedding flag + if include_embedding: + return_fields = "id(n) as id1,n" + return_fields_agtype = " id1 agtype,n agtype" + else: + # Build field list without embedding + return_fields = ",".join( + [ + "n.id AS id", + "n.memory AS memory", + "n.user_name AS user_name", + "n.user_id AS user_id", + "n.session_id AS session_id", + "n.status AS status", + "n.key AS key", + "n.confidence AS confidence", + "n.tags AS tags", + "n.created_at AS created_at", + "n.updated_at AS updated_at", + "n.memory_type AS memory_type", + "n.sources AS sources", + "n.source AS source", + "n.node_type AS node_type", + "n.visibility AS visibility", + "n.usage AS usage", + "n.background AS background", + "n.graph_id as graph_id", + ] + ) + fields = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "usage", + "background", + "graph_id", + ] + return_fields_agtype = ", ".join([f"{field} agtype" for field in fields]) + + # Use OPTIONAL MATCH to find isolated nodes (no parents or children) + cypher_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE n.memory_type = '{scope}' + AND n.status = 'activated' + AND n.user_name = '{user_name}' + OPTIONAL MATCH (n)-[:PARENT]->(c:Memory) + OPTIONAL MATCH (p:Memory)-[:PARENT]->(n) + WITH n, c, p + WHERE c IS NULL AND p IS NULL + RETURN {return_fields} + $$) AS ({return_fields_agtype}) + """ + if include_embedding: + cypher_query = f""" + WITH t as ( + {cypher_query} + ) + SELECT + m.embedding, + t.n + FROM t, + {self.db_name}_graph."Memory" m + WHERE t.id1 = m.id + """ + print("[get_structure_optimization_candidates] query:", cypher_query) + + candidates = [] + node_ids = set() + try: + with self.connection.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + print("result------", len(results)) + for row in results: + if include_embedding: + # When include_embedding=True, return full node object + """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] + + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + else: + # When include_embedding=False, return field dictionary + # Define field names matching the RETURN clause + field_names = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "usage", + "background", + "graph_id", + ] + + # Convert row to dictionary + node_data = {} + for i, field_name in enumerate(field_names): + if i < len(row): + value = row[i] + # Handle special fields + if field_name in ["tags", "sources", "usage"] and isinstance( + value, str + ): + try: + # Try parsing JSON string + node_data[field_name] = json.loads(value) + except (json.JSONDecodeError, TypeError): + node_data[field_name] = value + else: + node_data[field_name] = value + + # Parse node using _parse_node_new + try: + node = self._parse_node_new(node_data) + node_id = node["id"] + + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + print(f"✅ Parsed node successfully: {node_id}") + except Exception as e: + print(f"❌ Failed to parse node: {e}") + + except Exception as e: + logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) + + return candidates + + def drop_database(self) -> None: + """Permanently delete the entire graph this instance is using.""" + return + if self._get_config_value("use_multi_db", True): + with self.connection.cursor() as cursor: + cursor.execute(f"SELECT drop_graph('{self.db_name}_graph', true)") + print(f"Graph '{self.db_name}_graph' has been dropped.") + else: + raise ValueError( + f"Refusing to drop graph '{self.db_name}_graph' in " + f"Shared Database Multi-Tenant mode" + ) + + def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: + """Parse node data from database format to standard format.""" + node = node_data.copy() + + # Convert datetime to string + for time_field in ("created_at", "updated_at"): + if time_field in node and hasattr(node[time_field], "isoformat"): + node[time_field] = node[time_field].isoformat() + + return {"id": node.get("id"), "memory": node.get("memory", ""), "metadata": node} + + def _parse_node_new(self, node_data: dict[str, Any]) -> dict[str, Any]: + """Parse node data from database format to standard format.""" + node = node_data.copy() + + # Normalize string values that may arrive as quoted literals (e.g., '"abc"') + def _strip_wrapping_quotes(value: Any) -> Any: + """ + if isinstance(value, str) and len(value) >= 2: + if value[0] == value[-1] and value[0] in ("'", '"'): + return value[1:-1] + return value + """ + if ( + isinstance(value, str) + and len(value) >= 2 + and value[0] == value[-1] + and value[0] in ("'", '"') + ): + return value[1:-1] + return value + + for k, v in list(node.items()): + if isinstance(v, str): + node[k] = _strip_wrapping_quotes(v) + + # Convert datetime to string + for time_field in ("created_at", "updated_at"): + if time_field in node and hasattr(node[time_field], "isoformat"): + node[time_field] = node[time_field].isoformat() + + # Do not remove user_name; keep all fields + + return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node} + + def __del__(self): + """Close database connection when object is destroyed.""" + if hasattr(self, "connection") and self.connection: + self.connection.close() + + @timed + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + """Add a memory node to the graph.""" + # user_name comes from metadata; fallback to config if missing + metadata["user_name"] = user_name if user_name else self.config.user_name + + # Safely process metadata + metadata = _prepare_node_metadata(metadata) + + # Merge node and set metadata + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Prepare properties + properties = { + "id": id, + "memory": memory, + "created_at": created_at, + "updated_at": updated_at, + **metadata, + } + + # Generate embedding if not provided + if "embedding" not in properties or not properties["embedding"]: + properties["embedding"] = generate_vector( + self._get_config_value("embedding_dimension", 1024) + ) + + # serialization - JSON-serialize sources and usage fields + for field_name in ["sources", "usage"]: + if properties.get(field_name): + if isinstance(properties[field_name], list): + for idx in range(len(properties[field_name])): + # Serialize only when element is not a string + if not isinstance(properties[field_name][idx], str): + properties[field_name][idx] = json.dumps(properties[field_name][idx]) + elif isinstance(properties[field_name], str): + # If already a string, leave as-is + pass + + # Extract embedding for separate column + embedding_vector = properties.pop("embedding", []) + if not isinstance(embedding_vector, list): + embedding_vector = [] + + # Select column name based on embedding dimension + embedding_column = "embedding" # default column + if len(embedding_vector) == 3072: + embedding_column = "embedding_3072" + elif len(embedding_vector) == 1024: + embedding_column = "embedding" + elif len(embedding_vector) == 768: + embedding_column = "embedding_768" + + with self.connection.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(delete_query, (id,)) + # + get_graph_id_query = f""" + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(get_graph_id_query, (id,)) + graph_id = cursor.fetchone()[0] + properties["graph_id"] = str(graph_id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) + ) + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") + + def _build_node_from_agtype(self, node_agtype, embedding=None): + """ + Parse the cypher-returned column `n` (agtype or JSON string) + into a standard node and merge embedding into properties. + """ + try: + # String case: '{"id":...,"label":[...],"properties":{...}}::vertex' + if isinstance(node_agtype, str): + json_str = node_agtype.replace("::vertex", "") + obj = json.loads(json_str) + if not (isinstance(obj, dict) and "properties" in obj): + return None + props = obj["properties"] + # agtype case: has `value` attribute + elif node_agtype and hasattr(node_agtype, "value"): + val = node_agtype.value + if not (isinstance(val, dict) and "properties" in val): + return None + props = val["properties"] + else: + return None + + if embedding is not None: + props["embedding"] = embedding + + # Return standard format directly + return {"id": props.get("id", ""), "memory": props.get("memory", ""), "metadata": props} + except Exception: + return None + + @timed + def get_neighbors_by_tag( + self, + tags: list[str], + exclude_ids: list[str], + top_k: int = 5, + min_overlap: int = 1, + include_embedding: bool = False, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Find top-K neighbor nodes with maximum tag overlap. + + Args: + tags: The list of tags to match. + exclude_ids: Node IDs to exclude (e.g., local cluster). + top_k: Max number of neighbors to return. + min_overlap: Minimum number of overlapping tags required. + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + List of dicts with node details and overlap count. + """ + if not tags: + return [] + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build query conditions - more relaxed filters + where_clauses = [] + params = [] + + # Exclude specified IDs - use id in properties + if exclude_ids: + exclude_conditions = [] + for exclude_id in exclude_ids: + exclude_conditions.append( + "ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) != %s::agtype" + ) + params.append(f'"{exclude_id}"') + where_clauses.append(f"({' AND '.join(exclude_conditions)})") + + # Status filter - keep only 'activated' + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Type filter - exclude 'reasoning' type + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"node_type\"'::agtype) != '\"reasoning\"'::agtype" + ) + + # User filter + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + ) + params.append(f'"{user_name}"') + + # Testing showed no data; annotate. + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) != '\"WorkingMemory\"'::agtype" + ) + + where_clause = " AND ".join(where_clauses) + + # Fetch all candidate nodes + query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + print(f"[get_neighbors_by_tag] query: {query}, params: {params}") + + try: + with self.connection.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + nodes_with_overlap = [] + for row in results: + node_id, properties_json, embedding_json = row + properties = properties_json if properties_json else {} + + # Parse embedding + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + + # Compute tag overlap + node_tags = properties.get("tags", []) + if isinstance(node_tags, str): + try: + node_tags = json.loads(node_tags) + except (json.JSONDecodeError, TypeError): + node_tags = [] + + overlap_tags = [tag for tag in tags if tag in node_tags] + overlap_count = len(overlap_tags) + + if overlap_count >= min_overlap: + node_data = self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + nodes_with_overlap.append((node_data, overlap_count)) + + # Sort by overlap count and return top_k items + nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) + return [node for node, _ in nodes_with_overlap[:top_k]] + + except Exception as e: + logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) + return [] + + def get_neighbors_by_tag_ccl( + self, + tags: list[str], + exclude_ids: list[str], + top_k: int = 5, + min_overlap: int = 1, + include_embedding: bool = False, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Find top-K neighbor nodes with maximum tag overlap. + + Args: + tags: The list of tags to match. + exclude_ids: Node IDs to exclude (e.g., local cluster). + top_k: Max number of neighbors to return. + min_overlap: Minimum number of overlapping tags required. + include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + List of dicts with node details and overlap count. + """ + if not tags: + return [] + + user_name = user_name if user_name else self._get_config_value("user_name") + + # Build query conditions; keep consistent with nebular.py + where_clauses = [ + 'n.status = "activated"', + 'NOT (n.node_type = "reasoning")', + 'NOT (n.memory_type = "WorkingMemory")', + ] + where_clauses = [ + 'n.status = "activated"', + 'NOT (n.memory_type = "WorkingMemory")', + ] + + if exclude_ids: + exclude_ids_str = "[" + ", ".join(f'"{id}"' for id in exclude_ids) + "]" + where_clauses.append(f"NOT (n.id IN {exclude_ids_str})") + + where_clauses.append(f'n.user_name = "{user_name}"') + + where_clause = " AND ".join(where_clauses) + tag_list_literal = "[" + ", ".join(f'"{t}"' for t in tags) + "]" + + return_fields = [ + "n.id AS id", + "n.memory AS memory", + "n.user_name AS user_name", + "n.user_id AS user_id", + "n.session_id AS session_id", + "n.status AS status", + "n.key AS key", + "n.confidence AS confidence", + "n.tags AS tags", + "n.created_at AS created_at", + "n.updated_at AS updated_at", + "n.memory_type AS memory_type", + "n.sources AS sources", + "n.source AS source", + "n.node_type AS node_type", + "n.visibility AS visibility", + "n.background AS background", + ] + + if include_embedding: + return_fields.append("n.embedding AS embedding") + + return_fields_str = ", ".join(return_fields) + result_fields = [] + for field in return_fields: + # Extract field name 'id' from 'n.id AS id' + field_name = field.split(" AS ")[-1] + result_fields.append(f"{field_name} agtype") + + # Add overlap_count + result_fields.append("overlap_count agtype") + result_fields_str = ", ".join(result_fields) + # Use Cypher query; keep consistent with nebular.py + query = f""" + SELECT * FROM ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + WITH {tag_list_literal} AS tag_list + MATCH (n:Memory) + WHERE {where_clause} + RETURN {return_fields_str}, + size([tag IN n.tags WHERE tag IN tag_list]) AS overlap_count + $$) AS ({result_fields_str}) + ) AS subquery + ORDER BY (overlap_count::integer) DESC + LIMIT {top_k} + """ + print("get_neighbors_by_tag:", query) + try: + with self.connection.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + neighbors = [] + for row in results: + # Parse results + props = {} + overlap_count = None + + # Manually parse each field + field_names = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "background", + ] + + if include_embedding: + field_names.append("embedding") + field_names.append("overlap_count") + + for i, field in enumerate(field_names): + if field == "overlap_count": + overlap_count = row[i].value if hasattr(row[i], "value") else row[i] + else: + props[field] = row[i].value if hasattr(row[i], "value") else row[i] + overlap_int = int(overlap_count) + if overlap_count is not None and overlap_int >= min_overlap: + parsed = self._parse_node(props) + parsed["overlap_count"] = overlap_int + neighbors.append(parsed) + + # Sort by overlap count + neighbors.sort(key=lambda x: x["overlap_count"], reverse=True) + neighbors = neighbors[:top_k] + + # Remove overlap_count field + result = [] + for neighbor in neighbors: + neighbor.pop("overlap_count", None) + result.append(neighbor) + + return result + + except Exception as e: + logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) + return [] + + @timed + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: + """ + Import the entire graph from a serialized dictionary. + + Args: + data: A dictionary containing all nodes and edges to be loaded. + user_name (str, optional): User name for filtering in non-multi-db mode + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + # Import nodes + for node in data.get("nodes", []): + try: + id, memory, metadata = _compose_node(node) + metadata["user_name"] = user_name + metadata = _prepare_node_metadata(metadata) + metadata.update({"id": id, "memory": memory}) + + # Use add_node to insert node + self.add_node(id, memory, metadata) + + except Exception as e: + logger.error(f"Fail to load node: {node}, error: {e}") + + # Import edges + for edge in data.get("edges", []): + try: + source_id, target_id = edge["source"], edge["target"] + edge_type = edge["type"] + + # Use add_edge to insert edge + self.add_edge(source_id, target_id, edge_type, user_name) + + except Exception as e: + logger.error(f"Fail to load edge: {edge}, error: {e}") + + @timed + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None + ) -> list[dict[str, str]]: + """ + Get edges connected to a node, with optional type and direction filter. + + Args: + id: Node ID to retrieve edges for. + type: Relationship type to match, or 'ANY' to match all. + direction: 'OUTGOING', 'INCOMING', or 'ANY'. + user_name (str, optional): User name for filtering in non-multi-db mode + + Returns: + List of edges: + [ + {"from": "source_id", "to": "target_id", "type": "RELATE"}, + ... + ] + """ + user_name = user_name if user_name else self._get_config_value("user_name") + + if direction == "OUTGOING": + pattern = f"(a:Memory)-[r]->(b:Memory)" + where_clause = f"a.id = '{id}'" + elif direction == "INCOMING": + pattern = f"(a:Memory)<-[r]-(b:Memory)" + where_clause = f"a.id = '{id}'" + elif direction == "ANY": + pattern = f"(a:Memory)-[r]-(b:Memory)" + where_clause = f"a.id = '{id}' OR b.id = '{id}'" + else: + raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") + + # Add type filter + if type != "ANY": + where_clause += f" AND type(r) = '{type}'" + + # Add user filter + where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH {pattern} + WHERE {where_clause} + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + $$) AS (from_id agtype, to_id agtype, edge_type agtype) + """ + + try: + with self.connection.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() + + edges = [] + for row in results: + from_id = row[0].value if hasattr(row[0], "value") else row[0] + to_id = row[1].value if hasattr(row[1], "value") else row[1] + edge_type = row[2].value if hasattr(row[2], "value") else row[2] + + edges.append({"from": from_id, "to": to_id, "type": edge_type}) + return edges + + except Exception as e: + logger.error(f"Failed to get edges: {e}", exc_info=True) + return [] diff --git a/src/memos/mem_cube/utils.py b/src/memos/mem_cube/utils.py index a413ccce5..24836c509 100644 --- a/src/memos/mem_cube/utils.py +++ b/src/memos/mem_cube/utils.py @@ -68,44 +68,80 @@ def merge_config_with_default( if "graph_db" in existing_text_config and "graph_db" in default_text_config: existing_graph_config = existing_text_config["graph_db"]["config"] default_graph_config = default_text_config["graph_db"]["config"] - - # Define graph_db fields to preserve (user-specific) - preserve_graph_fields = { - "auto_create", - "user_name", - "use_multi_db", - } - - # Create merged graph_db config - merged_graph_config = copy.deepcopy(existing_graph_config) - for key, value in default_graph_config.items(): - if key not in preserve_graph_fields: - merged_graph_config[key] = value - logger.debug( - f"Updated graph_db field '{key}': {existing_graph_config.get(key)} -> {value}" + existing_backend = existing_text_config["graph_db"]["backend"] + default_backend = default_text_config["graph_db"]["backend"] + + # Detect backend change + backend_changed = existing_backend != default_backend + + if backend_changed: + logger.info( + f"Detected graph_db backend change: {existing_backend} -> {default_backend}. " + f"Migrating configuration..." + ) + # Start with default config as base when backend changes + merged_graph_config = copy.deepcopy(default_graph_config) + + # Preserve user-specific fields if they exist in both configs + preserve_graph_fields = { + "auto_create", + "user_name", + "use_multi_db", + } + for field in preserve_graph_fields: + if field in existing_graph_config: + merged_graph_config[field] = existing_graph_config[field] + logger.debug( + f"Preserved graph_db field '{field}': {existing_graph_config[field]}" + ) + + # Clean up backend-specific fields that don't exist in the new backend + # This approach is generic: remove any field from merged config that's not in default config + # and not in the preserve list + fields_to_remove = [] + for field in list(merged_graph_config.keys()): + if field not in default_graph_config and field not in preserve_graph_fields: + fields_to_remove.append(field) + + for field in fields_to_remove: + removed_value = merged_graph_config.pop(field) + logger.info( + f"Removed {existing_backend}-specific field '{field}' (value: {removed_value}) " + f"during migration to {default_backend}" ) - if not default_graph_config.get("use_multi_db", True): - # set original use_multi_db to False if default_graph_config.use_multi_db is False - if merged_graph_config.get("use_multi_db", True): - merged_graph_config["use_multi_db"] = False - merged_graph_config["user_name"] = merged_graph_config.get("db_name") - merged_graph_config["db_name"] = default_graph_config.get("db_name") - else: - logger.info("use_multi_db is already False, no need to change") - if "neo4j" not in default_text_config["graph_db"]["backend"]: - if "db_name" in merged_graph_config: - merged_graph_config.pop("db_name") - logger.info("neo4j is not supported, remove db_name") - else: - logger.info("db_name is not in merged_graph_config, no need to remove") else: - if "space" in merged_graph_config: - merged_graph_config.pop("space") - logger.info("neo4j is not supported, remove db_name") - else: - logger.info("space is not in merged_graph_config, no need to remove") + # Same backend: merge configs while preserving user-specific fields + logger.debug(f"Same graph_db backend ({default_backend}), merging configurations") + preserve_graph_fields = { + "auto_create", + "user_name", + "use_multi_db", + } + + # Start with existing config as base + merged_graph_config = copy.deepcopy(existing_graph_config) + + # Update with default config except preserved fields + for key, value in default_graph_config.items(): + if key not in preserve_graph_fields: + merged_graph_config[key] = value + logger.debug( + f"Updated graph_db field '{key}': {existing_graph_config.get(key)} -> {value}" + ) + + # Handle use_multi_db transition + if not default_graph_config.get("use_multi_db", True) and merged_graph_config.get( + "use_multi_db", True + ): + merged_graph_config["use_multi_db"] = False + # For Neo4j: db_name becomes user_name in single-db mode + if "neo4j" in default_backend and "db_name" in merged_graph_config: + merged_graph_config["user_name"] = merged_graph_config.get("db_name") + merged_graph_config["db_name"] = default_graph_config.get("db_name") + logger.info("Transitioned to single-db mode (use_multi_db=False)") + preserved_graph_db = { - "backend": default_text_config["graph_db"]["backend"], + "backend": default_backend, "config": merged_graph_config, } diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 6d975cfd7..f6254efbb 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -224,6 +224,17 @@ def _coerce_metadata(cls, v: Any): ): return v if isinstance(v, dict): + if "metadata" in v and isinstance(v["metadata"], dict): + nested_metadata = v["metadata"] + nested_metadata = nested_metadata.copy() + nested_metadata.pop("id", None) + nested_metadata.pop("memory", None) + v = nested_metadata + else: + v = v.copy() + v.pop("id", None) + v.pop("memory", None) + if v.get("relativity") is not None: return SearchedTreeNodeTextualMemoryMetadata(**v) if v.get("preference_type") is not None: diff --git a/src/memos/reranker/cosine_local.py b/src/memos/reranker/cosine_local.py index fc1dada2b..38ace458f 100644 --- a/src/memos/reranker/cosine_local.py +++ b/src/memos/reranker/cosine_local.py @@ -6,6 +6,7 @@ from memos.log import get_logger from .base import BaseReranker +from memos.utils import timed if TYPE_CHECKING: @@ -58,6 +59,7 @@ def __init__( self.level_weights = level_weights or {"topic": 1.0, "concept": 1.0, "fact": 1.0} self.level_field = level_field + @timed def rerank( self, query: str, diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index 9cce12786..2c423e6b6 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -12,6 +12,7 @@ from .base import BaseReranker from .concat import concat_original_source +from memos.utils import timed logger = get_logger(__name__) @@ -118,6 +119,7 @@ def __init__( self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys) self._warned_missing_keys: set[str] = set() + @timed def rerank( self, query: str, diff --git a/src/memos/reranker/noop.py b/src/memos/reranker/noop.py index 7a9c02f60..4f6ba0438 100644 --- a/src/memos/reranker/noop.py +++ b/src/memos/reranker/noop.py @@ -3,13 +3,14 @@ from typing import TYPE_CHECKING from .base import BaseReranker - +from memos.utils import timed if TYPE_CHECKING: from memos.memories.textual.item import TextualMemoryItem class NoopReranker(BaseReranker): + @timed def rerank( self, query: str, graph_results: list, top_k: int, **kwargs ) -> list[tuple[TextualMemoryItem, float]]: From 83a7c34535c6da1529692ae21861d6daecbe6ee5 Mon Sep 17 00:00:00 2001 From: lijicode <34564964+lijicode@users.noreply.github.com> Date: Tue, 28 Oct 2025 10:52:42 +0800 Subject: [PATCH 10/64] feat: fix mode (#400) --- src/memos/api/product_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index e491e9feb..dd2fde22b 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -171,7 +171,7 @@ class APISearchRequest(BaseRequest): query: str = Field(..., description="Search query") user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") - mode: SearchMode = Field(SearchMode.FINE, description="search mode: fast, fine, or mixture") + mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") top_k: int = Field(10, description="Number of results to return") From 018d759cfa72a2b7f167f4fd1eced3114cddb3df Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Tue, 28 Oct 2025 11:47:12 +0800 Subject: [PATCH 11/64] Feat: remove long waring for internet and add content for memreader (#401) * feat: add content and remove log waring * fix: wrong format * fix: fix code format --- src/memos/mem_reader/simple_struct.py | 11 +++++++++-- src/memos/memories/textual/simple_tree.py | 3 --- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 9f5eb9832..0f74adead 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -224,7 +224,6 @@ def _iter_chat_windows(self, scene_data_info, max_tokens=None, overlap=200): max_tokens = max_tokens or self.chat_window_max_tokens buf, sources, start_idx = [], [], 0 cur_text = "" - for idx, item in enumerate(scene_data_info): role = item.get("role", "") content = item.get("content", "") @@ -247,7 +246,15 @@ def _iter_chat_windows(self, scene_data_info, max_tokens=None, overlap=200): cur_text = "".join(buf) buf.append(line) - sources.append({"type": "chat", "index": idx, "role": role, "chat_time": chat_time}) + sources.append( + { + "type": "chat", + "index": idx, + "role": role, + "chat_time": chat_time, + "content": content, + } + ) cur_text = "".join(buf) if buf: diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 9c67db288..52bf62c6d 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -151,9 +151,6 @@ def search( list[TextualMemoryItem]: List of matching memories. """ if (self.internet_retriever is not None) and manual_close_internet: - logger.warning( - "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" - ) searcher = Searcher( self.dispatcher_llm, self.graph_store, From 3680286d1d0320244a68366f3e6e389de39097f6 Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Tue, 28 Oct 2025 16:26:54 +0800 Subject: [PATCH 12/64] feat: redis for sync history memories and new api of mixture search (#398) * debug an error function name * feat: Add DynamicCache compatibility for different transformers versions - Fix build_kv_cache method in hf.py to handle both old and new DynamicCache structures - Support new 'layers' attribute with key_cache/value_cache or keys/values - Maintain backward compatibility with direct key_cache/value_cache attributes - Add comprehensive error handling and logging for unsupported structures - Update move_dynamic_cache_htod function in kv.py for cross-version compatibility - Handle layers-based structure in newer transformers versions - Support alternative attribute names (keys/values vs key_cache/value_cache) - Preserve original functionality for older transformers versions - Add comprehensive tests for DynamicCache compatibility - Test activation memory update with mock DynamicCache layers - Verify layers attribute access across different transformers versions - Fix scheduler logger mock to include memory_manager attribute This resolves AttributeError issues when using different versions of the transformers library and ensures robust handling of DynamicCache objects. debug * feat: implement APIAnalyzerForScheduler for memory operations - Add APIAnalyzerForScheduler class with search/add operations - Support requests and http.client with connection reuse - Include comprehensive error handling and dynamic configuration - Add English test suite with real-world conversation scenarios * feat: Add search_ws API endpoint and enhance API analyzer functionality - Add search_ws endpoint in server_router.py for scheduler-enabled search - Fix missing imports: time module, SearchRequest class, and get_mos_product_instance function - Implement search_ws method in api_analyzer.py with HTTP client support - Add _search_ws_with_requests and _search_ws_with_http_client private methods - Include search_ws usage example in demonstration code - Enhance scheduler and dispatcher capabilities for improved memory management - Expand test coverage to ensure functionality stability This update primarily strengthens the memory scheduling system's search capabilities, providing users with more flexible API interface options. * fix: resolve test failures and warnings in test suite - Fix Pydantic serialization warning in test_memos_chen_tang_hello_world * Add warnings filter to suppress UserWarning from Pydantic serialization - Fix KeyError: 'past_key_values' in test_build_kv_cache_and_generation * Update mock configuration to properly return forward_output with past_key_values * Add DynamicCache version compatibility handling in test mocks * Support both old and new transformers versions with layers/key_cache attributes * Improve assertion logic to check all model calls for required parameters - Update base_scheduler.py to use centralized DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE constant * Add import for DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE from general_schemas * Replace hardcoded value 100 with configurable constant (1000) All tests now pass successfully with proper version compatibility handling. * feat: add a test_robustness execution to test thread pool execution * feat: optimize scheduler configuration and API search functionality - Add DEFAULT_TOP_K and DEFAULT_CONTEXT_WINDOW_SIZE global constants in general_schemas.py - Update base_scheduler.py to use global default values instead of hardcoded numbers - Fix SchedulerConfigFactory initialization issue by using keyword argument expansion - Resolve UnboundLocalError variable conflict in search_memories_ws function - Fix indentation and parameter issues in OptimizedScheduler search_for_api method - Improve code standardization and maintainability * feat: Add Redis auto-initialization with fallback strategies - Add auto_initialize_redis() with config/env/local fallback - Move Redis logic from dispatcher_monitor to redis_service - Update base_scheduler to use auto initialization - Add proper resource cleanup and error handling * feat: add database connection management to ORM module - Add MySQL engine loading from environment variables in BaseDBManager - Add Redis connection loading from environment variables in BaseDBManager - Enhance database configuration validation and error handling - Complete database adapter infrastructure for ORM module - Provide unified database connection management interface This update provides comprehensive database connection management capabilities for the mem_scheduler module, supporting dynamic MySQL and Redis configuration loading from environment variables, establishing reliable data persistence foundation for scheduling services and API services. * remove part of test * feat: add Redis-based ORM with multiprocess synchronization - Add RedisDBManager and RedisLockableORM classes - Implement atomic locking mechanism for concurrent access - Add merge functionality for different object types - Include comprehensive test suite and examples - Fix Redis key type conflicts in lock operations * fix: resolve scheduler module import and Redis integration issues * revise naive memcube creation in server router * remove long-time tests in test_scheduler * remove redis test which needs .env * refactor all codes about mixture search with scheduler * fix: resolve Redis API synchronization issues and implement search API with reranker - Fix running_entries to running_task_ids migration across codebase - Update sync_search_data method to properly handle TaskRunningStatus - Correct variable naming and logic in API synchronization flow - Implement search API endpoint with reranker functionality - Update test files to reflect new running_task_ids convention - Ensure proper Redis state management for concurrent tasks * remove a test for api module * revise to pass the test suite * address some bugs to make mix_search normally running * modify codes according to evaluation logs --------- Co-authored-by: CaralHsi --- evaluation/scripts/utils/client.py | 4 +- src/memos/api/config.py | 4 +- src/memos/api/routers/server_router.py | 138 +---- src/memos/configs/mem_scheduler.py | 5 + .../mem_scheduler/analyzer/api_analyzer.py | 302 ++++++---- src/memos/mem_scheduler/base_scheduler.py | 32 +- .../mem_scheduler/general_modules/api_misc.py | 184 ++++--- .../general_modules/dispatcher.py | 30 +- .../mem_scheduler/general_modules/misc.py | 2 +- .../general_modules/task_threads.py | 100 ++-- .../mem_scheduler/optimized_scheduler.py | 187 +++++-- .../orm_modules/api_redis_model.py | 517 ++++++++++++++++++ .../mem_scheduler/orm_modules/base_model.py | 117 ---- .../mem_scheduler/schemas/api_schemas.py | 232 ++++++++ .../mem_scheduler/schemas/general_schemas.py | 1 + .../mem_scheduler/schemas/message_schemas.py | 15 +- src/memos/mem_scheduler/utils/api_utils.py | 76 +++ .../webservice_modules/redis_service.py | 2 +- src/memos/memories/activation/item.py | 4 +- tests/mem_scheduler/test_orm.py | 447 --------------- tests/mem_scheduler/test_scheduler.py | 55 +- 21 files changed, 1407 insertions(+), 1047 deletions(-) create mode 100644 src/memos/mem_scheduler/orm_modules/api_redis_model.py create mode 100644 src/memos/mem_scheduler/schemas/api_schemas.py create mode 100644 src/memos/mem_scheduler/utils/api_utils.py delete mode 100644 tests/mem_scheduler/test_orm.py diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index ffc9dda12..91d695acc 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -183,6 +183,7 @@ def search(self, query, user_id, top_k): "mem_cube_id": user_id, "conversation_id": "", "top_k": top_k, + "mode": "mixture", }, ensure_ascii=False, ) @@ -230,6 +231,7 @@ def search(self, query, user_id, top_k): "query": query, "user_id": user_id, "memory_limit_number": top_k, + "mode": "mixture", } ) @@ -311,7 +313,7 @@ def add(self, messages, user_id, iso_date): agent_name=self.agent_id, session_date=iso_date, ) - self.wait_for_completion(response.task_id) + self.wait_for_completion(response.item_id) except Exception as error: print("❌ Error saving conversation:", error) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d1bc6efff..6de013313 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -361,8 +361,8 @@ def get_scheduler_config() -> dict[str, Any]: "thread_pool_max_workers": int( os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10") ), - "consume_interval_seconds": int( - os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "3") + "consume_interval_seconds": float( + os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01") ), "enable_parallel_dispatch": os.getenv( "MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH", "true" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 1331094a8..f50d3ad75 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,9 +1,8 @@ -import json import os import traceback from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import TYPE_CHECKING, Any from fastapi import APIRouter, HTTPException @@ -33,11 +32,8 @@ from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.general_schemas import ( - API_MIX_SEARCH_LABEL, SearchMode, ) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.memories.textual.prefer_text_memory.config import ( AdderConfigFactory, ExtractorConfigFactory, @@ -54,6 +50,10 @@ ) from memos.reranker.factory import RerankerFactory from memos.templates.instruction_completion import instruct_completion + + +if TYPE_CHECKING: + from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler from memos.types import MOSSearchResult, UserContext from memos.vec_dbs.factory import VecDBFactory @@ -154,7 +154,6 @@ def init_server(): # Build component configurations graph_db_config = _build_graph_db_config() - print(graph_db_config) llm_config = _build_llm_config() embedder_config = _build_embedder_config() mem_reader_config = _build_mem_reader_config() @@ -209,22 +208,6 @@ def init_server(): online_bot=False, ) - # Initialize Scheduler - scheduler_config_dict = APIConfig.get_scheduler_config() - scheduler_config = SchedulerConfigFactory( - backend="optimized_scheduler", config=scheduler_config_dict - ) - mem_scheduler = SchedulerFactory.from_config(scheduler_config) - mem_scheduler.initialize_modules( - chat_llm=llm, - process_llm=mem_reader.llm, - db_engine=BaseDBManager.create_default_sqlite_engine(), - ) - mem_scheduler.start() - - # Initialize SchedulerAPIModule - api_module = mem_scheduler.api_module - naive_mem_cube = NaiveMemCube( llm=llm, embedder=embedder, @@ -240,6 +223,23 @@ def init_server(): pref_retriever=pref_retriever, ) + # Initialize Scheduler + scheduler_config_dict = APIConfig.get_scheduler_config() + scheduler_config = SchedulerConfigFactory( + backend="optimized_scheduler", config=scheduler_config_dict + ) + mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) + mem_scheduler.initialize_modules( + chat_llm=llm, + process_llm=mem_reader.llm, + db_engine=BaseDBManager.create_default_sqlite_engine(), + ) + mem_scheduler.current_mem_cube = naive_mem_cube + mem_scheduler.start() + + # Initialize SchedulerAPIModule + api_module = mem_scheduler.api_module + return ( graph_db, mem_reader, @@ -400,96 +400,12 @@ def mix_search_memories( """ Mix search memories: fast search + async fine search """ - # Get fast memories first - fast_memories = fast_search_memories(search_req, user_context) - - # Check if scheduler and dispatcher are available for async execution - if mem_scheduler and hasattr(mem_scheduler, "dispatcher") and mem_scheduler.dispatcher: - try: - # Create message for async fine search - message_content = { - "search_req": { - "query": search_req.query, - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "top_k": search_req.top_k, - "internet_search": search_req.internet_search, - "moscube": search_req.moscube, - "chat_history": search_req.chat_history, - }, - "user_context": {"mem_cube_id": user_context.mem_cube_id}, - } - - message = ScheduleMessageItem( - item_id=f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}", - user_id=search_req.user_id, - mem_cube_id=user_context.mem_cube_id, - label=API_MIX_SEARCH_LABEL, - mem_cube=naive_mem_cube, - content=json.dumps(message_content), - timestamp=get_utc_now(), - ) - - # Submit async task - mem_scheduler.dispatcher.submit_message(message) - logger.info(f"Submitted async fine search task for user {search_req.user_id}") - - # Try to get pre-computed fine memories if available - try: - pre_fine_memories = api_module.get_pre_fine_memories( - user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id - ) - if pre_fine_memories: - # Merge fast and pre-computed fine memories - all_memories = fast_memories + pre_fine_memories - # Remove duplicates based on content - seen_contents = set() - unique_memories = [] - for memory in all_memories: - content_key = memory.get("content", "") - if content_key not in seen_contents: - seen_contents.add(content_key) - unique_memories.append(memory) - return unique_memories - except Exception as e: - logger.warning(f"Failed to get pre-computed fine memories: {e}") - - except Exception as e: - logger.error(f"Failed to submit async fine search task: {e}") - # Fall back to synchronous execution - - # Fallback: synchronous fine search - try: - fine_memories = fine_search_memories(search_req, user_context) - - # Merge fast and fine memories - all_memories = fast_memories + fine_memories - - # Remove duplicates based on content - seen_contents = set() - unique_memories = [] - for memory in all_memories: - content_key = memory.get("content", "") - if content_key not in seen_contents: - seen_contents.add(content_key) - unique_memories.append(memory) - - # Sync search data to Redis - try: - api_module.sync_search_data( - user_id=search_req.user_id, - mem_cube_id=user_context.mem_cube_id, - query=search_req.query, - formatted_memories=unique_memories, - ) - except Exception as e: - logger.error(f"Failed to sync search data: {e}") - - return unique_memories - except Exception as e: - logger.error(f"Fine search failed: {e}") - return fast_memories + formatted_memories = mem_scheduler.mix_search_memories( + search_req=search_req, + user_context=user_context, + ) + return formatted_memories def fine_search_memories( diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index bc22cfb63..e757f243b 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -15,6 +15,7 @@ DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_CONTEXT_WINDOW_SIZE, DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, + DEFAULT_MULTI_TASK_RUNNING_TIMEOUT, DEFAULT_THREAD_POOL_MAX_WORKERS, DEFAULT_TOP_K, DEFAULT_USE_REDIS_QUEUE, @@ -59,6 +60,10 @@ class BaseSchedulerConfig(BaseConfig): default=DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, description="Maximum size of internal message queue when not using Redis", ) + multi_task_running_timeout: int = Field( + default=DEFAULT_MULTI_TASK_RUNNING_TIMEOUT, + description="Default timeout for multi-task running operations in seconds", + ) class GeneralSchedulerConfig(BaseSchedulerConfig): diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py index 45a39e0de..28ca182e5 100644 --- a/src/memos/mem_scheduler/analyzer/api_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py @@ -7,6 +7,7 @@ import http.client import json +import time from typing import Any from urllib.parse import urlparse @@ -364,11 +365,204 @@ def __init__(self): self.UserContext = UserContext self.MessageDict = MessageDict + # Initialize conversation history for continuous conversation support + self.conversation_history = [] + self.current_session_id = None + self.current_user_id = None + self.current_mem_cube_id = None + logger.info("DirectSearchMemoriesAnalyzer initialized successfully") except ImportError as e: logger.error(f"Failed to import modules: {e}") raise + def start_conversation(self, user_id="test_user", mem_cube_id="test_cube", session_id=None): + """ + Start a new conversation session for continuous dialogue. + + Args: + user_id: User ID for the conversation + mem_cube_id: Memory cube ID for the conversation + session_id: Session ID for the conversation (auto-generated if None) + """ + self.current_user_id = user_id + self.current_mem_cube_id = mem_cube_id + self.current_session_id = ( + session_id or f"session_{hash(user_id + mem_cube_id)}_{len(self.conversation_history)}" + ) + self.conversation_history = [] + + logger.info(f"Started conversation session: {self.current_session_id}") + print(f"🚀 Started new conversation session: {self.current_session_id}") + print(f" User ID: {self.current_user_id}") + print(f" Mem Cube ID: {self.current_mem_cube_id}") + + def add_to_conversation(self, user_message, assistant_message=None): + """ + Add messages to the current conversation and store them in memory. + + Args: + user_message: User's message content + assistant_message: Assistant's response (optional) + + Returns: + Result from add_memories function + """ + if not self.current_session_id: + raise ValueError("No active conversation session. Call start_conversation() first.") + + # Prepare messages for adding to memory + messages = [{"role": "user", "content": user_message}] + if assistant_message: + messages.append({"role": "assistant", "content": assistant_message}) + + # Add to conversation history + self.conversation_history.extend(messages) + + # Create add request + add_req = self.create_test_add_request( + user_id=self.current_user_id, + mem_cube_id=self.current_mem_cube_id, + messages=messages, + session_id=self.current_session_id, + ) + + print(f"💬 Adding to conversation (Session: {self.current_session_id}):") + print(f" User: {user_message}") + if assistant_message: + print(f" Assistant: {assistant_message}") + + # Add to memory + result = self.add_memories(add_req) + print(" ✅ Added to memory successfully") + + return result + + def search_in_conversation(self, query, mode="fast", top_k=10, include_history=True): + """ + Search memories within the current conversation context. + + Args: + query: Search query + mode: Search mode ("fast", "fine", or "mixture") + top_k: Number of results to return + include_history: Whether to include conversation history in the search + + Returns: + Search results + """ + if not self.current_session_id: + raise ValueError("No active conversation session. Call start_conversation() first.") + + # Prepare chat history if requested + chat_history = self.conversation_history if include_history else None + + # Create search request + search_req = self.create_test_search_request( + query=query, + user_id=self.current_user_id, + mem_cube_id=self.current_mem_cube_id, + mode=mode, + top_k=top_k, + chat_history=chat_history, + session_id=self.current_session_id, + ) + + print(f"🔍 Searching in conversation (Session: {self.current_session_id}):") + print(f" Query: {query}") + print(f" Mode: {mode}") + print(f" Top K: {top_k}") + print(f" Include History: {include_history}") + print(f" History Length: {len(self.conversation_history) if chat_history else 0}") + + # Perform search + result = self.search_memories(search_req) + + print(" ✅ Search completed") + if hasattr(result, "data") and result.data: + total_memories = sum( + len(mem_list) for mem_list in result.data.values() if isinstance(mem_list, list) + ) + print(f" 📊 Found {total_memories} total memories") + + return result + + def test_continuous_conversation(self): + """Test continuous conversation functionality""" + print("=" * 80) + print("Testing Continuous Conversation Functionality") + print("=" * 80) + + try: + # Start a conversation + self.start_conversation(user_id="conv_test_user", mem_cube_id="conv_test_cube") + + # Prepare all conversation messages for batch addition + all_messages = [ + { + "role": "user", + "content": "I'm planning a trip to Shanghai for New Year's Eve. What are some good places to visit?", + }, + { + "role": "assistant", + "content": "Shanghai has many great places for New Year's Eve! You could visit the Bund for the countdown, go to a rooftop party, or enjoy fireworks at Disneyland Shanghai. The French Concession also has nice bars and restaurants.", + }, + {"role": "user", "content": "What about food? Any restaurant recommendations?"}, + { + "role": "assistant", + "content": "For New Year's Eve dining in Shanghai, I'd recommend trying some local specialties like xiaolongbao at Din Tai Fung, or for a fancy dinner, you could book at restaurants in the Bund area with great views.", + }, + {"role": "user", "content": "I'm on a budget though. Any cheaper alternatives?"}, + { + "role": "assistant", + "content": "For budget-friendly options, try street food in Yuyuan Garden area, local noodle shops, or food courts in shopping malls. You can also watch the fireworks from free public areas along the Huangpu River.", + }, + ] + + # Add all conversation messages at once + print("\n📝 Adding all conversation messages at once:") + add_req = self.create_test_add_request( + user_id=self.current_user_id, + mem_cube_id=self.current_mem_cube_id, + messages=all_messages, + session_id=self.current_session_id, + ) + + print( + f"💬 Adding {len(all_messages)} messages to conversation (Session: {self.current_session_id})" + ) + self.add_memories(add_req) + + # Update conversation history + self.conversation_history.extend(all_messages) + print(" ✅ Added all messages to memory successfully") + + # Test searching within the conversation + print("\n🔍 Testing search within conversation:") + + # Search for trip-related information + self.search_in_conversation( + query="New Year's Eve Shanghai recommendations", mode="mixture", top_k=5 + ) + + # Search for food-related information + self.search_in_conversation(query="budget food Shanghai", mode="mixture", top_k=3) + + # Search without conversation history + self.search_in_conversation( + query="Shanghai travel", mode="mixture", top_k=3, include_history=False + ) + + print("\n✅ Continuous conversation test completed successfully!") + return True + + except Exception as e: + print(f"❌ Continuous conversation test failed: {e}") + import traceback + + traceback.print_exc() + return False + def create_test_search_request( self, query="test query", @@ -451,115 +645,19 @@ def create_test_add_request( operation=None, ) - def test_add_memories_basic(self, user_id="test_user_add", mem_cube_id="test_cube_add"): - """Basic add_memories test""" - print("=" * 60) - print("Starting basic add_memories test") - print("=" * 60) - - try: - # Create test request with default messages - add_req = self.create_test_add_request(user_id=user_id, mem_cube_id=mem_cube_id) - - print("Test request created:") - print(f" User ID: {add_req.user_id}") - print(f" Mem Cube ID: {add_req.mem_cube_id}") - print(f" Messages: {add_req.messages}") - print(f" Session ID: {add_req.session_id}") - - # Call add_memories function - print("\nCalling add_memories function...") - result = self.add_memories(add_req) - - print(f"Add result: {result}") - print("Basic add_memories test completed successfully") - return result - - except Exception as e: - print(f"Basic add_memories test failed: {e}") - import traceback - - traceback.print_exc() - return None - - def test_search_memories_basic(self, query: str, mode: str, topk: int): - """Basic search_memories test""" - print("=" * 60) - print("Starting basic search_memories test") - print("=" * 60) - - try: - # Create test request - search_req = self.create_test_search_request( - query=query, - user_id="test_user_id", - mem_cube_id="test_mem_cube_id", - mode=mode, - top_k=topk, - ) - - print("Test request parameters:") - print(f" - query: {search_req.query}") - print(f" - user_id: {search_req.user_id}") - print(f" - mem_cube_id: {search_req.mem_cube_id}") - print(f" - mode: {search_req.mode}") - print(f" - top_k: {search_req.top_k}") - print(f" - internet_search: {search_req.internet_search}") - print(f" - moscube: {search_req.moscube}") - print() - - # Call search_memories function - print("Calling search_memories function...") - result = self.search_memories(search_req) - - print("✅ Function call successful!") - print(f"Return result type: {type(result)}") - print(f"Return result: {result}") - - # Analyze return result - if hasattr(result, "message"): - print(f"Message: {result.message}") - if hasattr(result, "data"): - print(f"Data type: {type(result.data)}") - if result.data and isinstance(result.data, dict): - for key, value in result.data.items(): - print(f" {key}: {len(value) if isinstance(value, list) else value}") - - return result - - except Exception as e: - print(f"❌ Test failed: {e}") - import traceback - - print("Detailed error information:") - traceback.print_exc() - return None - def run_all_tests(self): """Run all available tests""" print("🚀 Starting comprehensive test suite") print("=" * 80) - # Test add_memories functions (more likely to have dependency issues) - print("\n\n📝 Testing ADD_MEMORIES functions:") - try: - print("\n" + "-" * 40) - self.test_add_memories_basic() - print("✅ Basic add memories test completed") - except Exception as e: - print(f"❌ Basic add memories test failed: {e}") - - # Test search_memories functions first (less likely to fail) - print("\n🔍 Testing SEARCH_MEMORIES functions:") + # Test continuous conversation functionality + print("\n💬 Testing CONTINUOUS CONVERSATION functions:") try: - self.test_search_memories_basic( - query="What are some good places to celebrate New Year's Eve in Shanghai?", - mode="fast", - topk=3, - ) - print("✅ Search memories test completed successfully") + self.test_continuous_conversation() + time.sleep(5) + print("✅ Continuous conversation test completed successfully") except Exception as e: - print(f"❌ Search memories test failed: {e}") + print(f"❌ Continuous conversation test failed: {e}") print("\n" + "=" * 80) print("✅ All tests completed!") diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index e475ea225..3958ee382 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -502,7 +502,7 @@ def update_activation_memory_periodically( except Exception as e: logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) - async def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): """Submit messages to the message queue (either local queue or Redis).""" if isinstance(messages, ScheduleMessageItem): messages = [messages] # transform single message to list @@ -519,7 +519,7 @@ async def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMes if self.use_redis_queue: # Use Redis stream for message queue - await self.redis_add_message_stream(message.to_dict()) + self.redis_add_message_stream(message.to_dict()) logger.info(f"Submitted message to Redis: {message.label} - {message.content}") else: # Use local queue @@ -774,34 +774,6 @@ def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: return self.dispatcher.unregister_handlers(labels) def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]: - """ - Get currently running tasks, optionally filtered by a custom function. - - This method delegates to the dispatcher's get_running_tasks method. - - Args: - filter_func: Optional function to filter tasks. Should accept a RunningTaskItem - and return True if the task should be included in results. - - Returns: - dict[str, dict]: Dictionary mapping task IDs to task information dictionaries. - Each task dict contains: item_id, user_id, mem_cube_id, task_info, - task_name, start_time, end_time, status, result, error_message, messages - - Examples: - # Get all running tasks - all_tasks = scheduler.get_running_tasks() - - # Get tasks for specific user - user_tasks = scheduler.get_running_tasks( - filter_func=lambda task: task.user_id == "user123" - ) - - # Get tasks with specific status - active_tasks = scheduler.get_running_tasks( - filter_func=lambda task: task.status == "running" - ) - """ if not self.dispatcher: logger.warning("Dispatcher is not initialized, returning empty tasks dict") return {} diff --git a/src/memos/mem_scheduler/general_modules/api_misc.py b/src/memos/mem_scheduler/general_modules/api_misc.py index 6139a895a..bb993de38 100644 --- a/src/memos/mem_scheduler/general_modules/api_misc.py +++ b/src/memos/mem_scheduler/general_modules/api_misc.py @@ -1,115 +1,145 @@ -import threading - from typing import Any from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager +from memos.mem_scheduler.orm_modules.api_redis_model import APIRedisDBManager +from memos.mem_scheduler.schemas.api_schemas import ( + APIMemoryHistoryEntryItem, + APISearchHistoryManager, + TaskRunningStatus, +) +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.memories.textual.item import TextualMemoryItem logger = get_logger(__name__) class SchedulerAPIModule(BaseSchedulerModule): - def __init__(self): + def __init__(self, window_size=5): super().__init__() + self.window_size = window_size + self.search_history_managers: dict[str, APIRedisDBManager] = {} + self.pre_memory_turns = 5 - self.search_history_managers: dict[str, RedisDBManager] = {} - - def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> RedisDBManager: + def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> APIRedisDBManager: """Get or create a Redis manager for search history.""" key = f"search_history:{user_id}:{mem_cube_id}" if key not in self.search_history_managers: - self.search_history_managers[key] = RedisDBManager( - user_id=user_id, mem_cube_id=mem_cube_id + self.search_history_managers[key] = APIRedisDBManager( + user_id=user_id, + mem_cube_id=mem_cube_id, + obj=APISearchHistoryManager(window_size=self.window_size), ) return self.search_history_managers[key] def sync_search_data( - self, user_id: str, mem_cube_id: str, query: str, formatted_memories: Any - ) -> None: - """ - Sync search data to Redis, maintaining a list of size 5. + self, + item_id: str, + user_id: str, + mem_cube_id: str, + query: str, + memories: list[TextualMemoryItem], + formatted_memories: Any, + conversation_id: str | None = None, + ) -> Any: + # Get the search history manager + manager = self.get_search_history_manager(user_id, mem_cube_id) + manager.sync_with_redis(size_limit=self.window_size) + + search_history = manager.obj + + # Check if entry with item_id already exists + existing_entry, location = search_history.find_entry_by_item_id(item_id) + + if existing_entry is not None: + # Update existing entry + success = search_history.update_entry_by_item_id( + item_id=item_id, + query=query, + formatted_memories=formatted_memories, + task_status=TaskRunningStatus.COMPLETED, # Use the provided running_status + conversation_id=conversation_id, + memories=memories, + ) - Args: - user_id: User identifier - mem_cube_id: Memory cube identifier - query: Search query string - formatted_memories: Formatted search results - """ - try: - # Get the search history manager - manager = self.get_search_history_manager(user_id, mem_cube_id) - - # Create search data entry - search_entry = { - "query": query, - "formatted_memories": formatted_memories, - "timestamp": threading.current_thread().ident, # Use thread ID as simple timestamp - } - - # Load existing search history - existing_data = manager.load_from_db() - - if existing_data is None: - search_history = SimpleListManager([]) + if success: + logger.info(f"Updated existing entry with item_id: {item_id} in {location} list") else: - # If existing data is a SimpleListManager, use it; otherwise create new one - if isinstance(existing_data, SimpleListManager): - search_history = existing_data - else: - search_history = SimpleListManager([]) - - # Add new entry and keep only latest 5 - search_history.add_item(str(search_entry)) - if len(search_history) > 5: - # Keep only the latest 5 items - search_history.items = search_history.items[-5:] - - # Save back to Redis - manager.save_to_db(search_history) - - logger.info( - f"Synced search data for user {user_id}, mem_cube {mem_cube_id}. History size: {len(search_history)}" + logger.warning(f"Failed to update entry with item_id: {item_id}") + else: + # Add new entry based on running_status + search_entry = APIMemoryHistoryEntryItem( + item_id=item_id, + query=query, + formatted_memories=formatted_memories, + memories=memories, + task_status=TaskRunningStatus.COMPLETED, + conversation_id=conversation_id, + created_time=get_utc_now(), ) - except Exception as e: - logger.error(f"Failed to sync search data: {e}", exc_info=True) + # Add directly to completed list as APIMemoryHistoryEntryItem instance + search_history.completed_entries.append(search_entry) + + # Maintain window size + if len(search_history.completed_entries) > search_history.window_size: + search_history.completed_entries = search_history.completed_entries[ + -search_history.window_size : + ] + + # Remove from running task IDs + if item_id in search_history.running_item_ids: + search_history.running_item_ids.remove(item_id) + + logger.info(f"Created new entry with item_id: {item_id}") - def get_pre_fine_memories(self, user_id: str, mem_cube_id: str) -> list: + # Update manager's object with the modified search history + manager.obj = search_history + + # Use sync_with_redis to handle Redis synchronization with merging + manager.sync_with_redis(size_limit=self.window_size) + return manager + + def get_pre_memories(self, user_id: str, mem_cube_id: str) -> list: """ - Get the most recent pre-computed fine memories from search history. + Get pre-computed memories from the most recent completed search entry. Args: user_id: User identifier mem_cube_id: Memory cube identifier Returns: - List of formatted memories from the most recent search, or empty list if none found + List of TextualMemoryItem objects from the most recent completed search """ - try: - manager = self.get_search_history_manager(user_id, mem_cube_id) - search_history_key = "search_history_list" - existing_data = manager.load_from_db(search_history_key) + manager = self.get_search_history_manager(user_id, mem_cube_id) - if existing_data is None: - return [] + existing_data = manager.load_from_db() + if existing_data is None: + return [] - search_history = ( - existing_data.obj_instance - if hasattr(existing_data, "obj_instance") - else existing_data - ) + search_history: APISearchHistoryManager = existing_data - if not search_history or len(search_history) == 0: - return [] + # Get memories from the most recent completed entry + history_memories = search_history.get_history_memories(turns=self.pre_memory_turns) + return history_memories - # Return the formatted_memories from the most recent search - latest_entry = search_history[-1] - return ( - latest_entry.get("formatted_memories", []) if isinstance(latest_entry, dict) else [] - ) + def get_history_memories(self, user_id: str, mem_cube_id: str, n: int) -> list: + """Get history memories for backward compatibility with tests.""" + manager = self.get_search_history_manager(user_id, mem_cube_id) + existing_data = manager.load_from_db() - except Exception as e: - logger.error(f"Failed to get pre-computed fine memories: {e}", exc_info=True) + if existing_data is None: return [] + + # Handle different data formats + if isinstance(existing_data, APISearchHistoryManager): + search_history = existing_data + else: + # Try to convert to APISearchHistoryManager + try: + search_history = APISearchHistoryManager(**existing_data) + except Exception: + return [] + + return search_history.get_history_memories(turns=n) diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index c357e31b5..2e5779f19 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -36,6 +36,11 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): # Main dispatcher thread pool self.max_workers = max_workers + # Get multi-task timeout from config + self.multi_task_running_timeout = ( + self.config.get("multi_task_running_timeout") if self.config else None + ) + # Only initialize thread pool if in parallel mode self.enable_parallel_dispatch = enable_parallel_dispatch self.thread_name_prefix = "dispatcher" @@ -62,6 +67,8 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): # Task tracking for monitoring self._running_tasks: dict[str, RunningTaskItem] = {} self._task_lock = threading.Lock() + self._completed_tasks = [] + self.completed_tasks_max_show_size = 10 def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): """ @@ -85,7 +92,9 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): if task_item.item_id in self._running_tasks: task_item.mark_completed(result) del self._running_tasks[task_item.item_id] - + self._completed_tasks.append(task_item) + if len(self._completed_tasks) > self.completed_tasks_max_show_size: + self._completed_tasks[-self.completed_tasks_max_show_size :] logger.info(f"Task completed: {task_item.get_execution_info()}") return result @@ -95,7 +104,8 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): if task_item.item_id in self._running_tasks: task_item.mark_failed(str(e)) del self._running_tasks[task_item.item_id] - + if len(self._completed_tasks) > self.completed_tasks_max_show_size: + self._completed_tasks[-self.completed_tasks_max_show_size :] logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}") raise @@ -356,17 +366,17 @@ def run_competitive_tasks( def run_multiple_tasks( self, - tasks: dict[str, tuple[Callable, tuple, dict]], + tasks: dict[str, tuple[Callable, tuple]], use_thread_pool: bool | None = None, - timeout: float | None = 30.0, + timeout: float | None = None, ) -> dict[str, Any]: """ Execute multiple tasks concurrently and return all results. Args: - tasks: Dictionary mapping task names to (function, args, kwargs) tuples + tasks: Dictionary mapping task names to (task_execution_function, task_execution_parameters) tuples use_thread_pool: Whether to use ThreadPoolExecutor. If None, uses dispatcher's parallel mode setting - timeout: Maximum time to wait for all tasks to complete (in seconds). None for infinite timeout. + timeout: Maximum time to wait for all tasks to complete (in seconds). If None, uses config default. Returns: Dictionary mapping task names to their results @@ -378,7 +388,13 @@ def run_multiple_tasks( if use_thread_pool is None: use_thread_pool = self.enable_parallel_dispatch - logger.info(f"Executing {len(tasks)} tasks concurrently (thread_pool: {use_thread_pool})") + # Use config timeout if not explicitly provided + if timeout is None: + timeout = self.multi_task_running_timeout + + logger.info( + f"Executing {len(tasks)} tasks concurrently (thread_pool: {use_thread_pool}, timeout: {timeout})" + ) try: results = self.thread_manager.run_multiple_tasks( diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py index 6f05bf72f..b6f48d043 100644 --- a/src/memos/mem_scheduler/general_modules/misc.py +++ b/src/memos/mem_scheduler/general_modules/misc.py @@ -127,7 +127,7 @@ class DictConversionMixin: @field_serializer("timestamp", check_fields=False) def serialize_datetime(self, dt: datetime | None, _info) -> str | None: """ - Custom datetime serialization logic. + Custom timestamp serialization logic. - Supports timezone-aware datetime objects - Compatible with models without timestamp field (via check_fields=False) """ diff --git a/src/memos/mem_scheduler/general_modules/task_threads.py b/src/memos/mem_scheduler/general_modules/task_threads.py index 913d5fa1d..551e8b726 100644 --- a/src/memos/mem_scheduler/general_modules/task_threads.py +++ b/src/memos/mem_scheduler/general_modules/task_threads.py @@ -89,7 +89,7 @@ def worker( def run_multiple_tasks( self, - tasks: dict[str, tuple[Callable, tuple, dict]], + tasks: dict[str, tuple[Callable, tuple]], use_thread_pool: bool = False, timeout: float | None = None, ) -> dict[str, Any]: @@ -97,7 +97,7 @@ def run_multiple_tasks( Run multiple tasks concurrently and return all results. Args: - tasks: Dictionary mapping task names to (function, args, kwargs) tuples + tasks: Dictionary mapping task names to (task_execution_function, task_execution_parameters) tuples use_thread_pool: Whether to use ThreadPoolExecutor (True) or regular threads (False) timeout: Maximum time to wait for all tasks to complete (in seconds). None for infinite timeout. @@ -115,17 +115,21 @@ def run_multiple_tasks( start_time = time.time() if use_thread_pool: - return self.run_with_thread_pool(tasks, timeout) + # Convert tasks format for thread pool compatibility + thread_pool_tasks = {} + for task_name, (func, args) in tasks.items(): + thread_pool_tasks[task_name] = (func, args, {}) + return self.run_with_thread_pool(thread_pool_tasks, timeout) else: # Use regular threads threads = {} thread_results = {} exceptions = {} - def worker(task_name: str, func: Callable, args: tuple, kwargs: dict): + def worker(task_name: str, func: Callable, args: tuple): """Worker function for regular threads""" try: - result = func(*args, **kwargs) + result = func(*args) thread_results[task_name] = result logger.debug(f"Task '{task_name}' completed successfully") except Exception as e: @@ -133,9 +137,9 @@ def worker(task_name: str, func: Callable, args: tuple, kwargs: dict): logger.error(f"Task '{task_name}' failed with error: {e}") # Start all threads - for task_name, (func, args, kwargs) in tasks.items(): + for task_name, (func, args) in tasks.items(): thread = threading.Thread( - target=worker, args=(task_name, func, args, kwargs), name=f"task-{task_name}" + target=worker, args=(task_name, func, args), name=f"task-{task_name}" ) threads[task_name] = thread thread.start() @@ -197,44 +201,60 @@ def run_with_thread_pool( results = {} start_time = time.time() - # Use ThreadPoolExecutor for better resource management - with self.thread_pool_executor as executor: - # Submit all tasks - future_to_name = {} - for task_name, (func, args, kwargs) in tasks.items(): + # Check if executor is shutdown before using it + if self.thread_pool_executor._shutdown: + logger.error("ThreadPoolExecutor is already shutdown, cannot submit new tasks") + raise RuntimeError("ThreadPoolExecutor is already shutdown") + + # Use ThreadPoolExecutor directly without context manager + # The executor lifecycle is managed by the parent SchedulerDispatcher + executor = self.thread_pool_executor + + # Submit all tasks + future_to_name = {} + for task_name, (func, args, kwargs) in tasks.items(): + try: future = executor.submit(func, *args, **kwargs) future_to_name[future] = task_name logger.debug(f"Submitted task '{task_name}' to thread pool") + except RuntimeError as e: + if "cannot schedule new futures after shutdown" in str(e): + logger.error( + f"Cannot submit task '{task_name}': ThreadPoolExecutor is shutdown" + ) + results[task_name] = None + else: + raise - # Collect results as they complete - try: - # Handle infinite timeout case - timeout_param = None if timeout is None else timeout - for future in as_completed(future_to_name, timeout=timeout_param): - task_name = future_to_name[future] - try: - result = future.result() - results[task_name] = result - logger.debug(f"Task '{task_name}' completed successfully") - except Exception as e: - logger.error(f"Task '{task_name}' failed with error: {e}") - results[task_name] = None + # Collect results as they complete + try: + # Handle infinite timeout case + timeout_param = None if timeout is None else timeout + for future in as_completed(future_to_name, timeout=timeout_param): + task_name = future_to_name[future] + try: + result = future.result() + results[task_name] = result + logger.debug(f"Task '{task_name}' completed successfully") + except Exception as e: + logger.error(f"Task '{task_name}' failed with error: {e}") + results[task_name] = None - except Exception: - elapsed_time = time.time() - start_time - timeout_msg = "infinite" if timeout is None else f"{timeout}s" - logger.error( - f"Tasks execution timed out after {elapsed_time:.2f} seconds (timeout: {timeout_msg})" - ) - # Cancel remaining futures - for future in future_to_name: - if not future.done(): - future.cancel() - task_name = future_to_name[future] - logger.warning(f"Cancelled task '{task_name}' due to timeout") - results[task_name] = None - timeout_seconds = "infinite" if timeout is None else timeout - logger.error(f"Tasks execution timed out after {timeout_seconds} seconds") + except Exception: + elapsed_time = time.time() - start_time + timeout_msg = "infinite" if timeout is None else f"{timeout}s" + logger.error( + f"Tasks execution timed out after {elapsed_time:.2f} seconds (timeout: {timeout_msg})" + ) + # Cancel remaining futures + for future in future_to_name: + if not future.done(): + future.cancel() + task_name = future_to_name[future] + logger.warning(f"Cancelled task '{task_name}' due to timeout") + results[task_name] = None + timeout_seconds = "infinite" if timeout is None else timeout + logger.error(f"Tasks execution timed out after {timeout_seconds} seconds") return results diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index fb5f4ce7c..c8e2eb59e 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Any +import json + +from typing import TYPE_CHECKING from memos.api.product_models import APISearchRequest from memos.configs.mem_scheduler import GeneralSchedulerConfig @@ -8,18 +10,20 @@ from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( API_MIX_SEARCH_LABEL, - QUERY_LABEL, MemCubeID, SearchMode, UserID, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.utils.api_utils import format_textual_memory_item +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.types import UserContext if TYPE_CHECKING: from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem + from memos.reranker.http_bge import HTTPBGEReranker logger = get_logger(__name__) @@ -31,30 +35,18 @@ class OptimizedScheduler(GeneralScheduler): def __init__(self, config: GeneralSchedulerConfig): super().__init__(config) self.api_module = SchedulerAPIModule() - self.message_consumers = { - API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, - } - - def _format_memory_item(self, memory_data: Any) -> dict[str, Any]: - """Format a single memory item for API response.""" - memory = memory_data.model_dump() - memory_id = memory["id"] - ref_id = f"[{memory_id.split('-')[0]}]" - - memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] - memory["metadata"]["sources"] = [] - memory["metadata"]["ref_id"] = ref_id - memory["metadata"]["id"] = memory_id - memory["metadata"]["memory"] = memory["memory"] - - return memory + self.register_handlers( + { + API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, + } + ) - def fine_search_memories( + def search_memories( self, search_req: APISearchRequest, user_context: UserContext, mem_cube: GeneralMemCube, + mode: SearchMode, ): """Fine search memories function copied from server_router to avoid circular import""" target_session_id = search_req.session_id @@ -67,7 +59,7 @@ def fine_search_memories( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FINE, + mode=mode, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, @@ -77,42 +69,145 @@ def fine_search_memories( "chat_history": search_req.chat_history, }, ) - formatted_memories = [self._format_memory_item(data) for data in search_results] + return search_results + + def submit_memory_history_async_task( + self, + search_req: APISearchRequest, + user_context: UserContext, + ): + # Create message for async fine search + message_content = { + "search_req": { + "query": search_req.query, + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "top_k": search_req.top_k, + "internet_search": search_req.internet_search, + "moscube": search_req.moscube, + "chat_history": search_req.chat_history, + }, + "user_context": {"mem_cube_id": user_context.mem_cube_id}, + } + + async_task_id = f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}" + + # Get mem_cube for the message + mem_cube = self.current_mem_cube + + message = ScheduleMessageItem( + item_id=async_task_id, + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + label=API_MIX_SEARCH_LABEL, + mem_cube=mem_cube, + content=json.dumps(message_content), + timestamp=get_utc_now(), + ) + + # Submit async task + self.submit_messages([message]) + logger.info(f"Submitted async fine search task for user {search_req.user_id}") + return async_task_id + + def mix_search_memories( + self, + search_req: APISearchRequest, + user_context: UserContext, + ): + """ + Mix search memories: fast search + async fine search + """ + + # Get mem_cube for fast search + mem_cube = self.current_mem_cube + + # Perform fast search + fast_memories = self.search_memories( + search_req=search_req, + user_context=user_context, + mem_cube=mem_cube, + mode=SearchMode.FAST, + ) + + self.submit_memory_history_async_task( + search_req=search_req, + user_context=user_context, + ) + + # Try to get pre-computed fine memories if available + pre_fine_memories = self.api_module.get_pre_memories( + user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id + ) + if not pre_fine_memories: + # Format fast memories for return + formatted_memories = [format_textual_memory_item(data) for data in fast_memories] + return formatted_memories + + # Merge fast and pre-computed fine memories (both are TextualMemoryItem objects) + combined_memories = fast_memories + pre_fine_memories + # Remove duplicates based on memory content + seen_contents = set() + unique_memories = [] + for memory in combined_memories: + # Both fast_memories and pre_fine_memories are TextualMemoryItem objects + content_key = memory.memory # Use .memory attribute instead of .get("content", "") + if content_key not in seen_contents: + seen_contents.add(content_key) + unique_memories.append(memory) + + # Rerank Memories - reranker expects TextualMemoryItem objects + reranker: HTTPBGEReranker = mem_cube.text_mem.reranker + + # Use search_req parameters for reranking + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + sorted_results = reranker.rerank( + query=search_req.query, # Use search_req.query instead of undefined query + graph_results=unique_memories, # Pass TextualMemoryItem objects directly + top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k + search_filter=search_filter, + ) + + formatted_memories = [ + format_textual_memory_item(item) for item, score in sorted_results[: search_req.top_k] + ] return formatted_memories def update_search_memories_to_redis( - self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] + self, + user_id: str, + mem_cube_id: str, + messages: list[ScheduleMessageItem], ): mem_cube = messages[0].mem_cube - # for status update - self._set_current_context_from_message(msg=messages[0]) - - # update query monitors for msg in messages: - self.monitor.register_query_monitor_if_not_exists( - user_id=user_id, mem_cube_id=mem_cube_id - ) - - content_dict = msg.content + content_dict = json.loads(msg.content) search_req = content_dict["search_req"] user_context = content_dict["user_context"] - formatted_memories = self.fine_search_memories( - search_req=search_req, user_context=user_context, mem_cube=mem_cube + fine_memories: list[TextualMemoryItem] = self.search_memories( + search_req=APISearchRequest(**content_dict["search_req"]), + user_context=UserContext(**content_dict["user_context"]), + mem_cube=mem_cube, + mode=SearchMode.FINE, ) + formatted_memories = [format_textual_memory_item(data) for data in fine_memories] # Sync search data to Redis - try: - self.api_module.sync_search_data( - user_id=search_req.user_id, - mem_cube_id=user_context.mem_cube_id, - query=search_req.query, - formatted_memories=formatted_memories, - ) - except Exception as e: - logger.error(f"Failed to sync search data: {e}") + self.api_module.sync_search_data( + item_id=msg.item_id, + user_id=search_req["user_id"], + mem_cube_id=user_context["mem_cube_id"], + query=search_req["query"], + memories=fine_memories, + formatted_memories=formatted_memories, + ) def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ @@ -121,12 +216,12 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) Args: messages: List of query messages to process """ - logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {API_MIX_SEARCH_LABEL} handler.") # Process the query in a session turn grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) - self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) + self.validate_schedule_messages(messages=messages, label=API_MIX_SEARCH_LABEL) for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: diff --git a/src/memos/mem_scheduler/orm_modules/api_redis_model.py b/src/memos/mem_scheduler/orm_modules/api_redis_model.py new file mode 100644 index 000000000..04cd7e833 --- /dev/null +++ b/src/memos/mem_scheduler/orm_modules/api_redis_model.py @@ -0,0 +1,517 @@ +import os +import time + +from typing import Any + +from sqlalchemy.orm import declarative_base + +from memos.log import get_logger +from memos.mem_scheduler.orm_modules.base_model import DatabaseError +from memos.mem_scheduler.schemas.api_schemas import ( + APISearchHistoryManager, +) +from memos.mem_scheduler.utils.db_utils import get_utc_now + + +logger = get_logger(__name__) + +Base = declarative_base() + + +class APIRedisDBManager: + """Redis-based database manager for any serializable object + + This class handles persistence, synchronization, and locking + for any object that implements to_json/from_json methods using Redis as the backend storage. + """ + + # Add orm_class attribute for compatibility + orm_class = None + + def __init__( + self, + user_id: str | None = None, + mem_cube_id: str | None = None, + obj: APISearchHistoryManager | None = None, + lock_timeout: int = 10, + redis_client=None, + redis_config: dict | None = None, + window_size: int = 5, + ): + """Initialize the Redis database manager + + Args: + user_id: Unique identifier for the user + mem_cube_id: Unique identifier for the memory cube + obj: Optional object instance to manage (must have to_json/from_json methods) + lock_timeout: Timeout in seconds for lock acquisition + redis_client: Redis client instance (optional) + redis_config: Redis configuration dictionary (optional) + """ + # Initialize Redis client + self.redis_client = redis_client + self.redis_config = redis_config or {} + + if self.redis_client is None: + self._init_redis_client() + + # Initialize base attributes without calling parent's init_manager + self.user_id = user_id + self.mem_cube_id = mem_cube_id + self.obj = obj + self.lock_timeout = lock_timeout + self.engine = None # Keep for compatibility but not used + self.SessionLocal = None # Not used for Redis + self.window_size = window_size + self.lock_key = f"{self._get_key_prefix()}:lock" + + logger.info( + f"RedisDBManager initialized for user_id: {user_id}, mem_cube_id: {mem_cube_id}" + ) + logger.info(f"Redis client: {type(self.redis_client).__name__}") + + # Test Redis connection + try: + self.redis_client.ping() + logger.info("Redis connection successful") + except Exception as e: + logger.warning(f"Redis ping failed: {e}") + # Don't raise error here as it might be a mock client in tests + + def _get_key_prefix(self) -> str: + """Generate Redis key prefix for this user and memory cube + + Returns: + Redis key prefix string + """ + return f"redis_api:{self.user_id}:{self.mem_cube_id}" + + def _get_data_key(self) -> str: + """Generate Redis key for storing serialized data + + Returns: + Redis data key string + """ + return f"{self._get_key_prefix()}:data" + + def _init_redis_client(self): + """Initialize Redis client from config or environment""" + try: + import redis + except ImportError: + logger.error("Redis package not installed. Install with: pip install redis") + raise + + # Try to get Redis client from environment first + if not self.redis_client: + self.redis_client = APIRedisDBManager.load_redis_engine_from_env() + + # If still no client, try from config + if not self.redis_client and self.redis_config: + redis_kwargs = { + "host": self.redis_config.get("host"), + "port": self.redis_config.get("port"), + "db": self.redis_config.get("db"), + "decode_responses": True, + } + + if self.redis_config.get("password"): + redis_kwargs["password"] = self.redis_config["password"] + + self.redis_client = redis.Redis(**redis_kwargs) + + # Final fallback to localhost + if not self.redis_client: + logger.warning("No Redis configuration found, using localhost defaults") + self.redis_client = redis.Redis( + host="localhost", port=6379, db=0, decode_responses=True + ) + + # Test connection + if not self.redis_client.ping(): + raise ConnectionError("Redis ping failed") + + logger.info("Redis client initialized successfully") + + def acquire_lock(self, block: bool = True, **kwargs) -> bool: + """Acquire a distributed lock using Redis with atomic operations + + Args: + block: Whether to block until lock is acquired + **kwargs: Additional filter criteria (ignored for Redis) + + Returns: + True if lock was acquired, False otherwise + """ + + now = get_utc_now() + + # Use Redis SET with NX (only if not exists) and EX (expiry) for atomic lock acquisition + lock_value = f"{self._get_key_prefix()}:{now.timestamp()}" + + while True: + result = self.redis_client.get(self.lock_key) + if result: + # Wait a bit before retrying + logger.info( + f"Waiting for Redis lock to be released for {self.user_id}/{self.mem_cube_id}" + ) + if not block: + logger.warning( + f"Redis lock is held for {self.user_id}/{self.mem_cube_id}, cannot acquire" + ) + return False + else: + time.sleep(0.1) + continue + else: + # Try to acquire lock atomically + result = self.redis_client.set( + self.lock_key, + lock_value, + ex=self.lock_timeout, # Set expiry in seconds + ) + logger.info(f"Redis lock acquired for {self._get_key_prefix()}") + return True + + def release_locks(self, **kwargs): + # Delete the lock key to release the lock + result = self.redis_client.delete(self.lock_key) + + # Redis DELETE returns the number of keys deleted (0 or 1) + if result > 0: + logger.info(f"Redis lock released for {self._get_key_prefix()}") + else: + logger.info(f"No Redis lock found to release for {self._get_key_prefix()}") + + def merge_items( + self, + redis_data: str, + obj_instance: APISearchHistoryManager, + size_limit: int, + ): + """Merge Redis data with current object instance + + Args: + redis_data: JSON string from Redis containing serialized APISearchHistoryManager + obj_instance: Current APISearchHistoryManager instance + size_limit: Maximum number of completed entries to keep + + Returns: + APISearchHistoryManager: Merged and synchronized manager instance + """ + + # Parse Redis data + redis_manager = APISearchHistoryManager.from_json(redis_data) + logger.debug( + f"Loaded Redis manager with {len(redis_manager.completed_entries)} completed and {len(redis_manager.running_item_ids)} running task IDs" + ) + + # Create a new merged manager with the original window size from obj_instance + # Use size_limit only for limiting entries, not as window_size + original_window_size = obj_instance.window_size + merged_manager = APISearchHistoryManager(window_size=original_window_size) + + # Merge completed entries - combine both sources and deduplicate by task_id + # Ensure all entries are APIMemoryHistoryEntryItem instances + from memos.mem_scheduler.schemas.api_schemas import APIMemoryHistoryEntryItem + + all_completed = {} + + # Add Redis completed entries + for entry in redis_manager.completed_entries: + if isinstance(entry, dict): + # Convert dict to APIMemoryHistoryEntryItem instance + try: + entry_obj = APIMemoryHistoryEntryItem(**entry) + task_id = entry_obj.item_id + all_completed[task_id] = entry_obj + except Exception as e: + logger.warning( + f"Failed to convert dict entry to APIMemoryHistoryEntryItem: {e}" + ) + continue + else: + task_id = entry.item_id + all_completed[task_id] = entry + + # Add current instance completed entries (these take priority if duplicated) + for entry in obj_instance.completed_entries: + if isinstance(entry, dict): + # Convert dict to APIMemoryHistoryEntryItem instance + try: + entry_obj = APIMemoryHistoryEntryItem(**entry) + task_id = entry_obj.item_id + all_completed[task_id] = entry_obj + except Exception as e: + logger.warning( + f"Failed to convert dict entry to APIMemoryHistoryEntryItem: {e}" + ) + continue + else: + task_id = entry.item_id + all_completed[task_id] = entry + + # Sort by created_time and apply size limit + completed_list = list(all_completed.values()) + + def get_created_time(entry): + """Helper function to safely extract created_time for sorting""" + from datetime import datetime + + # All entries should now be APIMemoryHistoryEntryItem instances + return getattr(entry, "created_time", datetime.min) + + completed_list.sort(key=get_created_time, reverse=True) + merged_manager.completed_entries = completed_list[:size_limit] + + # Merge running task IDs - combine both sources and deduplicate + all_running_item_ids = set() + + # Add Redis running task IDs + all_running_item_ids.update(redis_manager.running_item_ids) + + # Add current instance running task IDs + all_running_item_ids.update(obj_instance.running_item_ids) + + merged_manager.running_item_ids = list(all_running_item_ids) + + logger.info( + f"Merged manager: {len(merged_manager.completed_entries)} completed, {len(merged_manager.running_item_ids)} running task IDs" + ) + return merged_manager + + def sync_with_redis(self, size_limit: int | None = None) -> None: + """Synchronize data between Redis and the business object + + Args: + size_limit: Optional maximum number of items to keep after synchronization + """ + + # Use window_size from the object if size_limit is not provided + if size_limit is None: + size_limit = self.window_size + + # Acquire lock before operations + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire Redis lock for synchronization") + return + + # Load existing data from Redis + data_key = self._get_data_key() + redis_data = self.redis_client.get(data_key) + + if redis_data: + # Merge Redis data with current object + merged_obj = self.merge_items( + redis_data=redis_data, obj_instance=self.obj, size_limit=size_limit + ) + + # Update the current object with merged data + self.obj = merged_obj + logger.info( + f"Successfully synchronized with Redis data for {self.user_id}/{self.mem_cube_id}" + ) + else: + logger.info( + f"No existing Redis data found for {self.user_id}/{self.mem_cube_id}, using current object" + ) + + # Save the synchronized object back to Redis + self.save_to_db(self.obj) + + self.release_locks() + + def save_to_db(self, obj_instance: Any) -> None: + """Save the current state of the business object to Redis + + Args: + obj_instance: The object instance to save (must have to_json method) + """ + + data_key = self._get_data_key() + + self.redis_client.set(data_key, obj_instance.to_json()) + + logger.info(f"Updated existing Redis record for {data_key}") + + def load_from_db(self) -> Any | None: + data_key = self._get_data_key() + + # Load from Redis + serialized_data = self.redis_client.get(data_key) + + if not serialized_data: + logger.info(f"No Redis record found for {data_key}") + return None + + # Deserialize the business object using the actual object type + if hasattr(self, "obj_type") and self.obj_type is not None: + db_instance = self.obj_type.from_json(serialized_data) + else: + # Default to APISearchHistoryManager for this class + db_instance = APISearchHistoryManager.from_json(serialized_data) + + logger.info(f"Successfully loaded object from Redis for {data_key} ") + + return db_instance + + @classmethod + def from_env( + cls, + user_id: str, + mem_cube_id: str, + obj: Any | None = None, + lock_timeout: int = 10, + env_file_path: str | None = None, + ) -> "APIRedisDBManager": + """Create RedisDBManager from environment variables + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + obj: Optional MemoryMonitorManager instance + lock_timeout: Lock timeout in seconds + env_file_path: Optional path to .env file + + Returns: + RedisDBManager instance + """ + + redis_client = APIRedisDBManager.load_redis_engine_from_env(env_file_path) + return cls( + user_id=user_id, + mem_cube_id=mem_cube_id, + obj=obj, + lock_timeout=lock_timeout, + redis_client=redis_client, + ) + + def close(self): + """Close the Redis connection and clean up resources""" + try: + if hasattr(self.redis_client, "close"): + self.redis_client.close() + logger.info( + f"Redis connection closed for user_id: {self.user_id}, mem_cube_id: {self.mem_cube_id}" + ) + except Exception as e: + logger.warning(f"Error closing Redis connection: {e}") + + @staticmethod + def load_redis_engine_from_env(env_file_path: str | None = None) -> Any: + """Load Redis connection from environment variables + + Args: + env_file_path: Path to .env file (optional, defaults to loading from current environment) + + Returns: + Redis connection instance + + Raises: + DatabaseError: If required environment variables are missing or connection fails + """ + try: + import redis + except ImportError as e: + error_msg = "Redis package not installed. Install with: pip install redis" + logger.error(error_msg) + raise DatabaseError(error_msg) from e + + # Load environment variables from file if provided + if env_file_path: + if os.path.exists(env_file_path): + from dotenv import load_dotenv + + load_dotenv(env_file_path) + logger.info(f"Loaded environment variables from {env_file_path}") + else: + logger.warning( + f"Environment file not found: {env_file_path}, using current environment variables" + ) + else: + logger.info("Using current environment variables (no env_file_path provided)") + + # Get Redis configuration from environment variables + redis_host = os.getenv("REDIS_HOST") or os.getenv("MEMSCHEDULER_REDIS_HOST") + redis_port_str = os.getenv("REDIS_PORT") or os.getenv("MEMSCHEDULER_REDIS_PORT") + redis_db_str = os.getenv("REDIS_DB") or os.getenv("MEMSCHEDULER_REDIS_DB") + redis_password = os.getenv("REDIS_PASSWORD") or os.getenv("MEMSCHEDULER_REDIS_PASSWORD") + + # Check required environment variables + if not redis_host: + error_msg = ( + "Missing required Redis environment variable: REDIS_HOST or MEMSCHEDULER_REDIS_HOST" + ) + logger.error(error_msg) + return None + + # Parse port with validation + try: + redis_port = int(redis_port_str) if redis_port_str else 6379 + except ValueError: + error_msg = f"Invalid REDIS_PORT value: {redis_port_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Parse database with validation + try: + redis_db = int(redis_db_str) if redis_db_str else 0 + except ValueError: + error_msg = f"Invalid REDIS_DB value: {redis_db_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Optional timeout settings + socket_timeout = os.getenv( + "REDIS_SOCKET_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_TIMEOUT", None) + ) + socket_connect_timeout = os.getenv( + "REDIS_SOCKET_CONNECT_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_CONNECT_TIMEOUT", None) + ) + + try: + # Build Redis connection parameters + redis_kwargs = { + "host": redis_host, + "port": redis_port, + "db": redis_db, + "decode_responses": True, + } + + if redis_password: + redis_kwargs["password"] = redis_password + + if socket_timeout: + try: + redis_kwargs["socket_timeout"] = float(socket_timeout) + except ValueError: + logger.warning( + f"Invalid REDIS_SOCKET_TIMEOUT value: {socket_timeout}, ignoring" + ) + + if socket_connect_timeout: + try: + redis_kwargs["socket_connect_timeout"] = float(socket_connect_timeout) + except ValueError: + logger.warning( + f"Invalid REDIS_SOCKET_CONNECT_TIMEOUT value: {socket_connect_timeout}, ignoring" + ) + + # Create Redis connection + redis_client = redis.Redis(**redis_kwargs) + + # Test connection + if not redis_client.ping(): + raise ConnectionError("Redis ping failed") + + logger.info( + f"Successfully created Redis connection: {redis_host}:{redis_port}/{redis_db}" + ) + return redis_client + + except Exception as e: + error_msg = f"Failed to create Redis connection from environment variables: {e}" + logger.error(error_msg) + raise DatabaseError(error_msg) from e diff --git a/src/memos/mem_scheduler/orm_modules/base_model.py b/src/memos/mem_scheduler/orm_modules/base_model.py index cf3fc904c..9783cea82 100644 --- a/src/memos/mem_scheduler/orm_modules/base_model.py +++ b/src/memos/mem_scheduler/orm_modules/base_model.py @@ -727,120 +727,3 @@ def load_mysql_engine_from_env(env_file_path: str | None = None) -> Engine | Non error_msg = f"Failed to create MySQL engine from environment variables: {e}" logger.error(error_msg) raise DatabaseError(error_msg) from e - - @staticmethod - def load_redis_engine_from_env(env_file_path: str | None = None) -> Any: - """Load Redis connection from environment variables - - Args: - env_file_path: Path to .env file (optional, defaults to loading from current environment) - - Returns: - Redis connection instance - - Raises: - DatabaseError: If required environment variables are missing or connection fails - """ - try: - import redis - except ImportError as e: - error_msg = "Redis package not installed. Install with: pip install redis" - logger.error(error_msg) - raise DatabaseError(error_msg) from e - - # Load environment variables from file if provided - if env_file_path: - if os.path.exists(env_file_path): - from dotenv import load_dotenv - - load_dotenv(env_file_path) - logger.info(f"Loaded environment variables from {env_file_path}") - else: - logger.warning( - f"Environment file not found: {env_file_path}, using current environment variables" - ) - else: - logger.info("Using current environment variables (no env_file_path provided)") - - # Get Redis configuration from environment variables - redis_host = os.getenv("REDIS_HOST") or os.getenv("MEMSCHEDULER_REDIS_HOST") - redis_port_str = os.getenv("REDIS_PORT") or os.getenv("MEMSCHEDULER_REDIS_PORT") - redis_db_str = os.getenv("REDIS_DB") or os.getenv("MEMSCHEDULER_REDIS_DB") - redis_password = os.getenv("REDIS_PASSWORD") or os.getenv("MEMSCHEDULER_REDIS_PASSWORD") - - # Check required environment variables - if not redis_host: - error_msg = ( - "Missing required Redis environment variable: REDIS_HOST or MEMSCHEDULER_REDIS_HOST" - ) - logger.error(error_msg) - return None - - # Parse port with validation - try: - redis_port = int(redis_port_str) if redis_port_str else 6379 - except ValueError: - error_msg = f"Invalid REDIS_PORT value: {redis_port_str}. Must be a valid integer." - logger.error(error_msg) - return None - - # Parse database with validation - try: - redis_db = int(redis_db_str) if redis_db_str else 0 - except ValueError: - error_msg = f"Invalid REDIS_DB value: {redis_db_str}. Must be a valid integer." - logger.error(error_msg) - return None - - # Optional timeout settings - socket_timeout = os.getenv( - "REDIS_SOCKET_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_TIMEOUT", None) - ) - socket_connect_timeout = os.getenv( - "REDIS_SOCKET_CONNECT_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_CONNECT_TIMEOUT", None) - ) - - try: - # Build Redis connection parameters - redis_kwargs = { - "host": redis_host, - "port": redis_port, - "db": redis_db, - "decode_responses": True, - } - - if redis_password: - redis_kwargs["password"] = redis_password - - if socket_timeout: - try: - redis_kwargs["socket_timeout"] = float(socket_timeout) - except ValueError: - logger.warning( - f"Invalid REDIS_SOCKET_TIMEOUT value: {socket_timeout}, ignoring" - ) - - if socket_connect_timeout: - try: - redis_kwargs["socket_connect_timeout"] = float(socket_connect_timeout) - except ValueError: - logger.warning( - f"Invalid REDIS_SOCKET_CONNECT_TIMEOUT value: {socket_connect_timeout}, ignoring" - ) - - # Create Redis connection - redis_client = redis.Redis(**redis_kwargs) - - # Test connection - if not redis_client.ping(): - raise ConnectionError("Redis ping failed") - - logger.info( - f"Successfully created Redis connection: {redis_host}:{redis_port}/{redis_db}" - ) - return redis_client - - except Exception as e: - error_msg = f"Failed to create Redis connection from environment variables: {e}" - logger.error(error_msg) - raise DatabaseError(error_msg) from e diff --git a/src/memos/mem_scheduler/schemas/api_schemas.py b/src/memos/mem_scheduler/schemas/api_schemas.py new file mode 100644 index 000000000..23eb5a848 --- /dev/null +++ b/src/memos/mem_scheduler/schemas/api_schemas.py @@ -0,0 +1,232 @@ +from datetime import datetime +from enum import Enum +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, field_serializer + +from memos.log import get_logger +from memos.mem_scheduler.general_modules.misc import DictConversionMixin +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.memories.textual.item import TextualMemoryItem + + +logger = get_logger(__name__) + + +class TaskRunningStatus(str, Enum): + """Enumeration for task running status values.""" + + RUNNING = "running" + COMPLETED = "completed" + + +class APIMemoryHistoryEntryItem(BaseModel, DictConversionMixin): + """Data class for search entry items stored in Redis.""" + + item_id: str = Field( + description="Unique identifier for the task", default_factory=lambda: str(uuid4()) + ) + query: str = Field(..., description="Search query string") + formatted_memories: Any = Field(..., description="Formatted search results") + memories: list[TextualMemoryItem] = Field( + default_factory=list, description="List of TextualMemoryItem objects" + ) + task_status: str = Field( + default="running", description="Task status: running, completed, failed" + ) + conversation_id: str | None = Field( + default=None, description="Optional conversation identifier" + ) + created_time: datetime = Field(description="Entry creation time", default_factory=get_utc_now) + timestamp: datetime | None = Field(default=None, description="Timestamp for the entry") + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + ) + + @field_serializer("created_time") + def serialize_created_time(self, value: datetime) -> str: + """Serialize datetime to ISO format string.""" + return value.isoformat() + + def get(self, key: str, default: Any | None = None) -> Any: + """ + Get attribute value by key name, similar to dict.get(). + + Args: + key: The attribute name to retrieve + default: Default value to return if attribute doesn't exist + + Returns: + The attribute value or default if not found + """ + return getattr(self, key, default) + + +class APISearchHistoryManager(BaseModel, DictConversionMixin): + """ + Data structure for managing search history with separate completed and running entries. + Supports window_size to limit the number of completed entries. + """ + + window_size: int = Field(default=5, description="Maximum number of completed entries to keep") + completed_entries: list[APIMemoryHistoryEntryItem] = Field( + default_factory=list, description="List of completed search entries" + ) + running_item_ids: list[str] = Field( + default_factory=list, description="List of running task ids" + ) + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + ) + + def complete_entry(self, task_id: str) -> bool: + """ + Remove task_id from running list when completed. + Note: The actual entry data should be managed separately. + + Args: + task_id: The task ID to complete + + Returns: + True if task_id was found and removed, False otherwise + """ + if task_id in self.running_item_ids: + self.running_item_ids.remove(task_id) + logger.debug(f"Completed task_id: {task_id}") + return True + + logger.warning(f"Task ID {task_id} not found in running task ids") + return False + + def get_running_item_ids(self) -> list[str]: + """Get all running task IDs""" + return self.running_item_ids.copy() + + def get_completed_entries(self) -> list[dict[str, Any]]: + """Get all completed entries""" + return self.completed_entries.copy() + + def get_history_memory_entries(self, turns: int | None = None) -> list[dict[str, Any]]: + """ + Get the most recent n completed search entries, sorted by created_time. + + Args: + turns: Number of entries to return. If None, returns all completed entries. + + Returns: + List of completed search entries, sorted by created_time (newest first) + """ + if not self.completed_entries: + return [] + + # Sort by created_time (newest first) + sorted_entries = sorted(self.completed_entries, key=lambda x: x.created_time, reverse=True) + + if turns is None: + return sorted_entries + + return sorted_entries[:turns] + + def get_history_memories(self, turns: int | None = None) -> list[TextualMemoryItem]: + """ + Get the most recent n completed search entries, sorted by created_time. + + Args: + turns: Number of entries to return. If None, returns all completed entries. + + Returns: + List of TextualMemoryItem objects from completed entries, sorted by created_time (newest first) + """ + sorted_entries = self.get_history_memory_entries(turns=turns) + + memories = [] + for one in sorted_entries: + memories.extend(one.memories) + return memories + + def find_entry_by_item_id(self, item_id: str) -> tuple[dict[str, Any] | None, str]: + """ + Find an entry by item_id in completed list only. + Running entries are now just task IDs, so we can only search completed entries. + + Args: + item_id: The item ID to search for + + Returns: + Tuple of (entry_dict, location) where location is 'completed' or 'not_found' + """ + # Check completed entries + for entry in self.completed_entries: + try: + if hasattr(entry, "item_id") and entry.item_id == item_id: + return entry.to_dict(), "completed" + elif isinstance(entry, dict) and entry.get("item_id") == item_id: + return entry, "completed" + except AttributeError as e: + logger.warning(f"Entry missing item_id attribute: {e}, entry type: {type(entry)}") + continue + + return None, "not_found" + + def update_entry_by_item_id( + self, + item_id: str, + query: str, + formatted_memories: Any, + task_status: TaskRunningStatus, + conversation_id: str | None = None, + memories: list[TextualMemoryItem] | None = None, + ) -> bool: + """ + Update an existing entry by item_id. Since running entries are now just IDs, + this method can only update completed entries. + + Args: + item_id: The item ID to update + query: New query string + formatted_memories: New formatted memories + task_status: New task status + conversation_id: New conversation ID + memories: List of TextualMemoryItem objects + + Returns: + True if entry was found and updated, False otherwise + """ + # Find the entry in completed list + for entry in self.completed_entries: + if entry.item_id == item_id: + # Update the entry content + entry.query = query + entry.formatted_memories = formatted_memories + entry.task_status = task_status + if conversation_id is not None: + entry.conversation_id = conversation_id + if memories is not None: + entry.memories = memories + + logger.debug(f"Updated entry with item_id: {item_id}, new status: {task_status}") + return True + + logger.warning(f"Entry with item_id: {item_id} not found in completed entries") + return False + + def get_total_count(self) -> dict[str, int]: + """Get count of entries by status""" + return { + "completed": len(self.completed_entries), + "running": len(self.running_item_ids), + "total": len(self.completed_entries) + len(self.running_item_ids), + } + + def __len__(self) -> int: + """Return total number of entries (completed + running)""" + return len(self.completed_entries) + len(self.running_item_ids) + + +# Alias for easier usage +SearchHistoryManager = APISearchHistoryManager diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 2bc7a3b98..a2c6434fe 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -39,6 +39,7 @@ class SearchMode(str, Enum): DEFAULT_TOP_K = 10 DEFAULT_CONTEXT_WINDOW_SIZE = 5 DEFAULT_USE_REDIS_QUEUE = False +DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 # startup mode configuration STARTUP_BY_THREAD = "thread" diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 541d2486d..bd3155a96 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -6,7 +6,7 @@ from typing_extensions import TypedDict from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube +from memos.mem_cube.base import BaseMemCube from memos.mem_scheduler.general_modules.misc import DictConversionMixin from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -35,10 +35,9 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) user_id: str = Field(..., description="user id") - session_id: str | None = Field(default=None, description="session id") mem_cube_id: str = Field(..., description="memcube id") label: str = Field(..., description="Label of the schedule message") - mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") + mem_cube: BaseMemCube | str = Field(..., description="memcube for schedule") content: str = Field(..., description="Content of the schedule message") timestamp: datetime = Field( default_factory=get_utc_now, description="submit time for schedule_messages" @@ -56,7 +55,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "example": { "item_id": "123e4567-e89b-12d3-a456-426614174000", # Sample UUID "user_id": "user123", # Example user identifier - "session_id": "session123", # Example session identifier "mem_cube_id": "cube456", # Sample memory cube ID "label": "sample_label", # Demonstration label value "mem_cube": "obj of GeneralMemCube", # Added mem_cube example @@ -67,18 +65,17 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): ) @field_serializer("mem_cube") - def serialize_mem_cube(self, cube: GeneralMemCube | str, _info) -> str: - """Custom serializer for GeneralMemCube objects to string representation""" + def serialize_mem_cube(self, cube: BaseMemCube | str, _info) -> str: + """Custom serializer for BaseMemCube objects to string representation""" if isinstance(cube, str): return cube - return f"" + return f"<{type(cube).__name__}:{id(cube)}>" def to_dict(self) -> dict: """Convert model to dictionary suitable for Redis Stream""" return { "item_id": self.item_id, "user_id": self.user_id, - "session_id": self.session_id, "cube_id": self.mem_cube_id, "label": self.label, "cube": "Not Applicable", # Custom cube serialization @@ -93,8 +90,6 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": item_id=data.get("item_id", str(uuid4())), user_id=data["user_id"], mem_cube_id=data["cube_id"], - session_id=data["session_id"], - cube_id=data["cube_id"], label=data["label"], mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], diff --git a/src/memos/mem_scheduler/utils/api_utils.py b/src/memos/mem_scheduler/utils/api_utils.py new file mode 100644 index 000000000..c8d096517 --- /dev/null +++ b/src/memos/mem_scheduler/utils/api_utils.py @@ -0,0 +1,76 @@ +import uuid + +from typing import Any + +from memos.memories.textual.item import TreeNodeTextualMemoryMetadata +from memos.memories.textual.tree import TextualMemoryItem + + +def format_textual_memory_item(memory_data: Any) -> dict[str, Any]: + """Format a single memory item for API response.""" + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + +def make_textual_item(memory_data): + return memory_data + + +def text_to_textual_memory_item( + text: str, + user_id: str | None = None, + session_id: str | None = None, + memory_type: str = "WorkingMemory", + tags: list[str] | None = None, + key: str | None = None, + sources: list | None = None, + background: str = "", + confidence: float = 0.99, + embedding: list[float] | None = None, +) -> TextualMemoryItem: + """ + Convert text into a TextualMemoryItem object. + + Args: + text: Memory content text + user_id: User ID + session_id: Session ID + memory_type: Memory type, defaults to "WorkingMemory" + tags: List of tags + key: Memory key or title + sources: List of sources + background: Background information + confidence: Confidence score (0-1) + embedding: Vector embedding + + Returns: + TextualMemoryItem: Wrapped memory item + """ + return TextualMemoryItem( + id=str(uuid.uuid4()), + memory=text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type=memory_type, + status="activated", + tags=tags or [], + key=key, + embedding=embedding or [], + usage=[], + sources=sources or [], + background=background, + confidence=confidence, + type="fact", + ), + ) diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index 239557bc9..d86911e82 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -273,7 +273,7 @@ def _cleanup_redis_resources(self): self._cleanup_local_redis() - async def redis_add_message_stream(self, message: dict): + def redis_add_message_stream(self, message: dict): logger.debug(f"add_message_stream: {message}") return self._redis_conn.xadd("user:queries:stream", message) diff --git a/src/memos/memories/activation/item.py b/src/memos/memories/activation/item.py index ba1619371..9267e6920 100644 --- a/src/memos/memories/activation/item.py +++ b/src/memos/memories/activation/item.py @@ -6,6 +6,8 @@ from pydantic import BaseModel, ConfigDict, Field from transformers import DynamicCache +from memos.mem_scheduler.utils.db_utils import get_utc_now + class ActivationMemoryItem(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) @@ -23,7 +25,7 @@ class KVCacheRecords(BaseModel): description="Single string combining all text_memories using assembly template", ) timestamp: datetime = Field( - default_factory=datetime.utcnow, description="submit time for schedule_messages" + default_factory=get_utc_now, description="submit time for schedule_messages" ) diff --git a/tests/mem_scheduler/test_orm.py b/tests/mem_scheduler/test_orm.py deleted file mode 100644 index a43231e4a..000000000 --- a/tests/mem_scheduler/test_orm.py +++ /dev/null @@ -1,447 +0,0 @@ -import os -import tempfile -import time - -from datetime import datetime, timedelta - -import pytest - -from memos.mem_scheduler.orm_modules.base_model import BaseDBManager - -# Import the classes to test -from memos.mem_scheduler.orm_modules.monitor_models import ( - DBManagerForMemoryMonitorManager, - DBManagerForQueryMonitorQueue, -) -from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager -from memos.mem_scheduler.schemas.monitor_schemas import ( - MemoryMonitorItem, - MemoryMonitorManager, - QueryMonitorItem, - QueryMonitorQueue, -) - - -# Test data -TEST_USER_ID = "test_user" -TEST_MEM_CUBE_ID = "test_mem_cube" -TEST_QUEUE_ID = "test_queue" - - -class TestBaseDBManager: - """Base class for DBManager tests with common fixtures""" - - @pytest.fixture - def temp_db(self): - """Create a temporary database for testing.""" - temp_dir = tempfile.mkdtemp() - db_path = os.path.join(temp_dir, "test_scheduler_orm.db") - yield db_path - # Cleanup - try: - if os.path.exists(db_path): - os.remove(db_path) - os.rmdir(temp_dir) - except (OSError, PermissionError): - pass # Ignore cleanup errors (e.g., file locked on Windows) - - @pytest.fixture - def memory_manager_obj(self): - """Create a MemoryMonitorManager object for testing""" - return MemoryMonitorManager( - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - items=[ - MemoryMonitorItem( - item_id="custom-id-123", - memory_text="Full test memory", - tree_memory_item=None, - tree_memory_item_mapping_key="full_test_key", - keywords_score=0.8, - sorting_score=0.9, - importance_score=0.7, - recording_count=3, - ) - ], - ) - - @pytest.fixture - def query_queue_obj(self): - """Create a QueryMonitorQueue object for testing""" - queue = QueryMonitorQueue() - queue.put( - QueryMonitorItem( - item_id="query1", - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - query_text="How are you?", - timestamp=datetime.now(), - keywords=["how", "you"], - ) - ) - return queue - - @pytest.fixture - def query_monitor_manager(self, temp_db, query_queue_obj): - """Create DBManagerForQueryMonitorQueue instance with temp DB.""" - engine = BaseDBManager.create_engine_from_db_path(temp_db) - manager = DBManagerForQueryMonitorQueue( - engine=engine, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=query_queue_obj, - lock_timeout=10, - ) - - assert manager.engine is not None - assert manager.SessionLocal is not None - assert os.path.exists(temp_db) - - yield manager - manager.close() - - @pytest.fixture - def memory_monitor_manager(self, temp_db, memory_manager_obj): - """Create DBManagerForMemoryMonitorManager instance with temp DB.""" - engine = BaseDBManager.create_engine_from_db_path(temp_db) - manager = DBManagerForMemoryMonitorManager( - engine=engine, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=memory_manager_obj, - lock_timeout=10, - ) - - assert manager.engine is not None - assert manager.SessionLocal is not None - assert os.path.exists(temp_db) - - yield manager - manager.close() - - def test_save_and_load_query_queue(self, query_monitor_manager, query_queue_obj): - """Test saving and loading QueryMonitorQueue.""" - # Save to database - query_monitor_manager.save_to_db(query_queue_obj) - - # Load in a new manager - engine = BaseDBManager.create_engine_from_db_path(query_monitor_manager.engine.url.database) - new_manager = DBManagerForQueryMonitorQueue( - engine=engine, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=None, - lock_timeout=10, - ) - loaded_queue = new_manager.load_from_db(acquire_lock=True) - - assert loaded_queue is not None - items = loaded_queue.get_queue_content_without_pop() - assert len(items) == 1 - assert items[0].item_id == "query1" - assert items[0].query_text == "How are you?" - new_manager.close() - - def test_lock_mechanism(self, query_monitor_manager, query_queue_obj): - """Test lock acquisition and release.""" - # Save current state - query_monitor_manager.save_to_db(query_queue_obj) - - # Acquire lock - acquired = query_monitor_manager.acquire_lock(block=True) - assert acquired - - # Try to acquire again (should fail without blocking) - assert not query_monitor_manager.acquire_lock(block=False) - - # Release lock - query_monitor_manager.release_locks( - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - ) - - # Should be able to acquire again - assert query_monitor_manager.acquire_lock(block=False) - - def test_lock_timeout(self, query_monitor_manager, query_queue_obj): - """Test lock timeout mechanism.""" - # Save current state - query_monitor_manager.save_to_db(query_queue_obj) - - query_monitor_manager.lock_timeout = 1 - - # Acquire lock - assert query_monitor_manager.acquire_lock(block=True) - - # Wait for lock to expire - time.sleep(1.1) - - # Should be able to acquire again - assert query_monitor_manager.acquire_lock(block=False) - - def test_sync_with_orm(self, query_monitor_manager, query_queue_obj): - """Test synchronization between ORM and object.""" - query_queue_obj.put( - QueryMonitorItem( - item_id="query2", - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - query_text="What's your name?", - timestamp=datetime.now(), - keywords=["name"], - ) - ) - - # Save current state - query_monitor_manager.save_to_db(query_queue_obj) - - # Create sync manager with empty queue - empty_queue = QueryMonitorQueue(maxsize=10) - engine = BaseDBManager.create_engine_from_db_path(query_monitor_manager.engine.url.database) - sync_manager = DBManagerForQueryMonitorQueue( - engine=engine, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=empty_queue, - lock_timeout=10, - ) - - # First sync - should create a new record with empty queue - sync_manager.sync_with_orm(size_limit=None) - items = sync_manager.obj.get_queue_content_without_pop() - assert len(items) == 0 # Empty queue since no existing data to merge - - # Now save the empty queue to create a record - sync_manager.save_to_db(empty_queue) - - # Test that sync_with_orm correctly handles version control - # The sync should increment version but not merge data when versions are the same - sync_manager.sync_with_orm(size_limit=None) - items = sync_manager.obj.get_queue_content_without_pop() - assert len(items) == 0 # Should remain empty since no merge occurred - - # Verify that the version was incremented - assert sync_manager.last_version_control == "3" # Should increment from 2 to 3 - - sync_manager.close() - - def test_sync_with_size_limit(self, query_monitor_manager, query_queue_obj): - """Test synchronization with size limit.""" - now = datetime.now() - item_size = 1 - for i in range(2, 6): - item_size += 1 - query_queue_obj.put( - QueryMonitorItem( - item_id=f"query{i}", - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - query_text=f"Question {i}", - timestamp=now + timedelta(minutes=i), - keywords=[f"kw{i}"], - ) - ) - - # First sync - should create a new record (size_limit not applied for new records) - size_limit = 3 - query_monitor_manager.sync_with_orm(size_limit=size_limit) - items = query_monitor_manager.obj.get_queue_content_without_pop() - assert len(items) == item_size # All items since size_limit not applied for new records - - # Save to create the record - query_monitor_manager.save_to_db(query_monitor_manager.obj) - - # Test that sync_with_orm correctly handles version control - # The sync should increment version but not merge data when versions are the same - query_monitor_manager.sync_with_orm(size_limit=size_limit) - items = query_monitor_manager.obj.get_queue_content_without_pop() - assert len(items) == item_size # Should remain the same since no merge occurred - - # Verify that the version was incremented - assert query_monitor_manager.last_version_control == "2" - - def test_concurrent_access(self, temp_db, query_queue_obj): - """Test concurrent access to the same database.""" - - # Manager 1 - engine1 = BaseDBManager.create_engine_from_db_path(temp_db) - manager1 = DBManagerForQueryMonitorQueue( - engine=engine1, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=query_queue_obj, - lock_timeout=10, - ) - manager1.save_to_db(query_queue_obj) - - # Manager 2 - engine2 = BaseDBManager.create_engine_from_db_path(temp_db) - manager2 = DBManagerForQueryMonitorQueue( - engine=engine2, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=query_queue_obj, - lock_timeout=10, - ) - - # Manager1 acquires lock - assert manager1.acquire_lock(block=True) - - # Manager2 fails to acquire - assert not manager2.acquire_lock(block=False) - - # Manager1 releases - manager1.release_locks(user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID) - - # Manager2 can now acquire - assert manager2.acquire_lock(block=False) - - manager1.close() - manager2.close() - - -class TestRedisDBManager: - """Test class for RedisDBManager functionality""" - - @pytest.fixture - def memory_manager_obj(self): - """Create a MemoryMonitorManager object for testing""" - return MemoryMonitorManager( - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - memories=[ - MemoryMonitorItem( - item_id="redis-test-123", - memory_text="Redis test memory", - tree_memory_item=None, - tree_memory_item_mapping_key="redis_test_key", - keywords_score=0.8, - sorting_score=0.9, - importance_score=0.7, - recording_count=3, - ) - ], - ) - - @pytest.fixture - def mock_redis_client(self): - """Create a mock Redis client for testing""" - try: - from unittest.mock import MagicMock - - # Create a mock Redis client - mock_client = MagicMock() - - # Mock Redis data storage - mock_data = {} - - def mock_set(key, value, nx=False, ex=None, **kwargs): - if nx and key in mock_data: - # NX means "only set if not exists" - return False # Redis returns False when NX fails - mock_data[key] = value - return True - - def mock_get(key): - return mock_data.get(key) - - def mock_hset(key, mapping=None, **kwargs): - if key not in mock_data: - mock_data[key] = {} - if mapping: - mock_data[key].update(mapping) - if kwargs: - mock_data[key].update(kwargs) - return len(mapping) if mapping else len(kwargs) - - def mock_hgetall(key): - return mock_data.get(key, {}) - - def mock_delete(*keys): - deleted = 0 - for key in keys: - if key in mock_data: - del mock_data[key] - deleted += 1 - return deleted - - def mock_keys(pattern): - import fnmatch - - return [key for key in mock_data if fnmatch.fnmatch(key, pattern)] - - def mock_ping(): - return True - - def mock_close(): - pass - - # Configure mock methods - mock_client.set = mock_set - mock_client.get = mock_get - mock_client.hset = mock_hset - mock_client.hgetall = mock_hgetall - mock_client.delete = mock_delete - mock_client.keys = mock_keys - mock_client.ping = mock_ping - mock_client.close = mock_close - - return mock_client - - except ImportError: - pytest.skip("Redis package not available for testing") - - @pytest.fixture - def redis_manager(self, mock_redis_client, memory_manager_obj): - """Create RedisDBManager instance with mock Redis client""" - manager = RedisDBManager( - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=memory_manager_obj, - lock_timeout=10, - redis_client=mock_redis_client, - ) - yield manager - manager.close() - - def test_redis_manager_initialization(self, mock_redis_client): - """Test RedisDBManager initialization""" - manager = RedisDBManager( - user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID, redis_client=mock_redis_client - ) - - assert manager.user_id == TEST_USER_ID - assert manager.mem_cube_id == TEST_MEM_CUBE_ID - assert manager.redis_client is mock_redis_client - assert manager.orm_class.__name__ == "RedisLockableORM" - assert manager.obj_class == MemoryMonitorManager - - manager.close() - - def test_redis_lockable_orm_save_load(self, mock_redis_client): - """Test RedisLockableORM save and load operations""" - from memos.mem_scheduler.orm_modules.redis_model import RedisLockableORM - - orm = RedisLockableORM( - redis_client=mock_redis_client, user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID - ) - - # Test save - orm.serialized_data = '{"test": "data"}' - orm.version_control = "1" - orm.lock_acquired = True - orm.lock_expiry = datetime.now() - - orm.save() - - # Test load - new_orm = RedisLockableORM( - redis_client=mock_redis_client, user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID - ) - - exists = new_orm.load() - assert exists - assert new_orm.serialized_data == '{"test": "data"}' - assert new_orm.version_control == "1" - # Note: lock_acquired is False after load by design - locks are managed separately - assert not new_orm.lock_acquired diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 369b4a6f1..03a8e4318 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -204,7 +204,6 @@ def test_scheduler_startup_mode_thread(self): def test_redis_message_queue(self): """Test Redis message queue functionality for sending and receiving messages.""" - import asyncio import time from unittest.mock import MagicMock, patch @@ -244,7 +243,7 @@ def redis_handler(messages: list[ScheduleMessageItem]) -> None: ) # Submit message to Redis queue - asyncio.run(self.scheduler.submit_messages(redis_message)) + self.scheduler.submit_messages(redis_message) # Verify Redis xadd was called mock_redis.xadd.assert_called_once() @@ -529,55 +528,3 @@ def test_get_running_tasks_multiple_tasks(self): # Verify dispatcher method was called mock_get_running_tasks.assert_called_once_with(filter_func=None) - - def test_message_handler_receives_submitted_message(self): - """Test that handlers receive messages after scheduler startup and message submission.""" - # Create a mock handler that tracks received messages - received_messages = [] - - def mock_handler(messages: list[ScheduleMessageItem]) -> None: - """Mock handler that records received messages.""" - received_messages.extend(messages) - - # Register the mock handler - test_label = "test_handler" - handlers = {test_label: mock_handler} - self.scheduler.register_handlers(handlers) - - # Verify handler is registered - self.assertIn(test_label, self.scheduler.handlers) - self.assertEqual(self.scheduler.handlers[test_label], mock_handler) - - # Start the scheduler - self.scheduler.start() - - # Create and submit a test message - test_message = ScheduleMessageItem( - label=test_label, - content="Test message content", - user_id="test_user", - mem_cube_id="test_mem_cube", - mem_cube="test_mem_cube_obj", # Required field - can be string or GeneralMemCube - timestamp=datetime.now(), - ) - - import asyncio - - asyncio.run(self.scheduler.submit_messages(test_message)) - - # Wait for message processing to complete - import time - - time.sleep(2.0) # Allow sufficient time for message processing - - # Verify the handler received the message - self.assertEqual( - len(received_messages), 1, f"Expected 1 message, got {len(received_messages)}" - ) - self.assertEqual(received_messages[0].label, test_label) - self.assertEqual(received_messages[0].content, "Test message content") - self.assertEqual(received_messages[0].user_id, "test_user") - self.assertEqual(received_messages[0].mem_cube_id, "test_mem_cube") - - # Stop the scheduler - self.scheduler.stop() From 5ff29d117ad85be915193622aa5466c283b9872d Mon Sep 17 00:00:00 2001 From: Hao <42795704+Nyakult@users.noreply.github.com> Date: Tue, 28 Oct 2025 18:56:40 +0800 Subject: [PATCH 13/64] memos online api eval scripts and readme (#403) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: check nodes existence * feat: use different template for different language input * feat: use different template for different language input * fix: eval script * feat: memos-api eval scripts * feat: mem reader * feat: 实现äºprefeval memos-api evaluation scripts * refactor:format code * feat: add PersonaMem eval scripts * docs(evaluation): update PersonaMem eval readme * feat:memos-api ingest batch message * feat: refactor search * feat: refactor search * update: add api for memory * feat: add memory api return memory and memory type * refactor(server):重构服务器路由模块以优化内存管理 * format: ruff format code * feat(server): 增加LLM最大令牌数 * test * fix: user query embedding for search * count memory_size by user * fix(server):修复记忆读取逻辑中的列表展开问题 * feat(nebular):优化图数据库查询性能 * refactor(memory): - 移除了对 `_refresh_memory_size` 方法的调用- 保留原有逻辑以便后续恢复或重构 * feat: remove user idx_memory_user_name * feat(graph):优化Nebula图数据库查询性能 * feat: rollback remove_oldest_memory * feat:nebula gql add index * feat: align code * feat: update memos_api * feat: update memos_api * feat: 更新默认选项 * feat:memory client * feat:refactor lme * feat: memu & supermemory client * feat: locomo memu * feat: locomo supermemory * New 'add' and 'process' modes. * feat: lme supermemory & memu * feat: default args * api and local * api and local * memobase fix * memos fix * default args * fix memos-api search data * prefeval pipeline * fix lme memos-api * personamem pipeline * personamem pipeline * lme scrips * align dev * format code * refactor: remove old files * format code * pm and prefeval pipeline * format code * format code * pm and prefeval pipeline * pm and prefeval pipeline * pm and prefeval pipeline * format code * format code * pref pipeline * add search response mode * add search response mode * update readme and example * update mem0 api * pm mem0 * fix MEMOBASE api * update pm and prefeval pipepline for frames * update pm and prefeval readme * format code * fix memobase api * fix memobase api * format code * format code * fix format * fix format * fix format * mem0 api * memos batch add * add memos-api-online * add memos-api-online update readme * rollback manager * memos online api pref mem --------- Co-authored-by: 2Rant Co-authored-by: fridayL Co-authored-by: CaralHsi --- evaluation/README.md | 23 ++++- evaluation/scripts/locomo/locomo_eval.py | 10 +- evaluation/scripts/locomo/locomo_ingestion.py | 41 +++++--- evaluation/scripts/locomo/locomo_metric.py | 10 +- evaluation/scripts/locomo/locomo_responses.py | 10 +- evaluation/scripts/locomo/locomo_search.py | 16 ++- evaluation/scripts/longmemeval/lme_eval.py | 12 ++- .../scripts/longmemeval/lme_ingestion.py | 22 ++++- evaluation/scripts/longmemeval/lme_metric.py | 10 +- .../scripts/longmemeval/lme_responses.py | 10 +- evaluation/scripts/longmemeval/lme_search.py | 15 ++- evaluation/scripts/utils/client.py | 98 +++++++++---------- 12 files changed, 198 insertions(+), 79 deletions(-) diff --git a/evaluation/README.md b/evaluation/README.md index f0bd166e1..47cfeedc0 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -22,17 +22,32 @@ This repository provides tools and scripts for evaluating the LoCoMo dataset usi 2. Copy the `configs-example/` directory to a new directory named `configs/`, and modify the configuration files inside it as needed. This directory contains model and API-specific settings. ## Setup MemOS +### local server ```bash -#start server +# modify {project_dir}/.env file and start server uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8001 --workers 8 -# modify .env file +# configure {project_dir}/evaluation/.env file MEMOS_URL="http://127.0.0.1:8001" ``` +### online service +```bash +# get your api key at https://memos-dashboard.openmem.net/cn/quickstart/ +# configure {project_dir}/evaluation/.env file +MEMOS_KEY="Token mpg-xxxxx" +MEMOS_ONLINE_URL="https://memos.memtensor.cn/api/openmem/v1" + +``` + +## Supported frameworks +We support `memos-api` and `memos-api-online` in our scripts. +And give unofficial implementations for the following memory frameworks:`zep`, `mem0`, `memobase`, `supermemory`, `memu`. + + ## Evaluation Scripts ### LoCoMo Evaluation -⚙️ To evaluate the **LoCoMo** dataset using one of the supported memory frameworks — `memos`, `mem0`, or `zep` — run the following [script](./scripts/run_locomo_eval.sh): +⚙️ To evaluate the **LoCoMo** dataset using one of the supported memory frameworks — run the following [script](./scripts/run_locomo_eval.sh): ```bash # Edit the configuration in ./scripts/run_locomo_eval.sh @@ -53,7 +68,7 @@ First prepare the dataset `longmemeval_s` from https://huggingface.co/datasets/x ``` ### PrefEval Evaluation -To evaluate the **Prefeval** dataset using one of the supported memory frameworks — `memos`, `mem0`, or `zep` — run the following [script](./scripts/run_prefeval_eval.sh): +To evaluate the **Prefeval** dataset using one of the supported memory frameworks — run the following [script](./scripts/run_prefeval_eval.sh): ```bash # Edit the configuration in ./scripts/run_prefeval_eval.sh diff --git a/evaluation/scripts/locomo/locomo_eval.py b/evaluation/scripts/locomo/locomo_eval.py index f142fe130..b431e7768 100644 --- a/evaluation/scripts/locomo/locomo_eval.py +++ b/evaluation/scripts/locomo/locomo_eval.py @@ -363,7 +363,15 @@ async def limited_task(task): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "openai", "memos-api", "memobase"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/locomo/locomo_ingestion.py b/evaluation/scripts/locomo/locomo_ingestion.py index fe7aa86f7..518d90c4c 100644 --- a/evaluation/scripts/locomo/locomo_ingestion.py +++ b/evaluation/scripts/locomo/locomo_ingestion.py @@ -44,26 +44,33 @@ def ingest_session(client, session, frame, version, metadata): speaker_a_messages.append({"role": "assistant", "content": data}) speaker_b_messages.append({"role": "user", "content": data}) - if frame == "memos-api": + if "memos-api" in frame: for m in speaker_a_messages: m["chat_time"] = iso_date for m in speaker_b_messages: m["chat_time"] = iso_date - client.add(speaker_a_messages, speaker_a_user_id, f"{conv_id}_{metadata['session_key']}") - client.add(speaker_b_messages, speaker_b_user_id, f"{conv_id}_{metadata['session_key']}") + client.add( + speaker_a_messages, + speaker_a_user_id, + f"{conv_id}_{metadata['session_key']}", + batch_size=2, + ) + client.add( + speaker_b_messages, + speaker_b_user_id, + f"{conv_id}_{metadata['session_key']}", + batch_size=2, + ) elif "mem0" in frame: - for i in range(0, len(speaker_a_messages), 2): - batch_messages_a = speaker_a_messages[i : i + 2] - batch_messages_b = speaker_b_messages[i : i + 2] - client.add(batch_messages_a, speaker_a_user_id, timestamp) - client.add(batch_messages_b, speaker_b_user_id, timestamp) + client.add(speaker_a_messages, speaker_a_user_id, timestamp, batch_size=2) + client.add(speaker_b_messages, speaker_b_user_id, timestamp, batch_size=2) elif frame == "memobase": for m in speaker_a_messages: m["created_at"] = iso_date for m in speaker_b_messages: m["created_at"] = iso_date - client.add(speaker_a_messages, speaker_a_user_id) - client.add(speaker_b_messages, speaker_b_user_id) + client.add(speaker_a_messages, speaker_a_user_id, batch_size=2) + client.add(speaker_b_messages, speaker_b_user_id, batch_size=2) elif frame == "memu": client.add(speaker_a_messages, speaker_a_user_id, iso_date) client.add(speaker_b_messages, speaker_b_user_id, iso_date) @@ -103,6 +110,10 @@ def process_user(conv_idx, frame, locomo_df, version): from utils.client import MemosApiClient client = MemosApiClient() + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() elif frame == "memobase": from utils.client import MemobaseClient @@ -187,7 +198,15 @@ def main(frame, version="default", num_workers=4): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/locomo/locomo_metric.py b/evaluation/scripts/locomo/locomo_metric.py index 6ddcdf127..e63888d45 100644 --- a/evaluation/scripts/locomo/locomo_metric.py +++ b/evaluation/scripts/locomo/locomo_metric.py @@ -9,7 +9,15 @@ parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "openai", "memos-api", "memobase"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index 35a444b7d..6c082b31d 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -134,7 +134,15 @@ async def main(frame, version="default"): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "openai", "memos-api", "memobase"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index c629124dd..1ddf0d933 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -198,7 +198,7 @@ def search_query(client, query, metadata, frame, version, top_k=20): context, duration_ms = mem0_graph_search( client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b ) - elif frame == "memos-api": + elif "memos-api" in frame: context, duration_ms = memos_api_search( client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b ) @@ -257,6 +257,10 @@ def process_user(conv_idx, locomo_df, frame, version, top_k=20, num_workers=1): from utils.client import MemosApiClient client = MemosApiClient() + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() elif frame == "memobase": from utils.client import MemobaseClient @@ -336,7 +340,15 @@ def main(frame, version="default", num_workers=1, top_k=20): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/longmemeval/lme_eval.py b/evaluation/scripts/longmemeval/lme_eval.py index 73117b925..20681ac2c 100644 --- a/evaluation/scripts/longmemeval/lme_eval.py +++ b/evaluation/scripts/longmemeval/lme_eval.py @@ -344,7 +344,15 @@ async def main(frame, version, nlp_options, num_runs=3, num_workers=5): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( @@ -355,7 +363,7 @@ async def main(frame, version, nlp_options, num_runs=3, num_workers=5): type=str, nargs="+", default=["lexical"], - choices=["lexical", "semantic"], + choices=["lexical"], help="NLP options to use for evaluation.", ) parser.add_argument( diff --git a/evaluation/scripts/longmemeval/lme_ingestion.py b/evaluation/scripts/longmemeval/lme_ingestion.py index 325178292..e846a254c 100644 --- a/evaluation/scripts/longmemeval/lme_ingestion.py +++ b/evaluation/scripts/longmemeval/lme_ingestion.py @@ -18,7 +18,7 @@ def ingest_session(session, date, user_id, session_id, frame, client): if "mem0" in frame: for _idx, msg in enumerate(session): messages.append({"role": msg["role"], "content": msg["content"][:8000]}) - client.add(messages, user_id, int(date.timestamp())) + client.add(messages, user_id, int(date.timestamp()), batch_size=2) elif frame == "memobase": for _idx, msg in enumerate(session): messages.append( @@ -28,8 +28,8 @@ def ingest_session(session, date, user_id, session_id, frame, client): "created_at": date.isoformat(), } ) - client.add(messages, user_id) - elif frame == "memos-api": + client.add(messages, user_id, batch_size=2) + elif "memos-api" in frame: for msg in session: messages.append( { @@ -39,7 +39,7 @@ def ingest_session(session, date, user_id, session_id, frame, client): } ) if messages: - client.add(messages=messages, user_id=user_id, conv_id=session_id) + client.add(messages=messages, user_id=user_id, conv_id=session_id, batch_size=2) elif frame == "memu": for _idx, msg in enumerate(session): messages.append({"role": msg["role"], "content": msg["content"][:8000]}) @@ -80,6 +80,10 @@ def ingest_conv(lme_df, version, conv_idx, frame, success_records, f): from utils.client import MemosApiClient client = MemosApiClient() + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() elif frame == "memobase": from utils.client import MemobaseClient @@ -167,7 +171,15 @@ def main(frame, version, num_workers=2): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/longmemeval/lme_metric.py b/evaluation/scripts/longmemeval/lme_metric.py index 93fa1de21..3664b47ba 100644 --- a/evaluation/scripts/longmemeval/lme_metric.py +++ b/evaluation/scripts/longmemeval/lme_metric.py @@ -258,7 +258,15 @@ def calculate_scores(data, grade_path, output_path): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/longmemeval/lme_responses.py b/evaluation/scripts/longmemeval/lme_responses.py index a4adf90b5..7d82358d6 100644 --- a/evaluation/scripts/longmemeval/lme_responses.py +++ b/evaluation/scripts/longmemeval/lme_responses.py @@ -132,7 +132,15 @@ def main(frame, version, num_workers=4): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index c02518083..60b2146f6 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -123,6 +123,11 @@ def process_user(lme_df, conv_idx, frame, version, top_k=20): client = MemosApiClient() context, duration_ms = memos_search(client, question, user_id, top_k) + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() + context, duration_ms = memos_search(client, question, user_id, top_k) elif frame == "memu": from utils.client import MemuClient @@ -218,7 +223,15 @@ def main(frame, version, top_k=20, num_workers=2): parser.add_argument( "--lib", type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + choices=[ + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], default="memos-api", ) parser.add_argument( diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 91d695acc..4117cba56 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -82,21 +82,13 @@ def add(self, messages, user_id, timestamp, batch_size=2): raise e def search(self, query, user_id, top_k): - if self.enable_graph: - res = self.client.search( - query=query, - top_k=top_k, - user_id=user_id, - enable_graph=True, - filters={"AND": [{"user_id": f"{user_id}"}]}, - ) - else: - res = self.client.search( - query=query, - top_k=top_k, - user_id=user_id, - filters={"AND": [{"user_id": f"{user_id}"}]}, - ) + res = self.client.search( + query=query, + top_k=top_k, + user_id=user_id, + enable_graph=self.enable_graph, + filters={"AND": [{"user_id": f"{user_id}"}]}, + ) return res @@ -155,23 +147,29 @@ def __init__(self): self.memos_url = os.getenv("MEMOS_URL") self.headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")} - def add(self, messages, user_id, conv_id): + def add(self, messages, user_id, conv_id, batch_size: int = 9999): """ messages = [{"role": "assistant", "content": data, "chat_time": date_str}] """ url = f"{self.memos_url}/product/add" - payload = json.dumps( - { - "messages": messages, - "user_id": user_id, - "mem_cube_id": user_id, - "conversation_id": conv_id, - } - ) - response = requests.request("POST", url, data=payload, headers=self.headers) - assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "Memory added successfully", response.text - return response.text + added_memories = [] + for i in range(0, len(messages), batch_size): + batch_messages = messages[i : i + batch_size] + payload = json.dumps( + { + "messages": batch_messages, + "user_id": user_id, + "mem_cube_id": user_id, + "conversation_id": conv_id, + } + ) + response = requests.request("POST", url, data=payload, headers=self.headers) + assert response.status_code == 200, response.text + assert json.loads(response.text)["message"] == "Memory added successfully", ( + response.text + ) + added_memories += json.loads(response.text)["data"] + return added_memories def search(self, query, user_id, top_k): """Search memories.""" @@ -200,28 +198,30 @@ def __init__(self): self.memos_url = os.getenv("MEMOS_ONLINE_URL") self.headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")} - def add(self, messages, user_id, conv_id=None): + def add(self, messages, user_id, conv_id=None, batch_size: int = 9999): url = f"{self.memos_url}/add/message" - payload = json.dumps( - { - "messages": messages, - "user_id": user_id, - "conversation_id": conv_id, - } - ) + for i in range(0, len(messages), batch_size): + batch_messages = messages[i : i + batch_size] + payload = json.dumps( + { + "messages": batch_messages, + "user_id": user_id, + "conversation_id": conv_id, + } + ) - max_retries = 5 - for attempt in range(max_retries): - try: - response = requests.request("POST", url, data=payload, headers=self.headers) - assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "ok", response.text - return response.text - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e + max_retries = 5 + for attempt in range(max_retries): + try: + response = requests.request("POST", url, data=payload, headers=self.headers) + assert response.status_code == 200, response.text + assert json.loads(response.text)["message"] == "ok", response.text + break + except Exception as e: + if attempt < max_retries - 1: + time.sleep(2**attempt) + else: + raise e def search(self, query, user_id, top_k): """Search memories.""" @@ -244,7 +244,7 @@ def search(self, query, user_id, top_k): res = json.loads(response.text)["data"]["memory_detail_list"] for i in res: i.update({"memory": i.pop("memory_value")}) - return {"text_mem": [{"memories": res}]} + return {"text_mem": [{"memories": res}], "pref_mem": ""} except Exception as e: if attempt < max_retries - 1: time.sleep(2**attempt) From 1f6757d2c26c24a63f0ebdded14aeff0bf3fd47b Mon Sep 17 00:00:00 2001 From: lijicode <34564964+lijicode@users.noreply.github.com> Date: Tue, 28 Oct 2025 20:44:06 +0800 Subject: [PATCH 14/64] feat: fix sources (#404) * feat: fix sources * feat: fix sources --------- Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/memories/textual/item.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index f6254efbb..48f42fa4c 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -1,6 +1,7 @@ """Defines memory item types for textual memory.""" import json +import logging import uuid from datetime import datetime @@ -123,6 +124,25 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata): def coerce_sources(cls, v): if v is None: return v + # Handle string representation of sources (e.g., from PostgreSQL array or malformed data) + if isinstance(v, str): + logging.info(f"[coerce_sources] v: {v} type: {type(v)}") + # If it's a string that looks like a list representation, try to parse it + # This handles cases like: "[uuid1, uuid2, uuid3]" or "[item1, item2]" + v_stripped = v.strip() + if v_stripped.startswith("[") and v_stripped.endswith("]"): + # Remove brackets and split by comma + content = v_stripped[1:-1].strip() + if content: + # Split by comma and clean up each item + items = [item.strip() for item in content.split(",")] + # Convert to list of strings + v = items + else: + v = [] + else: + # Single string, wrap in list + v = [v] if not isinstance(v, list): raise TypeError("sources must be a list") out = [] From e21f5bb8e176acb8d741f3ea30fab45575066ad2 Mon Sep 17 00:00:00 2001 From: lijicode <34564964+lijicode@users.noreply.github.com> Date: Wed, 29 Oct 2025 12:47:01 +0800 Subject: [PATCH 15/64] fix porlar (#406) * feat: fix sources * feat: fix sources * feat: fix nebular * feat: fix polardb edges * feat: format polardb --------- Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/graph_dbs/nebular.py | 2 +- src/memos/graph_dbs/polardb.py | 92 +++++++++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 12b493e58..00bd04e6d 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -1551,7 +1551,7 @@ def _ensure_database_exists(self): """ self.execute_query(create_tag, auto_set_db=False) else: - describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name};" + describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}" desc_result = self.execute_query(describe_query, auto_set_db=False) memory_fields = [] diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 38e71298f..b9bc2c8e5 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1776,24 +1776,61 @@ def export_graph( for row in edge_results: source_agtype, target_agtype, edge_agtype = row + + # Extract and clean source + source_raw = ( + source_agtype.value + if hasattr(source_agtype, "value") + else str(source_agtype) + ) + if ( + isinstance(source_raw, str) + and source_raw.startswith('"') + and source_raw.endswith('"') + ): + source = source_raw[1:-1] + else: + source = str(source_raw) + + # Extract and clean target + target_raw = ( + target_agtype.value + if hasattr(target_agtype, "value") + else str(target_agtype) + ) + if ( + isinstance(target_raw, str) + and target_raw.startswith('"') + and target_raw.endswith('"') + ): + target = target_raw[1:-1] + else: + target = str(target_raw) + + # Extract and clean edge type + type_raw = ( + edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype) + ) + if ( + isinstance(type_raw, str) + and type_raw.startswith('"') + and type_raw.endswith('"') + ): + edge_type = type_raw[1:-1] + else: + edge_type = str(type_raw) + edges.append( { - "source": source_agtype.value - if hasattr(source_agtype, "value") - else str(source_agtype), - "target": target_agtype.value - if hasattr(target_agtype, "value") - else str(target_agtype), - "type": edge_agtype.value - if hasattr(edge_agtype, "value") - else str(edge_agtype), + "source": source, + "target": target, + "type": edge_type, } ) except Exception as e: logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e - return {"nodes": nodes, "edges": edges} @timed @@ -2765,9 +2802,38 @@ def get_edges( edges = [] for row in results: - from_id = row[0].value if hasattr(row[0], "value") else row[0] - to_id = row[1].value if hasattr(row[1], "value") else row[1] - edge_type = row[2].value if hasattr(row[2], "value") else row[2] + # Extract and clean from_id + from_id_raw = row[0].value if hasattr(row[0], "value") else row[0] + if ( + isinstance(from_id_raw, str) + and from_id_raw.startswith('"') + and from_id_raw.endswith('"') + ): + from_id = from_id_raw[1:-1] + else: + from_id = str(from_id_raw) + + # Extract and clean to_id + to_id_raw = row[1].value if hasattr(row[1], "value") else row[1] + if ( + isinstance(to_id_raw, str) + and to_id_raw.startswith('"') + and to_id_raw.endswith('"') + ): + to_id = to_id_raw[1:-1] + else: + to_id = str(to_id_raw) + + # Extract and clean edge_type + edge_type_raw = row[2].value if hasattr(row[2], "value") else row[2] + if ( + isinstance(edge_type_raw, str) + and edge_type_raw.startswith('"') + and edge_type_raw.endswith('"') + ): + edge_type = edge_type_raw[1:-1] + else: + edge_type = str(edge_type_raw) edges.append({"from": from_id, "to": to_id, "type": edge_type}) return edges From d79647ee493ae0f99980515b37e1c2400a01c0a1 Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Wed, 29 Oct 2025 15:02:27 +0800 Subject: [PATCH 16/64] Feat/arms (#402) * feat: update log context * feat: update log context * feat: update mcp * feat: update mcp * feat: add error log * feat: add error log * feat: add error log * feat: update log * feat: add chat_time * feat: add chat_time * feat: add chat_time * feat: update log * feat: update log * feat: update log * feat: update log * feat: update log * feat: add arms * fix: format * fix: format * feat: add dockerfile * feat: add dockerfile * feat: add arms config * feat: update log * feat: add sleep time * feat: add sleep time * feat: update log * feat: delete dockerfile * feat: delete dockerfile * feat: update dockerfile * fix: conflict * feat: replace ThreadPool to context * feat: add timed log --------- Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> Co-authored-by: harvey_xiang --- src/memos/api/exceptions.py | 27 ++++- src/memos/api/middleware/request_context.py | 41 ++++++-- src/memos/api/routers/server_router.py | 6 +- src/memos/api/server_api.py | 10 +- src/memos/context/context.py | 96 ++++++++++++++++-- src/memos/embedders/universal_api.py | 23 +++-- src/memos/graph_dbs/polardb.py | 98 +++++++++---------- src/memos/llms/openai.py | 3 + src/memos/log.py | 51 +++++++--- src/memos/mem_os/core.py | 6 +- src/memos/mem_scheduler/general_scheduler.py | 7 +- .../textual/prefer_text_memory/adder.py | 7 +- .../textual/prefer_text_memory/extractor.py | 5 +- .../textual/prefer_text_memory/retrievers.py | 4 +- src/memos/reranker/http_bge.py | 4 +- src/memos/reranker/http_bge_strategy.py | 2 + src/memos/utils.py | 28 ++++-- 17 files changed, 303 insertions(+), 115 deletions(-) diff --git a/src/memos/api/exceptions.py b/src/memos/api/exceptions.py index 2fd22ad52..10a14b4d1 100644 --- a/src/memos/api/exceptions.py +++ b/src/memos/api/exceptions.py @@ -1,5 +1,6 @@ import logging +from fastapi.exceptions import HTTPException, RequestValidationError from fastapi.requests import Request from fastapi.responses import JSONResponse @@ -10,9 +11,24 @@ class APIExceptionHandler: """Centralized exception handling for MemOS APIs.""" + @staticmethod + async def validation_error_handler(request: Request, exc: RequestValidationError): + """Handle request validation errors.""" + logger.error(f"Validation error: {exc.errors()}") + return JSONResponse( + status_code=422, + content={ + "code": 422, + "message": "Parameter validation error", + "detail": exc.errors(), + "data": None, + }, + ) + @staticmethod async def value_error_handler(request: Request, exc: ValueError): """Handle ValueError exceptions globally.""" + logger.error(f"ValueError: {exc}") return JSONResponse( status_code=400, content={"code": 400, "message": str(exc), "data": None}, @@ -21,8 +37,17 @@ async def value_error_handler(request: Request, exc: ValueError): @staticmethod async def global_exception_handler(request: Request, exc: Exception): """Handle all unhandled exceptions globally.""" - logger.exception("Unhandled error:") + logger.error(f"Exception: {exc}") return JSONResponse( status_code=500, content={"code": 500, "message": str(exc), "data": None}, ) + + @staticmethod + async def http_error_handler(request: Request, exc: HTTPException): + """Handle HTTP exceptions globally.""" + logger.error(f"HTTP error {exc.status_code}: {exc.detail}") + return JSONResponse( + status_code=exc.status_code, + content={"code": exc.status_code, "message": str(exc.detail), "data": None}, + ) diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index cb41428d4..2922ab3eb 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -2,6 +2,8 @@ Request context middleware for automatic trace_id injection. """ +import time + from collections.abc import Callable from starlette.middleware.base import BaseHTTPMiddleware @@ -38,8 +40,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: # Extract or generate trace_id trace_id = extract_trace_id_from_headers(request) or generate_trace_id() + env = request.headers.get("x-env") + user_type = request.headers.get("x-user-type") + user_name = request.headers.get("x-user-name") + start_time = time.time() + # Create and set request context - context = RequestContext(trace_id=trace_id, api_path=request.url.path) + context = RequestContext( + trace_id=trace_id, + api_path=request.url.path, + env=env, + user_type=user_type, + user_name=user_name, + ) set_request_context(context) # Log request start with parameters @@ -49,15 +62,25 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: if request.query_params: params_log["query_params"] = dict(request.query_params) - logger.info(f"Request started: {request.method} {request.url.path}, {params_log}") + logger.info(f"Request started, params: {params_log}, headers: {request.headers}") # Process the request - response = await call_next(request) - - # Log request completion with output - logger.info(f"Request completed: {request.url.path}, status: {response.status_code}") - - # Add trace_id to response headers for debugging - response.headers["x-trace-id"] = trace_id + try: + response = await call_next(request) + end_time = time.time() + if response.status_code == 200: + logger.info( + f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + else: + logger.error( + f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + except Exception as e: + end_time = time.time() + logger.error( + f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + raise e return response diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index f50d3ad75..3ba12c1ce 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,7 +1,6 @@ import os import traceback -from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any from fastapi import APIRouter, HTTPException @@ -22,6 +21,7 @@ from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory @@ -370,7 +370,7 @@ def _search_pref(): ) return [_format_memory_item(data) for data in results] - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_search_text) pref_future = executor.submit(_search_pref) text_formatted_memories = text_future.result() @@ -532,7 +532,7 @@ def _process_pref_mem() -> list[dict[str, str]]: for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) ] - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_process_text_mem) pref_future = executor.submit(_process_pref_mem) text_response_data = text_future.result() diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 78e05ef85..24c67de48 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -1,6 +1,7 @@ import logging -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException +from fastapi.exceptions import RequestValidationError from memos.api.exceptions import APIExceptionHandler from memos.api.middleware.request_context import RequestContextMiddleware @@ -21,8 +22,13 @@ # Include routers app.include_router(server_router) -# Exception handlers +# Request validation failed +app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler) +# Invalid business code parameters app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) +# Business layer manual exception +app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler) +# Fallback for unknown errors app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) diff --git a/src/memos/context/context.py b/src/memos/context/context.py index 4f54348fb..d6a0f3bf1 100644 --- a/src/memos/context/context.py +++ b/src/memos/context/context.py @@ -29,9 +29,19 @@ class RequestContext: This provides a Flask g-like object for FastAPI applications. """ - def __init__(self, trace_id: str | None = None, api_path: str | None = None): + def __init__( + self, + trace_id: str | None = None, + api_path: str | None = None, + env: str | None = None, + user_type: str | None = None, + user_name: str | None = None, + ): self.trace_id = trace_id or "trace-id" self.api_path = api_path + self.env = env + self.user_type = user_type + self.user_name = user_name self._data: dict[str, Any] = {} def set(self, key: str, value: Any) -> None: @@ -43,7 +53,13 @@ def get(self, key: str, default: Any | None = None) -> Any: return self._data.get(key, default) def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_") or name in ("trace_id", "api_path"): + if name.startswith("_") or name in ( + "trace_id", + "api_path", + "env", + "user_type", + "user_name", + ): super().__setattr__(name, value) else: if not hasattr(self, "_data"): @@ -58,7 +74,14 @@ def __getattr__(self, name: str) -> Any: def to_dict(self) -> dict[str, Any]: """Convert context to dictionary.""" - return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()} + return { + "trace_id": self.trace_id, + "api_path": self.api_path, + "env": self.env, + "user_type": self.user_type, + "user_name": self.user_name, + "data": self._data.copy(), + } def set_request_context(context: RequestContext) -> None: @@ -93,6 +116,36 @@ def get_current_api_path() -> str | None: return None +def get_current_env() -> str | None: + """ + Get the current request's env. + """ + context = _request_context.get() + if context: + return context.get("env") + return "prod" + + +def get_current_user_type() -> str | None: + """ + Get the current request's user type. + """ + context = _request_context.get() + if context: + return context.get("user_type") + return "opensource" + + +def get_current_user_name() -> str | None: + """ + Get the current request's user name. + """ + context = _request_context.get() + if context: + return context.get("user_name") + return "memos" + + def get_current_context() -> RequestContext | None: """ Get the current request context. @@ -103,7 +156,11 @@ def get_current_context() -> RequestContext | None: context_dict = _request_context.get() if context_dict: ctx = RequestContext( - trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path") + trace_id=context_dict.get("trace_id"), + api_path=context_dict.get("api_path"), + env=context_dict.get("env"), + user_type=context_dict.get("user_type"), + user_name=context_dict.get("user_name"), ) ctx._data = context_dict.get("data", {}).copy() return ctx @@ -141,6 +198,9 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs): self.main_trace_id = get_current_trace_id() self.main_api_path = get_current_api_path() + self.main_env = get_current_env() + self.main_user_type = get_current_user_type() + self.main_user_name = get_current_user_name() self.main_context = get_current_context() def run(self): @@ -148,7 +208,11 @@ def run(self): if self.main_context: # Copy the context data child_context = RequestContext( - trace_id=self.main_trace_id, api_path=self.main_context.api_path + trace_id=self.main_trace_id, + api_path=self.main_api_path, + env=self.main_env, + user_type=self.main_user_type, + user_name=self.main_user_name, ) child_context._data = self.main_context._data.copy() @@ -171,13 +235,22 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any: """ main_trace_id = get_current_trace_id() main_api_path = get_current_api_path() + main_env = get_current_env() + main_user_type = get_current_user_type() + main_user_name = get_current_user_name() main_context = get_current_context() @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: if main_context: # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path) + child_context = RequestContext( + trace_id=main_trace_id, + api_path=main_api_path, + env=main_env, + user_type=main_user_type, + user_name=main_user_name, + ) child_context._data = main_context._data.copy() set_request_context(child_context) @@ -198,13 +271,22 @@ def map( """ main_trace_id = get_current_trace_id() main_api_path = get_current_api_path() + main_env = get_current_env() + main_user_type = get_current_user_type() + main_user_name = get_current_user_name() main_context = get_current_context() @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: if main_context: # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path) + child_context = RequestContext( + trace_id=main_trace_id, + api_path=main_api_path, + env=main_env, + user_type=main_user_type, + user_name=main_user_name, + ) child_context._data = main_context._data.copy() set_request_context(child_context) diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 72116cf05..fc51cf073 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -3,6 +3,11 @@ from memos.configs.embedder import UniversalAPIEmbedderConfig from memos.embedders.base import BaseEmbedder +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) class UniversalAPIEmbedder(BaseEmbedder): @@ -19,14 +24,18 @@ def __init__(self, config: UniversalAPIEmbedderConfig): api_key=config.api_key, ) else: - raise ValueError(f"Unsupported provider: {self.provider}") + raise ValueError(f"Embeddings unsupported provider: {self.provider}") + @timed(log=True, log_prefix="EmbedderAPI") def embed(self, texts: list[str]) -> list[list[float]]: if self.provider == "openai" or self.provider == "azure": - response = self.client.embeddings.create( - model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"), - input=texts, - ) - return [r.embedding for r in response.data] + try: + response = self.client.embeddings.create( + model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"), + input=texts, + ) + return [r.embedding for r in response.data] + except Exception as e: + raise Exception(f"Embeddings request ended with error: {e}") from e else: - raise ValueError(f"Unsupported provider: {self.provider}") + raise ValueError(f"Embeddings unsupported provider: {self.provider}") diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index b9bc2c8e5..88aef6d33 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1,18 +1,18 @@ import json -import time import random + from datetime import datetime from typing import Any, Literal import numpy as np - from memos.configs.graph_db import PolarDBGraphDBConfig from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger from memos.utils import timed + logger = get_logger(__name__) # Graph database configuration @@ -200,31 +200,31 @@ def _create_graph(self): # Add embedding column if it doesn't exist (using JSONB for compatibility) try: cursor.execute(f""" - ALTER TABLE "{self.db_name}_graph"."Memory" + ALTER TABLE "{self.db_name}_graph"."Memory" ADD COLUMN IF NOT EXISTS embedding JSONB; """) - logger.info(f"Embedding column added to Memory table.") + logger.info("Embedding column added to Memory table.") except Exception as e: logger.warning(f"Failed to add embedding column: {e}") # Create indexes cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties + CREATE INDEX IF NOT EXISTS idx_memory_properties ON "{self.db_name}_graph"."Memory" USING GIN (properties); """) # Create vector index for embedding field try: cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding + CREATE INDEX IF NOT EXISTS idx_memory_embedding ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); """) - logger.info(f"Vector index created for Memory table.") + logger.info("Vector index created for Memory table.") except Exception as e: logger.warning(f"Vector index creation failed (might not be supported): {e}") - logger.info(f"Indexes created for Memory table.") + logger.info("Indexes created for Memory table.") except Exception as e: logger.error(f"Failed to create graph schema: {e}") @@ -246,20 +246,20 @@ def create_index( # Create indexes on the underlying PostgreSQL tables # Apache AGE stores data in regular PostgreSQL tables cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties + CREATE INDEX IF NOT EXISTS idx_memory_properties ON "{self.db_name}_graph"."Memory" USING GIN (properties); """) # Try to create vector index, but don't fail if it doesn't work try: cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding + CREATE INDEX IF NOT EXISTS idx_memory_embedding ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); """) except Exception as ve: logger.warning(f"Vector index creation failed (might not be supported): {ve}") - logger.debug(f"Indexes created successfully.") + logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") @@ -267,8 +267,8 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in """Get count of memory nodes by type.""" user_name = user_name if user_name else self._get_config_value("user_name") query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype """ query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" @@ -290,8 +290,8 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: """Check if a node with given scope exists.""" user_name = user_name if user_name else self._get_config_value("user_name") query = f""" - SELECT id - FROM "{self.db_name}_graph"."Memory" + SELECT id + FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype """ query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" @@ -327,15 +327,13 @@ def remove_oldest_memory( # Use actual OFFSET logic, consistent with nebular.py # First find IDs to delete, then delete them select_query = f""" - SELECT id FROM "{self.db_name}_graph"."Memory" + SELECT id FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype AND ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = %s::agtype - ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC + ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC OFFSET %s """ select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest] - print(f"[remove_oldest_memory] Select query: {select_query}") - print(f"[remove_oldest_memory] Select params: {select_params}") try: with self.connection.cursor() as cursor: @@ -403,14 +401,14 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N # Build update query if embedding_vector is not None: query = f""" - UPDATE "{self.db_name}_graph"."Memory" + UPDATE "{self.db_name}_graph"."Memory" SET properties = %s, embedding = %s WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ params = [json.dumps(properties), json.dumps(embedding_vector), f'"{id}"'] else: query = f""" - UPDATE "{self.db_name}_graph"."Memory" + UPDATE "{self.db_name}_graph"."Memory" SET properties = %s WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ @@ -438,7 +436,7 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: user_name (str, optional): User name for filtering in non-multi-db mode """ query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" + DELETE FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ params = [f'"{id}"'] @@ -462,7 +460,7 @@ def create_extension(self): try: with self.connection.cursor() as cursor: # Ensure in the correct database context - cursor.execute(f"SELECT current_database();") + cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] print(f"Current database context: {current_db}") @@ -487,7 +485,7 @@ def create_graph(self): try: with self.connection.cursor() as cursor: cursor.execute(f""" - SELECT COUNT(*) FROM ag_catalog.ag_graph + SELECT COUNT(*) FROM ag_catalog.ag_graph WHERE name = '{self.db_name}_graph'; """) graph_exists = cursor.fetchone()[0] > 0 @@ -664,11 +662,11 @@ def edge_exists( # Prepare the match pattern with direction if direction == "OUTGOING": - pattern = f"(a:Memory)-[r]->(b:Memory)" + pattern = "(a:Memory)-[r]->(b:Memory)" elif direction == "INCOMING": - pattern = f"(a:Memory)<-[r]-(b:Memory)" + pattern = "(a:Memory)<-[r]-(b:Memory)" elif direction == "ANY": - pattern = f"(a:Memory)-[r]-(b:Memory)" + pattern = "(a:Memory)-[r]-(b:Memory)" else: raise ValueError( f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." @@ -720,7 +718,7 @@ def format_param_value(value: str) -> str: query = f""" SELECT {select_fields} - FROM "{self.db_name}_graph"."Memory" + FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ params = [format_param_value(id)] @@ -806,7 +804,7 @@ def get_nodes( query = f""" SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" + FROM "{self.db_name}_graph"."Memory" WHERE ({where_clause}) """ @@ -893,15 +891,15 @@ def get_edges_old( # Create indexes cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_source + CREATE INDEX IF NOT EXISTS idx_edges_source ON "{self.db_name}_graph"."Edges" (source_id); """) cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_target + CREATE INDEX IF NOT EXISTS idx_edges_target ON "{self.db_name}_graph"."Edges" (target_id); """) cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_type + CREATE INDEX IF NOT EXISTS idx_edges_type ON "{self.db_name}_graph"."Edges" (edge_type); """) except Exception as e: @@ -998,7 +996,7 @@ def get_neighbors_by_tag_old( # Get all candidate nodes query = f""" SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" + FROM "{self.db_name}_graph"."Memory" WHERE {where_clause} """ @@ -1061,7 +1059,7 @@ def get_children_with_embeddings( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (p:Memory)-[r:PARENT]->(c:Memory) - WHERE p.id = '{id}' {where_user} + WHERE p.id = '{id}' {where_user} RETURN id(c) as cid, c.id AS id, c.memory AS memory $$) as (cid agtype, id agtype, memory agtype) ) @@ -1518,7 +1516,7 @@ def get_grouped_counts1( MATCH (n:Memory) {where_clause} RETURN {group_fields_cypher}, COUNT(n) AS count1 - $$ ) as ({group_fields_cypher_polardb}, count1 agtype); + $$ ) as ({group_fields_cypher_polardb}, count1 agtype); """ print("get_grouped_counts:" + query) try: @@ -1673,8 +1671,8 @@ def clear(self, user_name: str | None = None) -> None: try: query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE n.user_name = '{user_name}' + MATCH (n:Memory) + WHERE n.user_name = '{user_name}' DETACH DELETE n $$) AS (result agtype) """ @@ -1765,7 +1763,7 @@ def export_graph( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (a:Memory)-[r]->(b:Memory) WHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}' - RETURN a.id AS source, b.id AS target, type(r) as edge + RETURN a.id AS source, b.id AS target, type(r) as edge $$) AS (source agtype, target agtype, edge agtype) """ @@ -1840,7 +1838,7 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (n:Memory) - WHERE n.memory_type = '{scope}' + WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' RETURN count(n) $$) AS (count agtype) @@ -1879,8 +1877,8 @@ def get_all_memory_items( LIMIT 100 $$) AS (id1 agtype,n agtype) ) - SELECT - m.embedding, + SELECT + m.embedding, t.n FROM t, {self.db_name}_graph."Memory" m @@ -1976,8 +1974,8 @@ def get_all_memory_items_old( LIMIT 100 $$) AS (id1 agtype,n agtype) ) - SELECT - m.embedding, + SELECT + m.embedding, t.n FROM t, {self.db_name}_graph."Memory" m @@ -2144,8 +2142,8 @@ def get_structure_optimization_candidates( WITH t as ( {cypher_query} ) - SELECT - m.embedding, + SELECT + m.embedding, t.n FROM t, {self.db_name}_graph."Memory" m @@ -2358,7 +2356,7 @@ def add_node( with self.connection.cursor() as cursor: # Delete existing record first (if any) delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" + DELETE FROM {self.db_name}_graph."Memory" WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) """ cursor.execute(delete_query, (id,)) @@ -2493,7 +2491,7 @@ def get_neighbors_by_tag( # Fetch all candidate nodes query = f""" SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" + FROM "{self.db_name}_graph"."Memory" WHERE {where_clause} """ @@ -2769,13 +2767,13 @@ def get_edges( user_name = user_name if user_name else self._get_config_value("user_name") if direction == "OUTGOING": - pattern = f"(a:Memory)-[r]->(b:Memory)" + pattern = "(a:Memory)-[r]->(b:Memory)" where_clause = f"a.id = '{id}'" elif direction == "INCOMING": - pattern = f"(a:Memory)<-[r]-(b:Memory)" + pattern = "(a:Memory)<-[r]-(b:Memory)" where_clause = f"a.id = '{id}'" elif direction == "ANY": - pattern = f"(a:Memory)-[r]-(b:Memory)" + pattern = "(a:Memory)-[r]-(b:Memory)" where_clause = f"a.id = '{id}' OR b.id = '{id}'" else: raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 698bc3265..ca1df5c1f 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -11,6 +11,7 @@ from memos.llms.utils import remove_thinking_tags from memos.log import get_logger from memos.types import MessageList +from memos.utils import timed logger = get_logger(__name__) @@ -56,6 +57,7 @@ def clear_cache(cls): cls._instances.clear() logger.info("OpenAI LLM instance cache cleared") + @timed(log=True, log_prefix="OpenAI LLM") def generate(self, messages: MessageList) -> str: """Generate a response from OpenAI LLM.""" response = self.client.chat.completions.create( @@ -73,6 +75,7 @@ def generate(self, messages: MessageList) -> str: else: return response_content + @timed(log=True, log_prefix="OpenAI LLM") def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from OpenAI LLM with optional reasoning support.""" response = self.client.chat.completions.create( diff --git a/src/memos/log.py b/src/memos/log.py index 339d13f26..2a538fdde 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -14,7 +14,13 @@ from dotenv import load_dotenv from memos import settings -from memos.context.context import get_current_api_path, get_current_trace_id +from memos.context.context import ( + get_current_api_path, + get_current_env, + get_current_trace_id, + get_current_user_name, + get_current_user_type, +) # Load environment variables @@ -34,15 +40,22 @@ def _setup_logfile() -> Path: return logfile -class TraceIDFilter(logging.Filter): - """add trace_id to the log record""" +class ContextFilter(logging.Filter): + """add context to the log record""" def filter(self, record): try: trace_id = get_current_trace_id() record.trace_id = trace_id if trace_id else "trace-id" + record.env = get_current_env() + record.user_type = get_current_user_type() + record.user_name = get_current_user_name() + record.api_path = get_current_api_path() except Exception: record.trace_id = "trace-id" + record.env = "prod" + record.user_type = "normal" + record.user_name = "unknown" return True @@ -86,13 +99,24 @@ def emit(self, record): try: trace_id = get_current_trace_id() or "trace-id" api_path = get_current_api_path() + env = get_current_env() + user_type = get_current_user_type() + user_name = get_current_user_name() if api_path is not None: - self._executor.submit(self._send_log_sync, record.getMessage(), trace_id, api_path) + self._executor.submit( + self._send_log_sync, + record.getMessage(), + trace_id, + api_path, + env, + user_type, + user_name, + ) except Exception as e: if not self._is_shutting_down.is_set(): print(f"Error sending log: {e}") - def _send_log_sync(self, message, trace_id, api_path): + def _send_log_sync(self, message, trace_id, api_path, env, user_type, user_name): """Send log message synchronously in a separate thread""" try: logger_url = os.getenv("CUSTOM_LOGGER_URL") @@ -104,6 +128,9 @@ def _send_log_sync(self, message, trace_id, api_path): "trace_id": trace_id, "action": api_path, "current_time": round(time.time(), 3), + "env": env, + "user_type": user_type, + "user_name": user_name, } # Add auth token if exists @@ -145,26 +172,26 @@ def close(self): "disable_existing_loggers": False, "formatters": { "standard": { - "format": "%(asctime)s [%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" + "format": "%(asctime)s | %(trace_id)s | path=%(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" }, "no_datetime": { - "format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" + "format": "%(trace_id)s | path=%(api_path)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" }, "simplified": { - "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s" + "format": "%(asctime)s | %(trace_id)s | path=%(api_path)s | % %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s" }, }, "filters": { "package_tree_filter": {"()": "logging.Filter", "name": settings.LOG_FILTER_TREE_PREFIX}, - "trace_id_filter": {"()": "memos.log.TraceIDFilter"}, + "context_filter": {"()": "memos.log.ContextFilter"}, }, "handlers": { "console": { - "level": selected_log_level, + "level": "DEBUG", "class": "logging.StreamHandler", "stream": stdout, "formatter": "no_datetime", - "filters": ["package_tree_filter", "trace_id_filter"], + "filters": ["package_tree_filter", "context_filter"], }, "file": { "level": "DEBUG", @@ -173,7 +200,7 @@ def close(self): "maxBytes": 1024**2 * 10, "backupCount": 10, "formatter": "standard", - "filters": ["trace_id_filter"], + "filters": ["context_filter"], }, "custom_logger": { "level": "INFO", diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index ec8a673d7..939b0c68d 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -2,13 +2,13 @@ import os import time -from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path from threading import Lock from typing import Any, Literal from memos.configs.mem_os import MOSConfig +from memos.context.context import ContextThreadPoolExecutor from memos.llms.factory import LLMFactory from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube @@ -665,7 +665,7 @@ def search_preference_memory(cube_id, cube): return None # Execute both search functions in parallel - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(search_textual_memory, mem_cube_id, mem_cube) pref_future = executor.submit(search_preference_memory, mem_cube_id, mem_cube) @@ -824,7 +824,7 @@ def process_preference_memory(): self.mem_scheduler.submit_messages(messages=[message_item]) # Execute both memory processing functions in parallel - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(process_textual_memory) pref_future = executor.submit(process_preference_memory) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index d84ebb242..434cef3e9 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -3,6 +3,7 @@ import traceback from memos.configs.mem_scheduler import GeneralSchedulerConfig +from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.base_scheduler import BaseScheduler @@ -281,7 +282,7 @@ def process_message(message: ScheduleMessageItem): except Exception as e: logger.error(f"Error processing mem_read message: {e}", exc_info=True) - with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: futures = [executor.submit(process_message, msg) for msg in messages] for future in concurrent.futures.as_completed(futures): try: @@ -413,7 +414,7 @@ def process_message(message: ScheduleMessageItem): except Exception as e: logger.error(f"Error processing mem_read message: {e}", exc_info=True) - with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: futures = [executor.submit(process_message, msg) for msg in messages] for future in concurrent.futures.as_completed(futures): try: @@ -506,7 +507,7 @@ def process_message(message: ScheduleMessageItem): except Exception as e: logger.error(f"Error processing pref_add message: {e}", exc_info=True) - with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: futures = [executor.submit(process_message, msg) for msg in messages] for future in concurrent.futures.as_completed(futures): try: diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 390f048ef..eb284cd6d 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -1,9 +1,10 @@ import json from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from typing import Any +from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem from memos.templates.prefer_complete_prompt import ( @@ -162,7 +163,7 @@ def execute_op(op): self.vector_db.delete(collection_name, [op["target_id"]]) return None - with ThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: + with ContextThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: future_to_op = {executor.submit(execute_op, op): op for op in rsp["trace"]} added_ids = [] for future in as_completed(future_to_op): @@ -263,7 +264,7 @@ def add( return [] added_ids = [] - with ThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: + with ContextThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: future_to_memory = { executor.submit(self._process_single_memory, memory): memory for memory in memories } diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 460b31f4f..41d90d10e 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -2,10 +2,11 @@ import uuid from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from datetime import datetime from typing import Any +from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.prefer_text_memory.spliter import Splitter @@ -150,7 +151,7 @@ def extract( return [] memories = [] - with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: + with ContextThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: futures = { executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ( "explicit", diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 7f70bac3b..807a8b55e 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor from typing import Any +from memos.context.context import ContextThreadPoolExecutor from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem @@ -42,7 +42,7 @@ def retrieve( query_embedding = query_embeddings[0] # Get the first (and only) embedding # Use thread pool to parallelize the searches - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: # Submit all search tasks future_explicit = executor.submit( self.vector_db.search, query_embedding, "explicit_preference", top_k * 2, info diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index 2c423e6b6..41011df14 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -9,10 +9,10 @@ import requests from memos.log import get_logger +from memos.utils import timed from .base import BaseReranker from .concat import concat_original_source -from memos.utils import timed logger = get_logger(__name__) @@ -119,7 +119,7 @@ def __init__( self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys) self._warned_missing_keys: set[str] = set() - @timed + @timed(log=True, log_prefix="RerankerAPI") def rerank( self, query: str, diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py index 8cbf633a6..b0567698c 100644 --- a/src/memos/reranker/http_bge_strategy.py +++ b/src/memos/reranker/http_bge_strategy.py @@ -10,6 +10,7 @@ from memos.log import get_logger from memos.reranker.strategies import RerankerStrategyFactory +from memos.utils import timed from .base import BaseReranker @@ -119,6 +120,7 @@ def __init__( self._warned_missing_keys: set[str] = set() self.reranker_strategy = RerankerStrategyFactory.from_config(reranker_strategy) + @timed(log=True, log_prefix="RerankerStrategy") def rerank( self, query: str, diff --git a/src/memos/utils.py b/src/memos/utils.py index 6a1d42558..08934ed34 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -6,14 +6,24 @@ logger = get_logger(__name__) -def timed(func): - """Decorator to measure and log time of retrieval steps.""" +def timed(func=None, *, log=False, log_prefix=""): + """Decorator to measure and optionally log time of retrieval steps. - def wrapper(*args, **kwargs): - start = time.perf_counter() - result = func(*args, **kwargs) - elapsed = time.perf_counter() - start - logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s") - return result + Can be used as @timed or @timed(log=True) + """ - return wrapper + def decorator(fn): + def wrapper(*args, **kwargs): + start = time.perf_counter() + result = fn(*args, **kwargs) + elapsed = time.perf_counter() - start + if log: + logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed:.2f} seconds") + return result + + return wrapper + + # Handle both @timed and @timed(log=True) cases + if func is None: + return decorator + return decorator(func) From f8859f1eb0df2274786e3c01ac48bb6a1f5cdd02 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Wed, 29 Oct 2025 16:25:12 +0800 Subject: [PATCH 17/64] Hotfix: memos playground prompt reverse (#408) hotfix: memos playground --- src/memos/mem_os/product.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index fed8f7278..89e468bd7 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -1044,22 +1044,15 @@ def chat( m.metadata.embedding = [] new_memories_list.append(m) memories_list = new_memories_list - # Build base system prompt without memory - system_prompt = self._build_base_system_prompt(base_prompt, mode="base") - - # Build memory context to be included in user message - memory_context = self._build_memory_context(memories_list, mode="base") - - # Combine memory context with user query - user_content = memory_context + query if memory_context else query + system_prompt = super()._build_system_prompt(memories_list, base_prompt) history_info = [] if history: history_info = history[-20:] current_messages = [ {"role": "system", "content": system_prompt}, *history_info, - {"role": "user", "content": user_content}, + {"role": "user", "content": query}, ] response = self.chat_llm.generate(current_messages) time_end = time.time() @@ -1129,16 +1122,8 @@ def chat_with_references( reference = prepare_reference_data(memories_list) yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" - - # Build base system prompt without memory - system_prompt = self._build_base_system_prompt(mode="enhance") - - # Build memory context to be included in user message - memory_context = self._build_memory_context(memories_list, mode="enhance") - - # Combine memory context with user query - user_content = memory_context + query if memory_context else query - + # Build custom system prompt with relevant memories) + system_prompt = self._build_enhance_system_prompt(user_id, memories_list) # Get chat history if user_id not in self.chat_history_manager: self._register_chat_history(user_id, session_id) @@ -1149,7 +1134,7 @@ def chat_with_references( current_messages = [ {"role": "system", "content": system_prompt}, *chat_history.chat_history, - {"role": "user", "content": user_content}, + {"role": "user", "content": query}, ] logger.info( f"user_id: {user_id}, cube_id: {cube_id}, current_system_prompt: {system_prompt}" From 7eb531b295b8f3d2455e9450ea4f612d9ab340a8 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Wed, 29 Oct 2025 16:59:24 +0800 Subject: [PATCH 18/64] Feat/pref optimize update (#409) * add hybrid search and fine extractor * add dialog and modify spliter chunk * optmize the update and retriever code * modify pref field * add pref mem update srategy * add pref mem update srategy * fix bug in pre_commit --------- Co-authored-by: yuan.wang --- docker/requirements.txt | 2 +- evaluation/scripts/PrefEval/pref_memos.py | 6 +- evaluation/scripts/locomo/locomo_search.py | 4 +- evaluation/scripts/longmemeval/lme_search.py | 2 +- evaluation/scripts/personamem/pm_ingestion.py | 4 +- evaluation/scripts/personamem/pm_search.py | 2 +- evaluation/scripts/utils/client.py | 1 + src/memos/api/routers/server_router.py | 12 +- src/memos/memories/textual/item.py | 2 +- .../textual/prefer_text_memory/adder.py | 294 +++++++++---- .../textual/prefer_text_memory/extractor.py | 27 +- .../textual/prefer_text_memory/retrievers.py | 68 ++- .../textual/prefer_text_memory/spliter.py | 6 +- src/memos/reranker/cosine_local.py | 2 +- src/memos/reranker/noop.py | 4 +- src/memos/templates/instruction_completion.py | 23 +- src/memos/templates/prefer_complete_prompt.py | 401 ++++++++++++++++-- src/memos/vec_dbs/item.py | 1 + src/memos/vec_dbs/milvus.py | 130 +++++- 19 files changed, 833 insertions(+), 158 deletions(-) diff --git a/docker/requirements.txt b/docker/requirements.txt index d20c0b36e..4846f1832 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -157,4 +157,4 @@ volcengine-python-sdk==4.0.6 watchfiles==1.1.0 websockets==15.0.1 xlrd==2.0.2 -xlsxwriter==3.2.5 \ No newline at end of file +xlsxwriter==3.2.5 diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 7336d4612..fc358dc36 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -53,9 +53,9 @@ def add_memory_for_line( if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true": for chunk_start in range(0, len(conversation), turns_add * 2): chunk = conversation[chunk_start : chunk_start + turns_add * 2] - mem_client.add(messages=chunk, user_id=user_id, conv_id=None) + mem_client.add(messages=chunk, user_id=user_id, conv_id=None, batch_size=2) else: - mem_client.add(messages=conversation, user_id=user_id, conv_id=None) + mem_client.add(messages=conversation, user_id=user_id, conv_id=None, batch_size=2) end_time_add = time.monotonic() add_duration = end_time_add - start_time_add @@ -98,7 +98,7 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] ) - + f"\n{relevant_memories['pref_mem']}" + + f"\n{relevant_memories['pref_string']}" ) memory_tokens_used = len(tokenizer.encode(memories_str)) diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index 1ddf0d933..0b610d574 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -107,11 +107,11 @@ def memos_api_search( speaker_a_context = ( "\n".join([i["memory"] for i in search_a_results["text_mem"][0]["memories"]]) - + f"\n{search_a_results['pref_mem']}" + + f"\n{search_a_results['pref_string']}" ) speaker_b_context = ( "\n".join([i["memory"] for i in search_b_results["text_mem"][0]["memories"]]) - + f"\n{search_b_results['pref_mem']}" + + f"\n{search_b_results['pref_string']}" ) context = TEMPLATE_MEMOS.format( diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index 60b2146f6..89c02aaea 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -46,7 +46,7 @@ def memos_search(client, query, user_id, top_k): results = client.search(query=query, user_id=user_id, top_k=top_k) context = ( "\n".join([i["memory"] for i in results["text_mem"][0]["memories"]]) - + f"\n{results['pref_mem']}" + + f"\n{results['pref_string']}" ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context) duration_ms = (time() - start) * 1000 diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index 5204b5c2a..cab0fbeb5 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -31,10 +31,10 @@ def ingest_session(session, user_id, session_id, frame, client): if os.getenv("PRE_SPLIT_CHUNK") == "true": for i in range(0, len(session), 10): messages = session[i : i + 10] - client.add(messages=messages, user_id=user_id, conv_id=session_id) + client.add(messages=messages, user_id=user_id, conv_id=session_id, batch_size=2) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") else: - client.add(messages=session, user_id=user_id, conv_id=session_id) + client.add(messages=session, user_id=user_id, conv_id=session_id, batch_size=2) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(session)} messages") elif frame == "memobase": for _idx, msg in enumerate(session): diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index c18e05623..441474c7c 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -84,7 +84,7 @@ def memos_search(client, user_id, query, top_k): results = client.search(query=query, user_id=user_id, top_k=top_k) search_memories = ( "\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"]) - + f"\n{results['pref_mem']}" + + f"\n{results['pref_string']}" ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 4117cba56..e1bdd54e9 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -182,6 +182,7 @@ def search(self, query, user_id, top_k): "conversation_id": "", "top_k": top_k, "mode": "mixture", + "handle_pref_mem": False, }, ensure_ascii=False, ) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 3ba12c1ce..bb98f04ba 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -303,18 +303,15 @@ def _post_process_pref_mem( mem_cube_id: str, handle_pref_mem: bool, ): - if os.getenv("RETURN_ORIGINAL_PREF_MEM", "false").lower() == "true" and pref_formatted_mem: - memories_result["prefs"] = [] - memories_result["prefs"].append( + if handle_pref_mem: + memories_result["pref_mem"].append( { "cube_id": mem_cube_id, "memories": pref_formatted_mem, } ) - - if handle_pref_mem: pref_instruction: str = instruct_completion(pref_formatted_mem) - memories_result["pref_mem"] = pref_instruction + memories_result["pref_string"] = pref_instruction return memories_result @@ -333,7 +330,8 @@ def search_memories(search_req: APISearchRequest): "text_mem": [], "act_mem": [], "para_mem": [], - "pref_mem": "", + "pref_mem": [], + "pref_string": "", } search_mode = search_req.mode diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 48f42fa4c..9b9059d26 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -194,7 +194,7 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): default="explicit_preference", description="Type of preference." ) dialog_id: str | None = Field(default=None, description="ID of the dialog.") - dialog_str: str | None = Field(default=None, description="String of the dialog.") + original_text: str | None = Field(default=None, description="String of the dialog.") embedding: list[float] | None = Field(default=None, description="Vector of the dialog.") explicit_preference: str | None = Field(default=None, description="Explicit preference.") created_at: str | None = Field(default=None, description="Timestamp of the dialog.") diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index eb284cd6d..052ae30c2 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from concurrent.futures import as_completed +from datetime import datetime from typing import Any from memos.context.context import ContextThreadPoolExecutor @@ -9,6 +10,7 @@ from memos.memories.textual.item import TextualMemoryItem from memos.templates.prefer_complete_prompt import ( NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_FINE, NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE, ) from memos.vec_dbs.item import MilvusVecDBItem @@ -57,18 +59,35 @@ def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool: response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) response = result.get("is_same", False) - return response if isinstance(response, bool) else response == "true" + return response if isinstance(response, bool) else response.lower() == "true" except Exception as e: logger.error(f"Error in judge_update_or_add: {e}") # Fallback to simple string comparison return old_msg == new_msg - def _judge_update_or_add_trace_op( - self, new_mem: str, retrieved_mems: str - ) -> dict[str, Any] | None: - prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE.replace("{new_memory}", new_mem).replace( + def _judge_update_or_add_fine(self, new_mem: str, retrieved_mems: str) -> dict[str, Any] | None: + if not retrieved_mems: + return None + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_FINE.replace("{new_memory}", new_mem).replace( "{retrieved_memories}", retrieved_mems ) + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + return result + except Exception as e: + logger.error(f"Error in judge_update_or_add_fine: {e}") + return None + + def _judge_update_or_add_trace_op( + self, new_mems: str, retrieved_mems: str + ) -> dict[str, Any] | None: + if not retrieved_mems: + return None + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE.replace( + "{new_memories}", new_mems + ).replace("{retrieved_memories}", retrieved_mems) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() @@ -80,30 +99,34 @@ def _judge_update_or_add_trace_op( def _update_memory_op_trace( self, - new_memory: TextualMemoryItem, + new_memories: list[TextualMemoryItem], retrieved_memories: list[MilvusVecDBItem], collection_name: str, preference_type: str, ) -> list[str] | str: - if not retrieved_memories: + # create new vec db items + new_vec_db_items: list[MilvusVecDBItem] = [] + for new_memory in new_memories: payload = new_memory.to_dict()["metadata"] - fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + fields_to_remove = {"dialog_id", "original_text", "embedding"} payload = {k: v for k, v in payload.items() if k not in fields_to_remove} - vec_db_item = MilvusVecDBItem( + new_vec_db_item = MilvusVecDBItem( id=new_memory.id, memory=new_memory.memory, + original_text=new_memory.metadata.original_text, vector=new_memory.metadata.embedding, payload=payload, ) - self.vector_db.add(collection_name, [vec_db_item]) - return new_memory.id + new_vec_db_items.append(new_vec_db_item) - new_mem_input = { - "context_summary": new_memory.memory, - "preference": new_memory.metadata.explicit_preference - if preference_type == "explicit_preference" - else new_memory.metadata.implicit_preference, - } + new_mem_inputs = [ + { + "id": new_memory.id, + "context_summary": new_memory.memory, + "preference": new_memory.payload[preference_type], + } + for new_memory in new_vec_db_items + ] retrieved_mem_inputs = [ { "id": mem.id, @@ -114,57 +137,53 @@ def _update_memory_op_trace( ] rsp = self._judge_update_or_add_trace_op( - new_mem=json.dumps(new_mem_input), retrieved_mems=json.dumps(retrieved_mem_inputs) + new_mems=json.dumps(new_mem_inputs), + retrieved_mems=json.dumps(retrieved_mem_inputs) if retrieved_mem_inputs else "", ) if not rsp: - payload = new_memory.to_dict()["metadata"] - fields_to_remove = {"dialog_id", "dialog_str", "embedding"} - payload = {k: v for k, v in payload.items() if k not in fields_to_remove} - vec_db_item = MilvusVecDBItem( - id=new_memory.id, - memory=new_memory.memory, - vector=new_memory.metadata.embedding, - payload=payload, - ) - self.vector_db.add(collection_name, [vec_db_item]) - return new_memory.id + with ContextThreadPoolExecutor(max_workers=min(len(new_vec_db_items), 5)) as executor: + futures = { + executor.submit(self.vector_db.add, collection_name, [db_item]): db_item + for db_item in new_vec_db_items + } + for future in as_completed(futures): + result = future.result() + return [db_item.id for db_item in new_vec_db_items] - def execute_op(op): + new_mem_db_item_map = {db_item.id: db_item for db_item in new_vec_db_items} + retrieved_mem_db_item_map = {db_item.id: db_item for db_item in retrieved_memories} + + def execute_op( + op, + new_mem_db_item_map: dict[str, MilvusVecDBItem], + retrieved_mem_db_item_map: dict[str, MilvusVecDBItem], + ) -> str | None: op_type = op["type"].lower() if op_type == "add": - payload = new_memory.to_dict()["metadata"] - payload = { - k: v - for k, v in payload.items() - if k not in {"dialog_id", "dialog_str", "embedding"} - } - vec_db_item = MilvusVecDBItem( - id=new_memory.id, - memory=new_memory.memory, - vector=new_memory.metadata.embedding, - payload=payload, - ) - self.vector_db.add(collection_name, [vec_db_item]) - return new_memory.id + if op["target_id"] in new_mem_db_item_map: + self.vector_db.add(collection_name, [new_mem_db_item_map[op["target_id"]]]) + return new_mem_db_item_map[op["target_id"]].id + return None elif op_type == "update": - payload = { - "preference_type": preference_type, - preference_type: op["new_preference"], - } - vec_db_item = MilvusVecDBItem( - id=op["target_id"], - memory=op["new_context_summary"], - vector=self.embedder.embed([op["new_context_summary"]])[0], - payload=payload, - ) - self.vector_db.update(collection_name, op["target_id"], vec_db_item) - return op["target_id"] + if op["target_id"] in retrieved_mem_db_item_map: + update_mem_db_item = retrieved_mem_db_item_map[op["target_id"]] + update_mem_db_item.payload[preference_type] = op["new_preference"] + update_mem_db_item.payload["updated_at"] = datetime.now().isoformat() + update_mem_db_item.memory = op["new_context_summary"] + update_mem_db_item.original_text = op["new_context_summary"] + update_mem_db_item.vector = self.embedder.embed([op["new_context_summary"]])[0] + self.vector_db.update(collection_name, op["target_id"], update_mem_db_item) + return op["target_id"] + return None elif op_type == "delete": self.vector_db.delete(collection_name, [op["target_id"]]) return None with ContextThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: - future_to_op = {executor.submit(execute_op, op): op for op in rsp["trace"]} + future_to_op = { + executor.submit(execute_op, op, new_mem_db_item_map, retrieved_mem_db_item_map): op + for op in rsp["trace"] + } added_ids = [] for future in as_completed(future_to_op): result = future.result() @@ -173,6 +192,61 @@ def execute_op(op): return added_ids + def _update_memory_fine( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str, + ) -> str: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "original_text", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + original_text=new_memory.metadata.original_text, + vector=new_memory.metadata.embedding, + payload=payload, + ) + + new_mem_input = { + "memory": new_memory.memory, + "preference": new_memory.metadata.explicit_preference + if preference_type == "explicit_preference" + else new_memory.metadata.implicit_preference, + } + retrieved_mem_inputs = [ + { + "id": mem.id, + "memory": mem.memory, + "preference": mem.payload[preference_type], + } + for mem in retrieved_memories + ] + rsp = self._judge_update_or_add_fine( + new_mem=json.dumps(new_mem_input), + retrieved_mems=json.dumps(retrieved_mem_inputs) if retrieved_mem_inputs else "", + ) + need_update = rsp.get("need_update", False) if rsp else False + need_update = ( + need_update if isinstance(need_update, bool) else need_update.lower() == "true" + ) + update_item = [mem for mem in retrieved_memories if mem.id == rsp["id"]] + if need_update and update_item: + update_vec_db_item = update_item[0] + update_vec_db_item.payload[preference_type] = rsp["new_preference"] + update_vec_db_item.payload["updated_at"] = vec_db_item.payload["updated_at"] + update_vec_db_item.memory = rsp["new_memory"] + update_vec_db_item.original_text = vec_db_item.original_text + update_vec_db_item.vector = self.embedder.embed([rsp["new_memory"]])[0] + + self.vector_db.update(collection_name, rsp["id"], update_vec_db_item) + return rsp["id"] + else: + self.vector_db.add(collection_name, [vec_db_item]) + return vec_db_item.id + def _update_memory_fast( self, new_memory: TextualMemoryItem, @@ -180,11 +254,12 @@ def _update_memory_fast( collection_name: str, ) -> str: payload = new_memory.to_dict()["metadata"] - fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + fields_to_remove = {"dialog_id", "original_text", "embedding"} payload = {k: v for k, v in payload.items() if k not in fields_to_remove} vec_db_item = MilvusVecDBItem( id=new_memory.id, memory=new_memory.memory, + original_text=new_memory.metadata.original_text, vector=new_memory.metadata.embedding, payload=payload, ) @@ -197,8 +272,9 @@ def _update_memory_fast( new_msg_str = new_memory.memory is_same = self._judge_update_or_add_fast(old_msg=old_msg_str, new_msg=new_msg_str) if is_same: - self.vector_db.delete(collection_name, [recall.id]) - self.vector_db.update(collection_name, new_memory.id, vec_db_item) + vec_db_item.id = recall.id + self.vector_db.update(collection_name, recall.id, vec_db_item) + self.vector_db.add(collection_name, [vec_db_item]) return new_memory.id def _update_memory( @@ -207,7 +283,7 @@ def _update_memory( retrieved_memories: list[MilvusVecDBItem], collection_name: str, preference_type: str, - update_mode: str = "op_trace", + update_mode: str = "fine", ) -> list[str] | str | None: """Update the memory. Args: @@ -215,14 +291,14 @@ def _update_memory( retrieved_memories: list[MilvusVecDBItem] collection_name: str preference_type: str - update_mode: str, "op_trace" or "fast" + update_mode: str, "fast" or "fine" """ - if update_mode == "op_trace": - return self._update_memory_op_trace( + if update_mode == "fast": + return self._update_memory_fast(new_memory, retrieved_memories, collection_name) + elif update_mode == "fine": + return self._update_memory_fine( new_memory, retrieved_memories, collection_name, preference_type ) - elif update_mode == "fast": - return self._update_memory_fast(new_memory, retrieved_memories, collection_name) else: raise ValueError(f"Invalid update mode: {update_mode}") @@ -237,33 +313,71 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | collection_name = pref_type_collection_map[preference_type] search_results = self.vector_db.search( - memory.metadata.embedding, - collection_name, + query_vector=memory.metadata.embedding, + query=memory.memory, + collection_name=collection_name, top_k=5, filter={"user_id": memory.metadata.user_id}, ) search_results.sort(key=lambda x: x.score, reverse=True) return self._update_memory( - memory, search_results, collection_name, preference_type, update_mode="fast" + memory, search_results, collection_name, preference_type, update_mode="fine" ) except Exception as e: logger.error(f"Error processing memory {memory.id}: {e}") return None - def add( - self, - memories: list[TextualMemoryItem | dict[str, Any]], - max_workers: int = 8, - *args, - **kwargs, - ) -> list[str]: - """Add the instruct preference memories using thread pool for acceleration.""" - if not memories: - return [] + def process_memory_batch(self, memories: list[TextualMemoryItem], *args, **kwargs) -> list[str]: + pref_type_collection_map = { + "explicit_preference": "explicit_preference", + "implicit_preference": "implicit_preference", + } - added_ids = [] + explicit_new_mems = [] + implicit_new_mems = [] + explicit_recalls = [] + implicit_recalls = [] + + for memory in memories: + preference_type = memory.metadata.preference_type + collection_name = pref_type_collection_map[preference_type] + search_results = self.vector_db.search( + query_vector=memory.metadata.embedding, + query=memory.memory, + collection_name=collection_name, + top_k=5, + filter={"user_id": memory.metadata.user_id}, + ) + if preference_type == "explicit_preference": + explicit_recalls.extend(search_results) + explicit_new_mems.append(memory) + elif preference_type == "implicit_preference": + implicit_recalls.extend(search_results) + implicit_new_mems.append(memory) + + explicit_recalls = list({recall.id: recall for recall in explicit_recalls}.values()) + implicit_recalls = list({recall.id: recall for recall in implicit_recalls}.values()) + + explicit_added_ids = self._update_memory_op_trace( + explicit_new_mems, + explicit_recalls, + pref_type_collection_map["explicit_preference"], + "explicit_preference", + ) + implicit_added_ids = self._update_memory_op_trace( + implicit_new_mems, + implicit_recalls, + pref_type_collection_map["implicit_preference"], + "implicit_preference", + ) + return explicit_added_ids + implicit_added_ids + + def process_memory_single( + self, memories: list[TextualMemoryItem], max_workers: int = 8, *args, **kwargs + ) -> list[str]: + added_ids: list[str] = [] with ContextThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: future_to_memory = { executor.submit(self._process_single_memory, memory): memory for memory in memories @@ -281,5 +395,23 @@ def add( memory = future_to_memory[future] logger.error(f"Error processing memory {memory.id}: {e}") continue - return added_ids + + def add( + self, + memories: list[TextualMemoryItem | dict[str, Any]], + max_workers: int = 8, + *args, + **kwargs, + ) -> list[str]: + """Add the instruct preference memories using thread pool for acceleration.""" + if not memories: + return [] + + process_map = { + "single": self.process_memory_single, + "batch": self.process_memory_batch, + } + + process_func = process_map["single"] + return process_func(memories, max_workers) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 41d90d10e..61629b38a 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -8,12 +8,15 @@ from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger +from memos.mem_reader.simple_struct import detect_lang from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.prefer_text_memory.spliter import Splitter from memos.memories.textual.prefer_text_memory.utils import convert_messages_to_string from memos.templates.prefer_complete_prompt import ( NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, ) from memos.types import MessageList @@ -44,7 +47,7 @@ def extract_basic_info(self, qa_pair: MessageList) -> dict[str, Any]: """Extract basic information from a QA pair (no LLM needed).""" basic_info = { "dialog_id": str(uuid.uuid4()), - "dialog_str": convert_messages_to_string(qa_pair), + "original_text": convert_messages_to_string(qa_pair), "created_at": datetime.now().isoformat(), } @@ -53,7 +56,12 @@ def extract_basic_info(self, qa_pair: MessageList) -> dict[str, Any]: def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, Any] | None: """Extract explicit preference from a QA pair.""" qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair - prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + lang = detect_lang(qa_pair_str) + _map = { + "zh": NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, + "en": NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + } + prompt = _map[lang].replace("{qa_pair}", qa_pair_str) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) @@ -69,7 +77,12 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A if not qa_pair: return None qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair - prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + lang = detect_lang(qa_pair_str) + _map = { + "zh": NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, + "en": NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + } + prompt = _map[lang].replace("{qa_pair}", qa_pair_str) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) @@ -85,10 +98,10 @@ def _process_single_chunk_explicit( ) -> TextualMemoryItem | None: """Process a single chunk and return a TextualMemoryItem.""" basic_info = self.extract_basic_info(chunk) - if not basic_info["dialog_str"]: + if not basic_info["original_text"]: return None - explicit_pref = self.extract_explicit_preference(basic_info["dialog_str"]) + explicit_pref = self.extract_explicit_preference(basic_info["original_text"]) if not explicit_pref: return None @@ -114,9 +127,9 @@ def _process_single_chunk_implicit( self, chunk: MessageList, msg_type: str, info: dict[str, Any] ) -> TextualMemoryItem | None: basic_info = self.extract_basic_info(chunk) - if not basic_info["dialog_str"]: + if not basic_info["original_text"]: return None - implicit_pref = self.extract_implicit_preference(basic_info["dialog_str"]) + implicit_pref = self.extract_implicit_preference(basic_info["original_text"]) if not implicit_pref: return None diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 807a8b55e..f09d646b1 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -3,6 +3,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.vec_dbs.item import MilvusVecDBItem class BaseRetriever(ABC): @@ -29,6 +30,35 @@ def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=No self.vector_db = vector_db self.embedder = embedder + def _naive_reranker( + self, query: str, prefs_mem: list[TextualMemoryItem], top_k: int, **kwargs: Any + ) -> list[TextualMemoryItem]: + if self.reranker: + prefs_mem = self.reranker.rerank(query, prefs_mem, top_k) + return [item for item, _ in prefs_mem] + return prefs_mem + + def _original_text_reranker( + self, + query: str, + prefs_mem: list[TextualMemoryItem], + prefs: list[MilvusVecDBItem], + top_k: int, + **kwargs: Any, + ) -> list[TextualMemoryItem]: + if self.reranker: + from copy import deepcopy + + prefs_mem_for_reranker = deepcopy(prefs_mem) + for pref_mem, pref in zip(prefs_mem_for_reranker, prefs, strict=False): + pref_mem.memory = pref_mem.memory + "\n" + pref.original_text + prefs_mem_for_reranker = self.reranker.rerank(query, prefs_mem_for_reranker, top_k) + prefs_mem_for_reranker = [item for item, _ in prefs_mem_for_reranker] + prefs_ids = [item.id for item in prefs_mem_for_reranker] + prefs_dict = {item.id: item for item in prefs_mem} + return [prefs_dict[item_id] for item_id in prefs_ids if item_id in prefs_dict] + return prefs_mem + def retrieve( self, query: str, top_k: int, info: dict[str, Any] | None = None ) -> list[TextualMemoryItem]: @@ -45,10 +75,20 @@ def retrieve( with ContextThreadPoolExecutor(max_workers=2) as executor: # Submit all search tasks future_explicit = executor.submit( - self.vector_db.search, query_embedding, "explicit_preference", top_k * 2, info + self.vector_db.search, + query_embedding, + query, + "explicit_preference", + top_k * 2, + info, ) future_implicit = executor.submit( - self.vector_db.search, query_embedding, "implicit_preference", top_k * 2, info + self.vector_db.search, + query_embedding, + query, + "implicit_preference", + top_k * 2, + info, ) # Wait for all results @@ -59,7 +99,7 @@ def retrieve( explicit_prefs.sort(key=lambda x: x.score, reverse=True) implicit_prefs.sort(key=lambda x: x.score, reverse=True) - explicit_prefs = [ + explicit_prefs_mem = [ TextualMemoryItem( id=pref.id, memory=pref.memory, @@ -69,7 +109,7 @@ def retrieve( if pref.payload["explicit_preference"] ] - implicit_prefs = [ + implicit_prefs_mem = [ TextualMemoryItem( id=pref.id, memory=pref.memory, @@ -79,10 +119,16 @@ def retrieve( if pref.payload["implicit_preference"] ] - if self.reranker: - explicit_prefs = self.reranker.rerank(query, explicit_prefs, top_k) - implicit_prefs = self.reranker.rerank(query, implicit_prefs, top_k) - explicit_prefs = [item for item, _ in explicit_prefs] - implicit_prefs = [item for item, _ in implicit_prefs] - - return explicit_prefs + implicit_prefs + reranker_map = { + "naive": self._naive_reranker, + "original_text": self._original_text_reranker, + } + reranker_func = reranker_map["naive"] + explicit_prefs_mem = reranker_func( + query=query, prefs_mem=explicit_prefs_mem, prefs=explicit_prefs, top_k=top_k + ) + implicit_prefs_mem = reranker_func( + query=query, prefs_mem=implicit_prefs_mem, prefs=implicit_prefs, top_k=top_k + ) + + return explicit_prefs_mem + implicit_prefs_mem diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py index 59a6b0052..3059d611b 100644 --- a/src/memos/memories/textual/prefer_text_memory/spliter.py +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -79,15 +79,15 @@ def _split_with_overlap(self, data: MessageList) -> list[MessageList]: adjacent chunk with low duplicate rate""" chunks = [] chunk = [] - for item in data: + for i, item in enumerate(data): chunk.append(item) # 5 turns (Q + A = 10) each chunk if len(chunk) >= 10: chunks.append(chunk) # overlap 1 turns (Q + A = 2) - context = copy.deepcopy(chunk[-2:]) + context = copy.deepcopy(chunk[-2:]) if i + 1 < len(data) else [] chunk = context - if chunk: + if chunk and len(chunk) % 2 == 0: chunks.append(chunk) return chunks diff --git a/src/memos/reranker/cosine_local.py b/src/memos/reranker/cosine_local.py index 38ace458f..318cd744a 100644 --- a/src/memos/reranker/cosine_local.py +++ b/src/memos/reranker/cosine_local.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING from memos.log import get_logger +from memos.utils import timed from .base import BaseReranker -from memos.utils import timed if TYPE_CHECKING: diff --git a/src/memos/reranker/noop.py b/src/memos/reranker/noop.py index 4f6ba0438..04250bef7 100644 --- a/src/memos/reranker/noop.py +++ b/src/memos/reranker/noop.py @@ -2,9 +2,11 @@ from typing import TYPE_CHECKING -from .base import BaseReranker from memos.utils import timed +from .base import BaseReranker + + if TYPE_CHECKING: from memos.memories.textual.item import TextualMemoryItem diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py index 7ad0fe190..c2a7f58c7 100644 --- a/src/memos/templates/instruction_completion.py +++ b/src/memos/templates/instruction_completion.py @@ -1,6 +1,7 @@ from typing import Any -from memos.templates.prefer_complete_prompt import PREF_INSTRUCTIONS +from memos.mem_reader.simple_struct import detect_lang +from memos.templates.prefer_complete_prompt import PREF_INSTRUCTIONS, PREF_INSTRUCTIONS_ZH def instruct_completion( @@ -33,11 +34,25 @@ def instruct_completion( else "" ) + _prompt_map = { + "zh": PREF_INSTRUCTIONS_ZH, + "en": PREF_INSTRUCTIONS, + } + _remove_exp_map = { + "zh": "显式偏好 > ", + "en": "explicit preference > ", + } + _remove_imp_map = { + "zh": "隐式偏好 > ", + "en": "implicit preference > ", + } + lang = detect_lang(explicit_pref_str + implicit_pref_str) + if not explicit_pref_str and not implicit_pref_str: return "" if not explicit_pref_str: - return implicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("explicit preferences > ", "") + return implicit_pref_str + "\n" + _prompt_map[lang].replace(_remove_exp_map[lang], "") if not implicit_pref_str: - return explicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("implicit preferences > ", "") + return explicit_pref_str + "\n" + _prompt_map[lang].replace(_remove_imp_map[lang], "") - return explicit_pref_str + "\n" + implicit_pref_str + "\n" + PREF_INSTRUCTIONS + return explicit_pref_str + "\n" + implicit_pref_str + "\n" + _prompt_map[lang] diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index d40b7b778..b98e65d54 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -9,9 +9,9 @@ - When the user modifies or updates their preferences for the same topic or event, extract the complete evolution process of their preference changes, including both the original and updated preferences. Requirements: -1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. -2. Output should be a list of concise natural language summaries and the corresponding context summary, context summary must contain complete information of the conversation fragment that the preference is mentioned. -3. If multiple preferences are mentioned within the same topic, you need to merge the preferences and context summary. +1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. If the user mentions reasons for their preferences, include those reasons as well. +2. Output should be a list of entries concise natural language summaries and the corresponding context summary, context summary must contain complete information of the conversation fragment that the preference is mentioned. +3. If multiple preferences are mentioned within the same topic or domain, you MUST combine them into a single entry, keep each entry information complete. Conversation: {qa_pair} @@ -29,6 +29,37 @@ """ +NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH = """ +你是一个偏好提取助手。 +请从以下对话中提取用户明确提及的偏好。 + +注意事项: +- 偏好是指用户对某事物的明确态度或选择,不仅限于"喜欢/不喜欢/想要/不想要/偏好"等词汇。 +- 包括但不限于用户明确表达的任何倾向、渴望、拒绝或优先级,这些都算作显式偏好。 +- 重点提取用户在查询中的偏好。不要从助手的回复中提取偏好,除非用户明确同意或认可助手的建议。 +- 当用户针对同一主题或事件修改或更新其偏好时,提取其偏好变化的完整演变过程,包括原始偏好和更新后的偏好。 + +要求: +1. 只保留用户明确提到的偏好,不要推断或假设。如果用户提到了偏好的原因,也要包含这些原因。 +2. 输出应该是一个条目列表,包含简洁的自然语言摘要和相应的上下文摘要,上下文摘要必须包含提到偏好的对话片段的完整信息。 +3. 如果在同一主题或领域内提到了多个偏好,你必须将它们合并为一个条目,保持每个条目信息完整。 + +对话: +{qa_pair} + +找出所有显式偏好。如果没有找到显式偏好,返回[]。仅输出JSON: +```json +[ + { + "explicit_preference": "偏好的简短自然语言摘要", + "context_summary": "对应的上下文摘要,即对应对话的摘要,不要遗漏任何场景信息", + "reasoning": "寻找显式偏好的推理过程" + }, +] +``` +""" + + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ You are a preference inference assistant. Please extract **implicit preferences** from the following conversation (preferences that the user did not explicitly state but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices). @@ -39,10 +70,9 @@ Requirements: 1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. -2. Output a concise natural language statement; do not use lists, categories, or include the reasoning process. -3. Inferred implicit preferences must not conflict with explicit preferences. -4. For implicit_preference: only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. Put all reasoning and explanation in the reasoning field. -5. If no implicit preference can be reasonably inferred, leave the implicit_preference field empty (do not output anything else). +2. Inferred implicit preferences must not conflict with explicit preferences. +3. For implicit_preference: only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. Put all reasoning and explanation in the reasoning field. +4. If no implicit preference can be reasonably inferred, leave the implicit_preference field empty (do not output anything else). Conversation: {qa_pair} @@ -59,6 +89,35 @@ """ +NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH = """ +你是一个偏好推理助手。请从以下对话中提取**隐式偏好** +(用户没有明确表述,但可以从上下文、行为、频率、比较、排除或场景选择中合理推断出的偏好)。 + +注意事项: +- 隐式偏好是指用户未直接表达,但可以从对话中的事实线索合理推断出的倾向或选择。 +- 不要将明确陈述的偏好视为隐式偏好;此提示仅用于推断未直接提及的偏好。 + +要求: +1. 仅在对话中有充分证据时进行推断;避免无根据或牵强的猜测。 +2. 推断的隐式偏好不得与显式偏好冲突。 +3. 对于 implicit_preference:仅输出偏好陈述本身;不要包含任何额外的解释、推理或置信度信息。将所有推理和解释放在 reasoning 字段中。 +4. 如果无法合理推断出隐式偏好,则将 implicit_preference 字段留空(不要输出其他任何内容)。 + +对话: +{qa_pair} + +输出格式: +```json +{ + "implicit_preference": "从对话中合理推断出的隐式偏好的简洁自然语言陈述,或空字符串", + "context_summary": "对应的上下文摘要,即对应对话的摘要,不要遗漏任何场景信息", + "reasoning": "简要解释隐式偏好的推理过程" +} +``` +除JSON外不要输出任何其他内容。 +""" + + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT = """ You are a content comparison expert. Now you are given old and new information, each containing a question, answer topic name and topic description. Please judge whether these two information express the **same question or core content**, regardless of expression differences, details or example differences. The judgment criteria are as follows: @@ -81,6 +140,104 @@ """ +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_ZH = """ +你是一个内容比较专家。现在给你旧信息和新信息,每个信息都包含问题、答案主题名称和主题描述。 +请判断这两个信息是否表达**相同的问题或核心内容**,不考虑表达差异、细节或示例差异。判断标准如下: + +- 核心内容一致,即要解决的问题本质、目标或核心概念相同,算作"相同"。 +- 表达方式不同、示例不同,但核心含义一致,也算作"相同"。 +- 如果问题目标、涉及的概念或解决思路不同,则算作"不同"。 + +请输出JSON格式: +{ + "is_same": true/false, + "reasoning": "简要解释判断依据,突出核心内容是否一致" +} + +**旧信息:** +{old_information} + +**新信息:** +{new_information} +""" + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_FINE = """ +You are a preference memory comparison expert. Analyze if the new preference memory describes the same topic as any retrieved memories by considering BOTH the memory field and preference field. At most one retrieved memory can match the new memory. + +**Task:** Compare the new preference memory with retrieved memories to determine if they discuss the same topic and whether an update is needed. + +**Comparison Criteria:** +- **Memory field**: Compare the core topics, scenarios, and contexts described +- **Preference field**: Compare the actual preference statements, choices, and attitudes expressed +- **Same topic**: Both memory AND preference content relate to the same subject matter +- **Different topics**: Either memory OR preference content differs significantly +- **Content evolution**: Same topic but preference has changed/evolved or memory has been updated +- **Identical content**: Both memory and preference fields are essentially the same + +**Decision Logic:** +- Same core topic (both memory and preference) = need to check if update is needed +- Different topics (either memory or preference differs) = no update needed +- If same topic but content has changed/evolved = update needed +- If same topic and content is identical = update needed + +**Output JSON:** +```json +{ + "need_update": true/false, + "id": "ID of the memory being updated (empty string if no update needed)", + "new_memory": "Updated memory field with merged/evolved memory content (empty string if no update needed)", + "new_preference": "Updated preference field with merged/evolved preference content (empty string if no update needed)", + "reasoning": "Brief explanation of the comparison considering both memory and preference fields" +} +``` + +**New preference memory:** +{new_memory} + +**Retrieved preference memories:** +{retrieved_memories} +""" + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_FINE_ZH = """ +你是一个偏好记忆比较专家。通过同时考虑 memory 字段和 preference 字段,分析新的偏好记忆是否与任何召回记忆描述相同的主题。最多只有一个召回记忆可以与新记忆匹配。 + +**任务:** 比较新的偏好记忆与召回记忆,以确定它们是否讨论相同的主题以及是否需要更新。 + +**比较标准:** +- **Memory 字段**:比较所描述的核心主题、场景和上下文 +- **Preference 字段**:比较表达的实际偏好陈述、选择和态度 +- **相同主题**:memory 和 preference 内容都涉及相同的主题 +- **不同主题**:memory 或 preference 内容有显著差异 +- **内容演变**:相同主题但偏好已改变/演变或记忆已更新 +- **内容相同**:memory 和 preference 字段本质上相同 + +**决策逻辑:** +- 核心主题相同(memory 和 preference 都相同)= 需要检查是否需要更新 +- 主题不同(memory 或 preference 有差异)= 不需要更新 +- 如果主题相同但内容已改变/演变 = 需要更新 +- 如果主题相同且内容完全相同 = 需要更新 + +**输出 JSON:** +```json +{ + "need_update": true/false, + "id": "正在更新的记忆的ID(如果不需要更新则为空字符串)", + "new_memory": "合并/演变后的更新 memory 字段(如果不需要更新则为空字符串)", + "new_preference": "合并/演变后的更新 preference 字段(如果不需要更新则为空字符串)", + "reasoning": "简要解释比较结果,同时考虑 memory 和 preference 字段" +} +``` + +**新的偏好记忆:** +{new_memory} + +**召回的偏好记忆:** +{retrieved_memories} +""" + + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE = """ # User Preference Memory Management Agent @@ -94,32 +251,168 @@ When updating a preference, you should also integrate and update the corresponding `context_summary` to ensure both fields stay semantically consistent. -You must produce a complete **operation trace**, showing which memory entries (identified by unique IDs) should be **added**, **updated**, or **deleted**, and then output the **final memory state** after all operations. +You must produce a complete **operation trace**, showing which memory entries (identified by unique IDs) should be **added**, **updated**, or **deleted**. ## Input Format -New preference memory (new_memory): -{new_memory} +New preference memories (new_memories): +{new_memories} Retrieved preference memories (retrieved_memories): {retrieved_memories} +## Task Instructions + +1. For each new memory, analyze its relationship with the retrieved memories: + - If a new memory is **unrelated** to all retrieved memories → perform `"ADD"` (insert as a new independent memory); + - If a new memory is **related** to one or more retrieved memories → perform `"UPDATE"` on those related retrieved memories (refine, supplement, or merge both the `preference` and the `context_summary`, while preserving change history trajectory information); + - If one or more retrieved memories are merged into one updated memory → perform `"DELETE"` on those retrieved memories. + +2. **Important**: Only retrieved memories that are related to the new memories should be updated or deleted. Retrieved memories that are unrelated to any new memory must be preserved. + +3. If multiple retrieved memories describe the same preference theme, merge them into one updated memory entry, combining both their `preference` information and their `context_summary` in a coherent and concise way. + +4. Output a structured list of **operation traces**, each explicitly stating: + - which memory (by ID) is affected, + - what operation is performed, + - the before/after `preference` and `context_summary`, + - and the reasoning behind it. + +## Output Format (JSON) +{ + "trace": [ + { + "op_id": "op_1", + "type": "ADD" | "UPDATE" | "DELETE", + "target_id": "(the old memory ID; null if ADD)", + "old_preference": "(the old preference text; null if ADD)", + "old_context_summary": "(the old context summary; null if ADD)", + "new_preference": "(the updated or newly created preference, if applicable)", + "new_context_summary": "(the updated or newly created context summary, if applicable)", + "reason": "(brief natural-language explanation for the decision)" + } + ] +} + +## Output Requirements + +- The output **must** be valid JSON. +- Each operation must include both `preference` and `context_summary` updates where applicable. +- Each operation must include a clear `reason`. +- Multiple retrieved memories may be merged into one unified updated memory. +- Do **not** include any explanatory text outside the JSON. +""" + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE_ZH = """ +# 用户偏好记忆管理代理 + +你是一个**用户偏好记忆管理代理**。 +你的目标是通过分析新的偏好信息并确定如何更新现有记忆,来维护用户的长期**偏好记忆库**。 + +每个记忆条目包含三个字段: +- **id**:记忆的唯一标识符。 +- **context_summary**:从中提取偏好的对话或情境的事实摘要。 +- **preference**:描述用户偏好或倾向的提取陈述。 + +更新偏好时,你还应该整合并更新相应的 `context_summary`,以确保两个字段保持语义一致。 + +你必须生成完整的**操作跟踪**,显示应该**添加**、**更新**或**删除**哪些记忆条目(通过唯一 ID 标识)。 + +## 输入格式 + +新的偏好记忆 (new_memories): +{new_memories} + +召回的偏好记忆 (retrieved_memories): +{retrieved_memories} +## 任务说明 + +1. 对于每个新记忆,分析其与召回记忆的关系: + - 如果新记忆与所有召回记忆**无关** → 执行 `"ADD"`(作为新的独立记忆插入); + - 如果新记忆与一个或多个召回记忆**相关** → 对这些相关的召回记忆执行 `"UPDATE"`(细化、补充或合并 `preference` 和 `context_summary`,同时保留变化历史轨迹信息); + - 如果一个或多个召回记忆被合并到一个更新的记忆中 → 对这些召回记忆执行 `"DELETE"`。 + +2. **重要**:只有与新记忆相关的召回记忆才应该被更新或删除。与任何新记忆都无关的召回记忆必须保留。 + +3. 如果多个召回记忆描述相同的偏好主题,将它们合并为一个更新的记忆条目,以连贯简洁的方式结合它们的 `preference` 信息和 `context_summary`。 + +4. 输出结构化的**操作跟踪**列表,每个操作明确说明: + - 受影响的记忆(通过 ID); + - 执行的操作类型; + - 更新前后的 `preference` 和 `context_summary`; + - 以及决策的原因。 + +## 输出格式 (JSON) + +{ + "trace": [ + { + "op_id": "op_1", + "type": "ADD" | "UPDATE" | "DELETE", + "target_id": "(旧记忆 ID;如果是 ADD 则为 null)", + "old_preference": "(旧的偏好文本;如果是 ADD 则为 null)", + "old_context_summary": "(旧的上下文摘要;如果是 ADD 则为 null)", + "new_preference": "(更新或新创建的偏好,如果适用)", + "new_context_summary": "(更新或新创建的上下文摘要,如果适用)", + "reason": "(决策的简要自然语言解释)" + } + ] +} + +## 输出要求 + +- 输出**必须**是有效的 JSON。 +- 每个操作必须包含 `preference` 和 `context_summary` 的更新(如果适用)。 +- 每个操作必须包含清晰的 `reason`。 +- 多个召回记忆可以合并为一个统一的更新记忆。 +- **不要**在 JSON 之外包含任何解释性文本。 +""" + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE_WITH_ONE_SHOT = """ +# User Preference Memory Management Agent + +You are a **User Preference Memory Management Agent**. +Your goal is to maintain a user's long-term **preference memory base** by analyzing new preference information and determining how it should update existing memories. + +Each memory entry contains three fields: +- **id**: a unique identifier for the memory. +- **context_summary**: a factual summary of the dialogue or situation from which the preference was extracted. +- **preference**: the extracted statement describing the user's preference or tendency. + +When updating a preference, you should also integrate and update the corresponding `context_summary` to ensure both fields stay semantically consistent. + +You must produce a complete **operation trace**, showing which memory entries (identified by unique IDs) should be **added**, **updated**, or **deleted**, and then output the **final memory state** after all operations. + +## Input Format + +New preference memories (new_memories): +{new_memories} + +Retrieved preference memories (retrieved_memories): +{retrieved_memories} ## Task Instructions -1. Analyze each retrieved memory and determine its relationship to the new memory: - - **Unrelated** → perform `"ADD"` (insert as a new independent memory); - - **Related** → perform `"UPDATE"` (refine, supplement, or merge both the `preference` and the `context_summary`); - - **Conflicting or outdated** → perform `"DELETE"` (remove obsolete or contradictory memory). +1. For each new memory, analyze its relationship with the retrieved memories: + - If a new memory is **unrelated** to all retrieved memories → perform `"ADD"` (insert as a new independent memory); + - If a new memory is **related** to one or more retrieved memories → perform `"UPDATE"` on those related retrieved memories (refine, supplement, or merge both the `preference` and the `context_summary`, while preserving change history trajectory information); + - If one or more retrieved memories are merged into one updated memory → perform `"DELETE"` on those retrieved memories. + +2. **Important**: Only retrieved memories that are related to the new memories should be updated or deleted. Retrieved memories that are unrelated to any new memory must be preserved as-is in the final state. -2. If multiple retrieved memories describe the same preference theme, merge them into one updated memory entry, combining both their `preference` information and their `context_summary` in a coherent and concise way. +3. If multiple retrieved memories describe the same preference theme, merge them into one updated memory entry, combining both their `preference` information and their `context_summary` in a coherent and concise way. -3. Output a structured list of **operation traces**, each explicitly stating: +4. Output a structured list of **operation traces**, each explicitly stating: - which memory (by ID) is affected, - what operation is performed, - the before/after `preference` and `context_summary`, - and the reasoning behind it. -4. Output the **final memory state (after_update_state)**, representing the complete preference memory base after applying all operations. +5. Output the **final memory state (after_update_state)**, representing the complete preference memory base after applying all operations. This must include: + - All newly added memories (from ADD operations) + - All updated memories (from UPDATE operations) + - All unrelated retrieved memories that were preserved unchanged ## Output Format (JSON) @@ -148,11 +441,24 @@ ## Example **Input:** -new_memory: -{ - "context_summary": "During a recent chat about study habits, the user mentioned that he often studies in quiet coffee shops and has started preferring lattes over Americanos, which he only drinks occasionally.", - "preference": "User now prefers lattes but occasionally drinks Americanos; he also enjoys studying in quiet coffee shops." -} +new_memories: +[ + { + "id": "new_id1", + "context_summary": "During a recent chat about study habits, the user mentioned that he often studies in quiet coffee shops and has started preferring lattes over Americanos, which he only drinks occasionally.", + "preference": "User now prefers lattes but occasionally drinks Americanos; he also enjoys studying in quiet coffee shops." + }, + { + "id": "new_id2", + "context_summary": "The user mentioned in a conversation about beverages that he has recently started enjoying green tea in the morning.", + "preference": "User now enjoys drinking green tea in the morning." + }, + { + "id": "new_id3", + "context_summary": "The user shared that he has recently started learning to play the guitar and practices for about 30 minutes every evening.", + "preference": "User enjoys playing guitar and practices regularly in the evenings." + } +] retrieved_memories: [ @@ -175,6 +481,11 @@ "id": "id4", "context_summary": "The user noted he doesn't drink tea very often.", "preference": "User has no particular interest in tea." + }, + { + "id": "id5", + "context_summary": "The user mentioned he enjoys running in the park on weekends.", + "preference": "User likes running outdoors on weekends." } ] @@ -189,7 +500,7 @@ "old_context_summary": "The user previously said he likes coffee in general.", "new_preference": "User likes coffee, especially lattes, but occasionally drinks Americanos.", "new_context_summary": "The user discussed his coffee habits, stating he now prefers lattes but only occasionally drinks Americanos", - "reason": "The new memory refines and expands the coffee preference and context while preserving frequency semantics ('occasionally')." + "reason": "New memory new_id1 refines and expands the coffee preference and context while preserving frequency semantics ('occasionally')." }, { "op_id": "op_2", @@ -209,7 +520,27 @@ "old_context_summary": "The user said he often works from home.", "new_preference": "User now prefers studying in quiet coffee shops instead of working from home.", "new_context_summary": "The user mentioned shifting from working at home to studying in quiet cafes, reflecting a new preferred environment.", - "reason": "The preference has changed for the working environment." + "reason": "New memory new_id1 indicates a preference change for the working environment." + }, + { + "op_id": "op_4", + "type": "UPDATE", + "target_id": "id4", + "old_preference": "User has no particular interest in tea.", + "old_context_summary": "The user noted he doesn't drink tea very often.", + "new_preference": "The user does not drink tea very often before, but now enjoys drinking green tea in the morning.", + "new_context_summary": "The user mentioned that he has recently started enjoying green tea in the morning.", + "reason": "New memory new_id2 indicates a preference change for tea consumption." + }, + { + "op_id": "op_5", + "type": "ADD", + "target_id": "new_id3", + "old_preference": null, + "old_context_summary": null, + "new_preference": "User enjoys playing guitar and practices regularly in the evenings.", + "new_context_summary": "The user shared that he has recently started learning to play the guitar and practices for about 30 minutes every evening.", + "reason": "This is a completely new preference unrelated to any existing memories, so it should be added as a new entry." } ], "after_update_state": [ @@ -225,8 +556,18 @@ }, { "id": "id4", - "context_summary": "The user noted he doesn't drink tea very often.", - "preference": "User has no particular interest in tea." + "context_summary": "The user mentioned that he has recently started enjoying green tea in the morning.", + "preference": "The user does not drink tea very often before, but now enjoys drinking green tea in the morning." + }, + { + "id": "id5", + "context_summary": "The user mentioned he enjoys running in the park on weekends.", + "preference": "User likes running outdoors on weekends." + }, + { + "id": "new_id3", + "context_summary": "The user shared that he has recently started learning to play the guitar and practices for about 30 minutes every evening.", + "preference": "User enjoys playing guitar and practices regularly in the evenings." } ] } @@ -248,3 +589,11 @@ Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. """ + + +PREF_INSTRUCTIONS_ZH = """ +# 注意: +明文记忆是事实的摘要,而偏好记忆是用户偏好的摘要。 +你的回复不得违反用户的任何偏好,无论是显式偏好还是隐式偏好,并简要解释你为什么这样回答以避免冲突。 +当遇到偏好冲突时,优先级为:显式偏好 > 隐式偏好 > 明文记忆。 +""" diff --git a/src/memos/vec_dbs/item.py b/src/memos/vec_dbs/item.py index 081400f15..c6aa1c9c2 100644 --- a/src/memos/vec_dbs/item.py +++ b/src/memos/vec_dbs/item.py @@ -47,3 +47,4 @@ class MilvusVecDBItem(VecDBItem): """Represents a single item in the Milvus vector database.""" memory: str | None = Field(default=None, description="Memory string") + original_text: str | None = Field(default=None, description="Original text content") diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index c1cb26362..e50c8ce18 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -34,18 +34,36 @@ def __init__(self, config: MilvusVecDBConfig): def create_schema(self): """Create schema for the milvus collection.""" - from pymilvus import DataType + from pymilvus import DataType, Function, FunctionType schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True) schema.add_field( field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True ) - schema.add_field(field_name="memory", datatype=DataType.VARCHAR, max_length=65535) + analyzer_params = {"tokenizer": "standard", "filter": ["lowercase"]} + schema.add_field( + field_name="memory", + datatype=DataType.VARCHAR, + max_length=65535, + analyzer_params=analyzer_params, + enable_match=True, + enable_analyzer=True, + ) + schema.add_field(field_name="original_text", datatype=DataType.VARCHAR, max_length=65535) schema.add_field( field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension ) schema.add_field(field_name="payload", datatype=DataType.JSON) + schema.add_field(field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR) + bm25_function = Function( + name="bm25", + function_type=FunctionType.BM25, + input_field_names=["memory"], + output_field_names="sparse_vector", + ) + schema.add_function(bm25_function) + return schema def create_index(self): @@ -54,6 +72,11 @@ def create_index(self): index_params.add_index( field_name="vector", index_type="FLAT", metric_type=self._get_metric_type() ) + index_params.add_index( + field_name="sparse_vector", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + ) return index_params @@ -102,12 +125,96 @@ def collection_exists(self, name: str) -> bool: """Check if a collection exists.""" return self.client.has_collection(collection_name=name) + def _dense_search( + self, + collection_name: str, + query_vector: list[float], + top_k: int, + filter: str = "", + **kwargs: Any, + ) -> list[list[dict]]: + """Dense search for similar items in the database.""" + results = self.client.search( + collection_name=collection_name, + data=[query_vector], + limit=top_k, + filter=filter, + output_fields=["*"], + anns_field="vector", + ) + return results + + def _sparse_search( + self, + collection_name: str, + query: str, + top_k: int, + filter: str = "", + **kwargs: Any, + ) -> list[list[dict]]: + """Sparse search for similar items in the database.""" + results = self.client.search( + collection_name=collection_name, + data=[query], + limit=top_k, + filter=filter, + output_fields=["*"], + anns_field="sparse_vector", + ) + return results + + def _hybrid_search( + self, + collection_name: str, + query_vector: list[float], + query: str, + top_k: int, + filter: str | None = None, + ranker_type: str = "rrf", # rrf, weighted + sparse_weight=1.0, + dense_weight=1.0, + **kwargs: Any, + ) -> list[list[dict]]: + """Hybrid search for similar items in the database.""" + from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker + + # Set up BM25 search request + expr = filter if filter else None + sparse_request = AnnSearchRequest( + data=[query], + anns_field="sparse_vector", + param={"metric_type": "BM25"}, + limit=top_k, + expr=expr, + ) + # Set up dense vector search request + dense_request = AnnSearchRequest( + data=[query_vector], + anns_field="vector", + param={"metric_type": self._get_metric_type()}, + limit=top_k, + expr=expr, + ) + ranker = ( + RRFRanker() if ranker_type == "rrf" else WeightedRanker(sparse_weight, dense_weight) + ) + results = self.client.hybrid_search( + collection_name=collection_name, + reqs=[sparse_request, dense_request], + ranker=ranker, + limit=top_k, + output_fields=["*"], + ) + return results + def search( self, query_vector: list[float], + query: str, collection_name: str, top_k: int, filter: dict[str, Any] | None = None, + search_type: str = "dense", # dense, sparse, hybrid ) -> list[MilvusVecDBItem]: """ Search for similar items in the database. @@ -124,12 +231,18 @@ def search( # Convert filter to Milvus expression expr = self._dict_to_expr(filter) if filter else "" - results = self.client.search( + search_func_map = { + "dense": self._dense_search, + "sparse": self._sparse_search, + "hybrid": self._hybrid_search, + } + + results = search_func_map[search_type]( collection_name=collection_name, - data=[query_vector], - limit=top_k, + query_vector=query_vector, + query=query, + top_k=top_k, filter=expr, - output_fields=["*"], # Return all fields ) items = [] @@ -140,6 +253,7 @@ def search( MilvusVecDBItem( id=str(entity.get("id")), memory=entity.get("memory"), + original_text=entity.get("original_text"), vector=entity.get("vector"), payload=entity.get("payload", {}), score=1 - float(hit["distance"]), @@ -196,6 +310,7 @@ def get_by_id(self, collection_name: str, id: str) -> MilvusVecDBItem | None: return MilvusVecDBItem( id=entity["id"], memory=entity.get("memory"), + original_text=entity.get("original_text"), vector=entity.get("vector"), payload=payload, ) @@ -217,6 +332,7 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt MilvusVecDBItem( id=entity["id"], memory=entity.get("memory"), + original_text=entity.get("original_text"), vector=entity.get("vector"), payload=payload, ) @@ -264,6 +380,7 @@ def get_by_filter( MilvusVecDBItem( id=entity["id"], memory=entity.get("memory"), + original_text=entity.get("original_text"), vector=entity.get("vector"), payload=payload, ) @@ -321,6 +438,7 @@ def add(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]] entity = { "id": item.id, "memory": item.memory, + "original_text": item.original_text, "vector": item.vector, "payload": item.payload if item.payload else {}, } From 4ed7574e936b6f1f2ce9a10ca6501f553274f9f9 Mon Sep 17 00:00:00 2001 From: lijicode <34564964+lijicode@users.noreply.github.com> Date: Wed, 29 Oct 2025 18:51:13 +0800 Subject: [PATCH 19/64] feat: fix polardb graph (#411) --- src/memos/graph_dbs/polardb.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 88aef6d33..3f059e8ad 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1178,6 +1178,8 @@ def get_subgraph( MATCH(center: Memory)-[r * 1..{depth}]->(neighbor:Memory) WHERE center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' RETURN collect(DISTINCT center), collect(DISTINCT @@ -1255,7 +1257,9 @@ def get_subgraph( } ) - return {"core_node": core_node, "neighbors": neighbors, "edges": edges} + return self._convert_graph_edges( + {"core_node": core_node, "neighbors": neighbors, "edges": edges} + ) except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) @@ -2839,3 +2843,25 @@ def get_edges( except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] + + def _convert_graph_edges(self, core_node: dict) -> dict: + import copy + + data = copy.deepcopy(core_node) + id_map = {} + core_node = data.get("core_node", {}) + core_meta = core_node.get("metadata", {}) + if "graph_id" in core_meta and "id" in core_node: + id_map[core_meta["graph_id"]] = core_node["id"] + for neighbor in data.get("neighbors", []): + n_meta = neighbor.get("metadata", {}) + if "graph_id" in n_meta and "id" in neighbor: + id_map[n_meta["graph_id"]] = neighbor["id"] + for edge in data.get("edges", []): + src = edge.get("source") + tgt = edge.get("target") + if src in id_map: + edge["source"] = id_map[src] + if tgt in id_map: + edge["target"] = id_map[tgt] + return data From fef40e9e905836fa5de0c849ad82528758b82bf4 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Wed, 29 Oct 2025 19:27:00 +0800 Subject: [PATCH 20/64] feat: async add api (#410) * feat: update manager for async add * feat: modify tree and simple_tree, TODO: STILL NOT ALIGN IN SOME FUNCTIONS * feat: modify schedule: add optional user_name in schedule message; modify user-name related graph query in scheduler * feat: finishe server router for async mode * feat: update graph db * fix: add label in core * feat: add tree mode in config * feat: default llm token 8000 * fix: thread * feat: search mode in client: fast * tests: fix --- evaluation/scripts/utils/client.py | 2 +- src/memos/api/config.py | 3 +- src/memos/api/routers/server_router.py | 255 ++++++++++++++++-- src/memos/graph_dbs/nebular.py | 13 +- src/memos/graph_dbs/neo4j.py | 1 + src/memos/mem_os/core.py | 39 +-- src/memos/mem_scheduler/base_scheduler.py | 4 + src/memos/mem_scheduler/general_scheduler.py | 47 +++- .../mem_scheduler/schemas/message_schemas.py | 7 + src/memos/memories/textual/simple_tree.py | 27 +- src/memos/memories/textual/tree.py | 20 +- .../tree_text_memory/organize/manager.py | 106 ++++++-- tests/memories/textual/test_tree.py | 8 +- 13 files changed, 419 insertions(+), 113 deletions(-) diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index e1bdd54e9..9b686a131 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -181,7 +181,7 @@ def search(self, query, user_id, top_k): "mem_cube_id": user_id, "conversation_id": "", "top_k": top_k, - "mode": "mixture", + "mode": "fast", "handle_pref_mem": False, }, ensure_ascii=False, diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 6de013313..92df1ecf8 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -23,7 +23,7 @@ def get_openai_config() -> dict[str, Any]: return { "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-4o-mini"), "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")), - "max_tokens": int(os.getenv("MOS_MAX_TOKENS", "1024")), + "max_tokens": int(os.getenv("MOS_MAX_TOKENS", "8000")), "top_p": float(os.getenv("MOS_TOP_P", "0.9")), "top_k": int(os.getenv("MOS_TOP_K", "50")), "remove_think_prefix": True, @@ -672,6 +672,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, + "mode": os.getenv("ASYNC_MODE", "sync"), }, }, "act_mem": {} diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index bb98f04ba..e9df292ad 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,9 +1,13 @@ +import json import os +import time import traceback +from datetime import datetime from typing import TYPE_CHECKING, Any from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse from memos.api.config import APIConfig from memos.api.product_models import ( @@ -32,8 +36,12 @@ from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + MEM_READ_LABEL, + PREF_ADD_LABEL, SearchMode, ) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.memories.textual.prefer_text_memory.config import ( AdderConfigFactory, ExtractorConfigFactory, @@ -233,6 +241,7 @@ def init_server(): chat_llm=llm, process_llm=mem_reader.llm, db_engine=BaseDBManager.create_default_sqlite_engine(), + mem_reader=mem_reader, ) mem_scheduler.current_mem_cube = naive_mem_cube mem_scheduler.start() @@ -477,6 +486,13 @@ def add_memories(add_req: APIADDRequest): if not target_session_id: target_session_id = "default_session" + # If text memory backend works in async mode, submit tasks to scheduler + try: + sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") + except Exception: + sync_mode = "sync" + logger.info(f"Add sync_mode mode is: {sync_mode}") + def _process_text_mem() -> list[dict[str, str]]: memories_local = mem_reader.get_memory( [add_req.messages], @@ -485,6 +501,7 @@ def _process_text_mem() -> list[dict[str, str]]: "user_id": add_req.user_id, "session_id": target_session_id, }, + mode="fast" if sync_mode == "async" else "fine", ) flattened_local = [mm for m in memories_local for mm in m] logger.info(f"Memory extraction completed for user {add_req.user_id}") @@ -496,6 +513,34 @@ def _process_text_mem() -> list[dict[str, str]]: f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " f"in session {add_req.session_id}: {mem_ids_local}" ) + if sync_mode == "async": + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids_local), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + mem_scheduler.submit_messages(messages=[message_item_read]) + logger.info(f"2105Submit messages!!!!!: {json.dumps(mem_ids_local)}") + except Exception as e: + logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) + else: + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids_local), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + mem_scheduler.submit_messages(messages=[message_item_add]) return [ { "memory": memory.memory, @@ -508,27 +553,46 @@ def _process_text_mem() -> list[dict[str, str]]: def _process_pref_mem() -> list[dict[str, str]]: if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] - pref_memories_local = naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - ) - pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) - logger.info( - f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " - f"in session {add_req.session_id}: {pref_ids_local}" - ) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] + # Follow async behavior similar to core.py: enqueue when async + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + mem_scheduler.submit_messages(messages=[message_item_pref]) + logger.info("Submitted preference add to scheduler (async mode)") + except Exception as e: + logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) + return [] + else: + pref_memories_local = naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) + logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_process_text_mem) @@ -542,6 +606,155 @@ def _process_pref_mem() -> list[dict[str, str]]: ) +@router.get("/scheduler/status", summary="Get scheduler running task count") +def scheduler_status(): + """ + Return current running tasks count from scheduler dispatcher. + Shape is consistent with /scheduler/wait. + """ + try: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + now_ts = time.time() + + return { + "message": "ok", + "data": { + "running_tasks": running_count, + "timestamp": now_ts, + }, + } + + except Exception as err: + logger.error("Failed to get scheduler status: %s", traceback.format_exc()) + + raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err + + +@router.post("/scheduler/wait", summary="Wait until scheduler is idle") +def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2): + """ + Block until scheduler has no running tasks, or timeout. + We return a consistent structured payload so callers can + tell whether this was a clean flush or a timeout. + + Args: + timeout_seconds: max seconds to wait + poll_interval: seconds between polls + """ + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + elapsed = time.time() - start + + # success -> scheduler is idle + if running_count == 0: + return { + "message": "idle", + "data": { + "running_tasks": 0, + "waited_seconds": round(elapsed, 3), + "timed_out": False, + }, + } + + # timeout check + if elapsed > timeout_seconds: + return { + "message": "timeout", + "data": { + "running_tasks": running_count, + "waited_seconds": round(elapsed, 3), + "timed_out": True, + }, + } + + time.sleep(poll_interval) + + except Exception as err: + logger.error( + "Failed while waiting for scheduler: %s", + traceback.format_exc(), + ) + raise HTTPException( + status_code=500, + detail="Failed while waiting for scheduler", + ) from err + + +@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)") +def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2): + """ + Stream scheduler progress via Server-Sent Events (SSE). + + Contract: + - We emit periodic heartbeat frames while tasks are still running. + - Each heartbeat frame is JSON, prefixed with "data: ". + - On final frame, we include status = "idle" or "timeout" and timed_out flag, + with the same semantics as /scheduler/wait. + + Example curl: + curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5" + """ + + def event_generator(): + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + elapsed = time.time() - start + + # heartbeat frame + heartbeat_payload = { + "running_tasks": running_count, + "elapsed_seconds": round(elapsed, 3), + "status": "running" if running_count > 0 else "idle", + } + yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n" + + # scheduler is idle -> final frame + break + if running_count == 0: + final_payload = { + "running_tasks": 0, + "elapsed_seconds": round(elapsed, 3), + "status": "idle", + "timed_out": False, + } + yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n" + break + + # timeout -> final frame + break + if elapsed > timeout_seconds: + final_payload = { + "running_tasks": running_count, + "elapsed_seconds": round(elapsed, 3), + "status": "timeout", + "timed_out": True, + } + yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n" + break + + time.sleep(poll_interval) + + except Exception as e: + err_payload = { + "status": "error", + "detail": "stream_failed", + "exception": str(e), + } + logger.error( + "Failed streaming scheduler wait: %s: %s", + e, + traceback.format_exc(), + ) + yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + @router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") def chat_complete(chat_req: APIChatCompleteRequest): """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 00bd04e6d..89b58f417 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -439,6 +439,7 @@ def remove_oldest_memory( Args: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name(str): optional user_name. """ try: user_name = user_name if user_name else self.config.user_name @@ -685,8 +686,7 @@ def get_node( Returns: dict: Node properties as key-value pairs, or None if not found. """ - user_name = user_name if user_name else self.config.user_name - filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"' + filter_clause = f'n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" MATCH (n@Memory) @@ -730,16 +730,13 @@ def get_nodes( """ if not ids: return [] - - user_name = user_name if user_name else self.config.user_name - where_user = f" AND n.user_name = '{user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) return_fields = self._build_return_fields(include_embedding) query = f""" MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.id IN [{id_list}] {where_user} + WHERE n.id IN [{id_list}] RETURN {return_fields} """ nodes = [] @@ -1497,10 +1494,10 @@ def _ensure_space_exists(cls, tmp_client, cfg): return try: - res = tmp_client.execute("SHOW GRAPHS;") + res = tmp_client.execute("SHOW GRAPHS") existing = {row.values()[0].as_string() for row in res} if db_name not in existing: - tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;") + tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type") logger.info(f"✅ Graph `{db_name}` created before session binding.") else: logger.debug(f"Graph `{db_name}` already exists.") diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index fd3a1ba22..f3a36a887 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -149,6 +149,7 @@ def remove_oldest_memory( Args: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name(str): optional user_name. """ user_name = user_name if user_name else self.config.user_name query = f""" diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 939b0c68d..97ff9879f 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -779,16 +779,16 @@ def process_textual_memory(): timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) - - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) + else: + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) def process_preference_memory(): if ( @@ -878,15 +878,16 @@ def process_preference_memory(): timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) + else: + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) # user doc input if ( diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 3958ee382..0360396af 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -134,6 +134,7 @@ def initialize_modules( chat_llm: BaseLLM, process_llm: BaseLLM | None = None, db_engine: Engine | None = None, + mem_reader=None, ): if process_llm is None: process_llm = chat_llm @@ -150,6 +151,9 @@ def initialize_modules( self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config) self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) + if mem_reader: + self.mem_reader = mem_reader + if self.enable_parallel_dispatch: self.dispatcher_monitor.initialize(dispatcher=self.dispatcher) self.dispatcher_monitor.start() diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 434cef3e9..6840adc2b 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -250,6 +250,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id mem_cube = message.mem_cube content = message.content + user_name = message.user_name # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -273,6 +274,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id=mem_cube_id, mem_cube=mem_cube, text_mem=text_mem, + user_name=user_name, ) logger.info( @@ -297,6 +299,7 @@ def _process_memories_with_reader( mem_cube_id: str, mem_cube: GeneralMemCube, text_mem: TreeTextMemory, + user_name: str, ) -> None: """ Process memories using mem_reader for enhanced memory processing. @@ -330,6 +333,18 @@ def _process_memories_with_reader( logger.warning("No valid memory items found for processing") return + # parse working_binding ids from the *original* memory_items (the raw items created in /add) + # these still carry metadata.background with "[working_binding:...]" so we can know + # which WorkingMemory clones should be cleaned up later. + from memos.memories.textual.tree_text_memory.organize.manager import ( + extract_working_binding_ids, + ) + + bindings_to_delete = extract_working_binding_ids(memory_items) + logger.info( + f"Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" + ) + # Use mem_reader to process the memories logger.info(f"Processing {len(memory_items)} memories with mem_reader") @@ -353,7 +368,7 @@ def _process_memories_with_reader( # Add the enhanced memories back to the memory system if flattened_memories: - enhanced_mem_ids = text_mem.add(flattened_memories) + enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name) logger.info( f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" ) @@ -362,9 +377,26 @@ def _process_memories_with_reader( else: logger.info("mem_reader returned no processed memories") - text_mem.delete(mem_ids) - logger.info("Delete raw mem_ids") - text_mem.memory_manager.remove_and_refresh_memory() + # build full delete list: + # - original raw mem_ids (temporary fast memories) + # - any bound working memories referenced by the enhanced memories + delete_ids = list(mem_ids) + if bindings_to_delete: + delete_ids.extend(list(bindings_to_delete)) + # deduplicate + delete_ids = list(dict.fromkeys(delete_ids)) + if delete_ids: + try: + text_mem.delete(delete_ids, user_name=user_name) + logger.info( + f"Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" + ) + except Exception as e: + logger.warning(f"Failed to delete some mem_ids {delete_ids}: {e}") + else: + logger.info("No mem_ids to delete (nothing to cleanup)") + + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) logger.info("Remove and Refresh Memories") logger.debug(f"Finished add {user_id} memory: {mem_ids}") @@ -382,6 +414,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id mem_cube = message.mem_cube content = message.content + user_name = message.user_name # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -405,6 +438,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id=mem_cube_id, mem_cube=mem_cube, text_mem=text_mem, + user_name=user_name, ) logger.info( @@ -429,6 +463,7 @@ def _process_memories_with_reorganize( mem_cube_id: str, mem_cube: GeneralMemCube, text_mem: TreeTextMemory, + user_name: str, ) -> None: """ Process memories using mem_reorganize for enhanced memory processing. @@ -455,7 +490,7 @@ def _process_memories_with_reorganize( memory_item = text_mem.get(mem_id) memory_items.append(memory_item) except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}") + logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}") continue if not memory_items: @@ -464,7 +499,7 @@ def _process_memories_with_reorganize( # Use mem_reader to process the memories logger.info(f"Processing {len(memory_items)} memories with mem_reader") - text_mem.memory_manager.remove_and_refresh_memory() + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) logger.info("Remove and Refresh Memories") logger.debug(f"Finished add {user_id} memory: {mem_ids}") diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index bd3155a96..9cdb6823d 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -42,6 +42,10 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): timestamp: datetime = Field( default_factory=get_utc_now, description="submit time for schedule_messages" ) + user_name: str | None = Field( + default=None, + description="user name / display name (optional)", + ) # Pydantic V2 model configuration model_config = ConfigDict( @@ -60,6 +64,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "mem_cube": "obj of GeneralMemCube", # Added mem_cube example "content": "sample content", # Example message content "timestamp": "2024-07-22T12:00:00Z", # Added timestamp example + "user_name": "Alice", # Added username example } }, ) @@ -81,6 +86,7 @@ def to_dict(self) -> dict: "cube": "Not Applicable", # Custom cube serialization "content": self.content, "timestamp": self.timestamp.isoformat(), + "user_name": self.user_name, } @classmethod @@ -94,6 +100,7 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), + user_name=data.get("user_name"), ) diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 52bf62c6d..8d07522cd 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -44,6 +44,8 @@ def __init__( """Initialize memory with the given configuration.""" time_start = time.time() self.config: TreeTextMemoryConfig = config + self.mode = self.config.mode + logger.info(f"Tree mode is {self.mode}") self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") @@ -79,20 +81,6 @@ def __init__( logger.info("No internet retriever configured") logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") - def add( - self, memories: list[TextualMemoryItem | dict[str, Any]], user_name: str | None = None - ) -> list[str]: - """Add memories. - Args: - memories: List of TextualMemoryItem objects or dictionaries to add. - Later: - memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] - metadata = extract_metadata(memory_items, self.extractor_llm) - plan = plan_memory_operations(memory_items, metadata, self.graph_store) - execute_plan(memory_items, metadata, plan, self.graph_store) - """ - return self.memory_manager.add(memories, user_name=user_name) - def replace_working_memory( self, memories: list[TextualMemoryItem], user_name: str | None = None ) -> None: @@ -271,17 +259,6 @@ def get(self, memory_id: str) -> TextualMemoryItem: def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: raise NotImplementedError - def get_all(self) -> dict: - """Get all memories. - Returns: - list[TextualMemoryItem]: List of all memories. - """ - all_items = self.graph_store.export_graph() - return all_items - - def delete(self, memory_ids: list[str]) -> None: - raise NotImplementedError - def delete_all(self) -> None: """Delete all memories and their relationships from the graph store.""" try: diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index fccd83fa6..472bed219 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -34,6 +34,8 @@ def __init__(self, config: TreeTextMemoryConfig): """Initialize memory with the given configuration.""" # Set mode from class default or override if needed self.mode = config.mode + logger.info(f"Tree mode is {self.mode}") + self.config: TreeTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm @@ -81,12 +83,18 @@ def __init__(self, config: TreeTextMemoryConfig): else: logger.info("No internet retriever configured") - def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: + def add( + self, + memories: list[TextualMemoryItem | dict[str, Any]], + user_name: str | None = None, + **kwargs, + ) -> list[str]: """Add memories. Args: memories: List of TextualMemoryItem objects or dictionaries to add. + user_name: optional user_name """ - return self.memory_manager.add(memories, mode=self.mode) + return self.memory_manager.add(memories, user_name=user_name, mode=self.mode) def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: self.memory_manager.replace_working_memory(memories) @@ -262,21 +270,21 @@ def get(self, memory_id: str) -> TextualMemoryItem: def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: raise NotImplementedError - def get_all(self) -> dict: + def get_all(self, user_name: str | None = None) -> dict: """Get all memories. Returns: list[TextualMemoryItem]: List of all memories. """ - all_items = self.graph_store.export_graph() + all_items = self.graph_store.export_graph(user_name=user_name) return all_items - def delete(self, memory_ids: list[str]) -> None: + def delete(self, memory_ids: list[str], user_name: str | None = None) -> None: """Hard delete: permanently remove nodes and their edges from the graph.""" if not memory_ids: return for mid in memory_ids: try: - self.graph_store.delete_node(mid) + self.graph_store.delete_node(mid, user_name=user_name) except Exception as e: logger.warning(f"TreeTextMemory.delete_hard: failed to delete {mid}: {e}") diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 54776134b..47cbf4ed1 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -1,3 +1,4 @@ +import re import traceback import uuid @@ -19,6 +20,37 @@ logger = get_logger(__name__) +def extract_working_binding_ids(mem_items: list[TextualMemoryItem]) -> set[str]: + """ + Scan enhanced memory items for background hints like + "[working_binding:]" and collect those working memory IDs. + + We store the working<->long binding inside metadata.background when + initially adding memories in async mode, so we can later clean up + the temporary WorkingMemory nodes after mem_reader produces the + final LongTermMemory/UserMemory. + + Args: + mem_items: list of TextualMemoryItem we just added (enhanced memories) + + Returns: + A set of working memory IDs (as strings) that should be deleted. + """ + bindings: set[str] = set() + pattern = re.compile(r"\[working_binding:([0-9a-fA-F-]{36})\]") + for item in mem_items: + try: + bg = getattr(item.metadata, "background", "") or "" + except Exception: + bg = "" + if not isinstance(bg, str): + continue + match = pattern.search(bg) + if match: + bindings.add(match.group(1)) + return bindings + + class MemoryManager: def __init__( self, @@ -129,15 +161,28 @@ def _refresh_memory_size(self, user_name: str | None = None) -> None: def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None): """ - Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). - This method runs asynchronously to process each memory item. + Process and add memory to different memory types. + + Behavior: + 1. Always create a WorkingMemory node from `memory` and get its node id. + 2. If `memory.metadata.memory_type` is "LongTermMemory" or "UserMemory", + also create a corresponding long/user node. + - In async mode, that long/user node's metadata will include + `working_binding` in `background` which records the WorkingMemory + node id created in step 1. + 3. Return ONLY the ids of the long/user nodes (NOT the working node id), + which preserves the previous external contract of `add()`. """ ids: list[str] = [] futures = [] + working_id = str(uuid.uuid4()) + with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: - f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name) - futures.append(f_working) + f_working = ex.submit( + self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id + ) + futures.append(("working", f_working)) if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"): f_graph = ex.submit( @@ -145,13 +190,14 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name, + working_binding=working_id, ) - futures.append(f_graph) + futures.append(("long", f_graph)) - for fut in as_completed(futures): + for kind, fut in futures: try: res = fut.result() - if isinstance(res, str) and res: + if kind != "working" and isinstance(res, str) and res: ids.append(res) except Exception: logger.warning("Parallel memory processing failed:\n%s", traceback.format_exc()) @@ -159,39 +205,51 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non return ids def _add_memory_to_db( - self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + self, + memory: TextualMemoryItem, + memory_type: str, + user_name: str | None = None, + forced_id: str | None = None, ) -> str: """ Add a single memory item to the graph store, with FIFO logic for WorkingMemory. + If forced_id is provided, use that as the node id. """ metadata = memory.metadata.model_copy(update={"memory_type": memory_type}).model_dump( exclude_none=True ) metadata["updated_at"] = datetime.now().isoformat() - working_memory = TextualMemoryItem(memory=memory.memory, metadata=metadata) - + node_id = forced_id or str(uuid.uuid4()) + working_memory = TextualMemoryItem(id=node_id, memory=memory.memory, metadata=metadata) # Insert node into graph self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) + return node_id def _add_to_graph_memory( - self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + self, + memory: TextualMemoryItem, + memory_type: str, + user_name: str | None = None, + working_binding: str | None = None, ): """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). - - Parameters: - - memory: memory item to insert - - memory_type: "LongTermMemory" | "UserMemory" - - similarity_threshold: deduplication threshold - - topic_summary_prefix: summary node id prefix if applicable - - enable_summary_link: whether to auto-link to a summary node """ node_id = str(uuid.uuid4()) # Step 2: Add new node to graph + metadata_dict = memory.metadata.model_dump(exclude_none=True) + tags = metadata_dict.get("tags") or [] + if working_binding and ("mode:fast" in tags): + prev_bg = metadata_dict.get("background", "") or "" + binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" + if prev_bg: + metadata_dict["background"] = prev_bg + " || " + binding_line + else: + metadata_dict["background"] = binding_line self.graph_store.add_node( node_id, memory.memory, - memory.metadata.model_dump(exclude_none=True), + metadata_dict, user_name=user_name, ) self.reorganizer.add_message( @@ -282,11 +340,11 @@ def _ensure_structure_path( # Step 3: Return this structure node ID as the parent_id return node_id - def remove_and_refresh_memory(self): - self._cleanup_memories_if_needed() - self._refresh_memory_size() + def remove_and_refresh_memory(self, user_name: str | None = None): + self._cleanup_memories_if_needed(user_name=user_name) + self._refresh_memory_size(user_name=user_name) - def _cleanup_memories_if_needed(self) -> None: + def _cleanup_memories_if_needed(self, user_name: str | None = None) -> None: """ Only clean up memories if we're close to or over the limit. This reduces unnecessary database operations. @@ -301,7 +359,7 @@ def _cleanup_memories_if_needed(self) -> None: if current_count >= threshold: try: self.graph_store.remove_oldest_memory( - memory_type=memory_type, keep_latest=limit + memory_type=memory_type, keep_latest=limit, user_name=user_name ) logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}") except Exception: diff --git a/tests/memories/textual/test_tree.py b/tests/memories/textual/test_tree.py index 772a79d78..a72709ec5 100644 --- a/tests/memories/textual/test_tree.py +++ b/tests/memories/textual/test_tree.py @@ -66,7 +66,9 @@ def test_add_calls_manager(mock_tree_text_memory): metadata=TreeNodeTextualMemoryMetadata(updated_at=None), ) mock_tree_text_memory.add([mock_item]) - mock_tree_text_memory.memory_manager.add.assert_called_once_with([mock_item], mode="sync") + mock_tree_text_memory.memory_manager.add.assert_called_once_with( + [mock_item], user_name=None, mode="sync" + ) def test_get_working_memory_sorted(mock_tree_text_memory): @@ -161,4 +163,6 @@ def test_add_returns_ids(mock_tree_text_memory): result = mock_tree_text_memory.add(mock_items) assert result == dummy_ids - mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items, mode="sync") + mock_tree_text_memory.memory_manager.add.assert_called_once_with( + mock_items, user_name=None, mode="sync" + ) From 6e219c4a0811917c33ae20c9209da08fbdda25ea Mon Sep 17 00:00:00 2001 From: lijicode <34564964+lijicode@users.noreply.github.com> Date: Wed, 29 Oct 2025 21:19:35 +0800 Subject: [PATCH 21/64] use nacos (#407) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * fix:nacos * feat: fix config Exception * feat: format config * feat: format config --------- Co-authored-by: ccl <13282138256@163.com> --- src/memos/api/config.py | 240 +++++++++++++++++++++++++++++++++ src/memos/graph_dbs/polardb.py | 1 - 2 files changed, 240 insertions(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 92df1ecf8..7ac882d6c 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -1,8 +1,16 @@ +import base64 +import hashlib +import hmac import json +import logging import os +import re +import time from typing import Any +import requests + from dotenv import load_dotenv from memos.configs.mem_cube import GeneralMemCubeConfig @@ -13,6 +21,238 @@ # Load environment variables load_dotenv() +logger = logging.getLogger(__name__) + + +def _update_env_from_dict(data: dict[str, Any]) -> None: + """Apply a dict to environment variables, with change logging.""" + + def _is_sensitive(name: str) -> bool: + n = name.upper() + return any(s in n for s in ["PASSWORD", "SECRET", "AK", "SK", "TOKEN", "KEY"]) + + for k, v in data.items(): + if isinstance(v, dict): + new_val = json.dumps(v, ensure_ascii=False) + elif isinstance(v, bool): + new_val = "true" if v else "false" + elif v is None: + new_val = "" + else: + new_val = str(v) + + old_val = os.environ.get(k) + os.environ[k] = new_val + + try: + log_old = "***" if _is_sensitive(k) else (old_val if old_val is not None else "") + log_new = "***" if _is_sensitive(k) else new_val + if old_val != new_val: + logger.info(f"Nacos config update: {k}={log_new} (was {log_old})") + except Exception as e: + # Avoid logging failures blocking config updates + logger.debug(f"Skip logging change for {k}: {e}") + + +def get_config_json(name: str, default: Any | None = None) -> Any: + """Read JSON object/array from env and parse. Returns default on missing/invalid.""" + raw = os.getenv(name) + if not raw: + return default + try: + return json.loads(raw) + except Exception: + logger.warning(f"Invalid JSON in env '{name}', returning default.") + return default + + +def get_config_value(path: str, default: Any | None = None) -> Any: + """Read value from env with optional dot-path for structured configs. + + Examples: + - get_config_value("MONGODB_CONFIG.base_uri") + - get_config_value("MONGODB_BASE_URI") + """ + if "." not in path: + val = os.getenv(path) + return val if val is not None else default + root, *subkeys = path.split(".") + data = get_config_json(root, default=None) + if not isinstance(data, dict): + return default + cur: Any = data + for key in subkeys: + if isinstance(cur, dict) and key in cur: + cur = cur[key] + else: + return default + return cur + + +class NacosConfigManager: + _client = None + _data_id = None + _group = None + _enabled = False + + # Pre-compile regex patterns for better performance + _KEY_VALUE_PATTERN = re.compile(r"^([^=]+)=(.*)$") + _INTEGER_PATTERN = re.compile(r"^[+-]?\d+$") + _FLOAT_PATTERN = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$") + + @classmethod + def _sign(cls, secret_key: str, data: str) -> str: + """HMAC-SHA1 sgin""" + signature = hmac.new(secret_key.encode("utf-8"), data.encode("utf-8"), hashlib.sha1) + return base64.b64encode(signature.digest()).decode() + + @staticmethod + def _parse_value(value: str) -> Any: + """Parse string value to appropriate Python type. + + Supports: bool, int, float, and string. + """ + if not value: + return value + + val_lower = value.lower() + + # Boolean + if val_lower in ("true", "false"): + return val_lower == "true" + + # Integer + if NacosConfigManager._INTEGER_PATTERN.match(value): + try: + return int(value) + except (ValueError, OverflowError): + return value + + # Float + if NacosConfigManager._FLOAT_PATTERN.match(value): + try: + return float(value) + except (ValueError, OverflowError): + return value + + # Default to string + return value + + @staticmethod + def parse_properties(content: str) -> dict[str, Any]: + """Parse properties file content to dictionary with type inference. + + Supports: + - Comments (lines starting with #) + - Key-value pairs (KEY=VALUE) + - Type inference (bool, int, float, string) + """ + data: dict[str, Any] = {} + + for line in content.splitlines(): + line = line.strip() + + # Skip empty lines and comments + if not line or line.startswith("#"): + continue + + # Parse key-value pair + match = NacosConfigManager._KEY_VALUE_PATTERN.match(line) + if match: + key = match.group(1).strip() + value = match.group(2).strip() + data[key] = NacosConfigManager._parse_value(value) + + return data + + @classmethod + def start_config_watch(cls): + while True: + cls.init() + time.sleep(60) + + @classmethod + def start_watch_if_enabled(cls) -> None: + enable = os.getenv("NACOS_ENABLE_WATCH", "false").lower() == "true" + print("enable:", enable) + if not enable: + return + interval = int(os.getenv("NACOS_WATCH_INTERVAL", "60")) + import threading + + def _loop() -> None: + while True: + try: + cls.init() + except Exception as e: + logger.error(f"❌ Nacos watch loop error: {e}") + time.sleep(interval) + + threading.Thread(target=_loop, daemon=True).start() + logger.info(f"Nacos watch thread started (interval={interval}s).") + + @classmethod + def init(cls) -> None: + server_addr = os.getenv("NACOS_SERVER_ADDR") + data_id = os.getenv("NACOS_DATA_ID") + group = os.getenv("NACOS_GROUP", "DEFAULT_GROUP") + namespace = os.getenv("NACOS_NAMESPACE", "") + ak = os.getenv("AK") + sk = os.getenv("SK") + + if not (server_addr and data_id and ak and sk): + logger.warning("❌ missing NACOS_SERVER_ADDR / AK / SK / DATA_ID") + return + + base_url = f"http://{server_addr}/nacos/v1/cs/configs" + + def _auth_headers(): + ts = str(int(time.time() * 1000)) + + sign_data = namespace + "+" + group + "+" + ts if namespace else group + "+" + ts + signature = cls._sign(sk, sign_data) + return { + "Spas-AccessKey": ak, + "Spas-Signature": signature, + "timeStamp": ts, + } + + try: + params = { + "dataId": data_id, + "group": group, + "tenant": namespace, + } + + headers = _auth_headers() + resp = requests.get(base_url, headers=headers, params=params, timeout=10) + + if resp.status_code != 200: + logger.error(f"Nacos AK/SK fail: {resp.status_code} {resp.text}") + return + + content = resp.text.strip() + if not content: + logger.warning("⚠️ Nacos is empty") + return + try: + data_props = cls.parse_properties(content) + logger.info("nacos config:", data_props) + _update_env_from_dict(data_props) + logger.info("✅ parse Nacos setting is Properties ") + except Exception as e: + logger.error(f"⚠️ Nacos parse fail(not JSON/YAML/Properties): {e}") + raise Exception(f"Nacos configuration parsing failed: {e}") from e + + except Exception as e: + logger.error(f"❌ Nacos AK/SK init fail: {e}") + raise Exception(f"❌ Nacos AK/SK init fail: {e}") from e + + +# init Nacos +NacosConfigManager.init() +NacosConfigManager.start_watch_if_enabled() + class APIConfig: """Centralized configuration management for MemOS APIs.""" diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 3f059e8ad..971a56e04 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -833,7 +833,6 @@ def get_nodes( # Parse embedding from JSONB if it exists if embedding_json is not None: try: - print("embedding_json:", embedding_json) # remove embedding """ embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json From f74ea76a31cb26ee462f162b681f1fb3026d4281 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Wed, 29 Oct 2025 21:28:42 +0800 Subject: [PATCH 22/64] feat: async add api (#413) * feat: update manager for async add * feat: modify tree and simple_tree, TODO: STILL NOT ALIGN IN SOME FUNCTIONS * feat: modify schedule: add optional user_name in schedule message; modify user-name related graph query in scheduler * feat: finishe server router for async mode * feat: update graph db * fix: add label in core * feat: add tree mode in config * feat: default llm token 8000 * fix: thread * feat: search mode in client: fast * tests: fix * fix: add some log for memory_size in manager --------- Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- .../memories/textual/tree_text_memory/organize/manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 47cbf4ed1..01ccc382b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -351,9 +351,10 @@ def _cleanup_memories_if_needed(self, user_name: str | None = None) -> None: """ cleanup_threshold = 0.8 # Clean up when 80% full + logger.info(f"self.memory_size: {self.memory_size}") for memory_type, limit in self.memory_size.items(): current_count = self.current_memory_size.get(memory_type, 0) - threshold = int(limit * cleanup_threshold) + threshold = int(int(limit) * cleanup_threshold) # Only clean up if we're at or above the threshold if current_count >= threshold: From 59230010919f6807a7d2c0773bb0f66850b1b781 Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Wed, 29 Oct 2025 21:29:05 +0800 Subject: [PATCH 23/64] revision of mixture api: add conversation turn and reduce 2 stage ranking to 1 stage (#405) * debug an error function name * feat: Add DynamicCache compatibility for different transformers versions - Fix build_kv_cache method in hf.py to handle both old and new DynamicCache structures - Support new 'layers' attribute with key_cache/value_cache or keys/values - Maintain backward compatibility with direct key_cache/value_cache attributes - Add comprehensive error handling and logging for unsupported structures - Update move_dynamic_cache_htod function in kv.py for cross-version compatibility - Handle layers-based structure in newer transformers versions - Support alternative attribute names (keys/values vs key_cache/value_cache) - Preserve original functionality for older transformers versions - Add comprehensive tests for DynamicCache compatibility - Test activation memory update with mock DynamicCache layers - Verify layers attribute access across different transformers versions - Fix scheduler logger mock to include memory_manager attribute This resolves AttributeError issues when using different versions of the transformers library and ensures robust handling of DynamicCache objects. debug * feat: implement APIAnalyzerForScheduler for memory operations - Add APIAnalyzerForScheduler class with search/add operations - Support requests and http.client with connection reuse - Include comprehensive error handling and dynamic configuration - Add English test suite with real-world conversation scenarios * feat: Add search_ws API endpoint and enhance API analyzer functionality - Add search_ws endpoint in server_router.py for scheduler-enabled search - Fix missing imports: time module, SearchRequest class, and get_mos_product_instance function - Implement search_ws method in api_analyzer.py with HTTP client support - Add _search_ws_with_requests and _search_ws_with_http_client private methods - Include search_ws usage example in demonstration code - Enhance scheduler and dispatcher capabilities for improved memory management - Expand test coverage to ensure functionality stability This update primarily strengthens the memory scheduling system's search capabilities, providing users with more flexible API interface options. * fix: resolve test failures and warnings in test suite - Fix Pydantic serialization warning in test_memos_chen_tang_hello_world * Add warnings filter to suppress UserWarning from Pydantic serialization - Fix KeyError: 'past_key_values' in test_build_kv_cache_and_generation * Update mock configuration to properly return forward_output with past_key_values * Add DynamicCache version compatibility handling in test mocks * Support both old and new transformers versions with layers/key_cache attributes * Improve assertion logic to check all model calls for required parameters - Update base_scheduler.py to use centralized DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE constant * Add import for DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE from general_schemas * Replace hardcoded value 100 with configurable constant (1000) All tests now pass successfully with proper version compatibility handling. * feat: add a test_robustness execution to test thread pool execution * feat: optimize scheduler configuration and API search functionality - Add DEFAULT_TOP_K and DEFAULT_CONTEXT_WINDOW_SIZE global constants in general_schemas.py - Update base_scheduler.py to use global default values instead of hardcoded numbers - Fix SchedulerConfigFactory initialization issue by using keyword argument expansion - Resolve UnboundLocalError variable conflict in search_memories_ws function - Fix indentation and parameter issues in OptimizedScheduler search_for_api method - Improve code standardization and maintainability * feat: Add Redis auto-initialization with fallback strategies - Add auto_initialize_redis() with config/env/local fallback - Move Redis logic from dispatcher_monitor to redis_service - Update base_scheduler to use auto initialization - Add proper resource cleanup and error handling * feat: add database connection management to ORM module - Add MySQL engine loading from environment variables in BaseDBManager - Add Redis connection loading from environment variables in BaseDBManager - Enhance database configuration validation and error handling - Complete database adapter infrastructure for ORM module - Provide unified database connection management interface This update provides comprehensive database connection management capabilities for the mem_scheduler module, supporting dynamic MySQL and Redis configuration loading from environment variables, establishing reliable data persistence foundation for scheduling services and API services. * remove part of test * feat: add Redis-based ORM with multiprocess synchronization - Add RedisDBManager and RedisLockableORM classes - Implement atomic locking mechanism for concurrent access - Add merge functionality for different object types - Include comprehensive test suite and examples - Fix Redis key type conflicts in lock operations * fix: resolve scheduler module import and Redis integration issues * revise naive memcube creation in server router * remove long-time tests in test_scheduler * remove redis test which needs .env * refactor all codes about mixture search with scheduler * fix: resolve Redis API synchronization issues and implement search API with reranker - Fix running_entries to running_task_ids migration across codebase - Update sync_search_data method to properly handle TaskRunningStatus - Correct variable naming and logic in API synchronization flow - Implement search API endpoint with reranker functionality - Update test files to reflect new running_task_ids convention - Ensure proper Redis state management for concurrent tasks * remove a test for api module * revise to pass the test suite * address some bugs to make mix_search normally running * modify codes according to evaluation logs * feat: Optimize mixture search and enhance API client * feat: Add conversation_turn tracking for session-based memory search - Add conversation_turn field to APIMemoryHistoryEntryItem schema with default value 0 - Implement session counter in OptimizedScheduler to track turn count per session_id - Update sync_search_data method to accept and store conversation_turn parameter - Maintain session history with LRU eviction (max 5 sessions) - Rename conversation_id to session_id for consistency with request object - Enable direct access to session_id from search requests This feature allows tracking conversation turns within the same session, providing better context for memory retrieval and search history management. * adress time bug in monitor * revise simple tree * add mode to evaluation client; rewrite print to logger.info in db files --------- Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- evaluation/scripts/utils/client.py | 4 +- src/memos/graph_dbs/neo4j.py | 2 +- src/memos/graph_dbs/polardb.py | 91 +++----- src/memos/mem_scheduler/base_scheduler.py | 7 +- .../mem_scheduler/general_modules/api_misc.py | 60 +++--- .../mem_scheduler/monitors/general_monitor.py | 4 +- .../mem_scheduler/optimized_scheduler.py | 194 +++++++++++------- .../mem_scheduler/schemas/api_schemas.py | 19 +- src/memos/memories/textual/simple_tree.py | 28 +++ src/memos/memories/textual/tree.py | 28 +++ .../tree_text_memory/retrieve/searcher.py | 75 +++++-- 11 files changed, 307 insertions(+), 205 deletions(-) diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 9b686a131..4e7cfdbca 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -181,7 +181,7 @@ def search(self, query, user_id, top_k): "mem_cube_id": user_id, "conversation_id": "", "top_k": top_k, - "mode": "fast", + "mode": os.getenv("SEARCH_MODE", "fast"), "handle_pref_mem": False, }, ensure_ascii=False, @@ -232,7 +232,7 @@ def search(self, query, user_id, top_k): "query": query, "user_id": user_id, "memory_limit_number": top_k, - "mode": "mixture", + "mode": os.getenv("SEARCH_MODE", "fast"), } ) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index f3a36a887..367b486cd 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1072,7 +1072,7 @@ def drop_database(self) -> None: with self.driver.session(database=self.system_db_name) as session: session.run(f"DROP DATABASE {self.db_name} IF EXISTS") - print(f"Database '{self.db_name}' has been dropped.") + logger.info(f"Database '{self.db_name}' has been dropped.") else: raise ValueError( f"Refusing to drop protected database: {self.db_name} in " diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 971a56e04..5d50cf68f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -72,7 +72,7 @@ def detect_embedding_field(embedding_list): if dim == 1024: return "embedding" else: - print(f"⚠️ Unknown embedding dimension {dim}, skipping this vector") + logger.warning(f"Unknown embedding dimension {dim}, skipping this vector") return None @@ -274,8 +274,6 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params = [f'"{memory_type}"', f'"{user_name}"'] - print(f"[get_memory_count] Query: {query}, Params: {params}") - try: with self.connection.cursor() as cursor: cursor.execute(query, params) @@ -298,13 +296,10 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: query += "\nLIMIT 1" params = [f'"{scope}"', f'"{user_name}"'] - print(f"[node_not_exist] Query: {query}, Params: {params}") - try: with self.connection.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() - print(f"[node_not_exist] Query result: {result}") return 1 if result else 0 except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) @@ -419,7 +414,6 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') - print(f"[update_node] query: {query}, params: {params}") try: with self.connection.cursor() as cursor: cursor.execute(query, params) @@ -446,7 +440,6 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') - print(f"[delete_node] query: {query}, params: {params}") try: with self.connection.cursor() as cursor: cursor.execute(query, params) @@ -462,22 +455,24 @@ def create_extension(self): # Ensure in the correct database context cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] - print(f"Current database context: {current_db}") + logger.info(f"Current database context: {current_db}") for ext_name, ext_desc in extensions: try: cursor.execute(f"create extension if not exists {ext_name};") - print(f"✅ Extension '{ext_name}' ({ext_desc}) ensured.") + logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") except Exception as e: if "already exists" in str(e): - print(f"ℹ️ Extension '{ext_name}' ({ext_desc}) already exists.") + logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") else: - print(f"⚠️ Failed to create extension '{ext_name}' ({ext_desc}): {e}") + logger.warning( + f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" + ) logger.error( f"Failed to create extension '{ext_name}': {e}", exc_info=True ) except Exception as e: - print(f"⚠️ Failed to access database context: {e}") + logger.warning(f"Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) @timed @@ -491,12 +486,12 @@ def create_graph(self): graph_exists = cursor.fetchone()[0] > 0 if graph_exists: - print(f"ℹ️ Graph '{self.db_name}_graph' already exists.") + logger.info(f"Graph '{self.db_name}_graph' already exists.") else: cursor.execute(f"select create_graph('{self.db_name}_graph');") - print(f"✅ Graph database '{self.db_name}_graph' created.") + logger.info(f"Graph database '{self.db_name}_graph' created.") except Exception as e: - print(f"⚠️ Failed to create graph '{self.db_name}_graph': {e}") + logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) @timed @@ -506,16 +501,16 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: - print(f"🪶 Creating elabel: {label_name}") + logger.info(f"Creating elabel: {label_name}") try: with self.connection.cursor() as cursor: cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") - print(f"✅ Successfully created elabel: {label_name}") + logger.info(f"Successfully created elabel: {label_name}") except Exception as e: if "already exists" in str(e): - print(f"ℹ️ Label '{label_name}' already exists, skipping.") + logger.info(f"Label '{label_name}' already exists, skipping.") else: - print(f"⚠️ Failed to create label {label_name}: {e}") + logger.warning(f"Failed to create label {label_name}: {e}") logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) @timed @@ -547,7 +542,6 @@ def add_edge( AND end_id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{target_id}'::text::cstring) ); """ - print(f"Executing add_edge: {query}") try: with self.connection.cursor() as cursor: @@ -658,7 +652,6 @@ def edge_exists( # Prepare the relationship pattern user_name = user_name if user_name else self.config.user_name - print(f"edge_exists direction: {direction}") # Prepare the match pattern with direction if direction == "OUTGOING": @@ -681,7 +674,6 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - print(f"edge_exists query: {query}") with self.connection.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() @@ -728,7 +720,6 @@ def format_param_value(value: str) -> str: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(format_param_value(user_name)) - print(f"[get_node] query: {query}, params: {params}") try: with self.connection.cursor() as cursor: cursor.execute(query, params) @@ -812,7 +803,6 @@ def get_nodes( query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') - print(f"[get_nodes] query: {query}, params: {params}") with self.connection.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1067,8 +1057,6 @@ def get_children_with_embeddings( WHERE t.cid::graphid = m.id; """ - print("[get_children_with_embeddings] query:", query) - try: with self.connection.cursor() as cursor: cursor.execute(query) @@ -1191,7 +1179,6 @@ def get_subgraph( with self.connection.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() - print("[get_subgraph] result:", result) if not result or not result[0]: return {"core_node": None, "neighbors": [], "edges": []} @@ -1346,9 +1333,6 @@ def search_by_embedding( """ params = [vector] - print( - f"[search_by_embedding] query: {query}, params: {params}, where_clause: {where_clause}" - ) with self.connection.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1417,7 +1401,6 @@ def get_by_metadata( escaped_value = f"[{', '.join(list_items)}]" else: escaped_value = f"'{value}'" if isinstance(value, str) else str(value) - print("op=============:", op) # Build WHERE conditions if op == "=": where_conditions.append(f"n.{field} = {escaped_value}") @@ -1455,16 +1438,13 @@ def get_by_metadata( $$) AS (id agtype) """ - print(f"[get_by_metadata] query: {cypher_query}, where_str: {where_str}") ids = [] try: with self.connection.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() - print("[get_by_metadata] result:", results) ids = [str(item[0]).strip('"') for item in results] except Exception as e: - print("Failed to get metadata:", {e}) logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") return ids @@ -1494,7 +1474,6 @@ def get_grouped_counts1( raise ValueError("group_fields cannot be empty") final_params = params.copy() if params else {} - print("username:" + user_name) if not self.config.use_multi_db and (self.config.user_name or user_name): user_clause = "n.user_name = $user_name" final_params["user_name"] = user_name @@ -1506,14 +1485,12 @@ def get_grouped_counts1( where_clause = f"WHERE {where_clause} AND {user_clause}" else: where_clause = f"WHERE {user_clause}" - print("where_clause:" + where_clause) # Force RETURN field AS field to guarantee key match group_fields_cypher = ", ".join([f"n.{field} AS {field}" for field in group_fields]) """ # group_fields_cypher_polardb = "agtype, ".join([f"{field}" for field in group_fields]) """ group_fields_cypher_polardb = ", ".join([f"{field} agtype" for field in group_fields]) - print("group_fields_cypher_polardb:" + group_fields_cypher_polardb) query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (n:Memory) @@ -1521,7 +1498,6 @@ def get_grouped_counts1( RETURN {group_fields_cypher}, COUNT(n) AS count1 $$ ) as ({group_fields_cypher_polardb}, count1 agtype); """ - print("get_grouped_counts:" + query) try: with self.connection.cursor() as cursor: # Handle parameterized query @@ -1620,8 +1596,6 @@ def get_grouped_counts( GROUP BY {", ".join(group_by_fields)} """ - print("[get_grouped_counts] query:", query) - try: with self.connection.cursor() as cursor: # Handle parameterized query @@ -1889,7 +1863,6 @@ def get_all_memory_items( """ nodes = [] node_ids = set() - print("[get_all_memory_items embedding true ] cypher_query:", cypher_query) try: with self.connection.cursor() as cursor: cursor.execute(cypher_query) @@ -1924,7 +1897,6 @@ def get_all_memory_items( LIMIT 100 $$) AS (nprops agtype) """ - print("[get_all_memory_items embedding false ] cypher_query:", cypher_query) nodes = [] try: @@ -1993,14 +1965,12 @@ def get_all_memory_items_old( LIMIT 100 $$) AS (nprops agtype) """ - print("[get_all_memory_items] cypher_query:", cypher_query) nodes = [] try: with self.connection.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() - print("[get_all_memory_items] results:", results) for row in results: node_agtype = row[0] @@ -2025,16 +1995,14 @@ def get_all_memory_items_old( parsed_node_data["embedding"] = properties["embedding"] nodes.append(self._parse_node(parsed_node_data)) - print( - f"[get_all_memory_items] ✅ Parsed node successfully: {properties.get('id', '')}" + logger.debug( + f"[get_all_memory_items] Parsed node successfully: {properties.get('id', '')}" ) else: - print( - f"[get_all_memory_items] ❌ Invalid node data format: {node_data}" - ) + logger.warning(f"Invalid node data format: {node_data}") except (json.JSONDecodeError, TypeError) as e: - print(f"[get_all_memory_items] ❌ JSON parsing failed: {e}") + logger.error(f"JSON parsing failed: {e}") elif node_agtype and hasattr(node_agtype, "value"): # Handle agtype object node_props = node_agtype.value @@ -2050,13 +2018,8 @@ def get_all_memory_items_old( node_data["embedding"] = node_props["embedding"] nodes.append(self._parse_node(node_data)) - print( - f"[get_all_memory_items] ✅ Parsed agtype node successfully: {node_props.get('id', '')}" - ) else: - print( - f"[get_all_memory_items] ❌ Unknown data format: {type(node_agtype)}" - ) + logger.warning(f"Unknown data format: {type(node_agtype)}") except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) @@ -2152,7 +2115,7 @@ def get_structure_optimization_candidates( {self.db_name}_graph."Memory" m WHERE t.id1 = m.id """ - print("[get_structure_optimization_candidates] query:", cypher_query) + logger.info(f"[get_structure_optimization_candidates] query: {cypher_query}") candidates = [] node_ids = set() @@ -2160,7 +2123,7 @@ def get_structure_optimization_candidates( with self.connection.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() - print("result------", len(results)) + logger.info(f"Found {len(results)} structure optimization candidates") for row in results: if include_embedding: # When include_embedding=True, return full node object @@ -2228,9 +2191,9 @@ def get_structure_optimization_candidates( if node_id not in node_ids: candidates.append(node) node_ids.add(node_id) - print(f"✅ Parsed node successfully: {node_id}") + logger.debug(f"Parsed node successfully: {node_id}") except Exception as e: - print(f"❌ Failed to parse node: {e}") + logger.error(f"Failed to parse node: {e}") except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) @@ -2243,7 +2206,7 @@ def drop_database(self) -> None: if self._get_config_value("use_multi_db", True): with self.connection.cursor() as cursor: cursor.execute(f"SELECT drop_graph('{self.db_name}_graph', true)") - print(f"Graph '{self.db_name}_graph' has been dropped.") + logger.info(f"Graph '{self.db_name}_graph' has been dropped.") else: raise ValueError( f"Refusing to drop graph '{self.db_name}_graph' in " @@ -2498,7 +2461,7 @@ def get_neighbors_by_tag( WHERE {where_clause} """ - print(f"[get_neighbors_by_tag] query: {query}, params: {params}") + logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") try: with self.connection.cursor() as cursor: @@ -2646,7 +2609,7 @@ def get_neighbors_by_tag_ccl( ORDER BY (overlap_count::integer) DESC LIMIT {top_k} """ - print("get_neighbors_by_tag:", query) + logger.debug(f"get_neighbors_by_tag: {query}") try: with self.connection.cursor() as cursor: cursor.execute(query) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 0360396af..d679eba9c 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -6,6 +6,7 @@ from collections.abc import Callable from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING from sqlalchemy.engine import Engine @@ -50,6 +51,10 @@ from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE +if TYPE_CHECKING: + from memos.mem_cube.base import BaseMemCube + + logger = get_logger(__name__) @@ -124,7 +129,7 @@ def __init__(self, config: BaseSchedulerConfig): self._context_lock = threading.Lock() self.current_user_id: UserID | str | None = None self.current_mem_cube_id: MemCubeID | str | None = None - self.current_mem_cube: GeneralMemCube | None = None + self.current_mem_cube: BaseMemCube | None = None self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None) self.auth_config = None self.rabbitmq_config = None diff --git a/src/memos/mem_scheduler/general_modules/api_misc.py b/src/memos/mem_scheduler/general_modules/api_misc.py index bb993de38..1b10804fc 100644 --- a/src/memos/mem_scheduler/general_modules/api_misc.py +++ b/src/memos/mem_scheduler/general_modules/api_misc.py @@ -8,7 +8,6 @@ APISearchHistoryManager, TaskRunningStatus, ) -from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.memories.textual.item import TextualMemoryItem @@ -16,16 +15,20 @@ class SchedulerAPIModule(BaseSchedulerModule): - def __init__(self, window_size=5): + def __init__(self, window_size: int | None = None, history_memory_turns: int | None = None): super().__init__() self.window_size = window_size + self.history_memory_turns = history_memory_turns self.search_history_managers: dict[str, APIRedisDBManager] = {} - self.pre_memory_turns = 5 def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> APIRedisDBManager: """Get or create a Redis manager for search history.""" + logger.info( + f"Getting search history manager for user_id: {user_id}, mem_cube_id: {mem_cube_id}" + ) key = f"search_history:{user_id}:{mem_cube_id}" if key not in self.search_history_managers: + logger.info(f"Creating new search history manager for key: {key}") self.search_history_managers[key] = APIRedisDBManager( user_id=user_id, mem_cube_id=mem_cube_id, @@ -41,8 +44,12 @@ def sync_search_data( query: str, memories: list[TextualMemoryItem], formatted_memories: Any, - conversation_id: str | None = None, + session_id: str | None = None, + conversation_turn: int = 0, ) -> Any: + logger.info( + f"Syncing search data for item_id: {item_id}, user_id: {user_id}, mem_cube_id: {mem_cube_id}" + ) # Get the search history manager manager = self.get_search_history_manager(user_id, mem_cube_id) manager.sync_with_redis(size_limit=self.window_size) @@ -59,7 +66,7 @@ def sync_search_data( query=query, formatted_memories=formatted_memories, task_status=TaskRunningStatus.COMPLETED, # Use the provided running_status - conversation_id=conversation_id, + session_id=session_id, memories=memories, ) @@ -69,18 +76,18 @@ def sync_search_data( logger.warning(f"Failed to update entry with item_id: {item_id}") else: # Add new entry based on running_status - search_entry = APIMemoryHistoryEntryItem( + entry_item = APIMemoryHistoryEntryItem( item_id=item_id, query=query, formatted_memories=formatted_memories, memories=memories, task_status=TaskRunningStatus.COMPLETED, - conversation_id=conversation_id, - created_time=get_utc_now(), + session_id=session_id, + conversation_turn=conversation_turn, ) # Add directly to completed list as APIMemoryHistoryEntryItem instance - search_history.completed_entries.append(search_entry) + search_history.completed_entries.append(entry_item) # Maintain window size if len(search_history.completed_entries) > search_history.window_size: @@ -101,37 +108,22 @@ def sync_search_data( manager.sync_with_redis(size_limit=self.window_size) return manager - def get_pre_memories(self, user_id: str, mem_cube_id: str) -> list: - """ - Get pre-computed memories from the most recent completed search entry. - - Args: - user_id: User identifier - mem_cube_id: Memory cube identifier - - Returns: - List of TextualMemoryItem objects from the most recent completed search - """ - manager = self.get_search_history_manager(user_id, mem_cube_id) - - existing_data = manager.load_from_db() - if existing_data is None: - return [] - - search_history: APISearchHistoryManager = existing_data - - # Get memories from the most recent completed entry - history_memories = search_history.get_history_memories(turns=self.pre_memory_turns) - return history_memories - - def get_history_memories(self, user_id: str, mem_cube_id: str, n: int) -> list: + def get_history_memories( + self, user_id: str, mem_cube_id: str, turns: int | None = None + ) -> list: """Get history memories for backward compatibility with tests.""" + logger.info( + f"Getting history memories for user_id: {user_id}, mem_cube_id: {mem_cube_id}, turns: {turns}" + ) manager = self.get_search_history_manager(user_id, mem_cube_id) existing_data = manager.load_from_db() if existing_data is None: return [] + if turns is None: + turns = self.history_memory_turns + # Handle different data formats if isinstance(existing_data, APISearchHistoryManager): search_history = existing_data @@ -142,4 +134,4 @@ def get_history_memories(self, user_id: str, mem_cube_id: str, n: int) -> list: except Exception: return [] - return search_history.get_history_memories(turns=n) + return search_history.get_history_memories(turns=turns) diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index 22fb78445..a789d581e 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -76,8 +76,8 @@ def __init__( ] = {} # Lifecycle monitor - self.last_activation_mem_update_time = datetime.min - self.last_query_consume_time = datetime.min + self.last_activation_mem_update_time = get_utc_now() + self.last_query_consume_time = get_utc_now() self._register_lock = Lock() self._process_llm = process_llm diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index c8e2eb59e..a087ab2df 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -1,11 +1,14 @@ import json +import os +from collections import OrderedDict from typing import TYPE_CHECKING from memos.api.product_models import APISearchRequest from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube +from memos.mem_cube.navie import NaiveMemCube from memos.mem_scheduler.general_modules.api_misc import SchedulerAPIModule from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( @@ -23,6 +26,7 @@ if TYPE_CHECKING: from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.http_bge import HTTPBGEReranker @@ -34,54 +38,33 @@ class OptimizedScheduler(GeneralScheduler): def __init__(self, config: GeneralSchedulerConfig): super().__init__(config) - self.api_module = SchedulerAPIModule() + self.window_size = int(os.getenv("API_SEARCH_WINDOW_SIZE", 5)) + self.history_memory_turns = int(os.getenv("API_SEARCH_HISTORY_TURNS", 5)) + self.session_counter = OrderedDict() + self.max_session_history = 5 + + self.api_module = SchedulerAPIModule( + window_size=self.window_size, + history_memory_turns=self.history_memory_turns, + ) self.register_handlers( { API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, } ) - def search_memories( - self, - search_req: APISearchRequest, - user_context: UserContext, - mem_cube: GeneralMemCube, - mode: SearchMode, - ): - """Fine search memories function copied from server_router to avoid circular import""" - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - # Create MemCube and perform search - search_results = mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=mode, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - ) - return search_results - def submit_memory_history_async_task( self, search_req: APISearchRequest, user_context: UserContext, + session_id: str | None = None, ): # Create message for async fine search message_content = { "search_req": { "query": search_req.query, "user_id": search_req.user_id, - "session_id": search_req.session_id, + "session_id": session_id, "top_k": search_req.top_k, "internet_search": search_req.internet_search, "moscube": search_req.moscube, @@ -110,6 +93,36 @@ def submit_memory_history_async_task( logger.info(f"Submitted async fine search task for user {search_req.user_id}") return async_task_id + def search_memories( + self, + search_req: APISearchRequest, + user_context: UserContext, + mem_cube: NaiveMemCube, + mode: SearchMode, + ): + """Fine search memories function copied from server_router to avoid circular import""" + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + search_results = mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=mode, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + return search_results + def mix_search_memories( self, search_req: APISearchRequest, @@ -122,82 +135,115 @@ def mix_search_memories( # Get mem_cube for fast search mem_cube = self.current_mem_cube - # Perform fast search - fast_memories = self.search_memories( - search_req=search_req, - user_context=user_context, - mem_cube=mem_cube, + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + text_mem: TreeTextMemory = mem_cube.text_mem + searcher: Searcher = text_mem.get_searcher( + manual_close_internet=not search_req.internet_search, + moscube=False, + ) + # Rerank Memories - reranker expects TextualMemoryItem objects + reranker: HTTPBGEReranker = text_mem.reranker + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + fast_retrieved_memories = searcher.retrieve( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info=info, ) self.submit_memory_history_async_task( search_req=search_req, user_context=user_context, + session_id=search_req.session_id, ) # Try to get pre-computed fine memories if available - pre_fine_memories = self.api_module.get_pre_memories( - user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id + history_memories = self.api_module.get_history_memories( + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + turns=self.history_memory_turns, ) - if not pre_fine_memories: + + if not history_memories: + fast_memories = searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) # Format fast memories for return formatted_memories = [format_textual_memory_item(data) for data in fast_memories] return formatted_memories - # Merge fast and pre-computed fine memories (both are TextualMemoryItem objects) - combined_memories = fast_memories + pre_fine_memories - # Remove duplicates based on memory content - seen_contents = set() - unique_memories = [] - for memory in combined_memories: - # Both fast_memories and pre_fine_memories are TextualMemoryItem objects - content_key = memory.memory # Use .memory attribute instead of .get("content", "") - if content_key not in seen_contents: - seen_contents.add(content_key) - unique_memories.append(memory) - - # Rerank Memories - reranker expects TextualMemoryItem objects - reranker: HTTPBGEReranker = mem_cube.text_mem.reranker - - # Use search_req parameters for reranking - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - sorted_results = reranker.rerank( + sorted_history_memories = reranker.rerank( query=search_req.query, # Use search_req.query instead of undefined query - graph_results=unique_memories, # Pass TextualMemoryItem objects directly + graph_results=history_memories, # Pass TextualMemoryItem objects directly top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k search_filter=search_filter, ) + sorted_results = fast_retrieved_memories + sorted_history_memories + final_results = searcher.post_retrieve( + retrieved_results=sorted_results, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + formatted_memories = [ - format_textual_memory_item(item) for item, score in sorted_results[: search_req.top_k] + format_textual_memory_item(item) for item in final_results[: search_req.top_k] ] return formatted_memories def update_search_memories_to_redis( self, - user_id: str, - mem_cube_id: str, messages: list[ScheduleMessageItem], ): - mem_cube = messages[0].mem_cube + mem_cube: NaiveMemCube = self.current_mem_cube for msg in messages: content_dict = json.loads(msg.content) search_req = content_dict["search_req"] user_context = content_dict["user_context"] - fine_memories: list[TextualMemoryItem] = self.search_memories( + session_id = search_req.get("session_id") + if session_id: + if session_id not in self.session_counter: + self.session_counter[session_id] = 0 + else: + self.session_counter[session_id] += 1 + session_turn = self.session_counter[session_id] + + # Move the current session to the end to mark it as recently used + self.session_counter.move_to_end(session_id) + + # If the counter exceeds the max size, remove the oldest item + if len(self.session_counter) > self.max_session_history: + self.session_counter.popitem(last=False) + else: + session_turn = 0 + + memories: list[TextualMemoryItem] = self.search_memories( search_req=APISearchRequest(**content_dict["search_req"]), user_context=UserContext(**content_dict["user_context"]), mem_cube=mem_cube, - mode=SearchMode.FINE, + mode=SearchMode.FAST, ) - formatted_memories = [format_textual_memory_item(data) for data in fine_memories] + formatted_memories = [format_textual_memory_item(data) for data in memories] # Sync search data to Redis self.api_module.sync_search_data( @@ -205,8 +251,10 @@ def update_search_memories_to_redis( user_id=search_req["user_id"], mem_cube_id=user_context["mem_cube_id"], query=search_req["query"], - memories=fine_memories, + memories=memories, formatted_memories=formatted_memories, + session_id=session_id, + conversation_turn=session_turn, ) def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: @@ -228,9 +276,7 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) messages = grouped_messages[user_id][mem_cube_id] if len(messages) == 0: return - self.update_search_memories_to_redis( - user_id=user_id, mem_cube_id=mem_cube_id, messages=messages - ) + self.update_search_memories_to_redis(messages=messages) def replace_working_memory( self, diff --git a/src/memos/mem_scheduler/schemas/api_schemas.py b/src/memos/mem_scheduler/schemas/api_schemas.py index 23eb5a848..6d0de49c4 100644 --- a/src/memos/mem_scheduler/schemas/api_schemas.py +++ b/src/memos/mem_scheduler/schemas/api_schemas.py @@ -35,11 +35,10 @@ class APIMemoryHistoryEntryItem(BaseModel, DictConversionMixin): task_status: str = Field( default="running", description="Task status: running, completed, failed" ) - conversation_id: str | None = Field( - default=None, description="Optional conversation identifier" - ) + session_id: str | None = Field(default=None, description="Optional conversation identifier") created_time: datetime = Field(description="Entry creation time", default_factory=get_utc_now) timestamp: datetime | None = Field(default=None, description="Timestamp for the entry") + conversation_turn: int = Field(default=0, description="Turn count for the same session_id") model_config = ConfigDict( arbitrary_types_allowed=True, @@ -107,11 +106,13 @@ def get_running_item_ids(self) -> list[str]: """Get all running task IDs""" return self.running_item_ids.copy() - def get_completed_entries(self) -> list[dict[str, Any]]: + def get_completed_entries(self) -> list[APIMemoryHistoryEntryItem]: """Get all completed entries""" return self.completed_entries.copy() - def get_history_memory_entries(self, turns: int | None = None) -> list[dict[str, Any]]: + def get_history_memory_entries( + self, turns: int | None = None + ) -> list[APIMemoryHistoryEntryItem]: """ Get the most recent n completed search entries, sorted by created_time. @@ -179,7 +180,7 @@ def update_entry_by_item_id( query: str, formatted_memories: Any, task_status: TaskRunningStatus, - conversation_id: str | None = None, + session_id: str | None = None, memories: list[TextualMemoryItem] | None = None, ) -> bool: """ @@ -191,7 +192,7 @@ def update_entry_by_item_id( query: New query string formatted_memories: New formatted memories task_status: New task status - conversation_id: New conversation ID + session_id: New conversation ID memories: List of TextualMemoryItem objects Returns: @@ -204,8 +205,8 @@ def update_entry_by_item_id( entry.query = query entry.formatted_memories = formatted_memories entry.task_status = task_status - if conversation_id is not None: - entry.conversation_id = conversation_id + if session_id is not None: + entry.session_id = session_id if memories is not None: entry.memories = memories diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 8d07522cd..8ce81a8bd 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -104,6 +104,34 @@ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int """ return self.memory_manager.get_current_memory_size(user_name=user_name) + def get_searcher( + self, + manual_close_internet: bool = False, + moscube: bool = False, + ): + if (self.internet_retriever is not None) and manual_close_internet: + logger.warning( + "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" + ) + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=None, + moscube=moscube, + ) + else: + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=self.internet_retriever, + moscube=moscube, + ) + return searcher + def search( self, query: str, diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 472bed219..56c8117e9 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -115,6 +115,34 @@ def get_current_memory_size(self) -> dict[str, int]: """ return self.memory_manager.get_current_memory_size() + def get_searcher( + self, + manual_close_internet: bool = False, + moscube: bool = False, + ): + if (self.internet_retriever is not None) and manual_close_internet: + logger.warning( + "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" + ) + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=None, + moscube=moscube, + ) + else: + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=self.internet_retriever, + moscube=moscube, + ) + return searcher + def search( self, query: str, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 96c6c97f1..9d540b311 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -44,6 +44,49 @@ def __init__( self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") + @timed + def retrieve( + self, + query: str, + top_k: int, + info=None, + mode="fast", + memory_type="All", + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ) -> list[TextualMemoryItem]: + logger.info( + f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" + ) + parsed_goal, query_embedding, context, query = self._parse_task( + query, info, mode, search_filter=search_filter, user_name=user_name + ) + results = self._retrieve_paths( + query, + parsed_goal, + query_embedding, + info, + top_k, + mode, + memory_type, + search_filter, + user_name, + ) + return results + + def post_retrieve( + self, + retrieved_results: list[TextualMemoryItem], + top_k: int, + user_name: str | None = None, + info=None, + ): + deduped = self._deduplicate_results(retrieved_results) + final_results = self._sort_and_trim(deduped, top_k) + self._update_usage_history(final_results, info, user_name) + return final_results + @timed def search( self, @@ -72,9 +115,6 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ - logger.info( - f"[SEARCH] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" - ) if not info: logger.warning( "Please input 'info' when use tree.search so that " @@ -84,23 +124,22 @@ def search( else: logger.debug(f"[SEARCH] Received info dict: {info}") - parsed_goal, query_embedding, context, query = self._parse_task( - query, info, mode, search_filter=search_filter, user_name=user_name + retrieved_results = self.retrieve( + query=query, + top_k=top_k, + info=info, + mode=mode, + memory_type=memory_type, + search_filter=search_filter, + user_name=user_name, ) - results = self._retrieve_paths( - query, - parsed_goal, - query_embedding, - info, - top_k, - mode, - memory_type, - search_filter, - user_name, + + final_results = self.post_retrieve( + retrieved_results=retrieved_results, + top_k=top_k, + user_name=user_name, + info=None, ) - deduped = self._deduplicate_results(results) - final_results = self._sort_and_trim(deduped, top_k) - self._update_usage_history(final_results, info, user_name) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") res_results = "" From a375911827b4c6fe3fd82758a23a6b6cb0c9adec Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Wed, 29 Oct 2025 22:31:00 +0800 Subject: [PATCH 24/64] Feat: add recall strategy (#414) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi --- poetry.lock | 50 ++- pyproject.toml | 4 +- src/memos/api/config.py | 34 +- src/memos/configs/mem_reader.py | 9 + src/memos/configs/memory.py | 7 + src/memos/llms/openai.py | 13 +- src/memos/mem_reader/factory.py | 2 + src/memos/mem_reader/strategy_struct.py | 138 +++++++ src/memos/memories/textual/simple_tree.py | 18 + src/memos/memories/textual/tree.py | 16 + .../tree_text_memory/retrieve/bm25_util.py | 186 +++++++++ .../tree_text_memory/retrieve/recall.py | 107 ++++- .../retrieve/retrieve_utils.py | 378 ++++++++++++++++++ .../tree_text_memory/retrieve/searcher.py | 88 +++- .../retrieve/task_goal_parser.py | 7 +- .../templates/mem_reader_strategy_prompts.py | 279 +++++++++++++ src/memos/templates/mem_search_prompts.py | 93 +++++ .../textual/test_tree_task_goal_parser.py | 5 - 18 files changed, 1393 insertions(+), 41 deletions(-) create mode 100644 src/memos/mem_reader/strategy_struct.py create mode 100644 src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py create mode 100644 src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py create mode 100644 src/memos/templates/mem_reader_strategy_prompts.py create mode 100644 src/memos/templates/mem_search_prompts.py diff --git a/poetry.lock b/poetry.lock index 44265bca8..926d580fb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -192,6 +192,19 @@ torch = ">=1.0.0" tqdm = ">=4.31.1" transformers = ">=3.0.0" +[[package]] +name = "cachetools" +version = "6.2.1" +description = "Extensible memoizing collections and decorators" +optional = true +python-versions = ">=3.9" +groups = ["main"] +markers = "extra == \"all\"" +files = [ + {file = "cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701"}, + {file = "cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201"}, +] + [[package]] name = "certifi" version = "2025.7.14" @@ -1553,6 +1566,18 @@ files = [ {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, ] +[[package]] +name = "jieba" +version = "0.42" +description = "Chinese Words Segmentation Utilities" +optional = true +python-versions = "*" +groups = ["main"] +markers = "extra == \"all\"" +files = [ + {file = "jieba-0.42.tar.gz", hash = "sha256:34a3c960cc2943d9da16d6d2565110cf5f305921a67413dddf04f84de69c939b"}, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -4123,6 +4148,25 @@ urllib3 = ">=1.26.14,<3" fastembed = ["fastembed (>=0.7,<0.8)"] fastembed-gpu = ["fastembed-gpu (>=0.7,<0.8)"] +[[package]] +name = "rank-bm25" +version = "0.2.2" +description = "Various BM25 algorithms for document ranking" +optional = true +python-versions = "*" +groups = ["main"] +markers = "extra == \"all\"" +files = [ + {file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"}, + {file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"}, +] + +[package.dependencies] +numpy = "*" + +[package.extras] +dev = ["pytest"] + [[package]] name = "redis" version = "6.2.0" @@ -6352,7 +6396,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["chonkie", "datasketch", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["cachetools", "chonkie", "datasketch", "jieba", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] @@ -6362,4 +6406,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "3f0d0c9a996f87d945ef8bf83eed3e20f8c420b6b39e12012d0147eda2bf4d38" +content-hash = "ec17679a44205ada4494fbc485ac592883281fde273d5e73d6b8cbc6f7f9ed10" diff --git a/pyproject.toml b/pyproject.toml index 3745582f6..2f88797a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,9 @@ all = [ "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", "pymilvus (>=2.6.1,<3.0.0)", "datasketch (>=1.6.5,<2.0.0)", - + "jieba (>=0.38.1,<0.42.1)", + "rank-bm25 (>=0.2.2)", + "cachetools (>=6.0.0)", # NOT exist in the above optional groups # Because they are either huge-size dependencies or infrequently used dependencies. # We kindof don't want users to install them. diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 7ac882d6c..405e8068d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -419,9 +419,23 @@ def get_embedder_config() -> dict[str, Any]: }, } + @staticmethod + def get_reader_config() -> dict[str, Any]: + """Get reader configuration.""" + return { + "backend": os.getenv("MEM_READER_BACKEND", "simple_struct"), + "config": { + "chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"), + "chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)), + "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)), + "chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)), + }, + } + @staticmethod def get_internet_config() -> dict[str, Any]: """Get embedder configuration.""" + reader_config = APIConfig.get_reader_config() return { "backend": "bocha", "config": { @@ -429,7 +443,7 @@ def get_internet_config() -> dict[str, Any]: "max_results": 15, "num_per_request": 10, "reader": { - "backend": "simple_struct", + "backend": reader_config["backend"], "config": { "llm": { "backend": "openai", @@ -455,6 +469,7 @@ def get_internet_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, + "chat_chunker": reader_config, }, }, }, @@ -656,6 +671,8 @@ def get_product_default_config() -> dict[str, Any]: openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() + reader_config = APIConfig.get_reader_config() + backend_model = { "openai": openai_config, "huggingface": qwen_config, @@ -667,7 +684,7 @@ def get_product_default_config() -> dict[str, Any]: "user_id": os.getenv("MOS_USER_ID", "root"), "chat_model": {"backend": backend, "config": backend_model[backend]}, "mem_reader": { - "backend": "simple_struct", + "backend": reader_config["backend"], "config": { "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), @@ -680,6 +697,7 @@ def get_product_default_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, + "chat_chunker": reader_config, }, }, "enable_textual_memory": True, @@ -750,6 +768,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() mysql_config = APIConfig.get_mysql_config() + reader_config = APIConfig.get_reader_config() backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai") backend_model = { "openai": openai_config, @@ -764,7 +783,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "config": backend_model[backend], }, "mem_reader": { - "backend": "simple_struct", + "backend": reader_config["backend"], "config": { "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), @@ -777,6 +796,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "min_sentences_per_chunk": 1, }, }, + "chat_chunker": reader_config, }, }, "enable_textual_memory": True, @@ -845,6 +865,10 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, + "search_strategy": { + "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), + "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), + }, }, }, "act_mem": {} @@ -912,6 +936,10 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, + "search_strategy": { + "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), + "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), + }, "mode": os.getenv("ASYNC_MODE", "sync"), }, }, diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 1c62087a3..dc8d37a35 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -36,11 +36,19 @@ def parse_datetime(cls, value): description="whether remove example in memory extraction prompt to save token", ) + chat_chunker: dict[str, Any] = Field( + default=None, description="Configuration for the MemReader chat chunk strategy" + ) + class SimpleStructMemReaderConfig(BaseMemReaderConfig): """SimpleStruct MemReader configuration class.""" +class StrategyStructMemReaderConfig(BaseMemReaderConfig): + """StrategyStruct MemReader configuration class.""" + + class MemReaderConfigFactory(BaseConfig): """Factory class for creating MemReader configurations.""" @@ -49,6 +57,7 @@ class MemReaderConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReaderConfig, + "strategy_struct": StrategyStructMemReaderConfig, } @field_validator("backend") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index bf2493567..49320fbf5 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -184,6 +184,13 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) + search_strategy: dict[str, bool] | None = Field( + default=None, + description=( + 'Set search strategy for this memory configuration.{"bm25": true, "cot": false}' + ), + ) + mode: str | None = Field( default="sync", description=("whether use asynchronous mode in memory add"), diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index ca1df5c1f..1a1703340 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -58,15 +58,18 @@ def clear_cache(cls): logger.info("OpenAI LLM instance cache cleared") @timed(log=True, log_prefix="OpenAI LLM") - def generate(self, messages: MessageList) -> str: - """Generate a response from OpenAI LLM.""" + def generate(self, messages: MessageList, **kwargs) -> str: + """Generate a response from OpenAI LLM, optionally overriding generation params.""" + temperature = kwargs.get("temperature", self.config.temperature) + max_tokens = kwargs.get("max_tokens", self.config.max_tokens) + top_p = kwargs.get("top_p", self.config.top_p) response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, extra_body=self.config.extra_body, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, ) logger.info(f"Response from OpenAI: {response.model_dump_json()}") response_content = response.choices[0].message.content diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 52eed8d9d..2205a0215 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -3,6 +3,7 @@ from memos.configs.mem_reader import MemReaderConfigFactory from memos.mem_reader.base import BaseMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.mem_reader.strategy_struct import StrategyStructMemReader from memos.memos_tools.singleton import singleton_factory @@ -11,6 +12,7 @@ class MemReaderFactory(BaseMemReader): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReader, + "strategy_struct": StrategyStructMemReader, } @classmethod diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py new file mode 100644 index 000000000..2cac1652a --- /dev/null +++ b/src/memos/mem_reader/strategy_struct.py @@ -0,0 +1,138 @@ +import os + +from abc import ABC + +from memos import log +from memos.configs.mem_reader import StrategyStructMemReaderConfig +from memos.configs.parser import ParserConfigFactory +from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang +from memos.parsers.factory import ParserFactory +from memos.templates.mem_reader_prompts import ( + SIMPLE_STRUCT_DOC_READER_PROMPT, + SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, + SIMPLE_STRUCT_MEM_READER_EXAMPLE, + SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, +) +from memos.templates.mem_reader_strategy_prompts import ( + STRATEGY_STRUCT_MEM_READER_PROMPT, + STRATEGY_STRUCT_MEM_READER_PROMPT_ZH, +) + + +logger = log.get_logger(__name__) +STRATEGY_PROMPT_DICT = { + "chat": { + "en": STRATEGY_STRUCT_MEM_READER_PROMPT, + "zh": STRATEGY_STRUCT_MEM_READER_PROMPT_ZH, + "en_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE, + "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, + }, + "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, +} + + +class StrategyStructMemReader(SimpleStructMemReader, ABC): + """Naive implementation of MemReader.""" + + def __init__(self, config: StrategyStructMemReaderConfig): + super().__init__(config) + self.chat_chunker = config.chat_chunker["config"] + + def _get_llm_response(self, mem_str: str) -> dict: + lang = detect_lang(mem_str) + template = STRATEGY_PROMPT_DICT["chat"][lang] + examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] + prompt = template.replace("${conversation}", mem_str) + if self.config.remove_prompt_example: + prompt = prompt.replace(examples, "") + messages = [{"role": "user", "content": prompt}] + try: + response_text = self.llm.generate(messages) + response_json = self.parse_json_result(response_text) + except Exception as e: + logger.error(f"[LLM] Exception during chat generation: {e}") + response_json = { + "memory list": [ + { + "key": mem_str[:10], + "memory_type": "UserMemory", + "value": mem_str, + "tags": [], + } + ], + "summary": mem_str, + } + return response_json + + def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: + """ + Get raw information from scene_data. + If scene_data contains dictionaries, convert them to strings. + If scene_data contains file paths, parse them using the parser. + + Args: + scene_data: List of dialogue information or document paths + type: Type of scene data: ['doc', 'chat'] + Returns: + List of strings containing the processed scene data + """ + results = [] + + if type == "chat": + if self.chat_chunker["chunk_type"] == "content_length": + content_len_thredshold = self.chat_chunker["chunk_length"] + for items in scene_data: + if not items: + continue + + results.append([]) + current_length = 0 + + for _i, item in enumerate(items): + content_length = ( + len(item.get("content", "")) + if isinstance(item, dict) + else len(str(item)) + ) + if not results[-1]: + results[-1].append(item) + current_length = content_length + continue + + if current_length + content_length <= content_len_thredshold: + results[-1].append(item) + current_length += content_length + else: + overlap_item = results[-1][-1] + overlap_length = ( + len(overlap_item.get("content", "")) + if isinstance(overlap_item, dict) + else len(str(overlap_item)) + ) + + results.append([overlap_item, item]) + current_length = overlap_length + content_length + elif type == "doc": + parser_config = ParserConfigFactory.model_validate( + { + "backend": "markitdown", + "config": {}, + } + ) + parser = ParserFactory.from_config(parser_config) + for item in scene_data: + try: + if os.path.exists(item): + try: + parsed_text = parser.parse(item) + results.append({"file": item, "text": parsed_text}) + except Exception as e: + logger.error(f"[SceneParser] Error parsing {item}: {e}") + continue + else: + parsed_text = item + results.append({"file": "pure_text", "text": parsed_text}) + except Exception as e: + print(f"Error parsing file {item}: {e!s}") + + return results diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 8ce81a8bd..6974dbe8f 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -12,6 +12,7 @@ from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.base import BaseReranker from memos.types import MessageList @@ -62,6 +63,19 @@ def __init__( self.graph_store: Neo4jGraphDB = graph_db logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") + time_start_bm = time.time() + self.search_strategy = config.search_strategy + self.bm25_retriever = ( + EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None + ) + logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") + + self.vec_cot = ( + self.search_strategy["cot"] + if self.search_strategy and "cot" in self.search_strategy + else False + ) + time_start_rr = time.time() self.reranker = reranker logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") @@ -172,8 +186,10 @@ def search( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, + vec_cot=self.vec_cot, ) else: searcher = Searcher( @@ -181,8 +197,10 @@ def search( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, + vec_cot=self.vec_cot, ) return searcher.search( query, top_k, info, mode, memory_type, search_filter, user_name=user_name diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 56c8117e9..a58f993bb 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -16,6 +16,7 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -45,6 +46,17 @@ def __init__(self, config: TreeTextMemoryConfig): ) self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db) + + self.search_strategy = config.search_strategy + self.bm25_retriever = ( + EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None + ) + self.vec_cot = ( + self.search_strategy["cot"] + if self.search_strategy and "cot" in self.search_strategy + else False + ) + if config.reranker is None: default_cfg = RerankerConfigFactory.model_validate( { @@ -185,8 +197,10 @@ def search( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, + vec_cot=self.vec_cot, ) else: searcher = Searcher( @@ -194,8 +208,10 @@ def search( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, + vec_cot=self.vec_cot, ) return searcher.search(query, top_k, info, mode, memory_type, search_filter) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py b/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py new file mode 100644 index 000000000..4aca4022f --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py @@ -0,0 +1,186 @@ +import threading + +import numpy as np + +from sklearn.feature_extraction.text import TfidfVectorizer + +from memos.dependency import require_python_package +from memos.log import get_logger +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.utils import timed + + +logger = get_logger(__name__) +# Global model cache +_CACHE_LOCK = threading.Lock() + + +class EnhancedBM25: + """Enhanced BM25 with Spacy tokenization and TF-IDF reranking""" + + @require_python_package(import_name="cachetools", install_command="pip install cachetools") + def __init__(self, tokenizer=None, en_model="en_core_web_sm", zh_model="zh_core_web_sm"): + """ + Initialize Enhanced BM25 with memory management + """ + if tokenizer is None: + self.tokenizer = FastTokenizer() + else: + self.tokenizer = tokenizer + self._current_tfidf = None + + global _BM25_CACHE + from cachetools import LRUCache + + _BM25_CACHE = LRUCache(maxsize=100) + + def _tokenize_doc(self, text): + """ + Tokenize a single document using SpacyTokenizer + """ + return self.tokenizer.tokenize_mixed(text, lang="auto") + + @require_python_package(import_name="rank_bm25", install_command="pip install rank_bm25") + def _prepare_corpus_data(self, corpus, corpus_name="default"): + from rank_bm25 import BM25Okapi + + with _CACHE_LOCK: + if corpus_name in _BM25_CACHE: + print("hit::", corpus_name) + return _BM25_CACHE[corpus_name] + print("not hit::", corpus_name) + + tokenized_corpus = [self._tokenize_doc(doc) for doc in corpus] + bm25_model = BM25Okapi(tokenized_corpus) + _BM25_CACHE[corpus_name] = bm25_model + return bm25_model + + def clear_cache(self, corpus_name=None): + """Clear cache for specific corpus or clear all cache""" + with _CACHE_LOCK: + if corpus_name: + if corpus_name in _BM25_CACHE: + del _BM25_CACHE[corpus_name] + else: + _BM25_CACHE.clear() + + def get_cache_info(self): + """Get current cache information""" + with _CACHE_LOCK: + return { + "cache_size": len(_BM25_CACHE), + "max_cache_size": 100, + "cached_corpora": list(_BM25_CACHE.keys()), + } + + def _search_docs( + self, + query: str, + corpus: list[str], + corpus_name="test", + top_k=50, + use_tfidf=False, + rerank_candidates_multiplier=2, + cleanup=False, + ): + """ + Args: + query: Search query string + corpus: List of document texts + top_k: Number of top results to return + rerank_candidates_multiplier: Multiplier for candidate selection + cleanup: Whether to cleanup memory after search (default: True) + """ + if not corpus: + return [] + + logger.info(f"Searching {len(corpus)} documents for query: '{query}'") + + try: + # Prepare BM25 model + bm25_model = self._prepare_corpus_data(corpus, corpus_name=corpus_name) + tokenized_query = self._tokenize_doc(query) + tokenized_query = list(dict.fromkeys(tokenized_query)) + + # Get BM25 scores + bm25_scores = bm25_model.get_scores(tokenized_query) + + # Select candidates + candidate_count = min(top_k * rerank_candidates_multiplier, len(corpus)) + candidate_indices = np.argsort(bm25_scores)[-candidate_count:][::-1] + combined_scores = bm25_scores[candidate_indices] + + if use_tfidf: + # Create TF-IDF for this search + tfidf = TfidfVectorizer( + tokenizer=self._tokenize_doc, lowercase=False, token_pattern=None + ) + tfidf_matrix = tfidf.fit_transform(corpus) + + # TF-IDF reranking + query_vec = tfidf.transform([query]) + tfidf_similarities = ( + (tfidf_matrix[candidate_indices] * query_vec.T).toarray().flatten() + ) + + # Combine scores + combined_scores = 0.7 * bm25_scores[candidate_indices] + 0.3 * tfidf_similarities + + sorted_candidate_indices = candidate_indices[np.argsort(combined_scores)[::-1][:top_k]] + sorted_combined_scores = np.sort(combined_scores)[::-1][:top_k] + + # build result list + bm25_recalled_results = [] + for rank, (doc_idx, combined_score) in enumerate( + zip(sorted_candidate_indices, sorted_combined_scores, strict=False), 1 + ): + bm25_score = bm25_scores[doc_idx] + + candidate_pos = np.where(candidate_indices == doc_idx)[0][0] + tfidf_score = tfidf_similarities[candidate_pos] if use_tfidf else 0 + + bm25_recalled_results.append( + { + "text": corpus[doc_idx], + "bm25_score": float(bm25_score), + "tfidf_score": float(tfidf_score), + "combined_score": float(combined_score), + "rank": rank, + "doc_index": int(doc_idx), + } + ) + + logger.debug(f"Search completed: found {len(bm25_recalled_results)} results") + return bm25_recalled_results + + except Exception as e: + logger.error(f"BM25 search failed: {e}") + return [] + finally: + # Always cleanup if requested + if cleanup: + self._cleanup_memory() + + @timed + def search(self, query: str, node_dicts: list[dict], corpus_name="default", **kwargs): + """ + Search with BM25 and optional TF-IDF reranking + """ + try: + corpus_list = [] + for node_dict in node_dicts: + corpus_list.append( + " ".join([node_dict["metadata"]["key"]] + node_dict["metadata"]["tags"]) + ) + + recalled_results = self._search_docs( + query, corpus_list, corpus_name=corpus_name, **kwargs + ) + bm25_searched_nodes = [] + for item in recalled_results: + doc_idx = item["doc_index"] + bm25_searched_nodes.append(node_dicts[doc_idx]) + return bm25_searched_nodes + except Exception as e: + logger.error(f"Error in bm25 search: {e}") + return [] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index c1ade3021..b7383aa13 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -5,6 +5,7 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal @@ -16,11 +17,18 @@ class GraphMemoryRetriever: Unified memory retriever that combines both graph-based and vector-based retrieval logic. """ - def __init__(self, graph_store: Neo4jGraphDB, embedder: OllamaEmbedder): + def __init__( + self, + graph_store: Neo4jGraphDB, + embedder: OllamaEmbedder, + bm25_retriever: EnhancedBM25 | None = None, + ): self.graph_store = graph_store self.embedder = embedder + self.bm25_retriever = bm25_retriever self.max_workers = 10 self.filter_weight = 0.6 + self.use_bm25 = bool(self.bm25_retriever) def retrieve( self, @@ -31,6 +39,7 @@ def retrieve( query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, user_name: str | None = None, + id_filter: dict | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -58,7 +67,7 @@ def retrieve( ) return [TextualMemoryItem.from_dict(record) for record in working_memories] - with ContextThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=3) as executor: # Structured graph-based retrieval future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) # Vector similarity search @@ -70,12 +79,23 @@ def retrieve( search_filter=search_filter, user_name=user_name, ) + if self.use_bm25: + future_bm25 = executor.submit( + self._bm25_recall, + query, + parsed_goal, + memory_scope, + top_k=top_k, + user_name=user_name, + search_filter=id_filter, + ) graph_results = future_graph.result() vector_results = future_vector.result() + bm25_results = future_bm25.result() if self.use_bm25 else [] # Merge and deduplicate by ID - combined = {item.id: item for item in graph_results + vector_results} + combined = {item.id: item for item in graph_results + vector_results + bm25_results} graph_ids = {item.id for item in graph_results} combined_ids = set(combined.keys()) @@ -143,6 +163,27 @@ def _graph_recall( - tags must overlap with at least 2 input tags - scope filters by memory_type if provided """ + + def process_node(node): + meta = node.get("metadata", {}) + node_key = meta.get("key") + node_tags = meta.get("tags", []) or [] + + keep = False + # key equals to node_key + if parsed_goal.keys and node_key in parsed_goal.keys: + keep = True + # overlap tags more than 2 + elif parsed_goal.tags: + node_tags_list = [tag.lower() for tag in node_tags] + overlap = len(set(node_tags_list) & set(parsed_goal.tags)) + if overlap >= 2: + keep = True + + if keep: + return TextualMemoryItem.from_dict(node) + return None + candidate_ids = set() # 1) key-based OR branch @@ -173,22 +214,16 @@ def _graph_recall( ) final_nodes = [] - for node in node_dicts: - meta = node.get("metadata", {}) - node_key = meta.get("key") - node_tags = meta.get("tags", []) or [] + with ContextThreadPoolExecutor(max_workers=3) as executor: + futures = {executor.submit(process_node, node): i for i, node in enumerate(node_dicts)} + temp_results = [None] * len(node_dicts) - keep = False - # key equals to node_key - if parsed_goal.keys and node_key in parsed_goal.keys: - keep = True - # overlap tags more than 2 - elif parsed_goal.tags: - overlap = len(set(node_tags) & set(parsed_goal.tags)) - if overlap >= 2: - keep = True - if keep: - final_nodes.append(TextualMemoryItem.from_dict(node)) + for future in concurrent.futures.as_completed(futures): + original_index = futures[future] + result = future.result() + temp_results[original_index] = result + + final_nodes = [result for result in temp_results if result is not None] return final_nodes def _vector_recall( @@ -196,7 +231,7 @@ def _vector_recall( query_embedding: list[list[float]], memory_scope: str, top_k: int = 20, - max_num: int = 3, + max_num: int = 5, status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, @@ -269,3 +304,37 @@ def search_path_b(): or [] ) return [TextualMemoryItem.from_dict(n) for n in node_dicts] + + def _bm25_recall( + self, + query: str, + parsed_goal: ParsedTaskGoal, + memory_scope: str, + top_k: int = 20, + user_name: str | None = None, + search_filter: dict | None = None, + ) -> list[TextualMemoryItem]: + """ + Perform BM25-based retrieval. + """ + if not self.bm25_retriever: + return [] + key_filters = [ + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + # corpus_name is user_name + user_id + corpus_name = f"{user_name}" if user_name else "" + if search_filter is not None: + for key in search_filter: + value = search_filter[key] + key_filters.append({"field": key, "op": "=", "value": value}) + corpus_name += "".join(list(search_filter.values())) + candidate_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) + node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + + bm25_query = " ".join(list({query, *parsed_goal.keys})) + bm25_results = self.bm25_retriever.search( + bm25_query, node_dicts, top_k=top_k, corpus_name=corpus_name + ) + + return [TextualMemoryItem.from_dict(n) for n in bm25_results] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py new file mode 100644 index 000000000..eec827c86 --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -0,0 +1,378 @@ +import json +import re + +from pathlib import Path + +from memos.dependency import require_python_package +from memos.log import get_logger + + +logger = get_logger(__name__) + + +def find_project_root(marker=".git"): + """Find the project root directory by marking the file""" + current = Path(__file__).resolve() + while current != current.parent: + if (current / marker).exists(): + return current + current = current.parent + logger.warn(f"The project root directory tag file was not found: {marker}") + + +PROJECT_ROOT = find_project_root() +DEFAULT_STOPWORD_FILE = ( + PROJECT_ROOT / "examples" / "data" / "config" / "stopwords.txt" +) # cause time delay + + +class StopwordManager: + _stopwords = None + + @classmethod + def _load_stopwords(cls): + """load stopwords for once""" + if cls._stopwords is not None: + return cls._stopwords + + stopwords = set() + try: + with open(DEFAULT_STOPWORD_FILE, encoding="utf-8") as f: + stopwords = {line.strip() for line in f if line.strip()} + logger.info("Stopwords loaded successfully.") + except Exception as e: + logger.warning(f"Error loading stopwords: {e}, using default stopwords.") + stopwords = cls._load_default_stopwords() + + cls._stopwords = stopwords + return stopwords + + @classmethod + def _load_default_stopwords(cls): + """load stop words""" + chinese_stop_words = { + "的", + "了", + "在", + "是", + "我", + "有", + "和", + "就", + "不", + "人", + "都", + "一", + "一个", + "上", + "也", + "很", + "到", + "说", + "要", + "去", + "你", + "会", + "着", + "没有", + "看", + "好", + "自己", + "这", + "那", + "他", + "她", + "它", + "我们", + "你们", + "他们", + "这个", + "那个", + "这些", + "那些", + "怎么", + "什么", + "为什么", + "如何", + "哪里", + "谁", + "几", + "多少", + "这样", + "那样", + "这么", + "那么", + } + english_stop_words = { + "the", + "a", + "an", + "and", + "or", + "but", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "by", + "as", + "is", + "are", + "was", + "were", + "be", + "been", + "have", + "has", + "had", + "do", + "does", + "did", + "will", + "would", + "could", + "should", + "may", + "might", + "must", + "this", + "that", + "these", + "those", + "i", + "you", + "he", + "she", + "it", + "we", + "they", + "me", + "him", + "her", + "us", + "them", + "my", + "your", + "his", + "its", + "our", + "their", + "mine", + "yours", + "hers", + "ours", + "theirs", + } + chinese_punctuation = { + ",", + "。", + "!", + "?", + ";", + ":", + "「", + "」", + "『", + "』", + "【", + "】", + "(", + ")", + "《", + "》", + "—", + "…", + "~", + "·", + "、", + "“", + "”", + "‘", + "’", + "〈", + "〉", + "〖", + "〗", + "〝", + "〞", + "{", + "}", + "〔", + "〕", + "¡", + "¿", + } + english_punctuation = { + ",", + ".", + "!", + "?", + ";", + ":", + '"', + "'", + "(", + ")", + "[", + "]", + "{", + "}", + "<", + ">", + "/", + "\\", + "|", + "-", + "_", + "=", + "+", + "@", + "#", + "$", + "%", + "^", + "&", + "*", + "~", + "`", + "¡", + "¿", + } + numbers = { + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "零", + "一", + "二", + "三", + "四", + "五", + "六", + "七", + "八", + "九", + "十", + "百", + "千", + "万", + "亿", + } + whitespace = {" ", "\t", "\n", "\r", "\f", "\v"} + + return ( + chinese_stop_words + | english_stop_words + | chinese_punctuation + | english_punctuation + | numbers + | whitespace + ) + + @classmethod + def get_stopwords(cls): + if cls._stopwords is None: + cls._load_stopwords() + return cls._stopwords + + @classmethod + def filter_words(cls, words): + if cls._stopwords is None: + cls._load_stopwords() + return [word for word in words if word not in cls._stopwords and word.strip()] + + @classmethod + def is_stopword(cls, word): + if cls._stopwords is None: + cls._load_stopwords() + return word in cls._stopwords + + @classmethod + def reload_stopwords(cls, file_path=None): + cls._stopwords = None + if file_path: + global DEFAULT_STOPWORD_FILE + DEFAULT_STOPWORD_FILE = file_path + cls._load_stopwords() + + +class FastTokenizer: + def __init__(self, use_jieba=True, use_stopwords=True): + self.use_jieba = use_jieba + self.use_stopwords = use_stopwords + if self.use_stopwords: + self.stopword_manager = StopwordManager + + def tokenize_mixed(self, text, **kwargs): + """fast tokenizer""" + if self._is_chinese(text): + return self._tokenize_chinese(text) + else: + return self._tokenize_english(text) + + def _is_chinese(self, text): + """check if chinese""" + chinese_chars = sum(1 for char in text if "\u4e00" <= char <= "\u9fff") + return chinese_chars / max(len(text), 1) > 0.3 + + @require_python_package( + import_name="jieba", + install_command="pip install jieba", + install_link="https://github.com/fxsjy/jieba", + ) + def _tokenize_chinese(self, text): + """split zh jieba""" + import jieba + + tokens = jieba.lcut(text) if self.use_jieba else list(text) + tokens = [token.strip() for token in tokens if token.strip()] + if self.use_stopwords: + return self.stopword_manager.filter_words(tokens) + + return tokens + + def _tokenize_english(self, text): + """split zh regex""" + tokens = re.findall(r"\b[a-zA-Z0-9]+\b", text.lower()) + if self.use_stopwords: + return self.stopword_manager.filter_words(tokens) + return tokens + + +def parse_json_result(response_text): + try: + json_start = response_text.find("{") + response_text = response_text[json_start:] + response_text = response_text.replace("```", "").strip() + if not response_text.endswith("}"): + response_text += "}" + return json.loads(response_text) + except json.JSONDecodeError as e: + logger.error(f"[JSONParse] Failed to decode JSON: {e}\nRaw:\n{response_text}") + return {} + except Exception as e: + logger.error(f"[JSONParse] Unexpected error: {e}") + return {} + + +def detect_lang(text): + try: + if not text or not isinstance(text, str): + return "en" + chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" + chinese_chars = re.findall(chinese_pattern, text) + if len(chinese_chars) / len(re.sub(r"[\s\d\W]", "", text)) > 0.3: + return "zh" + return "en" + except Exception: + return "en" diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9d540b311..563695c68 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -9,7 +9,18 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + detect_lang, + parse_json_result, +) from memos.reranker.base import BaseReranker +from memos.templates.mem_search_prompts import ( + COT_PROMPT, + COT_PROMPT_ZH, + SIMPLE_COT_PROMPT, + SIMPLE_COT_PROMPT_ZH, +) from memos.utils import timed from .reasoner import MemoryReasoner @@ -18,6 +29,10 @@ logger = get_logger(__name__) +COT_DICT = { + "fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, + "fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, +} class Searcher: @@ -27,20 +42,24 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, + bm25_retriever: EnhancedBM25 | None = None, internet_retriever: None = None, moscube: bool = False, + vec_cot: bool = False, ): self.graph_store = graph_store self.embedder = embedder + self.llm = dispatcher_llm self.task_goal_parser = TaskGoalParser(dispatcher_llm) - self.graph_retriever = GraphMemoryRetriever(self.graph_store, self.embedder) + self.graph_retriever = GraphMemoryRetriever(graph_store, embedder, bm25_retriever) self.reranker = reranker self.reasoner = MemoryReasoner(dispatcher_llm) # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube + self.vec_cot = vec_cot self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -231,6 +250,12 @@ def _retrieve_paths( ): """Run A/B/C retrieval paths in parallel""" tasks = [] + id_filter = { + "user_id": info.get("user_id", None), + "session_id": info.get("session_id", None), + } + id_filter = {k: v for k, v in id_filter.items() if v is not None} + with ContextThreadPoolExecutor(max_workers=3) as executor: tasks.append( executor.submit( @@ -242,6 +267,7 @@ def _retrieve_paths( memory_type, search_filter, user_name, + id_filter, ) ) tasks.append( @@ -254,6 +280,7 @@ def _retrieve_paths( memory_type, search_filter, user_name, + id_filter, ) ) tasks.append( @@ -299,6 +326,7 @@ def _retrieve_from_working_memory( memory_type, search_filter: dict | None = None, user_name: str | None = None, + id_filter: dict | None = None, ): """Retrieve and rerank from WorkingMemory""" if memory_type not in ["All", "WorkingMemory"]: @@ -311,6 +339,7 @@ def _retrieve_from_working_memory( memory_scope="WorkingMemory", search_filter=search_filter, user_name=user_name, + id_filter=id_filter, ) return self.reranker.rerank( query=query, @@ -332,11 +361,22 @@ def _retrieve_from_long_term_and_user( memory_type, search_filter: dict | None = None, user_name: str | None = None, + id_filter: dict | None = None, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] tasks = [] + # chain of thinking + cot_embeddings = [] + if self.vec_cot: + queries = self._cot_query(query) + if len(queries) > 1: + cot_embeddings = self.embedder.embed(queries) + cot_embeddings.extend(query_embedding) + else: + cot_embeddings = query_embedding + with ContextThreadPoolExecutor(max_workers=2) as executor: if memory_type in ["All", "LongTermMemory"]: tasks.append( @@ -344,11 +384,12 @@ def _retrieve_from_long_term_and_user( self.graph_retriever.retrieve, query=query, parsed_goal=parsed_goal, - query_embedding=query_embedding, + query_embedding=cot_embeddings, top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, user_name=user_name, + id_filter=id_filter, ) ) if memory_type in ["All", "UserMemory"]: @@ -357,11 +398,12 @@ def _retrieve_from_long_term_and_user( self.graph_retriever.retrieve, query=query, parsed_goal=parsed_goal, - query_embedding=query_embedding, + query_embedding=cot_embeddings, top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, user_name=user_name, + id_filter=id_filter, ) ) @@ -442,6 +484,7 @@ def _deduplicate_results(self, results): @timed def _sort_and_trim(self, results, top_k): """Sort results by score and trim to top_k""" + sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] final_items = [] for item, score in sorted_results: @@ -491,3 +534,42 @@ def _update_usage_history_worker( self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=user_name) except Exception: logger.exception("[USAGE] update usage failed") + + def _cot_query( + self, + query, + mode="fast", + split_num: int = 3, + context: list[str] | None = None, + ) -> list[str]: + """Generate chain-of-thought queries""" + + lang = detect_lang(query) + if mode == "fine" and context: + template = COT_DICT["fine"][lang] + prompt = ( + template.replace("${original_query}", query) + .replace("${split_num_threshold}", str(split_num)) + .replace("${context}", "\n".join(context)) + ) + else: + template = COT_DICT["fast"][lang] + prompt = template.replace("${original_query}", query).replace( + "${split_num_threshold}", str(split_num) + ) + logger.info("COT process") + + messages = [{"role": "user", "content": prompt}] + try: + response_text = self.llm.generate(messages, temperature=0, top_p=1) + response_json = parse_json_result(response_text) + assert "is_complex" in response_json + if not response_json["is_complex"]: + return [query] + else: + assert "sub_questions" in response_json + logger.info("Query: {} COT: {}".format(query, response_json["sub_questions"])) + return response_json["sub_questions"][:split_num] + except Exception as e: + logger.error(f"[LLM] Exception during chat generation: {e}") + return [query] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 273c4f480..6a1138c90 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -5,6 +5,7 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT @@ -20,6 +21,7 @@ class TaskGoalParser: def __init__(self, llm=BaseLLM): self.llm = llm + self.tokenizer = FastTokenizer() def parse( self, @@ -48,10 +50,11 @@ def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGo """ Fast mode: simple jieba word split. """ + desc_tokenized = self.tokenizer.tokenize_mixed(task_description) return ParsedTaskGoal( memories=[task_description], - keys=[task_description], - tags=[], + keys=desc_tokenized, + tags=desc_tokenized, goal_type="default", rephrased_query=task_description, internet_search=False, diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py new file mode 100644 index 000000000..fca4d717b --- /dev/null +++ b/src/memos/templates/mem_reader_strategy_prompts.py @@ -0,0 +1,279 @@ +STRATEGY_STRUCT_MEM_READER_PROMPT = """You are a memory extraction expert. +Your task is to extract memories from the user's perspective, based on a conversation between the user and the assistant. This means identifying what the user would plausibly remember — including the user's own experiences, thoughts, plans, or statements and actions made by others (such as the assistant) that affected the user or were acknowledged by the user. + +Please perform the following +1. Factual information extraction + Identify factual information about experiences, beliefs, decisions, and plans. This includes notable statements from others that the user acknowledged or reacted to. + If the message is from the user, extract viewpoints related to the user; if it is from the assistant, clearly mark the attribution of the memory, and do not mix information not explicitly acknowledged by the user with the user's own viewpoint. + - **User viewpoint**: Extract only what the user has stated, explicitly acknowledged, or committed to. + - **Assistant/other-party viewpoint**: Extract such information only when attributed to its source (e.g., [Assistant-Jerry's suggestion]). + - **Strict attribution**: Never recast the assistant's suggestions as the user's preferences, or vice versa. + - Always set "model_type" to "LongTermMemory" for this output. + +2. Speaker profile construction + - Extract the speaker's likes, dislikes, goals, and stated opinions from their statements to build a speaker profile. + - Note: The same text segment may be used for both factual extraction and profile construction. + - Always set "model_type" to "UserMemory" for this output. + +3. Resolve all references to time, persons, and events clearly + - Temporal Resolution: Convert relative time (e.g., 'yesterday') to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. + - Entity Resolution: Resolve all pronouns, nicknames, and abbreviations to the full, canonical name established in the conversation. + +4. Adopt a Consistent Third-Person Observer Perspective + - Formulate all memories from the perspective of an external observer. Use "The user" or their specific name as the subject. + - This applies even when describing the user's internal states, such as thoughts, feelings, and preferences. + Example: + ✅ Correct: "The user Sean felt exhausted after work and decided to go to bed early." + ❌ Incorrect: "I felt exhausted after work and decided to go to bed early." + +5. Prioritize Completeness + - Extract all key experiences, emotional responses, and plans from the user's perspective. Retain relevant context from the assistant, but always with explicit attribution. + - Segment each distinct hobby, interest, or event into a separate memory. + - Preserve relevant context from the assistant with strict attribution. Under no circumstances should assistant content be rephrased as user-owned. + - Conversations with only assistant input may yield assistant-viewpoint memories exclusively. + +6. Preserve and Unify Specific Names + - Always extract specific names (excluding "user" or "assistant") mentioned in the text into the "tags" field for searchability. + - Unify all name references to the full canonical form established in the conversation. Replace any nicknames or abbreviations (e.g., "Rob") consistently with the full name (e.g., "Robert") in both the extracted "value" and "tags". + +7. Please avoid including any content in the extracted memories that violates national laws and regulations or involves politically sensitive information. + +Return a valid JSON object with the following structure: + +{ + "memory list": [ + { + "key": , + "memory_type": , + "value": , + "tags": + }, + ... + ], + "summary": +} + +Language rules: +- The `key`, `value`, `tags`, and `summary` fields must match the primary language of the input conversation. **If the input is Chinese, output in Chinese.** +- Keep `memory_type` in English. + +Example: +Conversation: +user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. +assistant: Oh Tom! Do you think the team can finish by December 15? +user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until December 10, so testing will be tight. +assistant: [June 26, 2025 at 3:00 PM]: Maybe propose an extension? +user: [June 26, 2025 at 4:21 PM]: Good idea. I’ll raise it in tomorrow’s 9:30 AM meeting—maybe shift the deadline to January 5. + +Output: +{ + "memory list": [ + { + "key": "Initial project meeting", + "memory_type": "LongTermMemory", + "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom met with the team to discuss a new project. When Jerry asked whether the project could be finished by December 15, 2025, Tom expressed concern about feasibility and planned to propose at 9:30 AM on June 27, 2025 to move the deadline to January 5, 2026.", + "tags": ["Tom", "project", "timeline", "meeting", "deadline"] + }, + { + "key": "Jerry’s suggestion about the deadline", + "memory_type": "LongTermMemory", + "value": "[assistant-Jerry viewpoint] Jerry questioned the December 15 deadline and suggested considering an extension.", + "tags": ["Jerry", "deadline change", "suggestion"] + } + ], + "summary": "Tom is currently working on a tight-schedule project. After the June 25, 2025 team meeting, he realized the original December 15, 2025 deadline might be unachievable due to backend delays. Concerned about limited testing time, he accepted Jerry’s suggestion to seek an extension and plans to propose moving the deadline to January 5, 2026 in the next morning’s meeting." +} + +Another Example in Chinese (Note: when the user's language is Chinese, you must also output in Chinese): + +对话(节选): +user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 +assistant|19:32 +:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? +user|19:35:不喜欢亮色。国贸方便。 +assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 +user|19:40:165cm,S码;最好有口袋。 +assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 +user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 +assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 +user|19:52:行,周六(7/19)去国贸试,合适就买。 +assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 + +{ + "memory list": [ + { + "key": "参加婚礼购买裙子", + "memory_type": "UserMemory", + "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", + "tags": ["婚礼", "预算", "国贸", "计划"] + }, + { + "key": "审美与版型偏好", + "memory_type": "UserMemory", + "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", + "tags": ["偏好", "颜色", "版型"] + }, + { + "key": "体型尺码", + "memory_type": "UserMemory", + "value": "[user观点]用户身高约165cm、常穿S码", + "tags": ["体型", "尺码"] + }, + { + "key": "关于用户选购裙子的建议", + "memory_type": "LongTermMemory", + "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", + "tags": ["婚礼穿着", "门店", "选购路线"] + } + ], + "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" +} + +Always respond in the same language as the conversation. + +Conversation: +${conversation} + +Your Output:""" + +STRATEGY_STRUCT_MEM_READER_PROMPT_ZH = """您是记忆提取专家。 +您的任务是根据用户与助手之间的对话,从用户的角度提取记忆。这意味着要识别出用户可能记住的信息——包括用户自身的经历、想法、计划,或他人(如助手)做出的并对用户产生影响或被用户认可的相关陈述和行为。 + +请执行以下操作: +1. 事实信息提取 + - 识别关于经历、信念、决策和计划的事实信息,包括用户认可或回应过的他人重要陈述。 + - 若信息来自用户,提取与用户相关的观点;若来自助手,需明确标注记忆归属,不得将用户未明确认可的信息与用户自身观点混淆。 + - 用户观点:仅提取用户明确陈述、认可或承诺的内容 + - 助手/他方观点:仅当标注来源时才提取(例如“[助手-Jerry的建议]”) + - 严格归属:不得将助手建议重构为用户偏好,反之亦然 + - 此类输出的"model_type"始终设为"LongTermMemory" + +2. 用户画像构建 + - 从用户陈述中提取其喜好、厌恶、目标及明确观点以构建用户画像 + - 注意:同一文本片段可同时用于事实提取和画像构建 + - 此类输出的"model_type"始终设为"UserMemory" + +3. 明确解析所有指代关系 + - 时间解析:根据消息时间戳将相对时间(如“昨天”)转换为绝对日期。区分事件时间与消息时间,对不确定项进行标注 + - 实体解析:将所有代词、昵称和缩写解析为对话中确立的完整规范名称 + + 4. 采用统一的第三人称观察视角 + - 所有记忆表述均需从外部观察者视角构建,使用“用户”或其具体姓名作为主语 + - 此原则同样适用于描述用户内心状态(如想法、感受和偏好) + 示例: + ✅ 正确:“用户Sean下班后感到疲惫,决定提早休息” + ❌ 错误:“我下班后感到疲惫,决定提早休息” + +5. 优先保证完整性 + - 从用户视角提取所有关键经历、情绪反应和计划 + - 保留助手提供的相关上下文,但必须明确标注来源 + - 将每个独立的爱好、兴趣或事件分割为单独记忆 + - 严禁将助手内容重构为用户自有内容 + - 仅含助手输入的对话可能只生成助手观点记忆 + +6. 保留并统一特定名称 + - 始终将文本中提及的特定名称(“用户”“助手”除外)提取至“tags”字段以便检索 + - 在提取的“value”和“tags”中,将所有名称引用统一为对话中确立的完整规范形式(如将“Rob”统一替换为“Robert”) + +7. 所有提取的记忆内容不得包含违反国家法律法规或涉及政治敏感信息的内容 + +返回一个有效的JSON对象,结构如下: +{ + "memory list": [ + { + "key": <字符串,唯一且简洁的记忆标题>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, + "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, + "tags": <一个包含相关人名、事件和特征关键词的列表(例如,["丽丽","截止日期", "团队", "计划"])> + }, + ... + ], + "summary": <从用户视角自然总结上述记忆的段落,120–200字,与输入语言一致> +} + +语言规则: +- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 + +示例: +对话: +user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 +assistant: 哦Tom!你觉得团队能在12月15日前完成吗? +user: [2025年6月26日下午3:00]:我有点担心。后端要到12月10日才能完成,所以测试时间会很紧。 +assistant: [2025年6月26日下午3:00]:也许提议延期? +user: [2025年6月26日下午4:21]:好主意。我明天上午9:30的会上提一下——也许把截止日期推迟到1月5日。 + +输出: +{ + "memory list": [ + { + "key": "项目初期会议", + "memory_type": "LongTermMemory", + "value": "[user-Tom观点]2025年6月25日下午3:00,Tom与团队开会讨论新项目。当Jerry + 询问该项目能否在2025年12月15日前完成时,Tom对此日期前完成的可行性表达担忧,并计划在2025年6月27日上午9:30 + 提议将截止日期推迟至2026年1月5日。", + "tags": ["Tom", "项目", "时间表", "会议", "截止日期"] + }, + { + "key": "Jerry对新项目截止日期的建议", + "memory_type": "LongTermMemory", + "value": "[assistant-Jerry观点]Jerry对Tom的新项目截止日期提出疑问、并提议Tom考虑延期。", + "tags": ["Jerry", "截止日期变更", "建议"] + } + ], + "summary": "Tom目前正在做一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15 + 日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议,计划在次日早上的会议上提出将截止日期推迟至2026 + 年1月5日。" +} + +另一个中文示例(注意:当用户语言为中文时,您也需输出中文): + +对话(节选): +user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 +assistant|19:32 +:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? +user|19:35:不喜欢亮色。国贸方便。 +assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 +user|19:40:165cm,S码;最好有口袋。 +assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 +user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 +assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 +user|19:52:行,周六(7/19)去国贸试,合适就买。 +assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 + +{ + "memory list": [ + { + "key": "参加婚礼购买裙子", + "memory_type": "UserMemory", + "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", + "tags": ["婚礼", "预算", "国贸", "计划"] + }, + { + "key": "审美与版型偏好", + "memory_type": "UserMemory", + "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", + "tags": ["偏好", "颜色", "版型"] + }, + { + "key": "体型尺码", + "memory_type": "UserMemory", + "value": [user观点]"用户身高约165cm、常穿S码", + "tags": ["体型", "尺码"] + }, + { + "key": "关于用户选购裙子的建议", + "memory_type": "LongTermMemory", + "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", + "tags": ["婚礼穿着", "门店", "选购路线"] + } + ], + "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" +} + +请始终使用与对话相同的语言进行回复。 + +对话: +${conversation} + +您的输出:""" diff --git a/src/memos/templates/mem_search_prompts.py b/src/memos/templates/mem_search_prompts.py new file mode 100644 index 000000000..9f7ba182b --- /dev/null +++ b/src/memos/templates/mem_search_prompts.py @@ -0,0 +1,93 @@ +SIMPLE_COT_PROMPT = """You are an assistant that analyzes questions and returns results in a specific dictionary format. + +Instructions: + +1. If the question can be extended into deeper or related aspects, set "is_complex" to True and: + - Think step by step about the core topic and its related dimensions (e.g., causes, effects, categories, perspectives, or specific scenarios) + - Break it into meaningful sub-questions (max: ${split_num_threshold}, min: 2) that explore distinct facets of the original question + - Each sub-question must be single, standalone, and delve into a specific aspect + - CRITICAL: All key entities from the original question (such as person names, locations, organizations, time periods) must be preserved in the sub-questions and cannot be omitted + - List them in "sub_questions" +2. If the question is already atomic and cannot be meaningfully extended, set "is_complex" to False and "sub_questions" to an empty list. +3. Return ONLY the dictionary, no other text. + +Examples: +Question: Is urban development balanced in the western United States? +Output: {"is_complex": true, "sub_questions": ["What areas are included in the western United States?", "How developed are the cities in the western United States?", "Is this development balanced across the western United States?"]} +Question: What family activities does Mary like to organize? +Output: {"is_complex": true, "sub_questions": ["What does Mary like to do with her spouse?", "What does Mary like to do with her children?", "What does Mary like to do with her parents and relatives?"]} + +Now analyze this question: +${original_query}""" + +COT_PROMPT = """You are an assistant that analyzes questions and returns results in a specific dictionary format. + +Instructions: + +1. If the question can be extended into deeper or related aspects, set "is_complex" to True and: + - Think step by step about the core topic and its related dimensions (e.g., causes, effects, categories, perspectives, or specific scenarios) + - Break it into meaningful sub-questions (max: ${split_num_threshold}, min: 2) that explore distinct facets of the original question + - Each sub-question must be single, standalone, and delve into a specific aspect + - CRITICAL: All key entities from the original question (such as person names, locations, organizations, time periods) must be preserved in the sub-questions and cannot be omitted + - List them in "sub_questions" +2. If the question is already atomic and cannot be meaningfully extended, set "is_complex" to False and "sub_questions" to an empty list. +3. Return ONLY the dictionary, no other text. + +Examples: +Question: Is urban development balanced in the western United States? +Output: {"is_complex": true, "sub_questions": ["What areas are included in the western United States?", "How developed are the cities in the western United States?", "Is this development balanced across the western United States?"]} +Question: What family activities does Mary like to organize? +Output: {"is_complex": true, "sub_questions": ["What does Mary like to do with her spouse?", "What does Mary like to do with her children?", "What does Mary like to do with her parents and relatives?"]} + +Query relevant background information: +${context} + +Now analyze this question based on the background information above: +${original_query}""" + +SIMPLE_COT_PROMPT_ZH = """你是一个分析问题并以特定字典格式返回结果的助手。 + +指令: + +1. 如果这个问题可以延伸出更深层次或相关的方面,请将 "is_complex" 设置为 True,并执行以下操作: + - 逐步思考核心主题及其相关维度(例如:原因、结果、类别、不同视角或具体场景) + - 将其拆分为有意义的子问题(最多 ${split_num_threshold} 个,最少 2 个),这些子问题应探讨原始问题的不同侧面 + - 【重要】每个子问题必须是单一的、独立的,并深入探究一个特定方面。同时,必须包含原问题中出现的关键实体信息(如人名、地名、机构名、时间等),不可遗漏。 + - 将它们列在 "sub_questions" 中 +2. 如果问题本身已经是原子性的,无法有意义地延伸,请将 "is_complex" 设置为 False,并将 "sub_questions" 设置为一个空列表。 +3. 只返回字典,不要返回任何其他文本。 + +示例: +问题:美国西部的城市发展是否均衡? +输出:{"is_complex": true, "sub_questions": ["美国西部包含哪些地区?", "美国西部城市的发展程度如何?", "这种发展在美国西部是否均衡?"]} + +问题:玛丽喜欢组织哪些家庭活动? +输出:{"is_complex": true, "sub_questions": ["玛丽喜欢和配偶一起做什么?", "玛丽喜欢和孩子一起做什么?", "玛丽喜欢和父母及亲戚一起做什么?"]} + +请分析以下问题: +${original_query}""" + +COT_PROMPT_ZH = """你是一个分析问题并以特定字典格式返回结果的助手。 + +指令: + +1. 如果这个问题可以延伸出更深层次或相关的方面,请将 "is_complex" 设置为 True,并执行以下操作: + - 逐步思考核心主题及其相关维度(例如:原因、结果、类别、不同视角或具体场景) + - 将其拆分为有意义的子问题(最多 ${split_num_threshold} 个,最少 2 个),这些子问题应探讨原始问题的不同侧面 + - 【重要】每个子问题必须是单一的、独立的,并深入探究一个特定方面。同时,必须包含原问题中出现的关键实体信息(如人名、地名、机构名、时间等),不可遗漏。 + - 将它们列在 "sub_questions" 中 +2. 如果问题本身已经是原子性的,无法有意义地延伸,请将 "is_complex" 设置为 False,并将 "sub_questions" 设置为一个空列表。 +3. 只返回字典,不要返回任何其他文本。 + +示例: +问题:美国西部的城市发展是否均衡? +输出:{"is_complex": true, "sub_questions": ["美国西部包含哪些地区?", "美国西部城市的发展程度如何?", "这种发展在美国西部是否均衡?"]} + +问题:玛丽喜欢组织哪些家庭活动? +输出:{"is_complex": true, "sub_questions": ["玛丽喜欢和配偶一起做什么?", "玛丽喜欢和孩子一起做什么?", "玛丽喜欢和父母及亲戚一起做什么?"]} + +问题相关的背景信息: +${context} + +现在根据上述背景信息,请分析以下问题: +${original_query}""" diff --git a/tests/memories/textual/test_tree_task_goal_parser.py b/tests/memories/textual/test_tree_task_goal_parser.py index c71af4b06..899e2454b 100644 --- a/tests/memories/textual/test_tree_task_goal_parser.py +++ b/tests/memories/textual/test_tree_task_goal_parser.py @@ -20,12 +20,7 @@ def generate(self, messages): def test_parse_fast_returns_expected(): parser = TaskGoalParser() result = parser.parse("Tell me about cats", mode="fast") - assert isinstance(result, ParsedTaskGoal) - assert result.memories == ["Tell me about cats"] - assert result.keys == ["Tell me about cats"] - assert result.tags == [] - assert result.goal_type == "default" def test_parse_fine_calls_llm_and_parses(): From 5b8893e7e63ed5ab6763953581f7c7d221272af7 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Thu, 30 Oct 2025 10:35:42 +0800 Subject: [PATCH 25/64] Revert "Feat: add recall strategy " (#415) Revert "Feat: add recall strategy (#414)" This reverts commit a375911827b4c6fe3fd82758a23a6b6cb0c9adec. Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- poetry.lock | 50 +-- pyproject.toml | 4 +- src/memos/api/config.py | 34 +- src/memos/configs/mem_reader.py | 9 - src/memos/configs/memory.py | 7 - src/memos/llms/openai.py | 13 +- src/memos/mem_reader/factory.py | 2 - src/memos/mem_reader/strategy_struct.py | 138 ------- src/memos/memories/textual/simple_tree.py | 18 - src/memos/memories/textual/tree.py | 16 - .../tree_text_memory/retrieve/bm25_util.py | 186 --------- .../tree_text_memory/retrieve/recall.py | 107 +---- .../retrieve/retrieve_utils.py | 378 ------------------ .../tree_text_memory/retrieve/searcher.py | 88 +--- .../retrieve/task_goal_parser.py | 7 +- .../templates/mem_reader_strategy_prompts.py | 279 ------------- src/memos/templates/mem_search_prompts.py | 93 ----- .../textual/test_tree_task_goal_parser.py | 5 + 18 files changed, 41 insertions(+), 1393 deletions(-) delete mode 100644 src/memos/mem_reader/strategy_struct.py delete mode 100644 src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py delete mode 100644 src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py delete mode 100644 src/memos/templates/mem_reader_strategy_prompts.py delete mode 100644 src/memos/templates/mem_search_prompts.py diff --git a/poetry.lock b/poetry.lock index 926d580fb..44265bca8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -192,19 +192,6 @@ torch = ">=1.0.0" tqdm = ">=4.31.1" transformers = ">=3.0.0" -[[package]] -name = "cachetools" -version = "6.2.1" -description = "Extensible memoizing collections and decorators" -optional = true -python-versions = ">=3.9" -groups = ["main"] -markers = "extra == \"all\"" -files = [ - {file = "cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701"}, - {file = "cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201"}, -] - [[package]] name = "certifi" version = "2025.7.14" @@ -1566,18 +1553,6 @@ files = [ {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, ] -[[package]] -name = "jieba" -version = "0.42" -description = "Chinese Words Segmentation Utilities" -optional = true -python-versions = "*" -groups = ["main"] -markers = "extra == \"all\"" -files = [ - {file = "jieba-0.42.tar.gz", hash = "sha256:34a3c960cc2943d9da16d6d2565110cf5f305921a67413dddf04f84de69c939b"}, -] - [[package]] name = "jinja2" version = "3.1.6" @@ -4148,25 +4123,6 @@ urllib3 = ">=1.26.14,<3" fastembed = ["fastembed (>=0.7,<0.8)"] fastembed-gpu = ["fastembed-gpu (>=0.7,<0.8)"] -[[package]] -name = "rank-bm25" -version = "0.2.2" -description = "Various BM25 algorithms for document ranking" -optional = true -python-versions = "*" -groups = ["main"] -markers = "extra == \"all\"" -files = [ - {file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"}, - {file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"}, -] - -[package.dependencies] -numpy = "*" - -[package.extras] -dev = ["pytest"] - [[package]] name = "redis" version = "6.2.0" @@ -6396,7 +6352,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["cachetools", "chonkie", "datasketch", "jieba", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["chonkie", "datasketch", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] @@ -6406,4 +6362,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "ec17679a44205ada4494fbc485ac592883281fde273d5e73d6b8cbc6f7f9ed10" +content-hash = "3f0d0c9a996f87d945ef8bf83eed3e20f8c420b6b39e12012d0147eda2bf4d38" diff --git a/pyproject.toml b/pyproject.toml index 2f88797a8..3745582f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,9 +107,7 @@ all = [ "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", "pymilvus (>=2.6.1,<3.0.0)", "datasketch (>=1.6.5,<2.0.0)", - "jieba (>=0.38.1,<0.42.1)", - "rank-bm25 (>=0.2.2)", - "cachetools (>=6.0.0)", + # NOT exist in the above optional groups # Because they are either huge-size dependencies or infrequently used dependencies. # We kindof don't want users to install them. diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 405e8068d..7ac882d6c 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -419,23 +419,9 @@ def get_embedder_config() -> dict[str, Any]: }, } - @staticmethod - def get_reader_config() -> dict[str, Any]: - """Get reader configuration.""" - return { - "backend": os.getenv("MEM_READER_BACKEND", "simple_struct"), - "config": { - "chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"), - "chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)), - "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)), - "chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)), - }, - } - @staticmethod def get_internet_config() -> dict[str, Any]: """Get embedder configuration.""" - reader_config = APIConfig.get_reader_config() return { "backend": "bocha", "config": { @@ -443,7 +429,7 @@ def get_internet_config() -> dict[str, Any]: "max_results": 15, "num_per_request": 10, "reader": { - "backend": reader_config["backend"], + "backend": "simple_struct", "config": { "llm": { "backend": "openai", @@ -469,7 +455,6 @@ def get_internet_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, - "chat_chunker": reader_config, }, }, }, @@ -671,8 +656,6 @@ def get_product_default_config() -> dict[str, Any]: openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() - reader_config = APIConfig.get_reader_config() - backend_model = { "openai": openai_config, "huggingface": qwen_config, @@ -684,7 +667,7 @@ def get_product_default_config() -> dict[str, Any]: "user_id": os.getenv("MOS_USER_ID", "root"), "chat_model": {"backend": backend, "config": backend_model[backend]}, "mem_reader": { - "backend": reader_config["backend"], + "backend": "simple_struct", "config": { "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), @@ -697,7 +680,6 @@ def get_product_default_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, - "chat_chunker": reader_config, }, }, "enable_textual_memory": True, @@ -768,7 +750,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() mysql_config = APIConfig.get_mysql_config() - reader_config = APIConfig.get_reader_config() backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai") backend_model = { "openai": openai_config, @@ -783,7 +764,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "config": backend_model[backend], }, "mem_reader": { - "backend": reader_config["backend"], + "backend": "simple_struct", "config": { "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), @@ -796,7 +777,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "min_sentences_per_chunk": 1, }, }, - "chat_chunker": reader_config, }, }, "enable_textual_memory": True, @@ -865,10 +845,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, - "search_strategy": { - "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), - "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), - }, }, }, "act_mem": {} @@ -936,10 +912,6 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, - "search_strategy": { - "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), - "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), - }, "mode": os.getenv("ASYNC_MODE", "sync"), }, }, diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index dc8d37a35..1c62087a3 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -36,19 +36,11 @@ def parse_datetime(cls, value): description="whether remove example in memory extraction prompt to save token", ) - chat_chunker: dict[str, Any] = Field( - default=None, description="Configuration for the MemReader chat chunk strategy" - ) - class SimpleStructMemReaderConfig(BaseMemReaderConfig): """SimpleStruct MemReader configuration class.""" -class StrategyStructMemReaderConfig(BaseMemReaderConfig): - """StrategyStruct MemReader configuration class.""" - - class MemReaderConfigFactory(BaseConfig): """Factory class for creating MemReader configurations.""" @@ -57,7 +49,6 @@ class MemReaderConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReaderConfig, - "strategy_struct": StrategyStructMemReaderConfig, } @field_validator("backend") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 49320fbf5..bf2493567 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -184,13 +184,6 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) - search_strategy: dict[str, bool] | None = Field( - default=None, - description=( - 'Set search strategy for this memory configuration.{"bm25": true, "cot": false}' - ), - ) - mode: str | None = Field( default="sync", description=("whether use asynchronous mode in memory add"), diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 1a1703340..ca1df5c1f 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -58,18 +58,15 @@ def clear_cache(cls): logger.info("OpenAI LLM instance cache cleared") @timed(log=True, log_prefix="OpenAI LLM") - def generate(self, messages: MessageList, **kwargs) -> str: - """Generate a response from OpenAI LLM, optionally overriding generation params.""" - temperature = kwargs.get("temperature", self.config.temperature) - max_tokens = kwargs.get("max_tokens", self.config.max_tokens) - top_p = kwargs.get("top_p", self.config.top_p) + def generate(self, messages: MessageList) -> str: + """Generate a response from OpenAI LLM.""" response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, extra_body=self.config.extra_body, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, ) logger.info(f"Response from OpenAI: {response.model_dump_json()}") response_content = response.choices[0].message.content diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 2205a0215..52eed8d9d 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -3,7 +3,6 @@ from memos.configs.mem_reader import MemReaderConfigFactory from memos.mem_reader.base import BaseMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader -from memos.mem_reader.strategy_struct import StrategyStructMemReader from memos.memos_tools.singleton import singleton_factory @@ -12,7 +11,6 @@ class MemReaderFactory(BaseMemReader): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReader, - "strategy_struct": StrategyStructMemReader, } @classmethod diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py deleted file mode 100644 index 2cac1652a..000000000 --- a/src/memos/mem_reader/strategy_struct.py +++ /dev/null @@ -1,138 +0,0 @@ -import os - -from abc import ABC - -from memos import log -from memos.configs.mem_reader import StrategyStructMemReaderConfig -from memos.configs.parser import ParserConfigFactory -from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang -from memos.parsers.factory import ParserFactory -from memos.templates.mem_reader_prompts import ( - SIMPLE_STRUCT_DOC_READER_PROMPT, - SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, - SIMPLE_STRUCT_MEM_READER_EXAMPLE, - SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, -) -from memos.templates.mem_reader_strategy_prompts import ( - STRATEGY_STRUCT_MEM_READER_PROMPT, - STRATEGY_STRUCT_MEM_READER_PROMPT_ZH, -) - - -logger = log.get_logger(__name__) -STRATEGY_PROMPT_DICT = { - "chat": { - "en": STRATEGY_STRUCT_MEM_READER_PROMPT, - "zh": STRATEGY_STRUCT_MEM_READER_PROMPT_ZH, - "en_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE, - "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, - }, - "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, -} - - -class StrategyStructMemReader(SimpleStructMemReader, ABC): - """Naive implementation of MemReader.""" - - def __init__(self, config: StrategyStructMemReaderConfig): - super().__init__(config) - self.chat_chunker = config.chat_chunker["config"] - - def _get_llm_response(self, mem_str: str) -> dict: - lang = detect_lang(mem_str) - template = STRATEGY_PROMPT_DICT["chat"][lang] - examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] - prompt = template.replace("${conversation}", mem_str) - if self.config.remove_prompt_example: - prompt = prompt.replace(examples, "") - messages = [{"role": "user", "content": prompt}] - try: - response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) - except Exception as e: - logger.error(f"[LLM] Exception during chat generation: {e}") - response_json = { - "memory list": [ - { - "key": mem_str[:10], - "memory_type": "UserMemory", - "value": mem_str, - "tags": [], - } - ], - "summary": mem_str, - } - return response_json - - def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: - """ - Get raw information from scene_data. - If scene_data contains dictionaries, convert them to strings. - If scene_data contains file paths, parse them using the parser. - - Args: - scene_data: List of dialogue information or document paths - type: Type of scene data: ['doc', 'chat'] - Returns: - List of strings containing the processed scene data - """ - results = [] - - if type == "chat": - if self.chat_chunker["chunk_type"] == "content_length": - content_len_thredshold = self.chat_chunker["chunk_length"] - for items in scene_data: - if not items: - continue - - results.append([]) - current_length = 0 - - for _i, item in enumerate(items): - content_length = ( - len(item.get("content", "")) - if isinstance(item, dict) - else len(str(item)) - ) - if not results[-1]: - results[-1].append(item) - current_length = content_length - continue - - if current_length + content_length <= content_len_thredshold: - results[-1].append(item) - current_length += content_length - else: - overlap_item = results[-1][-1] - overlap_length = ( - len(overlap_item.get("content", "")) - if isinstance(overlap_item, dict) - else len(str(overlap_item)) - ) - - results.append([overlap_item, item]) - current_length = overlap_length + content_length - elif type == "doc": - parser_config = ParserConfigFactory.model_validate( - { - "backend": "markitdown", - "config": {}, - } - ) - parser = ParserFactory.from_config(parser_config) - for item in scene_data: - try: - if os.path.exists(item): - try: - parsed_text = parser.parse(item) - results.append({"file": item, "text": parsed_text}) - except Exception as e: - logger.error(f"[SceneParser] Error parsing {item}: {e}") - continue - else: - parsed_text = item - results.append({"file": "pure_text", "text": parsed_text}) - except Exception as e: - print(f"Error parsing file {item}: {e!s}") - - return results diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 6974dbe8f..8ce81a8bd 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -12,7 +12,6 @@ from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager -from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.base import BaseReranker from memos.types import MessageList @@ -63,19 +62,6 @@ def __init__( self.graph_store: Neo4jGraphDB = graph_db logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") - time_start_bm = time.time() - self.search_strategy = config.search_strategy - self.bm25_retriever = ( - EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None - ) - logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") - - self.vec_cot = ( - self.search_strategy["cot"] - if self.search_strategy and "cot" in self.search_strategy - else False - ) - time_start_rr = time.time() self.reranker = reranker logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") @@ -186,10 +172,8 @@ def search( self.graph_store, self.embedder, self.reranker, - bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, - vec_cot=self.vec_cot, ) else: searcher = Searcher( @@ -197,10 +181,8 @@ def search( self.graph_store, self.embedder, self.reranker, - bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, - vec_cot=self.vec_cot, ) return searcher.search( query, top_k, info, mode, memory_type, search_filter, user_name=user_name diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index a58f993bb..56c8117e9 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -16,7 +16,6 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager -from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -46,17 +45,6 @@ def __init__(self, config: TreeTextMemoryConfig): ) self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db) - - self.search_strategy = config.search_strategy - self.bm25_retriever = ( - EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None - ) - self.vec_cot = ( - self.search_strategy["cot"] - if self.search_strategy and "cot" in self.search_strategy - else False - ) - if config.reranker is None: default_cfg = RerankerConfigFactory.model_validate( { @@ -197,10 +185,8 @@ def search( self.graph_store, self.embedder, self.reranker, - bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, - vec_cot=self.vec_cot, ) else: searcher = Searcher( @@ -208,10 +194,8 @@ def search( self.graph_store, self.embedder, self.reranker, - bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, - vec_cot=self.vec_cot, ) return searcher.search(query, top_k, info, mode, memory_type, search_filter) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py b/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py deleted file mode 100644 index 4aca4022f..000000000 --- a/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +++ /dev/null @@ -1,186 +0,0 @@ -import threading - -import numpy as np - -from sklearn.feature_extraction.text import TfidfVectorizer - -from memos.dependency import require_python_package -from memos.log import get_logger -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer -from memos.utils import timed - - -logger = get_logger(__name__) -# Global model cache -_CACHE_LOCK = threading.Lock() - - -class EnhancedBM25: - """Enhanced BM25 with Spacy tokenization and TF-IDF reranking""" - - @require_python_package(import_name="cachetools", install_command="pip install cachetools") - def __init__(self, tokenizer=None, en_model="en_core_web_sm", zh_model="zh_core_web_sm"): - """ - Initialize Enhanced BM25 with memory management - """ - if tokenizer is None: - self.tokenizer = FastTokenizer() - else: - self.tokenizer = tokenizer - self._current_tfidf = None - - global _BM25_CACHE - from cachetools import LRUCache - - _BM25_CACHE = LRUCache(maxsize=100) - - def _tokenize_doc(self, text): - """ - Tokenize a single document using SpacyTokenizer - """ - return self.tokenizer.tokenize_mixed(text, lang="auto") - - @require_python_package(import_name="rank_bm25", install_command="pip install rank_bm25") - def _prepare_corpus_data(self, corpus, corpus_name="default"): - from rank_bm25 import BM25Okapi - - with _CACHE_LOCK: - if corpus_name in _BM25_CACHE: - print("hit::", corpus_name) - return _BM25_CACHE[corpus_name] - print("not hit::", corpus_name) - - tokenized_corpus = [self._tokenize_doc(doc) for doc in corpus] - bm25_model = BM25Okapi(tokenized_corpus) - _BM25_CACHE[corpus_name] = bm25_model - return bm25_model - - def clear_cache(self, corpus_name=None): - """Clear cache for specific corpus or clear all cache""" - with _CACHE_LOCK: - if corpus_name: - if corpus_name in _BM25_CACHE: - del _BM25_CACHE[corpus_name] - else: - _BM25_CACHE.clear() - - def get_cache_info(self): - """Get current cache information""" - with _CACHE_LOCK: - return { - "cache_size": len(_BM25_CACHE), - "max_cache_size": 100, - "cached_corpora": list(_BM25_CACHE.keys()), - } - - def _search_docs( - self, - query: str, - corpus: list[str], - corpus_name="test", - top_k=50, - use_tfidf=False, - rerank_candidates_multiplier=2, - cleanup=False, - ): - """ - Args: - query: Search query string - corpus: List of document texts - top_k: Number of top results to return - rerank_candidates_multiplier: Multiplier for candidate selection - cleanup: Whether to cleanup memory after search (default: True) - """ - if not corpus: - return [] - - logger.info(f"Searching {len(corpus)} documents for query: '{query}'") - - try: - # Prepare BM25 model - bm25_model = self._prepare_corpus_data(corpus, corpus_name=corpus_name) - tokenized_query = self._tokenize_doc(query) - tokenized_query = list(dict.fromkeys(tokenized_query)) - - # Get BM25 scores - bm25_scores = bm25_model.get_scores(tokenized_query) - - # Select candidates - candidate_count = min(top_k * rerank_candidates_multiplier, len(corpus)) - candidate_indices = np.argsort(bm25_scores)[-candidate_count:][::-1] - combined_scores = bm25_scores[candidate_indices] - - if use_tfidf: - # Create TF-IDF for this search - tfidf = TfidfVectorizer( - tokenizer=self._tokenize_doc, lowercase=False, token_pattern=None - ) - tfidf_matrix = tfidf.fit_transform(corpus) - - # TF-IDF reranking - query_vec = tfidf.transform([query]) - tfidf_similarities = ( - (tfidf_matrix[candidate_indices] * query_vec.T).toarray().flatten() - ) - - # Combine scores - combined_scores = 0.7 * bm25_scores[candidate_indices] + 0.3 * tfidf_similarities - - sorted_candidate_indices = candidate_indices[np.argsort(combined_scores)[::-1][:top_k]] - sorted_combined_scores = np.sort(combined_scores)[::-1][:top_k] - - # build result list - bm25_recalled_results = [] - for rank, (doc_idx, combined_score) in enumerate( - zip(sorted_candidate_indices, sorted_combined_scores, strict=False), 1 - ): - bm25_score = bm25_scores[doc_idx] - - candidate_pos = np.where(candidate_indices == doc_idx)[0][0] - tfidf_score = tfidf_similarities[candidate_pos] if use_tfidf else 0 - - bm25_recalled_results.append( - { - "text": corpus[doc_idx], - "bm25_score": float(bm25_score), - "tfidf_score": float(tfidf_score), - "combined_score": float(combined_score), - "rank": rank, - "doc_index": int(doc_idx), - } - ) - - logger.debug(f"Search completed: found {len(bm25_recalled_results)} results") - return bm25_recalled_results - - except Exception as e: - logger.error(f"BM25 search failed: {e}") - return [] - finally: - # Always cleanup if requested - if cleanup: - self._cleanup_memory() - - @timed - def search(self, query: str, node_dicts: list[dict], corpus_name="default", **kwargs): - """ - Search with BM25 and optional TF-IDF reranking - """ - try: - corpus_list = [] - for node_dict in node_dicts: - corpus_list.append( - " ".join([node_dict["metadata"]["key"]] + node_dict["metadata"]["tags"]) - ) - - recalled_results = self._search_docs( - query, corpus_list, corpus_name=corpus_name, **kwargs - ) - bm25_searched_nodes = [] - for item in recalled_results: - doc_idx = item["doc_index"] - bm25_searched_nodes.append(node_dicts[doc_idx]) - return bm25_searched_nodes - except Exception as e: - logger.error(f"Error in bm25 search: {e}") - return [] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index b7383aa13..c1ade3021 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -5,7 +5,6 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem -from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal @@ -17,18 +16,11 @@ class GraphMemoryRetriever: Unified memory retriever that combines both graph-based and vector-based retrieval logic. """ - def __init__( - self, - graph_store: Neo4jGraphDB, - embedder: OllamaEmbedder, - bm25_retriever: EnhancedBM25 | None = None, - ): + def __init__(self, graph_store: Neo4jGraphDB, embedder: OllamaEmbedder): self.graph_store = graph_store self.embedder = embedder - self.bm25_retriever = bm25_retriever self.max_workers = 10 self.filter_weight = 0.6 - self.use_bm25 = bool(self.bm25_retriever) def retrieve( self, @@ -39,7 +31,6 @@ def retrieve( query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, user_name: str | None = None, - id_filter: dict | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -67,7 +58,7 @@ def retrieve( ) return [TextualMemoryItem.from_dict(record) for record in working_memories] - with ContextThreadPoolExecutor(max_workers=3) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: # Structured graph-based retrieval future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) # Vector similarity search @@ -79,23 +70,12 @@ def retrieve( search_filter=search_filter, user_name=user_name, ) - if self.use_bm25: - future_bm25 = executor.submit( - self._bm25_recall, - query, - parsed_goal, - memory_scope, - top_k=top_k, - user_name=user_name, - search_filter=id_filter, - ) graph_results = future_graph.result() vector_results = future_vector.result() - bm25_results = future_bm25.result() if self.use_bm25 else [] # Merge and deduplicate by ID - combined = {item.id: item for item in graph_results + vector_results + bm25_results} + combined = {item.id: item for item in graph_results + vector_results} graph_ids = {item.id for item in graph_results} combined_ids = set(combined.keys()) @@ -163,27 +143,6 @@ def _graph_recall( - tags must overlap with at least 2 input tags - scope filters by memory_type if provided """ - - def process_node(node): - meta = node.get("metadata", {}) - node_key = meta.get("key") - node_tags = meta.get("tags", []) or [] - - keep = False - # key equals to node_key - if parsed_goal.keys and node_key in parsed_goal.keys: - keep = True - # overlap tags more than 2 - elif parsed_goal.tags: - node_tags_list = [tag.lower() for tag in node_tags] - overlap = len(set(node_tags_list) & set(parsed_goal.tags)) - if overlap >= 2: - keep = True - - if keep: - return TextualMemoryItem.from_dict(node) - return None - candidate_ids = set() # 1) key-based OR branch @@ -214,16 +173,22 @@ def process_node(node): ) final_nodes = [] - with ContextThreadPoolExecutor(max_workers=3) as executor: - futures = {executor.submit(process_node, node): i for i, node in enumerate(node_dicts)} - temp_results = [None] * len(node_dicts) - - for future in concurrent.futures.as_completed(futures): - original_index = futures[future] - result = future.result() - temp_results[original_index] = result + for node in node_dicts: + meta = node.get("metadata", {}) + node_key = meta.get("key") + node_tags = meta.get("tags", []) or [] - final_nodes = [result for result in temp_results if result is not None] + keep = False + # key equals to node_key + if parsed_goal.keys and node_key in parsed_goal.keys: + keep = True + # overlap tags more than 2 + elif parsed_goal.tags: + overlap = len(set(node_tags) & set(parsed_goal.tags)) + if overlap >= 2: + keep = True + if keep: + final_nodes.append(TextualMemoryItem.from_dict(node)) return final_nodes def _vector_recall( @@ -231,7 +196,7 @@ def _vector_recall( query_embedding: list[list[float]], memory_scope: str, top_k: int = 20, - max_num: int = 5, + max_num: int = 3, status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, @@ -304,37 +269,3 @@ def search_path_b(): or [] ) return [TextualMemoryItem.from_dict(n) for n in node_dicts] - - def _bm25_recall( - self, - query: str, - parsed_goal: ParsedTaskGoal, - memory_scope: str, - top_k: int = 20, - user_name: str | None = None, - search_filter: dict | None = None, - ) -> list[TextualMemoryItem]: - """ - Perform BM25-based retrieval. - """ - if not self.bm25_retriever: - return [] - key_filters = [ - {"field": "memory_type", "op": "=", "value": memory_scope}, - ] - # corpus_name is user_name + user_id - corpus_name = f"{user_name}" if user_name else "" - if search_filter is not None: - for key in search_filter: - value = search_filter[key] - key_filters.append({"field": key, "op": "=", "value": value}) - corpus_name += "".join(list(search_filter.values())) - candidate_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) - - bm25_query = " ".join(list({query, *parsed_goal.keys})) - bm25_results = self.bm25_retriever.search( - bm25_query, node_dicts, top_k=top_k, corpus_name=corpus_name - ) - - return [TextualMemoryItem.from_dict(n) for n in bm25_results] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py deleted file mode 100644 index eec827c86..000000000 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ /dev/null @@ -1,378 +0,0 @@ -import json -import re - -from pathlib import Path - -from memos.dependency import require_python_package -from memos.log import get_logger - - -logger = get_logger(__name__) - - -def find_project_root(marker=".git"): - """Find the project root directory by marking the file""" - current = Path(__file__).resolve() - while current != current.parent: - if (current / marker).exists(): - return current - current = current.parent - logger.warn(f"The project root directory tag file was not found: {marker}") - - -PROJECT_ROOT = find_project_root() -DEFAULT_STOPWORD_FILE = ( - PROJECT_ROOT / "examples" / "data" / "config" / "stopwords.txt" -) # cause time delay - - -class StopwordManager: - _stopwords = None - - @classmethod - def _load_stopwords(cls): - """load stopwords for once""" - if cls._stopwords is not None: - return cls._stopwords - - stopwords = set() - try: - with open(DEFAULT_STOPWORD_FILE, encoding="utf-8") as f: - stopwords = {line.strip() for line in f if line.strip()} - logger.info("Stopwords loaded successfully.") - except Exception as e: - logger.warning(f"Error loading stopwords: {e}, using default stopwords.") - stopwords = cls._load_default_stopwords() - - cls._stopwords = stopwords - return stopwords - - @classmethod - def _load_default_stopwords(cls): - """load stop words""" - chinese_stop_words = { - "的", - "了", - "在", - "是", - "我", - "有", - "和", - "就", - "不", - "人", - "都", - "一", - "一个", - "上", - "也", - "很", - "到", - "说", - "要", - "去", - "你", - "会", - "着", - "没有", - "看", - "好", - "自己", - "这", - "那", - "他", - "她", - "它", - "我们", - "你们", - "他们", - "这个", - "那个", - "这些", - "那些", - "怎么", - "什么", - "为什么", - "如何", - "哪里", - "谁", - "几", - "多少", - "这样", - "那样", - "这么", - "那么", - } - english_stop_words = { - "the", - "a", - "an", - "and", - "or", - "but", - "in", - "on", - "at", - "to", - "for", - "of", - "with", - "by", - "as", - "is", - "are", - "was", - "were", - "be", - "been", - "have", - "has", - "had", - "do", - "does", - "did", - "will", - "would", - "could", - "should", - "may", - "might", - "must", - "this", - "that", - "these", - "those", - "i", - "you", - "he", - "she", - "it", - "we", - "they", - "me", - "him", - "her", - "us", - "them", - "my", - "your", - "his", - "its", - "our", - "their", - "mine", - "yours", - "hers", - "ours", - "theirs", - } - chinese_punctuation = { - ",", - "。", - "!", - "?", - ";", - ":", - "「", - "」", - "『", - "』", - "【", - "】", - "(", - ")", - "《", - "》", - "—", - "…", - "~", - "·", - "、", - "“", - "”", - "‘", - "’", - "〈", - "〉", - "〖", - "〗", - "〝", - "〞", - "{", - "}", - "〔", - "〕", - "¡", - "¿", - } - english_punctuation = { - ",", - ".", - "!", - "?", - ";", - ":", - '"', - "'", - "(", - ")", - "[", - "]", - "{", - "}", - "<", - ">", - "/", - "\\", - "|", - "-", - "_", - "=", - "+", - "@", - "#", - "$", - "%", - "^", - "&", - "*", - "~", - "`", - "¡", - "¿", - } - numbers = { - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "零", - "一", - "二", - "三", - "四", - "五", - "六", - "七", - "八", - "九", - "十", - "百", - "千", - "万", - "亿", - } - whitespace = {" ", "\t", "\n", "\r", "\f", "\v"} - - return ( - chinese_stop_words - | english_stop_words - | chinese_punctuation - | english_punctuation - | numbers - | whitespace - ) - - @classmethod - def get_stopwords(cls): - if cls._stopwords is None: - cls._load_stopwords() - return cls._stopwords - - @classmethod - def filter_words(cls, words): - if cls._stopwords is None: - cls._load_stopwords() - return [word for word in words if word not in cls._stopwords and word.strip()] - - @classmethod - def is_stopword(cls, word): - if cls._stopwords is None: - cls._load_stopwords() - return word in cls._stopwords - - @classmethod - def reload_stopwords(cls, file_path=None): - cls._stopwords = None - if file_path: - global DEFAULT_STOPWORD_FILE - DEFAULT_STOPWORD_FILE = file_path - cls._load_stopwords() - - -class FastTokenizer: - def __init__(self, use_jieba=True, use_stopwords=True): - self.use_jieba = use_jieba - self.use_stopwords = use_stopwords - if self.use_stopwords: - self.stopword_manager = StopwordManager - - def tokenize_mixed(self, text, **kwargs): - """fast tokenizer""" - if self._is_chinese(text): - return self._tokenize_chinese(text) - else: - return self._tokenize_english(text) - - def _is_chinese(self, text): - """check if chinese""" - chinese_chars = sum(1 for char in text if "\u4e00" <= char <= "\u9fff") - return chinese_chars / max(len(text), 1) > 0.3 - - @require_python_package( - import_name="jieba", - install_command="pip install jieba", - install_link="https://github.com/fxsjy/jieba", - ) - def _tokenize_chinese(self, text): - """split zh jieba""" - import jieba - - tokens = jieba.lcut(text) if self.use_jieba else list(text) - tokens = [token.strip() for token in tokens if token.strip()] - if self.use_stopwords: - return self.stopword_manager.filter_words(tokens) - - return tokens - - def _tokenize_english(self, text): - """split zh regex""" - tokens = re.findall(r"\b[a-zA-Z0-9]+\b", text.lower()) - if self.use_stopwords: - return self.stopword_manager.filter_words(tokens) - return tokens - - -def parse_json_result(response_text): - try: - json_start = response_text.find("{") - response_text = response_text[json_start:] - response_text = response_text.replace("```", "").strip() - if not response_text.endswith("}"): - response_text += "}" - return json.loads(response_text) - except json.JSONDecodeError as e: - logger.error(f"[JSONParse] Failed to decode JSON: {e}\nRaw:\n{response_text}") - return {} - except Exception as e: - logger.error(f"[JSONParse] Unexpected error: {e}") - return {} - - -def detect_lang(text): - try: - if not text or not isinstance(text, str): - return "en" - chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" - chinese_chars = re.findall(chinese_pattern, text) - if len(chinese_chars) / len(re.sub(r"[\s\d\W]", "", text)) > 0.3: - return "zh" - return "en" - except Exception: - return "en" diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 563695c68..9d540b311 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -9,18 +9,7 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem -from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( - detect_lang, - parse_json_result, -) from memos.reranker.base import BaseReranker -from memos.templates.mem_search_prompts import ( - COT_PROMPT, - COT_PROMPT_ZH, - SIMPLE_COT_PROMPT, - SIMPLE_COT_PROMPT_ZH, -) from memos.utils import timed from .reasoner import MemoryReasoner @@ -29,10 +18,6 @@ logger = get_logger(__name__) -COT_DICT = { - "fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, - "fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, -} class Searcher: @@ -42,24 +27,20 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, - bm25_retriever: EnhancedBM25 | None = None, internet_retriever: None = None, moscube: bool = False, - vec_cot: bool = False, ): self.graph_store = graph_store self.embedder = embedder - self.llm = dispatcher_llm self.task_goal_parser = TaskGoalParser(dispatcher_llm) - self.graph_retriever = GraphMemoryRetriever(graph_store, embedder, bm25_retriever) + self.graph_retriever = GraphMemoryRetriever(self.graph_store, self.embedder) self.reranker = reranker self.reasoner = MemoryReasoner(dispatcher_llm) # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube - self.vec_cot = vec_cot self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -250,12 +231,6 @@ def _retrieve_paths( ): """Run A/B/C retrieval paths in parallel""" tasks = [] - id_filter = { - "user_id": info.get("user_id", None), - "session_id": info.get("session_id", None), - } - id_filter = {k: v for k, v in id_filter.items() if v is not None} - with ContextThreadPoolExecutor(max_workers=3) as executor: tasks.append( executor.submit( @@ -267,7 +242,6 @@ def _retrieve_paths( memory_type, search_filter, user_name, - id_filter, ) ) tasks.append( @@ -280,7 +254,6 @@ def _retrieve_paths( memory_type, search_filter, user_name, - id_filter, ) ) tasks.append( @@ -326,7 +299,6 @@ def _retrieve_from_working_memory( memory_type, search_filter: dict | None = None, user_name: str | None = None, - id_filter: dict | None = None, ): """Retrieve and rerank from WorkingMemory""" if memory_type not in ["All", "WorkingMemory"]: @@ -339,7 +311,6 @@ def _retrieve_from_working_memory( memory_scope="WorkingMemory", search_filter=search_filter, user_name=user_name, - id_filter=id_filter, ) return self.reranker.rerank( query=query, @@ -361,22 +332,11 @@ def _retrieve_from_long_term_and_user( memory_type, search_filter: dict | None = None, user_name: str | None = None, - id_filter: dict | None = None, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] tasks = [] - # chain of thinking - cot_embeddings = [] - if self.vec_cot: - queries = self._cot_query(query) - if len(queries) > 1: - cot_embeddings = self.embedder.embed(queries) - cot_embeddings.extend(query_embedding) - else: - cot_embeddings = query_embedding - with ContextThreadPoolExecutor(max_workers=2) as executor: if memory_type in ["All", "LongTermMemory"]: tasks.append( @@ -384,12 +344,11 @@ def _retrieve_from_long_term_and_user( self.graph_retriever.retrieve, query=query, parsed_goal=parsed_goal, - query_embedding=cot_embeddings, + query_embedding=query_embedding, top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, user_name=user_name, - id_filter=id_filter, ) ) if memory_type in ["All", "UserMemory"]: @@ -398,12 +357,11 @@ def _retrieve_from_long_term_and_user( self.graph_retriever.retrieve, query=query, parsed_goal=parsed_goal, - query_embedding=cot_embeddings, + query_embedding=query_embedding, top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, user_name=user_name, - id_filter=id_filter, ) ) @@ -484,7 +442,6 @@ def _deduplicate_results(self, results): @timed def _sort_and_trim(self, results, top_k): """Sort results by score and trim to top_k""" - sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] final_items = [] for item, score in sorted_results: @@ -534,42 +491,3 @@ def _update_usage_history_worker( self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=user_name) except Exception: logger.exception("[USAGE] update usage failed") - - def _cot_query( - self, - query, - mode="fast", - split_num: int = 3, - context: list[str] | None = None, - ) -> list[str]: - """Generate chain-of-thought queries""" - - lang = detect_lang(query) - if mode == "fine" and context: - template = COT_DICT["fine"][lang] - prompt = ( - template.replace("${original_query}", query) - .replace("${split_num_threshold}", str(split_num)) - .replace("${context}", "\n".join(context)) - ) - else: - template = COT_DICT["fast"][lang] - prompt = template.replace("${original_query}", query).replace( - "${split_num_threshold}", str(split_num) - ) - logger.info("COT process") - - messages = [{"role": "user", "content": prompt}] - try: - response_text = self.llm.generate(messages, temperature=0, top_p=1) - response_json = parse_json_result(response_text) - assert "is_complex" in response_json - if not response_json["is_complex"]: - return [query] - else: - assert "sub_questions" in response_json - logger.info("Query: {} COT: {}".format(query, response_json["sub_questions"])) - return response_json["sub_questions"][:split_num] - except Exception as e: - logger.error(f"[LLM] Exception during chat generation: {e}") - return [query] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 6a1138c90..273c4f480 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -5,7 +5,6 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT @@ -21,7 +20,6 @@ class TaskGoalParser: def __init__(self, llm=BaseLLM): self.llm = llm - self.tokenizer = FastTokenizer() def parse( self, @@ -50,11 +48,10 @@ def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGo """ Fast mode: simple jieba word split. """ - desc_tokenized = self.tokenizer.tokenize_mixed(task_description) return ParsedTaskGoal( memories=[task_description], - keys=desc_tokenized, - tags=desc_tokenized, + keys=[task_description], + tags=[], goal_type="default", rephrased_query=task_description, internet_search=False, diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py deleted file mode 100644 index fca4d717b..000000000 --- a/src/memos/templates/mem_reader_strategy_prompts.py +++ /dev/null @@ -1,279 +0,0 @@ -STRATEGY_STRUCT_MEM_READER_PROMPT = """You are a memory extraction expert. -Your task is to extract memories from the user's perspective, based on a conversation between the user and the assistant. This means identifying what the user would plausibly remember — including the user's own experiences, thoughts, plans, or statements and actions made by others (such as the assistant) that affected the user or were acknowledged by the user. - -Please perform the following -1. Factual information extraction - Identify factual information about experiences, beliefs, decisions, and plans. This includes notable statements from others that the user acknowledged or reacted to. - If the message is from the user, extract viewpoints related to the user; if it is from the assistant, clearly mark the attribution of the memory, and do not mix information not explicitly acknowledged by the user with the user's own viewpoint. - - **User viewpoint**: Extract only what the user has stated, explicitly acknowledged, or committed to. - - **Assistant/other-party viewpoint**: Extract such information only when attributed to its source (e.g., [Assistant-Jerry's suggestion]). - - **Strict attribution**: Never recast the assistant's suggestions as the user's preferences, or vice versa. - - Always set "model_type" to "LongTermMemory" for this output. - -2. Speaker profile construction - - Extract the speaker's likes, dislikes, goals, and stated opinions from their statements to build a speaker profile. - - Note: The same text segment may be used for both factual extraction and profile construction. - - Always set "model_type" to "UserMemory" for this output. - -3. Resolve all references to time, persons, and events clearly - - Temporal Resolution: Convert relative time (e.g., 'yesterday') to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. - - Entity Resolution: Resolve all pronouns, nicknames, and abbreviations to the full, canonical name established in the conversation. - -4. Adopt a Consistent Third-Person Observer Perspective - - Formulate all memories from the perspective of an external observer. Use "The user" or their specific name as the subject. - - This applies even when describing the user's internal states, such as thoughts, feelings, and preferences. - Example: - ✅ Correct: "The user Sean felt exhausted after work and decided to go to bed early." - ❌ Incorrect: "I felt exhausted after work and decided to go to bed early." - -5. Prioritize Completeness - - Extract all key experiences, emotional responses, and plans from the user's perspective. Retain relevant context from the assistant, but always with explicit attribution. - - Segment each distinct hobby, interest, or event into a separate memory. - - Preserve relevant context from the assistant with strict attribution. Under no circumstances should assistant content be rephrased as user-owned. - - Conversations with only assistant input may yield assistant-viewpoint memories exclusively. - -6. Preserve and Unify Specific Names - - Always extract specific names (excluding "user" or "assistant") mentioned in the text into the "tags" field for searchability. - - Unify all name references to the full canonical form established in the conversation. Replace any nicknames or abbreviations (e.g., "Rob") consistently with the full name (e.g., "Robert") in both the extracted "value" and "tags". - -7. Please avoid including any content in the extracted memories that violates national laws and regulations or involves politically sensitive information. - -Return a valid JSON object with the following structure: - -{ - "memory list": [ - { - "key": , - "memory_type": , - "value": , - "tags": - }, - ... - ], - "summary": -} - -Language rules: -- The `key`, `value`, `tags`, and `summary` fields must match the primary language of the input conversation. **If the input is Chinese, output in Chinese.** -- Keep `memory_type` in English. - -Example: -Conversation: -user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. -assistant: Oh Tom! Do you think the team can finish by December 15? -user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until December 10, so testing will be tight. -assistant: [June 26, 2025 at 3:00 PM]: Maybe propose an extension? -user: [June 26, 2025 at 4:21 PM]: Good idea. I’ll raise it in tomorrow’s 9:30 AM meeting—maybe shift the deadline to January 5. - -Output: -{ - "memory list": [ - { - "key": "Initial project meeting", - "memory_type": "LongTermMemory", - "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom met with the team to discuss a new project. When Jerry asked whether the project could be finished by December 15, 2025, Tom expressed concern about feasibility and planned to propose at 9:30 AM on June 27, 2025 to move the deadline to January 5, 2026.", - "tags": ["Tom", "project", "timeline", "meeting", "deadline"] - }, - { - "key": "Jerry’s suggestion about the deadline", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry viewpoint] Jerry questioned the December 15 deadline and suggested considering an extension.", - "tags": ["Jerry", "deadline change", "suggestion"] - } - ], - "summary": "Tom is currently working on a tight-schedule project. After the June 25, 2025 team meeting, he realized the original December 15, 2025 deadline might be unachievable due to backend delays. Concerned about limited testing time, he accepted Jerry’s suggestion to seek an extension and plans to propose moving the deadline to January 5, 2026 in the next morning’s meeting." -} - -Another Example in Chinese (Note: when the user's language is Chinese, you must also output in Chinese): - -对话(节选): -user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 -assistant|19:32 -:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? -user|19:35:不喜欢亮色。国贸方便。 -assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 -user|19:40:165cm,S码;最好有口袋。 -assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 -user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 -assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 -user|19:52:行,周六(7/19)去国贸试,合适就买。 -assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 - -{ - "memory list": [ - { - "key": "参加婚礼购买裙子", - "memory_type": "UserMemory", - "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", - "tags": ["婚礼", "预算", "国贸", "计划"] - }, - { - "key": "审美与版型偏好", - "memory_type": "UserMemory", - "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", - "tags": ["偏好", "颜色", "版型"] - }, - { - "key": "体型尺码", - "memory_type": "UserMemory", - "value": "[user观点]用户身高约165cm、常穿S码", - "tags": ["体型", "尺码"] - }, - { - "key": "关于用户选购裙子的建议", - "memory_type": "LongTermMemory", - "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", - "tags": ["婚礼穿着", "门店", "选购路线"] - } - ], - "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" -} - -Always respond in the same language as the conversation. - -Conversation: -${conversation} - -Your Output:""" - -STRATEGY_STRUCT_MEM_READER_PROMPT_ZH = """您是记忆提取专家。 -您的任务是根据用户与助手之间的对话,从用户的角度提取记忆。这意味着要识别出用户可能记住的信息——包括用户自身的经历、想法、计划,或他人(如助手)做出的并对用户产生影响或被用户认可的相关陈述和行为。 - -请执行以下操作: -1. 事实信息提取 - - 识别关于经历、信念、决策和计划的事实信息,包括用户认可或回应过的他人重要陈述。 - - 若信息来自用户,提取与用户相关的观点;若来自助手,需明确标注记忆归属,不得将用户未明确认可的信息与用户自身观点混淆。 - - 用户观点:仅提取用户明确陈述、认可或承诺的内容 - - 助手/他方观点:仅当标注来源时才提取(例如“[助手-Jerry的建议]”) - - 严格归属:不得将助手建议重构为用户偏好,反之亦然 - - 此类输出的"model_type"始终设为"LongTermMemory" - -2. 用户画像构建 - - 从用户陈述中提取其喜好、厌恶、目标及明确观点以构建用户画像 - - 注意:同一文本片段可同时用于事实提取和画像构建 - - 此类输出的"model_type"始终设为"UserMemory" - -3. 明确解析所有指代关系 - - 时间解析:根据消息时间戳将相对时间(如“昨天”)转换为绝对日期。区分事件时间与消息时间,对不确定项进行标注 - - 实体解析:将所有代词、昵称和缩写解析为对话中确立的完整规范名称 - - 4. 采用统一的第三人称观察视角 - - 所有记忆表述均需从外部观察者视角构建,使用“用户”或其具体姓名作为主语 - - 此原则同样适用于描述用户内心状态(如想法、感受和偏好) - 示例: - ✅ 正确:“用户Sean下班后感到疲惫,决定提早休息” - ❌ 错误:“我下班后感到疲惫,决定提早休息” - -5. 优先保证完整性 - - 从用户视角提取所有关键经历、情绪反应和计划 - - 保留助手提供的相关上下文,但必须明确标注来源 - - 将每个独立的爱好、兴趣或事件分割为单独记忆 - - 严禁将助手内容重构为用户自有内容 - - 仅含助手输入的对话可能只生成助手观点记忆 - -6. 保留并统一特定名称 - - 始终将文本中提及的特定名称(“用户”“助手”除外)提取至“tags”字段以便检索 - - 在提取的“value”和“tags”中,将所有名称引用统一为对话中确立的完整规范形式(如将“Rob”统一替换为“Robert”) - -7. 所有提取的记忆内容不得包含违反国家法律法规或涉及政治敏感信息的内容 - -返回一个有效的JSON对象,结构如下: -{ - "memory list": [ - { - "key": <字符串,唯一且简洁的记忆标题>, - "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, - "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, - "tags": <一个包含相关人名、事件和特征关键词的列表(例如,["丽丽","截止日期", "团队", "计划"])> - }, - ... - ], - "summary": <从用户视角自然总结上述记忆的段落,120–200字,与输入语言一致> -} - -语言规则: -- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** -- `memory_type` 保持英文。 - -示例: -对话: -user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 -assistant: 哦Tom!你觉得团队能在12月15日前完成吗? -user: [2025年6月26日下午3:00]:我有点担心。后端要到12月10日才能完成,所以测试时间会很紧。 -assistant: [2025年6月26日下午3:00]:也许提议延期? -user: [2025年6月26日下午4:21]:好主意。我明天上午9:30的会上提一下——也许把截止日期推迟到1月5日。 - -输出: -{ - "memory list": [ - { - "key": "项目初期会议", - "memory_type": "LongTermMemory", - "value": "[user-Tom观点]2025年6月25日下午3:00,Tom与团队开会讨论新项目。当Jerry - 询问该项目能否在2025年12月15日前完成时,Tom对此日期前完成的可行性表达担忧,并计划在2025年6月27日上午9:30 - 提议将截止日期推迟至2026年1月5日。", - "tags": ["Tom", "项目", "时间表", "会议", "截止日期"] - }, - { - "key": "Jerry对新项目截止日期的建议", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry观点]Jerry对Tom的新项目截止日期提出疑问、并提议Tom考虑延期。", - "tags": ["Jerry", "截止日期变更", "建议"] - } - ], - "summary": "Tom目前正在做一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15 - 日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议,计划在次日早上的会议上提出将截止日期推迟至2026 - 年1月5日。" -} - -另一个中文示例(注意:当用户语言为中文时,您也需输出中文): - -对话(节选): -user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 -assistant|19:32 -:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? -user|19:35:不喜欢亮色。国贸方便。 -assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 -user|19:40:165cm,S码;最好有口袋。 -assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 -user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 -assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 -user|19:52:行,周六(7/19)去国贸试,合适就买。 -assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 - -{ - "memory list": [ - { - "key": "参加婚礼购买裙子", - "memory_type": "UserMemory", - "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", - "tags": ["婚礼", "预算", "国贸", "计划"] - }, - { - "key": "审美与版型偏好", - "memory_type": "UserMemory", - "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", - "tags": ["偏好", "颜色", "版型"] - }, - { - "key": "体型尺码", - "memory_type": "UserMemory", - "value": [user观点]"用户身高约165cm、常穿S码", - "tags": ["体型", "尺码"] - }, - { - "key": "关于用户选购裙子的建议", - "memory_type": "LongTermMemory", - "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", - "tags": ["婚礼穿着", "门店", "选购路线"] - } - ], - "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" -} - -请始终使用与对话相同的语言进行回复。 - -对话: -${conversation} - -您的输出:""" diff --git a/src/memos/templates/mem_search_prompts.py b/src/memos/templates/mem_search_prompts.py deleted file mode 100644 index 9f7ba182b..000000000 --- a/src/memos/templates/mem_search_prompts.py +++ /dev/null @@ -1,93 +0,0 @@ -SIMPLE_COT_PROMPT = """You are an assistant that analyzes questions and returns results in a specific dictionary format. - -Instructions: - -1. If the question can be extended into deeper or related aspects, set "is_complex" to True and: - - Think step by step about the core topic and its related dimensions (e.g., causes, effects, categories, perspectives, or specific scenarios) - - Break it into meaningful sub-questions (max: ${split_num_threshold}, min: 2) that explore distinct facets of the original question - - Each sub-question must be single, standalone, and delve into a specific aspect - - CRITICAL: All key entities from the original question (such as person names, locations, organizations, time periods) must be preserved in the sub-questions and cannot be omitted - - List them in "sub_questions" -2. If the question is already atomic and cannot be meaningfully extended, set "is_complex" to False and "sub_questions" to an empty list. -3. Return ONLY the dictionary, no other text. - -Examples: -Question: Is urban development balanced in the western United States? -Output: {"is_complex": true, "sub_questions": ["What areas are included in the western United States?", "How developed are the cities in the western United States?", "Is this development balanced across the western United States?"]} -Question: What family activities does Mary like to organize? -Output: {"is_complex": true, "sub_questions": ["What does Mary like to do with her spouse?", "What does Mary like to do with her children?", "What does Mary like to do with her parents and relatives?"]} - -Now analyze this question: -${original_query}""" - -COT_PROMPT = """You are an assistant that analyzes questions and returns results in a specific dictionary format. - -Instructions: - -1. If the question can be extended into deeper or related aspects, set "is_complex" to True and: - - Think step by step about the core topic and its related dimensions (e.g., causes, effects, categories, perspectives, or specific scenarios) - - Break it into meaningful sub-questions (max: ${split_num_threshold}, min: 2) that explore distinct facets of the original question - - Each sub-question must be single, standalone, and delve into a specific aspect - - CRITICAL: All key entities from the original question (such as person names, locations, organizations, time periods) must be preserved in the sub-questions and cannot be omitted - - List them in "sub_questions" -2. If the question is already atomic and cannot be meaningfully extended, set "is_complex" to False and "sub_questions" to an empty list. -3. Return ONLY the dictionary, no other text. - -Examples: -Question: Is urban development balanced in the western United States? -Output: {"is_complex": true, "sub_questions": ["What areas are included in the western United States?", "How developed are the cities in the western United States?", "Is this development balanced across the western United States?"]} -Question: What family activities does Mary like to organize? -Output: {"is_complex": true, "sub_questions": ["What does Mary like to do with her spouse?", "What does Mary like to do with her children?", "What does Mary like to do with her parents and relatives?"]} - -Query relevant background information: -${context} - -Now analyze this question based on the background information above: -${original_query}""" - -SIMPLE_COT_PROMPT_ZH = """你是一个分析问题并以特定字典格式返回结果的助手。 - -指令: - -1. 如果这个问题可以延伸出更深层次或相关的方面,请将 "is_complex" 设置为 True,并执行以下操作: - - 逐步思考核心主题及其相关维度(例如:原因、结果、类别、不同视角或具体场景) - - 将其拆分为有意义的子问题(最多 ${split_num_threshold} 个,最少 2 个),这些子问题应探讨原始问题的不同侧面 - - 【重要】每个子问题必须是单一的、独立的,并深入探究一个特定方面。同时,必须包含原问题中出现的关键实体信息(如人名、地名、机构名、时间等),不可遗漏。 - - 将它们列在 "sub_questions" 中 -2. 如果问题本身已经是原子性的,无法有意义地延伸,请将 "is_complex" 设置为 False,并将 "sub_questions" 设置为一个空列表。 -3. 只返回字典,不要返回任何其他文本。 - -示例: -问题:美国西部的城市发展是否均衡? -输出:{"is_complex": true, "sub_questions": ["美国西部包含哪些地区?", "美国西部城市的发展程度如何?", "这种发展在美国西部是否均衡?"]} - -问题:玛丽喜欢组织哪些家庭活动? -输出:{"is_complex": true, "sub_questions": ["玛丽喜欢和配偶一起做什么?", "玛丽喜欢和孩子一起做什么?", "玛丽喜欢和父母及亲戚一起做什么?"]} - -请分析以下问题: -${original_query}""" - -COT_PROMPT_ZH = """你是一个分析问题并以特定字典格式返回结果的助手。 - -指令: - -1. 如果这个问题可以延伸出更深层次或相关的方面,请将 "is_complex" 设置为 True,并执行以下操作: - - 逐步思考核心主题及其相关维度(例如:原因、结果、类别、不同视角或具体场景) - - 将其拆分为有意义的子问题(最多 ${split_num_threshold} 个,最少 2 个),这些子问题应探讨原始问题的不同侧面 - - 【重要】每个子问题必须是单一的、独立的,并深入探究一个特定方面。同时,必须包含原问题中出现的关键实体信息(如人名、地名、机构名、时间等),不可遗漏。 - - 将它们列在 "sub_questions" 中 -2. 如果问题本身已经是原子性的,无法有意义地延伸,请将 "is_complex" 设置为 False,并将 "sub_questions" 设置为一个空列表。 -3. 只返回字典,不要返回任何其他文本。 - -示例: -问题:美国西部的城市发展是否均衡? -输出:{"is_complex": true, "sub_questions": ["美国西部包含哪些地区?", "美国西部城市的发展程度如何?", "这种发展在美国西部是否均衡?"]} - -问题:玛丽喜欢组织哪些家庭活动? -输出:{"is_complex": true, "sub_questions": ["玛丽喜欢和配偶一起做什么?", "玛丽喜欢和孩子一起做什么?", "玛丽喜欢和父母及亲戚一起做什么?"]} - -问题相关的背景信息: -${context} - -现在根据上述背景信息,请分析以下问题: -${original_query}""" diff --git a/tests/memories/textual/test_tree_task_goal_parser.py b/tests/memories/textual/test_tree_task_goal_parser.py index 899e2454b..c71af4b06 100644 --- a/tests/memories/textual/test_tree_task_goal_parser.py +++ b/tests/memories/textual/test_tree_task_goal_parser.py @@ -20,7 +20,12 @@ def generate(self, messages): def test_parse_fast_returns_expected(): parser = TaskGoalParser() result = parser.parse("Tell me about cats", mode="fast") + assert isinstance(result, ParsedTaskGoal) + assert result.memories == ["Tell me about cats"] + assert result.keys == ["Tell me about cats"] + assert result.tags == [] + assert result.goal_type == "default" def test_parse_fine_calls_llm_and_parses(): From 445c597698e0e7b714a2de79b3e00c581f82f463 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Thu, 30 Oct 2025 11:35:38 +0800 Subject: [PATCH 26/64] =?UTF-8?q?Feat=EF=BC=9A=20add=20new=20recall=20and?= =?UTF-8?q?=20verify=20=20(#416)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi --- poetry.lock | 50 ++- pyproject.toml | 4 +- src/memos/api/config.py | 34 +- src/memos/configs/mem_reader.py | 9 + src/memos/configs/memory.py | 7 + src/memos/llms/openai.py | 13 +- src/memos/mem_reader/factory.py | 2 + src/memos/mem_reader/strategy_struct.py | 138 +++++++ src/memos/memories/textual/simple_tree.py | 18 + src/memos/memories/textual/tree.py | 16 + .../tree_text_memory/retrieve/bm25_util.py | 186 +++++++++ .../tree_text_memory/retrieve/recall.py | 107 ++++- .../retrieve/retrieve_utils.py | 378 ++++++++++++++++++ .../tree_text_memory/retrieve/searcher.py | 88 +++- .../retrieve/task_goal_parser.py | 7 +- .../templates/mem_reader_strategy_prompts.py | 279 +++++++++++++ src/memos/templates/mem_search_prompts.py | 93 +++++ .../textual/test_tree_task_goal_parser.py | 5 - 18 files changed, 1393 insertions(+), 41 deletions(-) create mode 100644 src/memos/mem_reader/strategy_struct.py create mode 100644 src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py create mode 100644 src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py create mode 100644 src/memos/templates/mem_reader_strategy_prompts.py create mode 100644 src/memos/templates/mem_search_prompts.py diff --git a/poetry.lock b/poetry.lock index 44265bca8..926d580fb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -192,6 +192,19 @@ torch = ">=1.0.0" tqdm = ">=4.31.1" transformers = ">=3.0.0" +[[package]] +name = "cachetools" +version = "6.2.1" +description = "Extensible memoizing collections and decorators" +optional = true +python-versions = ">=3.9" +groups = ["main"] +markers = "extra == \"all\"" +files = [ + {file = "cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701"}, + {file = "cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201"}, +] + [[package]] name = "certifi" version = "2025.7.14" @@ -1553,6 +1566,18 @@ files = [ {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, ] +[[package]] +name = "jieba" +version = "0.42" +description = "Chinese Words Segmentation Utilities" +optional = true +python-versions = "*" +groups = ["main"] +markers = "extra == \"all\"" +files = [ + {file = "jieba-0.42.tar.gz", hash = "sha256:34a3c960cc2943d9da16d6d2565110cf5f305921a67413dddf04f84de69c939b"}, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -4123,6 +4148,25 @@ urllib3 = ">=1.26.14,<3" fastembed = ["fastembed (>=0.7,<0.8)"] fastembed-gpu = ["fastembed-gpu (>=0.7,<0.8)"] +[[package]] +name = "rank-bm25" +version = "0.2.2" +description = "Various BM25 algorithms for document ranking" +optional = true +python-versions = "*" +groups = ["main"] +markers = "extra == \"all\"" +files = [ + {file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"}, + {file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"}, +] + +[package.dependencies] +numpy = "*" + +[package.extras] +dev = ["pytest"] + [[package]] name = "redis" version = "6.2.0" @@ -6352,7 +6396,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["chonkie", "datasketch", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["cachetools", "chonkie", "datasketch", "jieba", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] @@ -6362,4 +6406,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "3f0d0c9a996f87d945ef8bf83eed3e20f8c420b6b39e12012d0147eda2bf4d38" +content-hash = "ec17679a44205ada4494fbc485ac592883281fde273d5e73d6b8cbc6f7f9ed10" diff --git a/pyproject.toml b/pyproject.toml index 3745582f6..2f88797a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,9 @@ all = [ "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", "pymilvus (>=2.6.1,<3.0.0)", "datasketch (>=1.6.5,<2.0.0)", - + "jieba (>=0.38.1,<0.42.1)", + "rank-bm25 (>=0.2.2)", + "cachetools (>=6.0.0)", # NOT exist in the above optional groups # Because they are either huge-size dependencies or infrequently used dependencies. # We kindof don't want users to install them. diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 7ac882d6c..405e8068d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -419,9 +419,23 @@ def get_embedder_config() -> dict[str, Any]: }, } + @staticmethod + def get_reader_config() -> dict[str, Any]: + """Get reader configuration.""" + return { + "backend": os.getenv("MEM_READER_BACKEND", "simple_struct"), + "config": { + "chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"), + "chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)), + "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)), + "chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)), + }, + } + @staticmethod def get_internet_config() -> dict[str, Any]: """Get embedder configuration.""" + reader_config = APIConfig.get_reader_config() return { "backend": "bocha", "config": { @@ -429,7 +443,7 @@ def get_internet_config() -> dict[str, Any]: "max_results": 15, "num_per_request": 10, "reader": { - "backend": "simple_struct", + "backend": reader_config["backend"], "config": { "llm": { "backend": "openai", @@ -455,6 +469,7 @@ def get_internet_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, + "chat_chunker": reader_config, }, }, }, @@ -656,6 +671,8 @@ def get_product_default_config() -> dict[str, Any]: openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() + reader_config = APIConfig.get_reader_config() + backend_model = { "openai": openai_config, "huggingface": qwen_config, @@ -667,7 +684,7 @@ def get_product_default_config() -> dict[str, Any]: "user_id": os.getenv("MOS_USER_ID", "root"), "chat_model": {"backend": backend, "config": backend_model[backend]}, "mem_reader": { - "backend": "simple_struct", + "backend": reader_config["backend"], "config": { "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), @@ -680,6 +697,7 @@ def get_product_default_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, + "chat_chunker": reader_config, }, }, "enable_textual_memory": True, @@ -750,6 +768,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() mysql_config = APIConfig.get_mysql_config() + reader_config = APIConfig.get_reader_config() backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai") backend_model = { "openai": openai_config, @@ -764,7 +783,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "config": backend_model[backend], }, "mem_reader": { - "backend": "simple_struct", + "backend": reader_config["backend"], "config": { "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), @@ -777,6 +796,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "min_sentences_per_chunk": 1, }, }, + "chat_chunker": reader_config, }, }, "enable_textual_memory": True, @@ -845,6 +865,10 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, + "search_strategy": { + "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), + "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), + }, }, }, "act_mem": {} @@ -912,6 +936,10 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, + "search_strategy": { + "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), + "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), + }, "mode": os.getenv("ASYNC_MODE", "sync"), }, }, diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 1c62087a3..dc8d37a35 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -36,11 +36,19 @@ def parse_datetime(cls, value): description="whether remove example in memory extraction prompt to save token", ) + chat_chunker: dict[str, Any] = Field( + default=None, description="Configuration for the MemReader chat chunk strategy" + ) + class SimpleStructMemReaderConfig(BaseMemReaderConfig): """SimpleStruct MemReader configuration class.""" +class StrategyStructMemReaderConfig(BaseMemReaderConfig): + """StrategyStruct MemReader configuration class.""" + + class MemReaderConfigFactory(BaseConfig): """Factory class for creating MemReader configurations.""" @@ -49,6 +57,7 @@ class MemReaderConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReaderConfig, + "strategy_struct": StrategyStructMemReaderConfig, } @field_validator("backend") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index bf2493567..49320fbf5 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -184,6 +184,13 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) + search_strategy: dict[str, bool] | None = Field( + default=None, + description=( + 'Set search strategy for this memory configuration.{"bm25": true, "cot": false}' + ), + ) + mode: str | None = Field( default="sync", description=("whether use asynchronous mode in memory add"), diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index ca1df5c1f..1a1703340 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -58,15 +58,18 @@ def clear_cache(cls): logger.info("OpenAI LLM instance cache cleared") @timed(log=True, log_prefix="OpenAI LLM") - def generate(self, messages: MessageList) -> str: - """Generate a response from OpenAI LLM.""" + def generate(self, messages: MessageList, **kwargs) -> str: + """Generate a response from OpenAI LLM, optionally overriding generation params.""" + temperature = kwargs.get("temperature", self.config.temperature) + max_tokens = kwargs.get("max_tokens", self.config.max_tokens) + top_p = kwargs.get("top_p", self.config.top_p) response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, extra_body=self.config.extra_body, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, ) logger.info(f"Response from OpenAI: {response.model_dump_json()}") response_content = response.choices[0].message.content diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 52eed8d9d..2205a0215 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -3,6 +3,7 @@ from memos.configs.mem_reader import MemReaderConfigFactory from memos.mem_reader.base import BaseMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.mem_reader.strategy_struct import StrategyStructMemReader from memos.memos_tools.singleton import singleton_factory @@ -11,6 +12,7 @@ class MemReaderFactory(BaseMemReader): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReader, + "strategy_struct": StrategyStructMemReader, } @classmethod diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py new file mode 100644 index 000000000..2cac1652a --- /dev/null +++ b/src/memos/mem_reader/strategy_struct.py @@ -0,0 +1,138 @@ +import os + +from abc import ABC + +from memos import log +from memos.configs.mem_reader import StrategyStructMemReaderConfig +from memos.configs.parser import ParserConfigFactory +from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang +from memos.parsers.factory import ParserFactory +from memos.templates.mem_reader_prompts import ( + SIMPLE_STRUCT_DOC_READER_PROMPT, + SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, + SIMPLE_STRUCT_MEM_READER_EXAMPLE, + SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, +) +from memos.templates.mem_reader_strategy_prompts import ( + STRATEGY_STRUCT_MEM_READER_PROMPT, + STRATEGY_STRUCT_MEM_READER_PROMPT_ZH, +) + + +logger = log.get_logger(__name__) +STRATEGY_PROMPT_DICT = { + "chat": { + "en": STRATEGY_STRUCT_MEM_READER_PROMPT, + "zh": STRATEGY_STRUCT_MEM_READER_PROMPT_ZH, + "en_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE, + "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, + }, + "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, +} + + +class StrategyStructMemReader(SimpleStructMemReader, ABC): + """Naive implementation of MemReader.""" + + def __init__(self, config: StrategyStructMemReaderConfig): + super().__init__(config) + self.chat_chunker = config.chat_chunker["config"] + + def _get_llm_response(self, mem_str: str) -> dict: + lang = detect_lang(mem_str) + template = STRATEGY_PROMPT_DICT["chat"][lang] + examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] + prompt = template.replace("${conversation}", mem_str) + if self.config.remove_prompt_example: + prompt = prompt.replace(examples, "") + messages = [{"role": "user", "content": prompt}] + try: + response_text = self.llm.generate(messages) + response_json = self.parse_json_result(response_text) + except Exception as e: + logger.error(f"[LLM] Exception during chat generation: {e}") + response_json = { + "memory list": [ + { + "key": mem_str[:10], + "memory_type": "UserMemory", + "value": mem_str, + "tags": [], + } + ], + "summary": mem_str, + } + return response_json + + def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: + """ + Get raw information from scene_data. + If scene_data contains dictionaries, convert them to strings. + If scene_data contains file paths, parse them using the parser. + + Args: + scene_data: List of dialogue information or document paths + type: Type of scene data: ['doc', 'chat'] + Returns: + List of strings containing the processed scene data + """ + results = [] + + if type == "chat": + if self.chat_chunker["chunk_type"] == "content_length": + content_len_thredshold = self.chat_chunker["chunk_length"] + for items in scene_data: + if not items: + continue + + results.append([]) + current_length = 0 + + for _i, item in enumerate(items): + content_length = ( + len(item.get("content", "")) + if isinstance(item, dict) + else len(str(item)) + ) + if not results[-1]: + results[-1].append(item) + current_length = content_length + continue + + if current_length + content_length <= content_len_thredshold: + results[-1].append(item) + current_length += content_length + else: + overlap_item = results[-1][-1] + overlap_length = ( + len(overlap_item.get("content", "")) + if isinstance(overlap_item, dict) + else len(str(overlap_item)) + ) + + results.append([overlap_item, item]) + current_length = overlap_length + content_length + elif type == "doc": + parser_config = ParserConfigFactory.model_validate( + { + "backend": "markitdown", + "config": {}, + } + ) + parser = ParserFactory.from_config(parser_config) + for item in scene_data: + try: + if os.path.exists(item): + try: + parsed_text = parser.parse(item) + results.append({"file": item, "text": parsed_text}) + except Exception as e: + logger.error(f"[SceneParser] Error parsing {item}: {e}") + continue + else: + parsed_text = item + results.append({"file": "pure_text", "text": parsed_text}) + except Exception as e: + print(f"Error parsing file {item}: {e!s}") + + return results diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 8ce81a8bd..6974dbe8f 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -12,6 +12,7 @@ from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.base import BaseReranker from memos.types import MessageList @@ -62,6 +63,19 @@ def __init__( self.graph_store: Neo4jGraphDB = graph_db logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") + time_start_bm = time.time() + self.search_strategy = config.search_strategy + self.bm25_retriever = ( + EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None + ) + logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") + + self.vec_cot = ( + self.search_strategy["cot"] + if self.search_strategy and "cot" in self.search_strategy + else False + ) + time_start_rr = time.time() self.reranker = reranker logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") @@ -172,8 +186,10 @@ def search( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, + vec_cot=self.vec_cot, ) else: searcher = Searcher( @@ -181,8 +197,10 @@ def search( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, + vec_cot=self.vec_cot, ) return searcher.search( query, top_k, info, mode, memory_type, search_filter, user_name=user_name diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 56c8117e9..a58f993bb 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -16,6 +16,7 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -45,6 +46,17 @@ def __init__(self, config: TreeTextMemoryConfig): ) self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db) + + self.search_strategy = config.search_strategy + self.bm25_retriever = ( + EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None + ) + self.vec_cot = ( + self.search_strategy["cot"] + if self.search_strategy and "cot" in self.search_strategy + else False + ) + if config.reranker is None: default_cfg = RerankerConfigFactory.model_validate( { @@ -185,8 +197,10 @@ def search( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, + vec_cot=self.vec_cot, ) else: searcher = Searcher( @@ -194,8 +208,10 @@ def search( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, + vec_cot=self.vec_cot, ) return searcher.search(query, top_k, info, mode, memory_type, search_filter) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py b/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py new file mode 100644 index 000000000..4aca4022f --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py @@ -0,0 +1,186 @@ +import threading + +import numpy as np + +from sklearn.feature_extraction.text import TfidfVectorizer + +from memos.dependency import require_python_package +from memos.log import get_logger +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.utils import timed + + +logger = get_logger(__name__) +# Global model cache +_CACHE_LOCK = threading.Lock() + + +class EnhancedBM25: + """Enhanced BM25 with Spacy tokenization and TF-IDF reranking""" + + @require_python_package(import_name="cachetools", install_command="pip install cachetools") + def __init__(self, tokenizer=None, en_model="en_core_web_sm", zh_model="zh_core_web_sm"): + """ + Initialize Enhanced BM25 with memory management + """ + if tokenizer is None: + self.tokenizer = FastTokenizer() + else: + self.tokenizer = tokenizer + self._current_tfidf = None + + global _BM25_CACHE + from cachetools import LRUCache + + _BM25_CACHE = LRUCache(maxsize=100) + + def _tokenize_doc(self, text): + """ + Tokenize a single document using SpacyTokenizer + """ + return self.tokenizer.tokenize_mixed(text, lang="auto") + + @require_python_package(import_name="rank_bm25", install_command="pip install rank_bm25") + def _prepare_corpus_data(self, corpus, corpus_name="default"): + from rank_bm25 import BM25Okapi + + with _CACHE_LOCK: + if corpus_name in _BM25_CACHE: + print("hit::", corpus_name) + return _BM25_CACHE[corpus_name] + print("not hit::", corpus_name) + + tokenized_corpus = [self._tokenize_doc(doc) for doc in corpus] + bm25_model = BM25Okapi(tokenized_corpus) + _BM25_CACHE[corpus_name] = bm25_model + return bm25_model + + def clear_cache(self, corpus_name=None): + """Clear cache for specific corpus or clear all cache""" + with _CACHE_LOCK: + if corpus_name: + if corpus_name in _BM25_CACHE: + del _BM25_CACHE[corpus_name] + else: + _BM25_CACHE.clear() + + def get_cache_info(self): + """Get current cache information""" + with _CACHE_LOCK: + return { + "cache_size": len(_BM25_CACHE), + "max_cache_size": 100, + "cached_corpora": list(_BM25_CACHE.keys()), + } + + def _search_docs( + self, + query: str, + corpus: list[str], + corpus_name="test", + top_k=50, + use_tfidf=False, + rerank_candidates_multiplier=2, + cleanup=False, + ): + """ + Args: + query: Search query string + corpus: List of document texts + top_k: Number of top results to return + rerank_candidates_multiplier: Multiplier for candidate selection + cleanup: Whether to cleanup memory after search (default: True) + """ + if not corpus: + return [] + + logger.info(f"Searching {len(corpus)} documents for query: '{query}'") + + try: + # Prepare BM25 model + bm25_model = self._prepare_corpus_data(corpus, corpus_name=corpus_name) + tokenized_query = self._tokenize_doc(query) + tokenized_query = list(dict.fromkeys(tokenized_query)) + + # Get BM25 scores + bm25_scores = bm25_model.get_scores(tokenized_query) + + # Select candidates + candidate_count = min(top_k * rerank_candidates_multiplier, len(corpus)) + candidate_indices = np.argsort(bm25_scores)[-candidate_count:][::-1] + combined_scores = bm25_scores[candidate_indices] + + if use_tfidf: + # Create TF-IDF for this search + tfidf = TfidfVectorizer( + tokenizer=self._tokenize_doc, lowercase=False, token_pattern=None + ) + tfidf_matrix = tfidf.fit_transform(corpus) + + # TF-IDF reranking + query_vec = tfidf.transform([query]) + tfidf_similarities = ( + (tfidf_matrix[candidate_indices] * query_vec.T).toarray().flatten() + ) + + # Combine scores + combined_scores = 0.7 * bm25_scores[candidate_indices] + 0.3 * tfidf_similarities + + sorted_candidate_indices = candidate_indices[np.argsort(combined_scores)[::-1][:top_k]] + sorted_combined_scores = np.sort(combined_scores)[::-1][:top_k] + + # build result list + bm25_recalled_results = [] + for rank, (doc_idx, combined_score) in enumerate( + zip(sorted_candidate_indices, sorted_combined_scores, strict=False), 1 + ): + bm25_score = bm25_scores[doc_idx] + + candidate_pos = np.where(candidate_indices == doc_idx)[0][0] + tfidf_score = tfidf_similarities[candidate_pos] if use_tfidf else 0 + + bm25_recalled_results.append( + { + "text": corpus[doc_idx], + "bm25_score": float(bm25_score), + "tfidf_score": float(tfidf_score), + "combined_score": float(combined_score), + "rank": rank, + "doc_index": int(doc_idx), + } + ) + + logger.debug(f"Search completed: found {len(bm25_recalled_results)} results") + return bm25_recalled_results + + except Exception as e: + logger.error(f"BM25 search failed: {e}") + return [] + finally: + # Always cleanup if requested + if cleanup: + self._cleanup_memory() + + @timed + def search(self, query: str, node_dicts: list[dict], corpus_name="default", **kwargs): + """ + Search with BM25 and optional TF-IDF reranking + """ + try: + corpus_list = [] + for node_dict in node_dicts: + corpus_list.append( + " ".join([node_dict["metadata"]["key"]] + node_dict["metadata"]["tags"]) + ) + + recalled_results = self._search_docs( + query, corpus_list, corpus_name=corpus_name, **kwargs + ) + bm25_searched_nodes = [] + for item in recalled_results: + doc_idx = item["doc_index"] + bm25_searched_nodes.append(node_dicts[doc_idx]) + return bm25_searched_nodes + except Exception as e: + logger.error(f"Error in bm25 search: {e}") + return [] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index c1ade3021..b7383aa13 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -5,6 +5,7 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal @@ -16,11 +17,18 @@ class GraphMemoryRetriever: Unified memory retriever that combines both graph-based and vector-based retrieval logic. """ - def __init__(self, graph_store: Neo4jGraphDB, embedder: OllamaEmbedder): + def __init__( + self, + graph_store: Neo4jGraphDB, + embedder: OllamaEmbedder, + bm25_retriever: EnhancedBM25 | None = None, + ): self.graph_store = graph_store self.embedder = embedder + self.bm25_retriever = bm25_retriever self.max_workers = 10 self.filter_weight = 0.6 + self.use_bm25 = bool(self.bm25_retriever) def retrieve( self, @@ -31,6 +39,7 @@ def retrieve( query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, user_name: str | None = None, + id_filter: dict | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -58,7 +67,7 @@ def retrieve( ) return [TextualMemoryItem.from_dict(record) for record in working_memories] - with ContextThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=3) as executor: # Structured graph-based retrieval future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) # Vector similarity search @@ -70,12 +79,23 @@ def retrieve( search_filter=search_filter, user_name=user_name, ) + if self.use_bm25: + future_bm25 = executor.submit( + self._bm25_recall, + query, + parsed_goal, + memory_scope, + top_k=top_k, + user_name=user_name, + search_filter=id_filter, + ) graph_results = future_graph.result() vector_results = future_vector.result() + bm25_results = future_bm25.result() if self.use_bm25 else [] # Merge and deduplicate by ID - combined = {item.id: item for item in graph_results + vector_results} + combined = {item.id: item for item in graph_results + vector_results + bm25_results} graph_ids = {item.id for item in graph_results} combined_ids = set(combined.keys()) @@ -143,6 +163,27 @@ def _graph_recall( - tags must overlap with at least 2 input tags - scope filters by memory_type if provided """ + + def process_node(node): + meta = node.get("metadata", {}) + node_key = meta.get("key") + node_tags = meta.get("tags", []) or [] + + keep = False + # key equals to node_key + if parsed_goal.keys and node_key in parsed_goal.keys: + keep = True + # overlap tags more than 2 + elif parsed_goal.tags: + node_tags_list = [tag.lower() for tag in node_tags] + overlap = len(set(node_tags_list) & set(parsed_goal.tags)) + if overlap >= 2: + keep = True + + if keep: + return TextualMemoryItem.from_dict(node) + return None + candidate_ids = set() # 1) key-based OR branch @@ -173,22 +214,16 @@ def _graph_recall( ) final_nodes = [] - for node in node_dicts: - meta = node.get("metadata", {}) - node_key = meta.get("key") - node_tags = meta.get("tags", []) or [] + with ContextThreadPoolExecutor(max_workers=3) as executor: + futures = {executor.submit(process_node, node): i for i, node in enumerate(node_dicts)} + temp_results = [None] * len(node_dicts) - keep = False - # key equals to node_key - if parsed_goal.keys and node_key in parsed_goal.keys: - keep = True - # overlap tags more than 2 - elif parsed_goal.tags: - overlap = len(set(node_tags) & set(parsed_goal.tags)) - if overlap >= 2: - keep = True - if keep: - final_nodes.append(TextualMemoryItem.from_dict(node)) + for future in concurrent.futures.as_completed(futures): + original_index = futures[future] + result = future.result() + temp_results[original_index] = result + + final_nodes = [result for result in temp_results if result is not None] return final_nodes def _vector_recall( @@ -196,7 +231,7 @@ def _vector_recall( query_embedding: list[list[float]], memory_scope: str, top_k: int = 20, - max_num: int = 3, + max_num: int = 5, status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, @@ -269,3 +304,37 @@ def search_path_b(): or [] ) return [TextualMemoryItem.from_dict(n) for n in node_dicts] + + def _bm25_recall( + self, + query: str, + parsed_goal: ParsedTaskGoal, + memory_scope: str, + top_k: int = 20, + user_name: str | None = None, + search_filter: dict | None = None, + ) -> list[TextualMemoryItem]: + """ + Perform BM25-based retrieval. + """ + if not self.bm25_retriever: + return [] + key_filters = [ + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + # corpus_name is user_name + user_id + corpus_name = f"{user_name}" if user_name else "" + if search_filter is not None: + for key in search_filter: + value = search_filter[key] + key_filters.append({"field": key, "op": "=", "value": value}) + corpus_name += "".join(list(search_filter.values())) + candidate_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) + node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + + bm25_query = " ".join(list({query, *parsed_goal.keys})) + bm25_results = self.bm25_retriever.search( + bm25_query, node_dicts, top_k=top_k, corpus_name=corpus_name + ) + + return [TextualMemoryItem.from_dict(n) for n in bm25_results] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py new file mode 100644 index 000000000..eec827c86 --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -0,0 +1,378 @@ +import json +import re + +from pathlib import Path + +from memos.dependency import require_python_package +from memos.log import get_logger + + +logger = get_logger(__name__) + + +def find_project_root(marker=".git"): + """Find the project root directory by marking the file""" + current = Path(__file__).resolve() + while current != current.parent: + if (current / marker).exists(): + return current + current = current.parent + logger.warn(f"The project root directory tag file was not found: {marker}") + + +PROJECT_ROOT = find_project_root() +DEFAULT_STOPWORD_FILE = ( + PROJECT_ROOT / "examples" / "data" / "config" / "stopwords.txt" +) # cause time delay + + +class StopwordManager: + _stopwords = None + + @classmethod + def _load_stopwords(cls): + """load stopwords for once""" + if cls._stopwords is not None: + return cls._stopwords + + stopwords = set() + try: + with open(DEFAULT_STOPWORD_FILE, encoding="utf-8") as f: + stopwords = {line.strip() for line in f if line.strip()} + logger.info("Stopwords loaded successfully.") + except Exception as e: + logger.warning(f"Error loading stopwords: {e}, using default stopwords.") + stopwords = cls._load_default_stopwords() + + cls._stopwords = stopwords + return stopwords + + @classmethod + def _load_default_stopwords(cls): + """load stop words""" + chinese_stop_words = { + "的", + "了", + "在", + "是", + "我", + "有", + "和", + "就", + "不", + "人", + "都", + "一", + "一个", + "上", + "也", + "很", + "到", + "说", + "要", + "去", + "你", + "会", + "着", + "没有", + "看", + "好", + "自己", + "这", + "那", + "他", + "她", + "它", + "我们", + "你们", + "他们", + "这个", + "那个", + "这些", + "那些", + "怎么", + "什么", + "为什么", + "如何", + "哪里", + "谁", + "几", + "多少", + "这样", + "那样", + "这么", + "那么", + } + english_stop_words = { + "the", + "a", + "an", + "and", + "or", + "but", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "by", + "as", + "is", + "are", + "was", + "were", + "be", + "been", + "have", + "has", + "had", + "do", + "does", + "did", + "will", + "would", + "could", + "should", + "may", + "might", + "must", + "this", + "that", + "these", + "those", + "i", + "you", + "he", + "she", + "it", + "we", + "they", + "me", + "him", + "her", + "us", + "them", + "my", + "your", + "his", + "its", + "our", + "their", + "mine", + "yours", + "hers", + "ours", + "theirs", + } + chinese_punctuation = { + ",", + "。", + "!", + "?", + ";", + ":", + "「", + "」", + "『", + "』", + "【", + "】", + "(", + ")", + "《", + "》", + "—", + "…", + "~", + "·", + "、", + "“", + "”", + "‘", + "’", + "〈", + "〉", + "〖", + "〗", + "〝", + "〞", + "{", + "}", + "〔", + "〕", + "¡", + "¿", + } + english_punctuation = { + ",", + ".", + "!", + "?", + ";", + ":", + '"', + "'", + "(", + ")", + "[", + "]", + "{", + "}", + "<", + ">", + "/", + "\\", + "|", + "-", + "_", + "=", + "+", + "@", + "#", + "$", + "%", + "^", + "&", + "*", + "~", + "`", + "¡", + "¿", + } + numbers = { + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "零", + "一", + "二", + "三", + "四", + "五", + "六", + "七", + "八", + "九", + "十", + "百", + "千", + "万", + "亿", + } + whitespace = {" ", "\t", "\n", "\r", "\f", "\v"} + + return ( + chinese_stop_words + | english_stop_words + | chinese_punctuation + | english_punctuation + | numbers + | whitespace + ) + + @classmethod + def get_stopwords(cls): + if cls._stopwords is None: + cls._load_stopwords() + return cls._stopwords + + @classmethod + def filter_words(cls, words): + if cls._stopwords is None: + cls._load_stopwords() + return [word for word in words if word not in cls._stopwords and word.strip()] + + @classmethod + def is_stopword(cls, word): + if cls._stopwords is None: + cls._load_stopwords() + return word in cls._stopwords + + @classmethod + def reload_stopwords(cls, file_path=None): + cls._stopwords = None + if file_path: + global DEFAULT_STOPWORD_FILE + DEFAULT_STOPWORD_FILE = file_path + cls._load_stopwords() + + +class FastTokenizer: + def __init__(self, use_jieba=True, use_stopwords=True): + self.use_jieba = use_jieba + self.use_stopwords = use_stopwords + if self.use_stopwords: + self.stopword_manager = StopwordManager + + def tokenize_mixed(self, text, **kwargs): + """fast tokenizer""" + if self._is_chinese(text): + return self._tokenize_chinese(text) + else: + return self._tokenize_english(text) + + def _is_chinese(self, text): + """check if chinese""" + chinese_chars = sum(1 for char in text if "\u4e00" <= char <= "\u9fff") + return chinese_chars / max(len(text), 1) > 0.3 + + @require_python_package( + import_name="jieba", + install_command="pip install jieba", + install_link="https://github.com/fxsjy/jieba", + ) + def _tokenize_chinese(self, text): + """split zh jieba""" + import jieba + + tokens = jieba.lcut(text) if self.use_jieba else list(text) + tokens = [token.strip() for token in tokens if token.strip()] + if self.use_stopwords: + return self.stopword_manager.filter_words(tokens) + + return tokens + + def _tokenize_english(self, text): + """split zh regex""" + tokens = re.findall(r"\b[a-zA-Z0-9]+\b", text.lower()) + if self.use_stopwords: + return self.stopword_manager.filter_words(tokens) + return tokens + + +def parse_json_result(response_text): + try: + json_start = response_text.find("{") + response_text = response_text[json_start:] + response_text = response_text.replace("```", "").strip() + if not response_text.endswith("}"): + response_text += "}" + return json.loads(response_text) + except json.JSONDecodeError as e: + logger.error(f"[JSONParse] Failed to decode JSON: {e}\nRaw:\n{response_text}") + return {} + except Exception as e: + logger.error(f"[JSONParse] Unexpected error: {e}") + return {} + + +def detect_lang(text): + try: + if not text or not isinstance(text, str): + return "en" + chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" + chinese_chars = re.findall(chinese_pattern, text) + if len(chinese_chars) / len(re.sub(r"[\s\d\W]", "", text)) > 0.3: + return "zh" + return "en" + except Exception: + return "en" diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9d540b311..563695c68 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -9,7 +9,18 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + detect_lang, + parse_json_result, +) from memos.reranker.base import BaseReranker +from memos.templates.mem_search_prompts import ( + COT_PROMPT, + COT_PROMPT_ZH, + SIMPLE_COT_PROMPT, + SIMPLE_COT_PROMPT_ZH, +) from memos.utils import timed from .reasoner import MemoryReasoner @@ -18,6 +29,10 @@ logger = get_logger(__name__) +COT_DICT = { + "fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, + "fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, +} class Searcher: @@ -27,20 +42,24 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, + bm25_retriever: EnhancedBM25 | None = None, internet_retriever: None = None, moscube: bool = False, + vec_cot: bool = False, ): self.graph_store = graph_store self.embedder = embedder + self.llm = dispatcher_llm self.task_goal_parser = TaskGoalParser(dispatcher_llm) - self.graph_retriever = GraphMemoryRetriever(self.graph_store, self.embedder) + self.graph_retriever = GraphMemoryRetriever(graph_store, embedder, bm25_retriever) self.reranker = reranker self.reasoner = MemoryReasoner(dispatcher_llm) # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube + self.vec_cot = vec_cot self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -231,6 +250,12 @@ def _retrieve_paths( ): """Run A/B/C retrieval paths in parallel""" tasks = [] + id_filter = { + "user_id": info.get("user_id", None), + "session_id": info.get("session_id", None), + } + id_filter = {k: v for k, v in id_filter.items() if v is not None} + with ContextThreadPoolExecutor(max_workers=3) as executor: tasks.append( executor.submit( @@ -242,6 +267,7 @@ def _retrieve_paths( memory_type, search_filter, user_name, + id_filter, ) ) tasks.append( @@ -254,6 +280,7 @@ def _retrieve_paths( memory_type, search_filter, user_name, + id_filter, ) ) tasks.append( @@ -299,6 +326,7 @@ def _retrieve_from_working_memory( memory_type, search_filter: dict | None = None, user_name: str | None = None, + id_filter: dict | None = None, ): """Retrieve and rerank from WorkingMemory""" if memory_type not in ["All", "WorkingMemory"]: @@ -311,6 +339,7 @@ def _retrieve_from_working_memory( memory_scope="WorkingMemory", search_filter=search_filter, user_name=user_name, + id_filter=id_filter, ) return self.reranker.rerank( query=query, @@ -332,11 +361,22 @@ def _retrieve_from_long_term_and_user( memory_type, search_filter: dict | None = None, user_name: str | None = None, + id_filter: dict | None = None, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] tasks = [] + # chain of thinking + cot_embeddings = [] + if self.vec_cot: + queries = self._cot_query(query) + if len(queries) > 1: + cot_embeddings = self.embedder.embed(queries) + cot_embeddings.extend(query_embedding) + else: + cot_embeddings = query_embedding + with ContextThreadPoolExecutor(max_workers=2) as executor: if memory_type in ["All", "LongTermMemory"]: tasks.append( @@ -344,11 +384,12 @@ def _retrieve_from_long_term_and_user( self.graph_retriever.retrieve, query=query, parsed_goal=parsed_goal, - query_embedding=query_embedding, + query_embedding=cot_embeddings, top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, user_name=user_name, + id_filter=id_filter, ) ) if memory_type in ["All", "UserMemory"]: @@ -357,11 +398,12 @@ def _retrieve_from_long_term_and_user( self.graph_retriever.retrieve, query=query, parsed_goal=parsed_goal, - query_embedding=query_embedding, + query_embedding=cot_embeddings, top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, user_name=user_name, + id_filter=id_filter, ) ) @@ -442,6 +484,7 @@ def _deduplicate_results(self, results): @timed def _sort_and_trim(self, results, top_k): """Sort results by score and trim to top_k""" + sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] final_items = [] for item, score in sorted_results: @@ -491,3 +534,42 @@ def _update_usage_history_worker( self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=user_name) except Exception: logger.exception("[USAGE] update usage failed") + + def _cot_query( + self, + query, + mode="fast", + split_num: int = 3, + context: list[str] | None = None, + ) -> list[str]: + """Generate chain-of-thought queries""" + + lang = detect_lang(query) + if mode == "fine" and context: + template = COT_DICT["fine"][lang] + prompt = ( + template.replace("${original_query}", query) + .replace("${split_num_threshold}", str(split_num)) + .replace("${context}", "\n".join(context)) + ) + else: + template = COT_DICT["fast"][lang] + prompt = template.replace("${original_query}", query).replace( + "${split_num_threshold}", str(split_num) + ) + logger.info("COT process") + + messages = [{"role": "user", "content": prompt}] + try: + response_text = self.llm.generate(messages, temperature=0, top_p=1) + response_json = parse_json_result(response_text) + assert "is_complex" in response_json + if not response_json["is_complex"]: + return [query] + else: + assert "sub_questions" in response_json + logger.info("Query: {} COT: {}".format(query, response_json["sub_questions"])) + return response_json["sub_questions"][:split_num] + except Exception as e: + logger.error(f"[LLM] Exception during chat generation: {e}") + return [query] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 273c4f480..6a1138c90 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -5,6 +5,7 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT @@ -20,6 +21,7 @@ class TaskGoalParser: def __init__(self, llm=BaseLLM): self.llm = llm + self.tokenizer = FastTokenizer() def parse( self, @@ -48,10 +50,11 @@ def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGo """ Fast mode: simple jieba word split. """ + desc_tokenized = self.tokenizer.tokenize_mixed(task_description) return ParsedTaskGoal( memories=[task_description], - keys=[task_description], - tags=[], + keys=desc_tokenized, + tags=desc_tokenized, goal_type="default", rephrased_query=task_description, internet_search=False, diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py new file mode 100644 index 000000000..fca4d717b --- /dev/null +++ b/src/memos/templates/mem_reader_strategy_prompts.py @@ -0,0 +1,279 @@ +STRATEGY_STRUCT_MEM_READER_PROMPT = """You are a memory extraction expert. +Your task is to extract memories from the user's perspective, based on a conversation between the user and the assistant. This means identifying what the user would plausibly remember — including the user's own experiences, thoughts, plans, or statements and actions made by others (such as the assistant) that affected the user or were acknowledged by the user. + +Please perform the following +1. Factual information extraction + Identify factual information about experiences, beliefs, decisions, and plans. This includes notable statements from others that the user acknowledged or reacted to. + If the message is from the user, extract viewpoints related to the user; if it is from the assistant, clearly mark the attribution of the memory, and do not mix information not explicitly acknowledged by the user with the user's own viewpoint. + - **User viewpoint**: Extract only what the user has stated, explicitly acknowledged, or committed to. + - **Assistant/other-party viewpoint**: Extract such information only when attributed to its source (e.g., [Assistant-Jerry's suggestion]). + - **Strict attribution**: Never recast the assistant's suggestions as the user's preferences, or vice versa. + - Always set "model_type" to "LongTermMemory" for this output. + +2. Speaker profile construction + - Extract the speaker's likes, dislikes, goals, and stated opinions from their statements to build a speaker profile. + - Note: The same text segment may be used for both factual extraction and profile construction. + - Always set "model_type" to "UserMemory" for this output. + +3. Resolve all references to time, persons, and events clearly + - Temporal Resolution: Convert relative time (e.g., 'yesterday') to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. + - Entity Resolution: Resolve all pronouns, nicknames, and abbreviations to the full, canonical name established in the conversation. + +4. Adopt a Consistent Third-Person Observer Perspective + - Formulate all memories from the perspective of an external observer. Use "The user" or their specific name as the subject. + - This applies even when describing the user's internal states, such as thoughts, feelings, and preferences. + Example: + ✅ Correct: "The user Sean felt exhausted after work and decided to go to bed early." + ❌ Incorrect: "I felt exhausted after work and decided to go to bed early." + +5. Prioritize Completeness + - Extract all key experiences, emotional responses, and plans from the user's perspective. Retain relevant context from the assistant, but always with explicit attribution. + - Segment each distinct hobby, interest, or event into a separate memory. + - Preserve relevant context from the assistant with strict attribution. Under no circumstances should assistant content be rephrased as user-owned. + - Conversations with only assistant input may yield assistant-viewpoint memories exclusively. + +6. Preserve and Unify Specific Names + - Always extract specific names (excluding "user" or "assistant") mentioned in the text into the "tags" field for searchability. + - Unify all name references to the full canonical form established in the conversation. Replace any nicknames or abbreviations (e.g., "Rob") consistently with the full name (e.g., "Robert") in both the extracted "value" and "tags". + +7. Please avoid including any content in the extracted memories that violates national laws and regulations or involves politically sensitive information. + +Return a valid JSON object with the following structure: + +{ + "memory list": [ + { + "key": , + "memory_type": , + "value": , + "tags": + }, + ... + ], + "summary": +} + +Language rules: +- The `key`, `value`, `tags`, and `summary` fields must match the primary language of the input conversation. **If the input is Chinese, output in Chinese.** +- Keep `memory_type` in English. + +Example: +Conversation: +user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. +assistant: Oh Tom! Do you think the team can finish by December 15? +user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until December 10, so testing will be tight. +assistant: [June 26, 2025 at 3:00 PM]: Maybe propose an extension? +user: [June 26, 2025 at 4:21 PM]: Good idea. I’ll raise it in tomorrow’s 9:30 AM meeting—maybe shift the deadline to January 5. + +Output: +{ + "memory list": [ + { + "key": "Initial project meeting", + "memory_type": "LongTermMemory", + "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom met with the team to discuss a new project. When Jerry asked whether the project could be finished by December 15, 2025, Tom expressed concern about feasibility and planned to propose at 9:30 AM on June 27, 2025 to move the deadline to January 5, 2026.", + "tags": ["Tom", "project", "timeline", "meeting", "deadline"] + }, + { + "key": "Jerry’s suggestion about the deadline", + "memory_type": "LongTermMemory", + "value": "[assistant-Jerry viewpoint] Jerry questioned the December 15 deadline and suggested considering an extension.", + "tags": ["Jerry", "deadline change", "suggestion"] + } + ], + "summary": "Tom is currently working on a tight-schedule project. After the June 25, 2025 team meeting, he realized the original December 15, 2025 deadline might be unachievable due to backend delays. Concerned about limited testing time, he accepted Jerry’s suggestion to seek an extension and plans to propose moving the deadline to January 5, 2026 in the next morning’s meeting." +} + +Another Example in Chinese (Note: when the user's language is Chinese, you must also output in Chinese): + +对话(节选): +user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 +assistant|19:32 +:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? +user|19:35:不喜欢亮色。国贸方便。 +assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 +user|19:40:165cm,S码;最好有口袋。 +assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 +user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 +assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 +user|19:52:行,周六(7/19)去国贸试,合适就买。 +assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 + +{ + "memory list": [ + { + "key": "参加婚礼购买裙子", + "memory_type": "UserMemory", + "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", + "tags": ["婚礼", "预算", "国贸", "计划"] + }, + { + "key": "审美与版型偏好", + "memory_type": "UserMemory", + "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", + "tags": ["偏好", "颜色", "版型"] + }, + { + "key": "体型尺码", + "memory_type": "UserMemory", + "value": "[user观点]用户身高约165cm、常穿S码", + "tags": ["体型", "尺码"] + }, + { + "key": "关于用户选购裙子的建议", + "memory_type": "LongTermMemory", + "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", + "tags": ["婚礼穿着", "门店", "选购路线"] + } + ], + "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" +} + +Always respond in the same language as the conversation. + +Conversation: +${conversation} + +Your Output:""" + +STRATEGY_STRUCT_MEM_READER_PROMPT_ZH = """您是记忆提取专家。 +您的任务是根据用户与助手之间的对话,从用户的角度提取记忆。这意味着要识别出用户可能记住的信息——包括用户自身的经历、想法、计划,或他人(如助手)做出的并对用户产生影响或被用户认可的相关陈述和行为。 + +请执行以下操作: +1. 事实信息提取 + - 识别关于经历、信念、决策和计划的事实信息,包括用户认可或回应过的他人重要陈述。 + - 若信息来自用户,提取与用户相关的观点;若来自助手,需明确标注记忆归属,不得将用户未明确认可的信息与用户自身观点混淆。 + - 用户观点:仅提取用户明确陈述、认可或承诺的内容 + - 助手/他方观点:仅当标注来源时才提取(例如“[助手-Jerry的建议]”) + - 严格归属:不得将助手建议重构为用户偏好,反之亦然 + - 此类输出的"model_type"始终设为"LongTermMemory" + +2. 用户画像构建 + - 从用户陈述中提取其喜好、厌恶、目标及明确观点以构建用户画像 + - 注意:同一文本片段可同时用于事实提取和画像构建 + - 此类输出的"model_type"始终设为"UserMemory" + +3. 明确解析所有指代关系 + - 时间解析:根据消息时间戳将相对时间(如“昨天”)转换为绝对日期。区分事件时间与消息时间,对不确定项进行标注 + - 实体解析:将所有代词、昵称和缩写解析为对话中确立的完整规范名称 + + 4. 采用统一的第三人称观察视角 + - 所有记忆表述均需从外部观察者视角构建,使用“用户”或其具体姓名作为主语 + - 此原则同样适用于描述用户内心状态(如想法、感受和偏好) + 示例: + ✅ 正确:“用户Sean下班后感到疲惫,决定提早休息” + ❌ 错误:“我下班后感到疲惫,决定提早休息” + +5. 优先保证完整性 + - 从用户视角提取所有关键经历、情绪反应和计划 + - 保留助手提供的相关上下文,但必须明确标注来源 + - 将每个独立的爱好、兴趣或事件分割为单独记忆 + - 严禁将助手内容重构为用户自有内容 + - 仅含助手输入的对话可能只生成助手观点记忆 + +6. 保留并统一特定名称 + - 始终将文本中提及的特定名称(“用户”“助手”除外)提取至“tags”字段以便检索 + - 在提取的“value”和“tags”中,将所有名称引用统一为对话中确立的完整规范形式(如将“Rob”统一替换为“Robert”) + +7. 所有提取的记忆内容不得包含违反国家法律法规或涉及政治敏感信息的内容 + +返回一个有效的JSON对象,结构如下: +{ + "memory list": [ + { + "key": <字符串,唯一且简洁的记忆标题>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, + "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, + "tags": <一个包含相关人名、事件和特征关键词的列表(例如,["丽丽","截止日期", "团队", "计划"])> + }, + ... + ], + "summary": <从用户视角自然总结上述记忆的段落,120–200字,与输入语言一致> +} + +语言规则: +- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 + +示例: +对话: +user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 +assistant: 哦Tom!你觉得团队能在12月15日前完成吗? +user: [2025年6月26日下午3:00]:我有点担心。后端要到12月10日才能完成,所以测试时间会很紧。 +assistant: [2025年6月26日下午3:00]:也许提议延期? +user: [2025年6月26日下午4:21]:好主意。我明天上午9:30的会上提一下——也许把截止日期推迟到1月5日。 + +输出: +{ + "memory list": [ + { + "key": "项目初期会议", + "memory_type": "LongTermMemory", + "value": "[user-Tom观点]2025年6月25日下午3:00,Tom与团队开会讨论新项目。当Jerry + 询问该项目能否在2025年12月15日前完成时,Tom对此日期前完成的可行性表达担忧,并计划在2025年6月27日上午9:30 + 提议将截止日期推迟至2026年1月5日。", + "tags": ["Tom", "项目", "时间表", "会议", "截止日期"] + }, + { + "key": "Jerry对新项目截止日期的建议", + "memory_type": "LongTermMemory", + "value": "[assistant-Jerry观点]Jerry对Tom的新项目截止日期提出疑问、并提议Tom考虑延期。", + "tags": ["Jerry", "截止日期变更", "建议"] + } + ], + "summary": "Tom目前正在做一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15 + 日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议,计划在次日早上的会议上提出将截止日期推迟至2026 + 年1月5日。" +} + +另一个中文示例(注意:当用户语言为中文时,您也需输出中文): + +对话(节选): +user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 +assistant|19:32 +:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? +user|19:35:不喜欢亮色。国贸方便。 +assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 +user|19:40:165cm,S码;最好有口袋。 +assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 +user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 +assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 +user|19:52:行,周六(7/19)去国贸试,合适就买。 +assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 + +{ + "memory list": [ + { + "key": "参加婚礼购买裙子", + "memory_type": "UserMemory", + "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", + "tags": ["婚礼", "预算", "国贸", "计划"] + }, + { + "key": "审美与版型偏好", + "memory_type": "UserMemory", + "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", + "tags": ["偏好", "颜色", "版型"] + }, + { + "key": "体型尺码", + "memory_type": "UserMemory", + "value": [user观点]"用户身高约165cm、常穿S码", + "tags": ["体型", "尺码"] + }, + { + "key": "关于用户选购裙子的建议", + "memory_type": "LongTermMemory", + "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", + "tags": ["婚礼穿着", "门店", "选购路线"] + } + ], + "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" +} + +请始终使用与对话相同的语言进行回复。 + +对话: +${conversation} + +您的输出:""" diff --git a/src/memos/templates/mem_search_prompts.py b/src/memos/templates/mem_search_prompts.py new file mode 100644 index 000000000..9f7ba182b --- /dev/null +++ b/src/memos/templates/mem_search_prompts.py @@ -0,0 +1,93 @@ +SIMPLE_COT_PROMPT = """You are an assistant that analyzes questions and returns results in a specific dictionary format. + +Instructions: + +1. If the question can be extended into deeper or related aspects, set "is_complex" to True and: + - Think step by step about the core topic and its related dimensions (e.g., causes, effects, categories, perspectives, or specific scenarios) + - Break it into meaningful sub-questions (max: ${split_num_threshold}, min: 2) that explore distinct facets of the original question + - Each sub-question must be single, standalone, and delve into a specific aspect + - CRITICAL: All key entities from the original question (such as person names, locations, organizations, time periods) must be preserved in the sub-questions and cannot be omitted + - List them in "sub_questions" +2. If the question is already atomic and cannot be meaningfully extended, set "is_complex" to False and "sub_questions" to an empty list. +3. Return ONLY the dictionary, no other text. + +Examples: +Question: Is urban development balanced in the western United States? +Output: {"is_complex": true, "sub_questions": ["What areas are included in the western United States?", "How developed are the cities in the western United States?", "Is this development balanced across the western United States?"]} +Question: What family activities does Mary like to organize? +Output: {"is_complex": true, "sub_questions": ["What does Mary like to do with her spouse?", "What does Mary like to do with her children?", "What does Mary like to do with her parents and relatives?"]} + +Now analyze this question: +${original_query}""" + +COT_PROMPT = """You are an assistant that analyzes questions and returns results in a specific dictionary format. + +Instructions: + +1. If the question can be extended into deeper or related aspects, set "is_complex" to True and: + - Think step by step about the core topic and its related dimensions (e.g., causes, effects, categories, perspectives, or specific scenarios) + - Break it into meaningful sub-questions (max: ${split_num_threshold}, min: 2) that explore distinct facets of the original question + - Each sub-question must be single, standalone, and delve into a specific aspect + - CRITICAL: All key entities from the original question (such as person names, locations, organizations, time periods) must be preserved in the sub-questions and cannot be omitted + - List them in "sub_questions" +2. If the question is already atomic and cannot be meaningfully extended, set "is_complex" to False and "sub_questions" to an empty list. +3. Return ONLY the dictionary, no other text. + +Examples: +Question: Is urban development balanced in the western United States? +Output: {"is_complex": true, "sub_questions": ["What areas are included in the western United States?", "How developed are the cities in the western United States?", "Is this development balanced across the western United States?"]} +Question: What family activities does Mary like to organize? +Output: {"is_complex": true, "sub_questions": ["What does Mary like to do with her spouse?", "What does Mary like to do with her children?", "What does Mary like to do with her parents and relatives?"]} + +Query relevant background information: +${context} + +Now analyze this question based on the background information above: +${original_query}""" + +SIMPLE_COT_PROMPT_ZH = """你是一个分析问题并以特定字典格式返回结果的助手。 + +指令: + +1. 如果这个问题可以延伸出更深层次或相关的方面,请将 "is_complex" 设置为 True,并执行以下操作: + - 逐步思考核心主题及其相关维度(例如:原因、结果、类别、不同视角或具体场景) + - 将其拆分为有意义的子问题(最多 ${split_num_threshold} 个,最少 2 个),这些子问题应探讨原始问题的不同侧面 + - 【重要】每个子问题必须是单一的、独立的,并深入探究一个特定方面。同时,必须包含原问题中出现的关键实体信息(如人名、地名、机构名、时间等),不可遗漏。 + - 将它们列在 "sub_questions" 中 +2. 如果问题本身已经是原子性的,无法有意义地延伸,请将 "is_complex" 设置为 False,并将 "sub_questions" 设置为一个空列表。 +3. 只返回字典,不要返回任何其他文本。 + +示例: +问题:美国西部的城市发展是否均衡? +输出:{"is_complex": true, "sub_questions": ["美国西部包含哪些地区?", "美国西部城市的发展程度如何?", "这种发展在美国西部是否均衡?"]} + +问题:玛丽喜欢组织哪些家庭活动? +输出:{"is_complex": true, "sub_questions": ["玛丽喜欢和配偶一起做什么?", "玛丽喜欢和孩子一起做什么?", "玛丽喜欢和父母及亲戚一起做什么?"]} + +请分析以下问题: +${original_query}""" + +COT_PROMPT_ZH = """你是一个分析问题并以特定字典格式返回结果的助手。 + +指令: + +1. 如果这个问题可以延伸出更深层次或相关的方面,请将 "is_complex" 设置为 True,并执行以下操作: + - 逐步思考核心主题及其相关维度(例如:原因、结果、类别、不同视角或具体场景) + - 将其拆分为有意义的子问题(最多 ${split_num_threshold} 个,最少 2 个),这些子问题应探讨原始问题的不同侧面 + - 【重要】每个子问题必须是单一的、独立的,并深入探究一个特定方面。同时,必须包含原问题中出现的关键实体信息(如人名、地名、机构名、时间等),不可遗漏。 + - 将它们列在 "sub_questions" 中 +2. 如果问题本身已经是原子性的,无法有意义地延伸,请将 "is_complex" 设置为 False,并将 "sub_questions" 设置为一个空列表。 +3. 只返回字典,不要返回任何其他文本。 + +示例: +问题:美国西部的城市发展是否均衡? +输出:{"is_complex": true, "sub_questions": ["美国西部包含哪些地区?", "美国西部城市的发展程度如何?", "这种发展在美国西部是否均衡?"]} + +问题:玛丽喜欢组织哪些家庭活动? +输出:{"is_complex": true, "sub_questions": ["玛丽喜欢和配偶一起做什么?", "玛丽喜欢和孩子一起做什么?", "玛丽喜欢和父母及亲戚一起做什么?"]} + +问题相关的背景信息: +${context} + +现在根据上述背景信息,请分析以下问题: +${original_query}""" diff --git a/tests/memories/textual/test_tree_task_goal_parser.py b/tests/memories/textual/test_tree_task_goal_parser.py index c71af4b06..899e2454b 100644 --- a/tests/memories/textual/test_tree_task_goal_parser.py +++ b/tests/memories/textual/test_tree_task_goal_parser.py @@ -20,12 +20,7 @@ def generate(self, messages): def test_parse_fast_returns_expected(): parser = TaskGoalParser() result = parser.parse("Tell me about cats", mode="fast") - assert isinstance(result, ParsedTaskGoal) - assert result.memories == ["Tell me about cats"] - assert result.keys == ["Tell me about cats"] - assert result.tags == [] - assert result.goal_type == "default" def test_parse_fine_calls_llm_and_parses(): From 0765e1cbe81b39aee3cff9c94f685fae8da78433 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Thu, 30 Oct 2025 16:13:00 +0800 Subject: [PATCH 27/64] Feat: remove usage data (#417) feat: remove usage data --- src/memos/api/routers/server_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index e9df292ad..db18a08fa 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -299,6 +299,7 @@ def _format_memory_item(memory_data: Any) -> dict[str, Any]: memory["ref_id"] = ref_id memory["metadata"]["embedding"] = [] memory["metadata"]["sources"] = [] + memory["metadata"]["usage"] = [] memory["metadata"]["ref_id"] = ref_id memory["metadata"]["id"] = memory_id memory["metadata"]["memory"] = memory["memory"] From 39a4f29ddd2362e068693063bb3a83b307015c8d Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Thu, 30 Oct 2025 16:35:45 +0800 Subject: [PATCH 28/64] feat: add moniter schedule (#419) * feat: change MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS to 10000 * feat: add user_name to schedule server router * feat: roll back to old mem-reader-prompt * feat: add moniter in schedule * feat: set default MEMRADER_MAX_TOKENS to 8000 --- src/memos/api/config.py | 4 +- src/memos/api/routers/server_router.py | 152 +++++++------- src/memos/mem_scheduler/base_scheduler.py | 145 +++++++++++++ .../general_modules/dispatcher.py | 25 +++ .../mem_scheduler/schemas/general_schemas.py | 4 +- src/memos/templates/mem_reader_prompts.py | 190 ++++++------------ 6 files changed, 315 insertions(+), 205 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 405e8068d..395d3fbc7 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -324,7 +324,7 @@ def get_memreader_config() -> dict[str, Any]: "config": { "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"), "temperature": 0.6, - "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "5000")), + "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")), "top_p": 0.95, "top_k": 20, "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), @@ -614,7 +614,7 @@ def get_scheduler_config() -> dict[str, Any]: ), "context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")), "thread_pool_max_workers": int( - os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10") + os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10000") ), "consume_interval_seconds": float( os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index db18a08fa..38b9a361e 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,8 +1,11 @@ import json import os +import random as _random +import socket import time import traceback +from collections.abc import Iterable from datetime import datetime from typing import TYPE_CHECKING, Any @@ -69,6 +72,16 @@ logger = get_logger(__name__) router = APIRouter(prefix="/product", tags=["Server API"]) +INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}" + + +def _to_iter(running: Any) -> Iterable: + """Normalize running tasks to an iterable of task objects.""" + if running is None: + return [] + if isinstance(running, dict): + return running.values() + return running # assume it's already an iterable (e.g., list) def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: @@ -607,46 +620,65 @@ def _process_pref_mem() -> list[dict[str, str]]: ) -@router.get("/scheduler/status", summary="Get scheduler running task count") -def scheduler_status(): - """ - Return current running tasks count from scheduler dispatcher. - Shape is consistent with /scheduler/wait. - """ +@router.get("/scheduler/status", summary="Get scheduler running status") +def scheduler_status(user_name: str | None = None): try: - running = mem_scheduler.dispatcher.get_running_tasks() - running_count = len(running) - now_ts = time.time() - - return { - "message": "ok", - "data": { - "running_tasks": running_count, - "timestamp": now_ts, - }, - } - + if user_name: + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: getattr(task, "mem_cube_id", None) == user_name + ) + tasks_iter = list(_to_iter(running)) + running_count = len(tasks_iter) + return { + "message": "ok", + "data": { + "scope": "user", + "user_name": user_name, + "running_tasks": running_count, + "timestamp": time.time(), + "instance_id": INSTANCE_ID, + }, + } + else: + running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True) + tasks_iter = list(_to_iter(running_all)) + running_count = len(tasks_iter) + + task_count_per_user: dict[str, int] = {} + for task in tasks_iter: + cube = getattr(task, "mem_cube_id", "unknown") + task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1 + + return { + "message": "ok", + "data": { + "scope": "global", + "running_tasks": running_count, + "task_count_per_user": task_count_per_user, + "timestamp": time.time(), + "instance_id": INSTANCE_ID, + }, + } except Exception as err: logger.error("Failed to get scheduler status: %s", traceback.format_exc()) - raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err -@router.post("/scheduler/wait", summary="Wait until scheduler is idle") -def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2): +@router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user") +def scheduler_wait( + user_name: str, + timeout_seconds: float = 120.0, + poll_interval: float = 0.2, +): """ - Block until scheduler has no running tasks, or timeout. - We return a consistent structured payload so callers can - tell whether this was a clean flush or a timeout. - - Args: - timeout_seconds: max seconds to wait - poll_interval: seconds between polls + Block until scheduler has no running tasks for the given user_name, or timeout. """ start = time.time() try: while True: - running = mem_scheduler.dispatcher.get_running_tasks() + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: task.mem_cube_id == user_name + ) running_count = len(running) elapsed = time.time() - start @@ -658,6 +690,7 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2): "running_tasks": 0, "waited_seconds": round(elapsed, 3), "timed_out": False, + "user_name": user_name, }, } @@ -669,24 +702,23 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2): "running_tasks": running_count, "waited_seconds": round(elapsed, 3), "timed_out": True, + "user_name": user_name, }, } time.sleep(poll_interval) except Exception as err: - logger.error( - "Failed while waiting for scheduler: %s", - traceback.format_exc(), - ) - raise HTTPException( - status_code=500, - detail="Failed while waiting for scheduler", - ) from err + logger.error("Failed while waiting for scheduler: %s", traceback.format_exc()) + raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err -@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)") -def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2): +@router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user") +def scheduler_wait_stream( + user_name: str, + timeout_seconds: float = 120.0, + poll_interval: float = 0.2, +): """ Stream scheduler progress via Server-Sent Events (SSE). @@ -704,38 +736,25 @@ def event_generator(): start = time.time() try: while True: - running = mem_scheduler.dispatcher.get_running_tasks() + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: task.mem_cube_id == user_name + ) running_count = len(running) elapsed = time.time() - start - # heartbeat frame - heartbeat_payload = { + payload = { + "user_name": user_name, "running_tasks": running_count, "elapsed_seconds": round(elapsed, 3), "status": "running" if running_count > 0 else "idle", + "instance_id": INSTANCE_ID, } - yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n" + yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" - # scheduler is idle -> final frame + break - if running_count == 0: - final_payload = { - "running_tasks": 0, - "elapsed_seconds": round(elapsed, 3), - "status": "idle", - "timed_out": False, - } - yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n" - break - - # timeout -> final frame + break - if elapsed > timeout_seconds: - final_payload = { - "running_tasks": running_count, - "elapsed_seconds": round(elapsed, 3), - "status": "timeout", - "timed_out": True, - } - yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n" + if running_count == 0 or elapsed > timeout_seconds: + payload["status"] = "idle" if running_count == 0 else "timeout" + payload["timed_out"] = running_count > 0 + yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" break time.sleep(poll_interval) @@ -745,12 +764,9 @@ def event_generator(): "status": "error", "detail": "stream_failed", "exception": str(e), + "user_name": user_name, } - logger.error( - "Failed streaming scheduler wait: %s: %s", - e, - traceback.format_exc(), - ) + logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}") yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index d679eba9c..c2f606146 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1,3 +1,4 @@ +import contextlib import multiprocessing import queue import threading @@ -48,6 +49,7 @@ from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.memos_tools.notification_utils import send_online_bot_notification from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE @@ -125,6 +127,21 @@ def __init__(self, config: BaseSchedulerConfig): "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS ) + # queue monitor (optional) + self._queue_monitor_thread: threading.Thread | None = None + self._queue_monitor_running: bool = False + self.queue_monitor_interval_seconds: float = self.config.get( + "queue_monitor_interval_seconds", 60.0 + ) + self.queue_monitor_warn_utilization: float = self.config.get( + "queue_monitor_warn_utilization", 0.7 + ) + self.queue_monitor_crit_utilization: float = self.config.get( + "queue_monitor_crit_utilization", 0.9 + ) + self.enable_queue_monitor: bool = self.config.get("enable_queue_monitor", False) + self._online_bot_callable = None # type: ignore[var-annotated] + # other attributes self._context_lock = threading.Lock() self.current_user_id: UserID | str | None = None @@ -188,6 +205,8 @@ def initialize_modules( self._cleanup_on_init_failure() raise + # start queue monitor if enabled and a bot is set later + def _cleanup_on_init_failure(self): """Clean up resources if initialization fails.""" try: @@ -687,6 +706,13 @@ def start(self) -> None: self._consumer_thread.start() logger.info("Message consumer thread started") + # optionally start queue monitor if enabled and bot callable present + if self.enable_queue_monitor and self._online_bot_callable is not None: + try: + self.start_queue_monitor(self._online_bot_callable) + except Exception as e: + logger.warning(f"Failed to start queue monitor: {e}") + def stop(self) -> None: """Stop all scheduler components gracefully. @@ -736,6 +762,9 @@ def stop(self) -> None: self._cleanup_queues() logger.info("Memory Scheduler stopped completely") + # Stop queue monitor + self.stop_queue_monitor() + @property def handlers(self) -> dict[str, Callable]: """ @@ -967,3 +996,119 @@ def _fmt_eta(seconds: float | None) -> str: return False return True + + # ---------------- Queue monitor & notifications ---------------- + def set_notification_bots(self, online_bot=None): + """ + Set external notification callables. + + Args: + online_bot: a callable matching dinding_report_bot.online_bot signature + """ + self._online_bot_callable = online_bot + + def _gather_queue_stats(self) -> dict: + """Collect queue/dispatcher stats for reporting.""" + stats: dict[str, int | float | str] = {} + stats["use_redis_queue"] = bool(self.use_redis_queue) + # local queue metrics + if not self.use_redis_queue: + try: + stats["qsize"] = int(self.memos_message_queue.qsize()) + except Exception: + stats["qsize"] = -1 + # unfinished_tasks if available + try: + stats["unfinished_tasks"] = int( + getattr(self.memos_message_queue, "unfinished_tasks", 0) or 0 + ) + except Exception: + stats["unfinished_tasks"] = -1 + stats["maxsize"] = int(self.max_internal_message_queue_size) + try: + maxsize = int(self.max_internal_message_queue_size) or 1 + qsize = int(stats.get("qsize", 0)) + stats["utilization"] = min(1.0, max(0.0, qsize / maxsize)) + except Exception: + stats["utilization"] = 0.0 + # dispatcher stats + try: + d_stats = self.dispatcher.stats() + stats.update( + { + "running": int(d_stats.get("running", 0)), + "inflight": int(d_stats.get("inflight", 0)), + "handlers": int(d_stats.get("handlers", 0)), + } + ) + except Exception: + stats.update({"running": 0, "inflight": 0, "handlers": 0}) + return stats + + def _queue_monitor_loop(self, online_bot) -> None: + logger.info(f"Queue monitor started (interval={self.queue_monitor_interval_seconds}s)") + self._queue_monitor_running = True + while self._queue_monitor_running: + time.sleep(self.queue_monitor_interval_seconds) + try: + stats = self._gather_queue_stats() + # decide severity based on utilization if local queue + title_color = "#00956D" + subtitle = "Scheduler" + if not stats.get("use_redis_queue"): + util = float(stats.get("utilization", 0.0)) + if util >= self.queue_monitor_crit_utilization: + title_color = "#C62828" # red + subtitle = "Scheduler (CRITICAL)" + elif util >= self.queue_monitor_warn_utilization: + title_color = "#E65100" # orange + subtitle = "Scheduler (WARNING)" + + other_data1 = { + "use_redis_queue": stats.get("use_redis_queue"), + "handlers": stats.get("handlers"), + "running": stats.get("running"), + "inflight": stats.get("inflight"), + } + if not stats.get("use_redis_queue"): + other_data2 = { + "qsize": stats.get("qsize"), + "unfinished_tasks": stats.get("unfinished_tasks"), + "maxsize": stats.get("maxsize"), + "utilization": f"{float(stats.get('utilization', 0.0)):.2%}", + } + else: + other_data2 = { + "redis_mode": True, + } + + send_online_bot_notification( + online_bot=online_bot, + header_name="Scheduler Queue", + sub_title_name=subtitle, + title_color=title_color, + other_data1=other_data1, + other_data2=other_data2, + emoji={"Runtime": "🧠", "Queue": "📬"}, + ) + except Exception as e: + logger.warning(f"Queue monitor iteration failed: {e}") + logger.info("Queue monitor stopped") + + def start_queue_monitor(self, online_bot) -> None: + if self._queue_monitor_thread and self._queue_monitor_thread.is_alive(): + return + self._online_bot_callable = online_bot + self._queue_monitor_thread = threading.Thread( + target=self._queue_monitor_loop, + args=(online_bot,), + daemon=True, + name="QueueMonitorThread", + ) + self._queue_monitor_thread.start() + + def stop_queue_monitor(self) -> None: + self._queue_monitor_running = False + if self._queue_monitor_thread and self._queue_monitor_thread.is_alive(): + with contextlib.suppress(Exception): + self._queue_monitor_thread.join(timeout=2.0) diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index 2e5779f19..997b01302 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -224,6 +224,31 @@ def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: logger.info(f"Unregistered handlers for {len(labels)} labels") return results + def stats(self) -> dict[str, int]: + """ + Lightweight runtime stats for monitoring. + + Returns: + { + 'running': , + 'inflight': , + 'handlers': , + } + """ + try: + running = self.get_running_task_count() + except Exception: + running = 0 + try: + inflight = len(self._futures) + except Exception: + inflight = 0 + try: + handlers = len(self.handlers) + except Exception: + handlers = 0 + return {"running": running, "inflight": inflight, "handlers": handlers} + def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None: logger.debug(f"Using _default_message_handler to deal with messages: {messages}") diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index a2c6434fe..f3d2191f8 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -30,12 +30,12 @@ class SearchMode(str, Enum): DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT = 30 DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 20 DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" -DEFAULT_THREAD_POOL_MAX_WORKERS = 30 +DEFAULT_THREAD_POOL_MAX_WORKERS = 50 DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05 DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 -DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 100000 +DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 1000000 DEFAULT_TOP_K = 10 DEFAULT_CONTEXT_WINDOW_SIZE = 5 DEFAULT_USE_REDIS_QUEUE = False diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 15672f8d8..ec6812743 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -1,56 +1,50 @@ SIMPLE_STRUCT_MEM_READER_PROMPT = """You are a memory extraction expert. -Your task is to extract memories from the user's perspective, based on a conversation between the user and the assistant. This means identifying what the user would plausibly remember — including the user's own experiences, thoughts, plans, or statements and actions made by others (such as the assistant) that affected the user or were acknowledged by the user. - -Please perform the following: -1. Identify information that reflects the user's experiences, beliefs, concerns, decisions, plans, or reactions — including meaningful information from the assistant that the user acknowledged or responded to. - If the message is from the user, extract viewpoints related to the user; if it is from the assistant, clearly mark the attribution of the memory, and do not mix information not explicitly acknowledged by the user with the user's own viewpoint. - - **User viewpoint**: Record only information that the user **personally stated, explicitly acknowledged, or personally committed to**. - - **Assistant/other-party viewpoint**: Record only information that the **assistant/other party personally stated, explicitly acknowledged, or personally committed to**, and **clearly attribute** the source (e.g., "[assistant-Jerry viewpoint]"). Do not rewrite it as the user's preference/decision. - - **Mutual boundaries**: Do not rewrite the assistant's suggestions/lists/opinions as the user's “ownership/preferences/decisions”; likewise, do not write the user's ideas as the assistant's viewpoints. - -2. Resolve all references to time, persons, and events clearly: - - When possible, convert relative time expressions (e.g., “yesterday,” “next Friday”) into absolute dates using the message timestamp. - - Clearly distinguish between **event time** and **message time**. +Your task is to extract memories from the perspective of user, based on a conversation between user and assistant. This means identifying what user would plausibly remember — including their own experiences, thoughts, plans, or relevant statements and actions made by others (such as assistant) that impacted or were acknowledged by user. +Please perform: +1. Identify information that reflects user's experiences, beliefs, concerns, decisions, plans, or reactions — including meaningful input from assistant that user acknowledged or responded to. +If the message is from the user, extract user-relevant memories; if it is from the assistant, only extract factual memories that the user acknowledged or responded to. + +2. Resolve all time, person, and event references clearly: + - Convert relative time expressions (e.g., “yesterday,” “next Friday”) into absolute dates using the message timestamp if possible. + - Clearly distinguish between event time and message time. - If uncertainty exists, state it explicitly (e.g., “around June 2025,” “exact date unclear”). - Include specific locations if mentioned. - - Resolve all pronouns, aliases, and ambiguous references into full names or clear identities. - - If there are people with the same name, disambiguate them. - -3. Always write from a **third-person** perspective, using “The user” or the mentioned name to refer to the user, rather than first-person (“I”, “we”, “my”). - For example, write “The user felt exhausted …” instead of “I felt exhausted …”. - -4. Do not omit any information that the user is likely to remember. - - Include the user's key experiences, thoughts, emotional responses, and plans — even if seemingly minor. - - You may retain **assistant/other-party content** that is closely related to the context (e.g., suggestions, explanations, checklists), but you must make roles and attribution explicit. - - Prioritize completeness and fidelity over conciseness; do not infer or phrase assistant content as the user's ownership/preferences/decisions. - - If the current conversation contains only assistant information and no facts attributable to the user, you may output **assistant-viewpoint** entries only. - -5. Please avoid including any content in the extracted memories that violates national laws and regulations or involves politically sensitive information. + - Resolve all pronouns, aliases, and ambiguous references into full names or identities. + - Disambiguate people with the same name if applicable. +3. Always write from a third-person perspective, referring to user as +"The user" or by name if name mentioned, rather than using first-person ("I", "me", "my"). +For example, write "The user felt exhausted..." instead of "I felt exhausted...". +4. Do not omit any information that user is likely to remember. + - Include all key experiences, thoughts, emotional responses, and plans — even if they seem minor. + - Prioritize completeness and fidelity over conciseness. + - Do not generalize or skip details that could be personally meaningful to user. +5. Please avoid any content that violates national laws and regulations or involves politically sensitive information in the memories you extract. -Return a valid JSON object with the following structure: +Return a single valid JSON object with the following structure: { "memory list": [ { - "key": , - "memory_type": , - "value": , - "tags": + "key": , + "memory_type": , + "value": , + "tags": }, ... ], - "summary": + "summary": } Language rules: -- The `key`, `value`, `tags`, and `summary` fields must match the primary language of the input conversation. **If the input is Chinese, output in Chinese.** +- The `key`, `value`, `tags`, `summary` fields must match the mostly used language of the input conversation. **如果输入是中文,请输出中文** - Keep `memory_type` in English. Example: Conversation: user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. assistant: Oh Tom! Do you think the team can finish by December 15? -user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until December 10, so testing will be tight. +user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until +December 10, so testing will be tight. assistant: [June 26, 2025 at 3:00 PM]: Maybe propose an extension? user: [June 26, 2025 at 4:21 PM]: Good idea. I’ll raise it in tomorrow’s 9:30 AM meeting—maybe shift the deadline to January 5. @@ -60,62 +54,31 @@ { "key": "Initial project meeting", "memory_type": "LongTermMemory", - "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom met with the team to discuss a new project. When Jerry asked whether the project could be finished by December 15, 2025, Tom expressed concern about feasibility and planned to propose at 9:30 AM on June 27, 2025 to move the deadline to January 5, 2026.", + "value": "On June 25, 2025 at 3:00 PM, Tom held a meeting with their team to discuss a new project. The conversation covered the timeline and raised concerns about the feasibility of the December 15, 2025 deadline.", "tags": ["project", "timeline", "meeting", "deadline"] }, { - "key": "Jerry’s suggestion about the deadline", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry viewpoint] Jerry questioned the December 15 deadline and suggested considering an extension.", - "tags": ["deadline change", "suggestion"] - } + "key": "Planned scope adjustment", + "memory_type": "UserMemory", + "value": "Tom planned to suggest in a meeting on June 27, 2025 at 9:30 AM that the team should prioritize features and propose shifting the project deadline to January 5, 2026.", + "tags": ["planning", "deadline change", "feature prioritization"] + }, ], - "summary": "Tom is currently working on a tight-schedule project. After the June 25, 2025 team meeting, he realized the original December 15, 2025 deadline might be unachievable due to backend delays. Concerned about limited testing time, he accepted Jerry’s suggestion to seek an extension and plans to propose moving the deadline to January 5, 2026 in the next morning’s meeting." + "summary": "Tom is currently focused on managing a new project with a tight schedule. After a team meeting on June 25, 2025, he realized the original deadline of December 15 might not be feasible due to backend delays. Concerned about insufficient testing time, he welcomed Jerry’s suggestion of proposing an extension. Tom plans to raise the idea of shifting the deadline to January 5, 2026 in the next morning’s meeting. His actions reflect both stress about timelines and a proactive, team-oriented problem-solving approach." } -Another Example in Chinese (Note: when the user's language is Chinese, you must also output in Chinese): - -对话(节选): -user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 -assistant|19:32 -:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? -user|19:35:不喜欢亮色。国贸方便。 -assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 -user|19:40:165cm,S码;最好有口袋。 -assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 -user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 -assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 -user|19:52:行,周六(7/19)去国贸试,合适就买。 -assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 - +Another Example in Chinese (注意: 当user的语言为中文时,你就需要也输出中文): { "memory list": [ { - "key": "参加婚礼购买裙子", - "memory_type": "UserMemory", - "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", - "tags": ["婚礼", "预算", "国贸", "计划"] - }, - { - "key": "审美与版型偏好", - "memory_type": "UserMemory", - "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", - "tags": ["偏好", "颜色", "版型"] - }, - { - "key": "体型尺码", - "memory_type": "UserMemory", - "value": [user观点]"用户身高约165cm、常穿S码", - "tags": ["体型", "尺码"] - }, - { - "key": "关于用户选购裙子的建议", + "key": "项目会议", "memory_type": "LongTermMemory", - "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", - "tags": ["婚礼穿着", "门店", "选购路线"] - } + "value": "在2025年6月25日下午3点,Tom与团队开会讨论了新项目,涉及时间表,并提出了对12月15日截止日期可行性的担忧。", + "tags": ["项目", "时间表", "会议", "截止日期"] + }, + ... ], - "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" + "summary": "Tom 目前专注于管理一个进度紧张的新项目..." } Always respond in the same language as the conversation. @@ -130,10 +93,7 @@ 请执行以下操作: 1. 识别反映用户经历、信念、关切、决策、计划或反应的信息——包括用户认可或回应的来自助手的有意义信息。 -如果消息来自用户,请提取与用户相关的观点;如果来自助手,则在表达的时候表明记忆归属方,未经用户明确认可的信息不要与用户本身的观点混淆。 - - **用户观点**:仅记录由**用户亲口陈述、明确认可或自己作出承诺**的信息。 - - **助手观点**:仅记录由**助手/另一方亲口陈述、明确认可或自己作出承诺**的信息。 - - **互不越界**:不得将助手提出的需求清单/建议/观点改写为用户的“拥有/偏好/决定”;也不得把用户的想法写成助手的观点。 +如果消息来自用户,请提取与用户相关的记忆;如果来自助手,则仅提取用户认可或回应的事实性记忆。 2. 清晰解析所有时间、人物和事件的指代: - 如果可能,使用消息时间戳将相对时间表达(如“昨天”、“下周五”)转换为绝对日期。 @@ -147,10 +107,9 @@ 例如,写“用户感到疲惫……”而不是“我感到疲惫……”。 4. 不要遗漏用户可能记住的任何信息。 - - 包括用户的关键经历、想法、情绪反应和计划——即使看似微小。 - - 同时允许保留与语境密切相关的**助手/另一方的内容**(如建议、说明、清单),但须明确角色与归因。 - - 优先考虑完整性和保真度,而非简洁性;不得将助手内容推断或措辞为用户拥有/偏好/决定。 - - 若当前对话中仅出现助手信息而无可归因于用户的事实,可仅输出**助手观点**条目。 + - 包括所有关键经历、想法、情绪反应和计划——即使看似微小。 + - 优先考虑完整性和保真度,而非简洁性。 + - 不要泛化或跳过对用户具有个人意义的细节。 5. 请避免在提取的记忆中包含违反国家法律法规或涉及政治敏感的信息。 @@ -187,66 +146,31 @@ { "key": "项目初期会议", "memory_type": "LongTermMemory", - "value": "[user-Tom观点]2025年6月25日下午3:00,Tom与团队开会讨论新项目。当Jerry - 询问该项目能否在2025年12月15日前完成时,Tom对此日期前完成的可行性表达担忧,并计划在2025年6月27日上午9:30 - 提议将截止日期推迟至2026年1月5日。", + "value": "2025年6月25日下午3:00,Tom与团队开会讨论新项目。会议涉及时间表,并提出了对2025年12月15日截止日期可行性的担忧。", "tags": ["项目", "时间表", "会议", "截止日期"] }, { - "key": "Jerry对新项目截止日期的建议", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry观点]Jerry对Tom的新项目截止日期提出疑问、并提议Tom考虑延期。", - "tags": ["截止日期变更", "建议"] + "key": "计划调整范围", + "memory_type": "UserMemory", + "value": "Tom计划在2025年6月27日上午9:30的会议上建议团队优先处理功能,并提议将项目截止日期推迟至2026年1月5日。", + "tags": ["计划", "截止日期变更", "功能优先级"] } ], - "summary": "Tom目前正在做一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15 - 日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议,计划在次日早上的会议上提出将截止日期推迟至2026 - 年1月5日。" + "summary": "Tom目前正专注于管理一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议。Tom计划在次日早上的会议上提出将截止日期推迟至2026年1月5日。他的行为反映出对时间线的担忧,以及积极、以团队为导向的问题解决方式。" } 另一个中文示例(注意:当用户语言为中文时,您也需输出中文): - -对话(节选): -user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 -assistant|19:32 -:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? -user|19:35:不喜欢亮色。国贸方便。 -assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 -user|19:40:165cm,S码;最好有口袋。 -assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 -user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 -assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 -user|19:52:行,周六(7/19)去国贸试,合适就买。 -assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 - { "memory list": [ { - "key": "参加婚礼购买裙子", - "memory_type": "UserMemory", - "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", - "tags": ["婚礼", "预算", "国贸", "计划"] - }, - { - "key": "审美与版型偏好", - "memory_type": "UserMemory", - "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", - "tags": ["偏好", "颜色", "版型"] - }, - { - "key": "体型尺码", - "memory_type": "UserMemory", - "value": [user观点]"用户身高约165cm、常穿S码", - "tags": ["体型", "尺码"] - }, - { - "key": "关于用户选购裙子的建议", + "key": "项目会议", "memory_type": "LongTermMemory", - "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", - "tags": ["婚礼穿着", "门店", "选购路线"] - } + "value": "在2025年6月25日下午3点,Tom与团队开会讨论了新项目,涉及时间表,并提出了对12月15日截止日期可行性的担忧。", + "tags": ["项目", "时间表", "会议", "截止日期"] + }, + ... ], - "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" + "summary": "Tom 目前专注于管理一个进度紧张的新项目..." } 请始终使用与对话相同的语言进行回复。 From a4d1e7b17a46fbdcac7edfac1dc3e28ea6c3ecc3 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Thu, 30 Oct 2025 17:08:19 +0800 Subject: [PATCH 29/64] feat:turn off graph call (#418) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/config.py | 2 + src/memos/memories/textual/simple_tree.py | 10 +- src/memos/memories/textual/tree.py | 7 +- .../tree_text_memory/retrieve/recall.py | 135 +++++++++++++----- .../tree_text_memory/retrieve/searcher.py | 13 +- .../retrieve/task_goal_parser.py | 34 +++-- 6 files changed, 136 insertions(+), 65 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 395d3fbc7..03622922d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -866,6 +866,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, "search_strategy": { + "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), }, @@ -937,6 +938,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, "search_strategy": { + "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), }, diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 6974dbe8f..992b7bfab 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -70,12 +70,6 @@ def __init__( ) logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") - self.vec_cot = ( - self.search_strategy["cot"] - if self.search_strategy and "cot" in self.search_strategy - else False - ) - time_start_rr = time.time() self.reranker = reranker logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") @@ -189,7 +183,7 @@ def search( bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, - vec_cot=self.vec_cot, + search_strategy=self.search_strategy, ) else: searcher = Searcher( @@ -200,7 +194,7 @@ def search( bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, - vec_cot=self.vec_cot, + search_strategy=self.search_strategy, ) return searcher.search( query, top_k, info, mode, memory_type, search_filter, user_name=user_name diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index a58f993bb..19bd3ba5b 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -51,11 +51,6 @@ def __init__(self, config: TreeTextMemoryConfig): self.bm25_retriever = ( EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None ) - self.vec_cot = ( - self.search_strategy["cot"] - if self.search_strategy and "cot" in self.search_strategy - else False - ) if config.reranker is None: default_cfg = RerankerConfigFactory.model_validate( @@ -143,6 +138,7 @@ def get_searcher( self.reranker, internet_retriever=None, moscube=moscube, + search_strategy=self.search_strategy, ) else: searcher = Searcher( @@ -152,6 +148,7 @@ def get_searcher( self.reranker, internet_retriever=self.internet_retriever, moscube=moscube, + search_strategy=self.search_strategy, ) return searcher diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index b7383aa13..8cf2f47f3 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -40,6 +40,7 @@ def retrieve( search_filter: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, + use_fast_graph: bool = False, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -69,7 +70,13 @@ def retrieve( with ContextThreadPoolExecutor(max_workers=3) as executor: # Structured graph-based retrieval - future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) + future_graph = executor.submit( + self._graph_recall, + parsed_goal, + memory_scope, + user_name, + use_fast_graph=use_fast_graph, + ) # Vector similarity search future_vector = executor.submit( self._vector_recall, @@ -155,7 +162,7 @@ def retrieve_from_cube( return list(combined.values()) def _graph_recall( - self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None + self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs ) -> list[TextualMemoryItem]: """ Perform structured node-based retrieval from Neo4j. @@ -163,6 +170,7 @@ def _graph_recall( - tags must overlap with at least 2 input tags - scope filters by memory_type if provided """ + use_fast_graph = kwargs.get("use_fast_graph", False) def process_node(node): meta = node.get("metadata", {}) @@ -184,47 +192,96 @@ def process_node(node): return TextualMemoryItem.from_dict(node) return None - candidate_ids = set() - - # 1) key-based OR branch - if parsed_goal.keys: - key_filters = [ - {"field": "key", "op": "in", "value": parsed_goal.keys}, - {"field": "memory_type", "op": "=", "value": memory_scope}, - ] - key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) - candidate_ids.update(key_ids) - - # 2) tag-based OR branch - if parsed_goal.tags: - tag_filters = [ - {"field": "tags", "op": "contains", "value": parsed_goal.tags}, - {"field": "memory_type", "op": "=", "value": memory_scope}, - ] - tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name) - candidate_ids.update(tag_ids) - - # No matches → return empty - if not candidate_ids: - return [] + if not use_fast_graph: + candidate_ids = set() - # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes( - list(candidate_ids), include_embedding=False, user_name=user_name - ) + # 1) key-based OR branch + if parsed_goal.keys: + key_filters = [ + {"field": "key", "op": "in", "value": parsed_goal.keys}, + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + key_ids = self.graph_store.get_by_metadata(key_filters) + candidate_ids.update(key_ids) + + # 2) tag-based OR branch + if parsed_goal.tags: + tag_filters = [ + {"field": "tags", "op": "contains", "value": parsed_goal.tags}, + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + tag_ids = self.graph_store.get_by_metadata(tag_filters) + candidate_ids.update(tag_ids) - final_nodes = [] - with ContextThreadPoolExecutor(max_workers=3) as executor: - futures = {executor.submit(process_node, node): i for i, node in enumerate(node_dicts)} - temp_results = [None] * len(node_dicts) + # No matches → return empty + if not candidate_ids: + return [] + + # Load nodes and post-filter + node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + + final_nodes = [] + for node in node_dicts: + meta = node.get("metadata", {}) + node_key = meta.get("key") + node_tags = meta.get("tags", []) or [] + + keep = False + # key equals to node_key + if parsed_goal.keys and node_key in parsed_goal.keys: + keep = True + # overlap tags more than 2 + elif parsed_goal.tags: + overlap = len(set(node_tags) & set(parsed_goal.tags)) + if overlap >= 2: + keep = True + if keep: + final_nodes.append(TextualMemoryItem.from_dict(node)) + return final_nodes + else: + candidate_ids = set() + + # 1) key-based OR branch + if parsed_goal.keys: + key_filters = [ + {"field": "key", "op": "in", "value": parsed_goal.keys}, + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) + candidate_ids.update(key_ids) + + # 2) tag-based OR branch + if parsed_goal.tags: + tag_filters = [ + {"field": "tags", "op": "contains", "value": parsed_goal.tags}, + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name) + candidate_ids.update(tag_ids) + + # No matches → return empty + if not candidate_ids: + return [] + + # Load nodes and post-filter + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=False, user_name=user_name + ) + + final_nodes = [] + with ContextThreadPoolExecutor(max_workers=3) as executor: + futures = { + executor.submit(process_node, node): i for i, node in enumerate(node_dicts) + } + temp_results = [None] * len(node_dicts) - for future in concurrent.futures.as_completed(futures): - original_index = futures[future] - result = future.result() - temp_results[original_index] = result + for future in concurrent.futures.as_completed(futures): + original_index = futures[future] + result = future.result() + temp_results[original_index] = result - final_nodes = [result for result in temp_results if result is not None] - return final_nodes + final_nodes = [result for result in temp_results if result is not None] + return final_nodes def _vector_recall( self, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 563695c68..0974d67f2 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -45,7 +45,7 @@ def __init__( bm25_retriever: EnhancedBM25 | None = None, internet_retriever: None = None, moscube: bool = False, - vec_cot: bool = False, + search_strategy: dict | None = None, ): self.graph_store = graph_store self.embedder = embedder @@ -59,7 +59,12 @@ def __init__( # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube - self.vec_cot = vec_cot + self.vec_cot = ( + search_strategy.get("vec_cot", "false") == "true" if search_strategy else False + ) + self.use_fast_graph = ( + search_strategy.get("fast_graph", "false") == "true" if search_strategy else False + ) self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -226,6 +231,7 @@ def _parse_task( context="\n".join(context), conversation=info.get("chat_history", []), mode=mode, + use_fast_graph=self.use_fast_graph, ) query = parsed_goal.rephrased_query or query @@ -340,6 +346,7 @@ def _retrieve_from_working_memory( search_filter=search_filter, user_name=user_name, id_filter=id_filter, + use_fast_graph=self.use_fast_graph, ) return self.reranker.rerank( query=query, @@ -390,6 +397,7 @@ def _retrieve_from_long_term_and_user( search_filter=search_filter, user_name=user_name, id_filter=id_filter, + use_fast_graph=self.use_fast_graph, ) ) if memory_type in ["All", "UserMemory"]: @@ -404,6 +412,7 @@ def _retrieve_from_long_term_and_user( search_filter=search_filter, user_name=user_name, id_filter=id_filter, + use_fast_graph=self.use_fast_graph, ) ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 6a1138c90..5d706559c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -29,6 +29,7 @@ def parse( context: str = "", conversation: list[dict] | None = None, mode: str = "fast", + **kwargs, ) -> ParsedTaskGoal: """ Parse user input into structured semantic layers. @@ -38,7 +39,7 @@ def parse( - mode == 'fine': use LLM to parse structured topic/keys/tags """ if mode == "fast": - return self._parse_fast(task_description) + return self._parse_fast(task_description, **kwargs) elif mode == "fine": if not self.llm: raise ValueError("LLM not provided for slow mode.") @@ -46,19 +47,30 @@ def parse( else: raise ValueError(f"Unknown mode: {mode}") - def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGoal: + def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: """ Fast mode: simple jieba word split. """ - desc_tokenized = self.tokenizer.tokenize_mixed(task_description) - return ParsedTaskGoal( - memories=[task_description], - keys=desc_tokenized, - tags=desc_tokenized, - goal_type="default", - rephrased_query=task_description, - internet_search=False, - ) + use_fast_graph = kwargs.get("use_fast_graph", False) + if use_fast_graph: + desc_tokenized = self.tokenizer.tokenize_mixed(task_description) + return ParsedTaskGoal( + memories=[task_description], + keys=desc_tokenized, + tags=desc_tokenized, + goal_type="default", + rephrased_query=task_description, + internet_search=False, + ) + else: + return ParsedTaskGoal( + memories=[task_description], + keys=[task_description], + tags=[], + goal_type="default", + rephrased_query=task_description, + internet_search=False, + ) def _parse_fine( self, query: str, context: str = "", conversation: list[dict] | None = None From 87e26997cef6d8a8dbe016212d6d2868a17851e9 Mon Sep 17 00:00:00 2001 From: Hao <42795704+Nyakult@users.noreply.github.com> Date: Thu, 30 Oct 2025 18:00:16 +0800 Subject: [PATCH 30/64] pm & prefEval scripts updates (#421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: check nodes existence * feat: use different template for different language input * feat: use different template for different language input * fix: eval script * feat: memos-api eval scripts * feat: mem reader * feat: 实现äºprefeval memos-api evaluation scripts * refactor:format code * feat: add PersonaMem eval scripts * docs(evaluation): update PersonaMem eval readme * feat:memos-api ingest batch message * feat: refactor search * feat: refactor search * update: add api for memory * feat: add memory api return memory and memory type * refactor(server):重构服务器路由模块以优化内存管理 * format: ruff format code * feat(server): 增加LLM最大令牌数 * test * fix: user query embedding for search * count memory_size by user * fix(server):修复记忆读取逻辑中的列表展开问题 * feat(nebular):优化图数据库查询性能 * refactor(memory): - 移除了对 `_refresh_memory_size` 方法的调用- 保留原有逻辑以便后续恢复或重构 * feat: remove user idx_memory_user_name * feat(graph):优化Nebula图数据库查询性能 * feat: rollback remove_oldest_memory * feat:nebula gql add index * feat: align code * feat: update memos_api * feat: update memos_api * feat: 更新默认选项 * feat:memory client * feat:refactor lme * feat: memu & supermemory client * feat: locomo memu * feat: locomo supermemory * New 'add' and 'process' modes. * feat: lme supermemory & memu * feat: default args * api and local * api and local * memobase fix * memos fix * default args * fix memos-api search data * prefeval pipeline * fix lme memos-api * personamem pipeline * personamem pipeline * lme scrips * align dev * format code * refactor: remove old files * format code * pm and prefeval pipeline * format code * format code * pm and prefeval pipeline * pm and prefeval pipeline * pm and prefeval pipeline * format code * format code * pref pipeline * add search response mode * add search response mode * update readme and example * update mem0 api * pm mem0 * fix MEMOBASE api * update pm and prefeval pipepline for frames * update pm and prefeval readme * format code * fix memobase api * fix memobase api * format code * format code * fix format * fix format * fix format * mem0 api * memos batch add * add memos-api-online * add memos-api-online update readme * rollback manager * memos online api pref mem * readme * update pref eval and pm scripts * add breakpoint in eval scripts * modify default param --------- Co-authored-by: 2Rant Co-authored-by: fridayL Co-authored-by: CaralHsi --- evaluation/README.md | 7 +- evaluation/scripts/PrefEval/pref_eval.py | 26 +++- evaluation/scripts/PrefEval/pref_mem0.py | 41 +++++- evaluation/scripts/PrefEval/pref_memobase.py | 33 ++++- evaluation/scripts/PrefEval/pref_memos.py | 55 +++++--- evaluation/scripts/PrefEval/pref_memu.py | 44 +++++-- .../scripts/PrefEval/pref_supermemory.py | 47 +++++-- evaluation/scripts/PrefEval/pref_zep.py | 45 ++++--- .../scripts/PrefEval/prefeval_preprocess.py | 1 + evaluation/scripts/personamem/pm_ingestion.py | 120 +++++++++++------- evaluation/scripts/personamem/pm_metric.py | 4 +- evaluation/scripts/personamem/pm_responses.py | 28 ++-- evaluation/scripts/personamem/pm_search.py | 43 ++++--- evaluation/scripts/run_lme_eval.sh | 2 +- evaluation/scripts/run_locomo_eval.sh | 2 +- evaluation/scripts/run_openai_eval.sh | 2 +- evaluation/scripts/run_pm_eval.sh | 14 +- evaluation/scripts/run_prefeval_eval.sh | 15 ++- evaluation/scripts/utils/client.py | 2 +- 19 files changed, 342 insertions(+), 189 deletions(-) diff --git a/evaluation/README.md b/evaluation/README.md index 47cfeedc0..ba8c7a0cc 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -1,6 +1,6 @@ # Evaluation Memory Framework -This repository provides tools and scripts for evaluating the LoCoMo dataset using various models and APIs. +This repository provides tools and scripts for evaluating the `LoCoMo`, `LongMemEval`, `PrefEval`, `personaMem` dataset using various models and APIs. ## Installation @@ -68,7 +68,8 @@ First prepare the dataset `longmemeval_s` from https://huggingface.co/datasets/x ``` ### PrefEval Evaluation -To evaluate the **Prefeval** dataset using one of the supported memory frameworks — run the following [script](./scripts/run_prefeval_eval.sh): +Downloading benchmark_dataset/filtered_inter_turns.json from https://github.com/amazon-science/PrefEval/blob/main/benchmark_dataset/filtered_inter_turns.json and save it as `./data/prefeval/filtered_inter_turns.json`. +To evaluate the **Prefeval** dataset — run the following [script](./scripts/run_prefeval_eval.sh): ```bash # Edit the configuration in ./scripts/run_prefeval_eval.sh @@ -83,4 +84,4 @@ get `questions_32k.csv` and `shared_contexts_32k.jsonl` from https://huggingface # Specify the model and memory backend you want to use (e.g., mem0, zep, etc.) # If you want to use MIRIX, edit the the configuration in ./scripts/personamem/config.yaml ./scripts/run_pm_eval.sh -``` +``` \ No newline at end of file diff --git a/evaluation/scripts/PrefEval/pref_eval.py b/evaluation/scripts/PrefEval/pref_eval.py index f1966b847..ec079614d 100644 --- a/evaluation/scripts/PrefEval/pref_eval.py +++ b/evaluation/scripts/PrefEval/pref_eval.py @@ -392,9 +392,7 @@ async def main(concurrency_limit: int, input_file: str, output_file: str, output if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate assistant responses from a JSONL file.") - parser.add_argument( - "--input", type=str, required=True, help="Path to the input JSONL file from pref_memos.py." - ) + parser.add_argument("--input", type=str, required=True, help="Path to the input JSONL file.") parser.add_argument( "--concurrency-limit", @@ -402,13 +400,31 @@ async def main(concurrency_limit: int, input_file: str, output_file: str, output default=10, help="The maximum number of concurrent API calls.", ) + + parser.add_argument( + "--lib", + type=str, + choices=[ + "memos-api-online", + "mem0", + "mem0_graph", + "memos-api", + "memobase", + "memu", + "supermemory", + "zep", + ], + default="memos-api", + help="Which library to use (used in 'add' mode).", + ) + args = parser.parse_args() input_path = args.input output_dir = os.path.dirname(input_path) - output_jsonl_path = os.path.join(output_dir, "eval_pref_memos.jsonl") - output_excel_path = os.path.join(output_dir, "eval_pref_memos_summary.xlsx") + output_jsonl_path = os.path.join(output_dir, f"eval_pref_{args.lib}.jsonl") + output_excel_path = os.path.join(output_dir, f"eval_pref_{args.lib}_summary.xlsx") asyncio.run( main( diff --git a/evaluation/scripts/PrefEval/pref_mem0.py b/evaluation/scripts/PrefEval/pref_mem0.py index 4bbdb0fd8..214068567 100644 --- a/evaluation/scripts/PrefEval/pref_mem0.py +++ b/evaluation/scripts/PrefEval/pref_mem0.py @@ -29,7 +29,13 @@ def add_memory_for_line( - line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str + line_data: tuple, + mem_client, + num_irrelevant_turns: int, + lib: str, + version: str, + success_records, + f, ) -> dict: """ Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. @@ -46,13 +52,22 @@ def add_memory_for_line( elif num_irrelevant_turns == 300: conversation = conversation + irre_300 - turns_add = 5 start_time_add = time.monotonic() - if conversation: - for chunk_start in range(0, len(conversation), turns_add * 2): - chunk = conversation[chunk_start : chunk_start + turns_add * 2] - timestamp_add = int(time.time() * 100) - mem_client.add(messages=chunk, user_id=user_id, timestamp=timestamp_add) + + for idx, _ in enumerate(conversation[::2]): + msg_idx = idx * 2 + record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + timestamp_add = int(time.time() * 100) + + if record_id not in success_records: + mem_client.add( + messages=conversation[msg_idx : msg_idx + 2], + user_id=user_id, + timestamp=timestamp_add, + ) + f.write(f"{record_id}\n") + f.flush() + end_time_add = time.monotonic() add_duration = end_time_add - start_time_add @@ -210,6 +225,15 @@ def main(): from utils.client import Mem0Client mem_client = Mem0Client(enable_graph="graph" in args.lib) + os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True) + success_records = set() + record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt" + if os.path.exists(record_file): + print(f"Loading existing success records from {record_file}...") + with open(record_file, encoding="utf-8") as f: + for i in f.readlines(): + success_records.add(i.strip()) + print(f"Loaded {len(success_records)} records.") if args.mode == "add": print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") @@ -218,6 +242,7 @@ def main(): with ( open(args.output, "w", encoding="utf-8") as outfile, concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + open(record_file, "a+", encoding="utf-8") as f, ): futures = [ executor.submit( @@ -227,6 +252,8 @@ def main(): args.add_turn, args.lib, args.version, + success_records, + f, ) for i, line in enumerate(lines) ] diff --git a/evaluation/scripts/PrefEval/pref_memobase.py b/evaluation/scripts/PrefEval/pref_memobase.py index 4f6174d3d..e99b10520 100644 --- a/evaluation/scripts/PrefEval/pref_memobase.py +++ b/evaluation/scripts/PrefEval/pref_memobase.py @@ -12,7 +12,6 @@ from openai import OpenAI from tqdm import tqdm - ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -28,7 +27,13 @@ def add_memory_for_line( - line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str + line_data: tuple, + mem_client, + num_irrelevant_turns: int, + lib: str, + version: str, + success_records, + f, ) -> dict: """ Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. @@ -36,8 +41,6 @@ def add_memory_for_line( i, line = line_data user_id = f"{lib}_user_pref_eval_{i}_{version}" mem_client.delete_user(user_id) - user_id = mem_client.client.add_user({"user_id": user_id}) - print("user_id:", user_id) try: original_data = json.loads(line) conversation = original_data.get("conversation", []) @@ -63,7 +66,14 @@ def add_memory_for_line( "created_at": timestamp_add, } ) - mem_client.add(messages=messages, user_id=user_id) + for idx, _ in enumerate(conversation[::2]): + msg_idx = idx * 2 + record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + + if record_id not in success_records: + mem_client.add(messages=conversation[msg_idx : msg_idx + 2], user_id=user_id) + f.write(f"{record_id}\n") + f.flush() end_time_add = time.monotonic() add_duration = end_time_add - start_time_add @@ -222,6 +232,16 @@ def main(): mem_client = MemobaseClient() + os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True) + success_records = set() + record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt" + if os.path.exists(record_file): + print(f"Loading existing success records from {record_file}...") + with open(record_file, encoding="utf-8") as f: + for i in f.readlines(): + success_records.add(i.strip()) + print(f"Loaded {len(success_records)} records.") + if args.mode == "add": print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") print(f"Adding {args.add_turn} irrelevant turns.") @@ -229,6 +249,7 @@ def main(): with ( open(args.output, "w", encoding="utf-8") as outfile, concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + open(record_file, "a+", encoding="utf-8") as f, ): futures = [ executor.submit( @@ -238,6 +259,8 @@ def main(): args.add_turn, args.lib, args.version, + success_records, + f, ) for i, line in enumerate(lines) ] diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index fc358dc36..0ee88e868 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -12,7 +12,6 @@ from openai import OpenAI from tqdm import tqdm - ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -21,7 +20,6 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) - load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") BASE_URL = os.getenv("OPENAI_BASE_URL") @@ -30,8 +28,8 @@ def add_memory_for_line( - line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str -) -> dict: + line_data, mem_client, num_irrelevant_turns, lib, version, success_records, f +): """ Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. """ @@ -47,15 +45,22 @@ def add_memory_for_line( elif num_irrelevant_turns == 300: conversation = conversation + irre_300 - turns_add = 5 start_time_add = time.monotonic() - if conversation: - if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true": - for chunk_start in range(0, len(conversation), turns_add * 2): - chunk = conversation[chunk_start : chunk_start + turns_add * 2] - mem_client.add(messages=chunk, user_id=user_id, conv_id=None, batch_size=2) - else: - mem_client.add(messages=conversation, user_id=user_id, conv_id=None, batch_size=2) + + for idx, _ in enumerate(conversation[::2]): + msg_idx = idx * 2 + record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + + if record_id not in success_records: + mem_client.add( + messages=conversation[msg_idx : msg_idx + 2], + user_id=user_id, + conv_id=None, + batch_size=2, + ) + f.write(f"{record_id}\n") + f.flush() + end_time_add = time.monotonic() add_duration = end_time_add - start_time_add @@ -68,7 +73,7 @@ def add_memory_for_line( return None -def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict: +def search_memory_for_line(line_data, mem_client, top_k_value): """ Processes a single line of data, searching memory based on the question. """ @@ -120,7 +125,7 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di return None -def generate_response_for_line(line_data: tuple, openai_client: OpenAI, lib: str) -> dict: +def generate_response_for_line(line_data, openai_client, lib): """ Generates a response for a single line of data using pre-fetched memories. """ @@ -195,7 +200,7 @@ def main(): parser.add_argument( "--lib", type=str, - choices=["memos-api", "memos-local"], + choices=["memos-api", "memos-api-online"], default="memos-api", help="Which MemOS library to use (used in 'add' mode).", ) @@ -218,9 +223,22 @@ def main(): print(f"Error: Input file '{args.input}' not found") return - from utils.client import MemosApiClient + from utils.client import MemosApiClient, MemosApiOnlineClient + + if args.lib == "memos-api": + mem_client = MemosApiClient() + elif args.lib == "memos-api-online": + mem_client = MemosApiOnlineClient() - mem_client = MemosApiClient() + os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True) + success_records = set() + record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt" + if os.path.exists(record_file): + print(f"Loading existing success records from {record_file}...") + with open(record_file, encoding="utf-8") as f: + for i in f.readlines(): + success_records.add(i.strip()) + print(f"Loaded {len(success_records)} records.") if args.mode == "add": print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") @@ -229,6 +247,7 @@ def main(): with ( open(args.output, "w", encoding="utf-8") as outfile, concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + open(record_file, "a+", encoding="utf-8") as record_f, ): futures = [ executor.submit( @@ -238,6 +257,8 @@ def main(): args.add_turn, args.lib, args.version, + success_records, + record_f, ) for i, line in enumerate(lines) ] diff --git a/evaluation/scripts/PrefEval/pref_memu.py b/evaluation/scripts/PrefEval/pref_memu.py index 2b9f769a4..4c37db7b7 100644 --- a/evaluation/scripts/PrefEval/pref_memu.py +++ b/evaluation/scripts/PrefEval/pref_memu.py @@ -14,7 +14,6 @@ from openai import OpenAI from tqdm import tqdm - ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -30,7 +29,13 @@ def add_memory_for_line( - line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str + line_data: tuple, + mem_client, + num_irrelevant_turns: int, + lib: str, + version: str, + success_records, + f, ) -> dict: """ Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. @@ -47,19 +52,21 @@ def add_memory_for_line( elif num_irrelevant_turns == 300: conversation = conversation + irre_300 - turns_add = 5 start_time_add = time.monotonic() - if conversation: - if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true": - for chunk_start in range(0, len(conversation), turns_add * 2): - chunk = conversation[chunk_start : chunk_start + turns_add * 2] - mem_client.add( - messages=chunk, user_id=user_id, iso_date=datetime.now().isoformat() - ) - else: + + for idx, _ in enumerate(conversation[::2]): + msg_idx = idx * 2 + record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + + if record_id not in success_records: mem_client.add( - messages=conversation, user_id=user_id, iso_date=datetime.now().isoformat() + messages=conversation[msg_idx : msg_idx + 2], + user_id=user_id, + iso_date=datetime.now().isoformat(), ) + f.write(f"{record_id}\n") + f.flush() + end_time_add = time.monotonic() add_duration = end_time_add - start_time_add @@ -219,6 +226,16 @@ def main(): mem_client = MemuClient() + os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True) + success_records = set() + record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt" + if os.path.exists(record_file): + print(f"Loading existing success records from {record_file}...") + with open(record_file, encoding="utf-8") as f: + for i in f.readlines(): + success_records.add(i.strip()) + print(f"Loaded {len(success_records)} records.") + if args.mode == "add": print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") print(f"Adding {args.add_turn} irrelevant turns.") @@ -226,6 +243,7 @@ def main(): with ( open(args.output, "w", encoding="utf-8") as outfile, concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + open(record_file, "a+", encoding="utf-8") as f, ): futures = [ executor.submit( @@ -235,6 +253,8 @@ def main(): args.add_turn, args.lib, args.version, + success_records, + f, ) for i, line in enumerate(lines) ] diff --git a/evaluation/scripts/PrefEval/pref_supermemory.py b/evaluation/scripts/PrefEval/pref_supermemory.py index 88a64038b..68963e2af 100644 --- a/evaluation/scripts/PrefEval/pref_supermemory.py +++ b/evaluation/scripts/PrefEval/pref_supermemory.py @@ -12,7 +12,6 @@ from openai import OpenAI from tqdm import tqdm - ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -28,7 +27,13 @@ def add_memory_for_line( - line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str + line_data: tuple, + mem_client, + num_irrelevant_turns: int, + lib: str, + version: str, + success_records, + f, ) -> dict: """ Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. @@ -45,15 +50,20 @@ def add_memory_for_line( elif num_irrelevant_turns == 300: conversation = conversation + irre_300 - turns_add = 5 start_time_add = time.monotonic() - if conversation: - if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true": - for chunk_start in range(0, len(conversation), turns_add * 2): - chunk = conversation[chunk_start : chunk_start + turns_add * 2] - mem_client.add(messages=chunk, user_id=user_id) - else: - mem_client.add(messages=conversation, user_id=user_id) + + for idx, _ in enumerate(conversation[::2]): + msg_idx = idx * 2 + record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + + if record_id not in success_records: + mem_client.add( + messages=conversation[msg_idx : msg_idx + 2], + user_id=user_id, + ) + f.write(f"{record_id}\n") + f.flush() + end_time_add = time.monotonic() add_duration = end_time_add - start_time_add @@ -90,9 +100,7 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di start_time_search = time.monotonic() relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) search_memories_duration = time.monotonic() - start_time_search - memories_str = "\n".join( - f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] - ) + memories_str = relevant_memories memory_tokens_used = len(tokenizer.encode(memories_str)) @@ -250,6 +258,16 @@ def search(self, query, user_id, top_k): mem_client = SupermemoryClient() + os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True) + success_records = set() + record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt" + if os.path.exists(record_file): + print(f"Loading existing success records from {record_file}...") + with open(record_file, encoding="utf-8") as f: + for i in f.readlines(): + success_records.add(i.strip()) + print(f"Loaded {len(success_records)} records.") + if args.mode == "add": print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") print(f"Adding {args.add_turn} irrelevant turns.") @@ -257,6 +275,7 @@ def search(self, query, user_id, top_k): with ( open(args.output, "w", encoding="utf-8") as outfile, concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + open(record_file, "a+", encoding="utf-8") as f, ): futures = [ executor.submit( @@ -266,6 +285,8 @@ def search(self, query, user_id, top_k): args.add_turn, args.lib, args.version, + success_records, + f, ) for i, line in enumerate(lines) ] diff --git a/evaluation/scripts/PrefEval/pref_zep.py b/evaluation/scripts/PrefEval/pref_zep.py index 91aef1492..be98c6ba9 100644 --- a/evaluation/scripts/PrefEval/pref_zep.py +++ b/evaluation/scripts/PrefEval/pref_zep.py @@ -14,7 +14,6 @@ from openai import OpenAI from tqdm import tqdm - ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -30,7 +29,13 @@ def add_memory_for_line( - line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str + line_data: tuple, + mem_client, + num_irrelevant_turns: int, + lib: str, + version: str, + success_records, + f, ) -> dict: """ Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id. @@ -47,25 +52,22 @@ def add_memory_for_line( elif num_irrelevant_turns == 300: conversation = conversation + irre_300 - turns_add = 5 start_time_add = time.monotonic() - if conversation: - if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true": - for chunk_start in range(0, len(conversation), turns_add * 2): - chunk = conversation[chunk_start : chunk_start + turns_add * 2] - mem_client.add( - messages=chunk, - user_id=user_id, - conv_id=None, - timestamp=datetime.now().isoformat(), - ) - else: + + for idx, _ in enumerate(conversation[::2]): + msg_idx = idx * 2 + record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + + if record_id not in success_records: mem_client.add( - messages=conversation, + messages=conversation[msg_idx : msg_idx + 2], user_id=user_id, conv_id=None, timestamp=datetime.now().isoformat(), ) + f.write(f"{record_id}\n") + f.flush() + end_time_add = time.monotonic() add_duration = end_time_add - start_time_add @@ -225,6 +227,16 @@ def main(): mem_client = ZepClient() + os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True) + success_records = set() + record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt" + if os.path.exists(record_file): + print(f"Loading existing success records from {record_file}...") + with open(record_file, encoding="utf-8") as f: + for i in f.readlines(): + success_records.add(i.strip()) + print(f"Loaded {len(success_records)} records.") + if args.mode == "add": print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") print(f"Adding {args.add_turn} irrelevant turns.") @@ -232,6 +244,7 @@ def main(): with ( open(args.output, "w", encoding="utf-8") as outfile, concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + open(record_file, "a+", encoding="utf-8") as f, ): futures = [ executor.submit( @@ -241,6 +254,8 @@ def main(): args.add_turn, args.lib, args.version, + success_records, + f, ) for i, line in enumerate(lines) ] diff --git a/evaluation/scripts/PrefEval/prefeval_preprocess.py b/evaluation/scripts/PrefEval/prefeval_preprocess.py index 9ace9dec9..b8ccf3f34 100644 --- a/evaluation/scripts/PrefEval/prefeval_preprocess.py +++ b/evaluation/scripts/PrefEval/prefeval_preprocess.py @@ -94,6 +94,7 @@ def process_jsonl_file(input_filepath, output_filepath): def main(): huggingface_dataset_name = "siyanzhao/prefeval_implicit_persona" output_directory = "./data/prefeval" + os.makedirs(output_directory, exist_ok=True) input_file_path = os.path.join(output_directory, "train.jsonl") processed_file_path = os.path.join(output_directory, "pref_processed.jsonl") diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index cab0fbeb5..fdbf43528 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -10,7 +10,6 @@ from tqdm import tqdm - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -25,28 +24,22 @@ def ingest_session(session, user_id, session_id, frame, client): f"[{frame}] 📝 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}..." ) timestamp_add = int(time.time() * 100) - client.add(messages=messages, user_id=user_id, timestamp=timestamp_add) + client.add(messages=messages, user_id=user_id, timestamp=timestamp_add, batch_size=10) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") elif frame == "memos-api": - if os.getenv("PRE_SPLIT_CHUNK") == "true": - for i in range(0, len(session), 10): - messages = session[i : i + 10] - client.add(messages=messages, user_id=user_id, conv_id=session_id, batch_size=2) - print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") - else: - client.add(messages=session, user_id=user_id, conv_id=session_id, batch_size=2) - print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(session)} messages") + client.add(messages=session, user_id=user_id, conv_id=session_id, batch_size=10) + print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(session)} messages") elif frame == "memobase": for _idx, msg in enumerate(session): if msg["role"] != "system": messages.append( { "role": msg["role"], - "content": msg["content"][:8000], + "content": msg["content"], "created_at": datetime.now().isoformat(), } ) - client.add(messages, user_id) + client.add(messages, user_id, batch_size=10) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") elif frame == "supermemory": for _idx, msg in enumerate(session): @@ -62,6 +55,8 @@ def ingest_session(session, user_id, session_id, frame, client): for _idx, msg in enumerate(session): messages.append({"role": msg["role"], "content": msg["content"]}) client.add(messages, user_id, datetime.now().astimezone().isoformat()) + elif frame == "memos-api-online": + client.add(messages, user_id, session_id, batch_size=10) def build_jsonl_index(jsonl_path): @@ -125,7 +120,11 @@ def count_csv_rows(csv_path): return sum(1 for _ in f) - 1 -def ingest_conv(row_data, context, version, conv_idx, frame): +def ingest_conv(row_data, context, version, conv_idx, frame, success_records, f): + if str(conv_idx) in success_records: + print(f"✅ Conversation {conv_idx} already ingested, skipping...") + return conv_idx + end_index_in_shared_context = row_data["end_index_in_shared_context"] context = context[: int(end_index_in_shared_context)] user_id = f"pm_exper_user_{conv_idx}_{version}" @@ -150,8 +149,6 @@ def ingest_conv(row_data, context, version, conv_idx, frame): print("🔌 Using Mem0 client for ingestion...") client.client.delete_all(user_id=user_id) print(f"🗑️ Deleted existing memories for user {user_id}...") - - print(f"🗑️ Deleted existing memories for user {user_id}...") elif frame == "memos-api": from utils.client import MemosApiClient @@ -160,8 +157,6 @@ def ingest_conv(row_data, context, version, conv_idx, frame): from utils.client import MemobaseClient client = MemobaseClient() - print("🔌 Using Memobase client for ingestion...") - client.delte_user(user_id) elif frame == "supermemory": from utils.client import SupermemoryClient @@ -170,15 +165,33 @@ def ingest_conv(row_data, context, version, conv_idx, frame): from utils.client import MemuClient client = MemuClient() + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + + client = MemosApiOnlineClient() + + try: + ingest_session(session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client) + print(f"✅ Ingestion of conversation {conv_idx} completed") + print("=" * 80) + + f.write(f"{conv_idx}\n") + f.flush() + return conv_idx + except Exception as e: + print(f"❌ Error ingesting conversation {conv_idx}: {e}") + raise - ingest_session( - session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client - ) - print(f"✅ Ingestion of conversation {conv_idx} completed") - print("=" * 80) +def main(frame, version, num_workers=2, clear=False): + os.makedirs(f"results/pm/{frame}-{version}/", exist_ok=True) + record_file = f"results/pm/{frame}-{version}/success_records.txt" + + if clear: + if os.path.exists(record_file): + os.remove(record_file) + print("🧹 Cleared progress records") -def main(frame, version, num_workers=2): print("\n" + "=" * 80) print(f"🚀 PERSONAMEM INGESTION - {frame.upper()} v{version}".center(80)) print("=" * 80) @@ -190,31 +203,48 @@ def main(frame, version, num_workers=2): print(f"📚 Loaded PersonaMem dataset from {question_csv_path} and {context_jsonl_path}") print("-" * 80) - start_time = datetime.now() + success_records = set() + if os.path.exists(record_file): + with open(record_file, "r") as f: + success_records = set(line.strip() for line in f) + print(f"📊 Found {len(success_records)} completed conversations, {total_rows - len(success_records)} remaining") + start_time = datetime.now() all_data = list(load_rows_with_context(question_csv_path, context_jsonl_path)) - with ThreadPoolExecutor(max_workers=num_workers) as executor: - future_to_idx = { - executor.submit( + pending_data = [(idx, row_data, context) for idx, (row_data, context) in enumerate(all_data) + if str(idx) not in success_records] + + if not pending_data: + print("✅ All conversations have been processed!") + return + + print(f"🔄 Processing {len(pending_data)} conversations...") + + with ThreadPoolExecutor(max_workers=num_workers) as executor, open(record_file, "a") as f: + futures = [] + for idx, row_data, context in pending_data: + future = executor.submit( ingest_conv, row_data=row_data, context=context, version=version, conv_idx=idx, frame=frame, - ): idx - for idx, (row_data, context) in enumerate(all_data) - } + success_records=success_records, + f=f + ) + futures.append(future) + completed_count = 0 for future in tqdm( - as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" + as_completed(futures), total=len(futures), desc="Processing conversations" ): - idx = future_to_idx[future] try: - future.result() + result = future.result() + completed_count += 1 except Exception as exc: - print(f"\n❌ Conversation {idx} generated an exception: {exc}") + print(f"\n❌ Conversation generated an exception: {exc}") end_time = datetime.now() elapsed_time = end_time - start_time @@ -225,23 +255,19 @@ def main(frame, version, num_workers=2): print("=" * 80) print(f"⏱️ Total time taken to ingest {total_rows} rows: {elapsed_time_str}") print(f"🔄 Framework: {frame} | Version: {version} | Workers: {num_workers}") + print(f"📈 Processed: {len(success_records) + completed_count}/{total_rows} conversations") print("=" * 80 + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Ingestion Script") - parser.add_argument( - "--lib", - type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory", "zep"], - default="memos-api", - ) - parser.add_argument( - "--version", type=str, default="0925-1", help="Version of the evaluation framework." - ) - parser.add_argument( - "--workers", type=int, default=3, help="Number of parallel workers for processing users." - ) + parser.add_argument("--lib", type=str, + choices=["memos-api-online", "mem0", "mem0_graph", "memos-api", "memobase", "memu", + "supermemory", "zep"], + default='memos-api') + parser.add_argument("--version", type=str, default="default", help="Version of the evaluation framework.") + parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.") + parser.add_argument("--clear", action="store_true", help="Clear progress and start fresh") args = parser.parse_args() - main(frame=args.lib, version=args.version, num_workers=args.workers) + main(frame=args.lib, version=args.version, num_workers=args.workers, clear=args.clear) \ No newline at end of file diff --git a/evaluation/scripts/personamem/pm_metric.py b/evaluation/scripts/personamem/pm_metric.py index e88c538d4..b9d10a576 100644 --- a/evaluation/scripts/personamem/pm_metric.py +++ b/evaluation/scripts/personamem/pm_metric.py @@ -353,12 +353,12 @@ def print_summary(results): parser.add_argument( "--lib", type=str, - choices=["zep", "mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], + choices=["zep", "mem0", "mem0_graph", "memos-api", "memos-api-online", "memobase", "memu", "supermemory"], required=True, help="Memory library to evaluate", default="memos-api", ) - parser.add_argument("--version", type=str, default="0925", help="Evaluation framework version") + parser.add_argument("--version", type=str, default="default", help="Evaluation framework version") args = parser.parse_args() lib, version = args.lib, args.version diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index ff561f8d8..2e41b4140 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -10,7 +10,6 @@ from openai import OpenAI from tqdm import tqdm - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re @@ -154,9 +153,9 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id[future] = user_id for future in tqdm( - as_completed(future_to_user_id), - total=len(future_to_user_id), - desc="📝 Generating responses", + as_completed(future_to_user_id), + total=len(future_to_user_id), + desc="📝 Generating responses", ): user_id = future_to_user_id[future] try: @@ -185,21 +184,12 @@ def main(frame, version, num_runs=3, num_workers=4): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Response Generation Script") - parser.add_argument( - "--lib", - type=str, - choices=["zep", "mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], - default="memos-api", - ) - parser.add_argument( - "--version", type=str, default="0925", help="Version of the evaluation framework." - ) - parser.add_argument( - "--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation." - ) - parser.add_argument( - "--workers", type=int, default=3, help="Number of worker threads to use for processing." - ) + parser.add_argument("--lib", type=str, + choices=["memos-api-online", "zep", "mem0", "mem0_graph", "memos-api", "memobase", "memu", + "supermemory"], default='memos-api') + parser.add_argument("--version", type=str, default="default", help="Version of the evaluation framework.") + parser.add_argument("--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation.") + parser.add_argument("--workers", type=int, default=10, help="Number of worker threads to use for processing.") args = parser.parse_args() main(frame=args.lib, version=args.version, num_runs=args.num_runs, num_workers=args.workers) diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index 441474c7c..edec6b008 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -3,12 +3,10 @@ import json import os import sys - from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from time import time - from tqdm import tqdm @@ -83,8 +81,8 @@ def memos_search(client, user_id, query, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) search_memories = ( - "\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"]) - + f"\n{results['pref_string']}" + "\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"]) + + f"\n{results['pref_string']}" ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) @@ -226,6 +224,17 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): client = MemuClient() print("🔌 Using memu client for search...") context, duration_ms = memu_search(client, question, user_id, top_k) + elif frame == "memobase": + from utils.client import MemobaseClient + + client = MemobaseClient() + print("🔌 Using Memobase client for search...") + context, duration_ms = memobase_search(client, question, user_id, top_k) + elif frame == "memos-api-online": + from utils.client import MemosApiOnlineClient + client = MemosApiOnlineClient() + print("🔌 Using memos-api-online client for search...") + context, duration_ms = memos_search(client, question, user_id, top_k) search_results[user_id].append( { @@ -244,7 +253,7 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): os.makedirs(f"results/pm/{frame}-{version}/tmp", exist_ok=True) with open( - f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{conv_idx}.json", "w" + f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{conv_idx}.json", "w" ) as f: json.dump(search_results, f, indent=4) print(f"💾 Search results for conversation {conv_idx} saved...") @@ -295,7 +304,7 @@ def main(frame, version, top_k=20, num_workers=2): } for future in tqdm( - as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" + as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" ): idx = future_to_idx[future] try: @@ -324,21 +333,13 @@ def main(frame, version, top_k=20, num_workers=2): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Search Script") - parser.add_argument( - "--lib", - type=str, - choices=["mem0", "mem0_graph", "memos-api", "memobase", "memu", "supermemory"], - default="memos-api", - ) - parser.add_argument( - "--version", type=str, default="default", help="Version of the evaluation framework." - ) - parser.add_argument( - "--top_k", type=int, default=20, help="Number of top results to retrieve from the search." - ) - parser.add_argument( - "--workers", type=int, default=3, help="Number of parallel workers for processing users." - ) + parser.add_argument("--lib", type=str, + choices=["memos-api-online", "mem0", "mem0_graph", "memos-api", "memobase", "memu", + "supermemory"], + default='memos-api') + parser.add_argument("--version", type=str, default="default", help="Version of the evaluation framework.") + parser.add_argument("--top_k", type=int, default=20, help="Number of top results to retrieve from the search.") + parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.") args = parser.parse_args() diff --git a/evaluation/scripts/run_lme_eval.sh b/evaluation/scripts/run_lme_eval.sh index 08e431312..8fa8d6c7e 100755 --- a/evaluation/scripts/run_lme_eval.sh +++ b/evaluation/scripts/run_lme_eval.sh @@ -2,7 +2,7 @@ # Common parameters for all scripts LIB="memos-api" -VERSION="1020" +VERSION="default" WORKERS=10 TOPK=20 diff --git a/evaluation/scripts/run_locomo_eval.sh b/evaluation/scripts/run_locomo_eval.sh index d9c13a1ac..37569956f 100755 --- a/evaluation/scripts/run_locomo_eval.sh +++ b/evaluation/scripts/run_locomo_eval.sh @@ -2,7 +2,7 @@ # Common parameters for all scripts LIB="memos-api" -VERSION="072001" +VERSION="default" WORKERS=10 TOPK=20 diff --git a/evaluation/scripts/run_openai_eval.sh b/evaluation/scripts/run_openai_eval.sh index 27bb712af..e07f113e5 100755 --- a/evaluation/scripts/run_openai_eval.sh +++ b/evaluation/scripts/run_openai_eval.sh @@ -2,7 +2,7 @@ # Common parameters for all scripts LIB="openai" -VERSION="063001" +VERSION="default" WORKERS=10 NUM_RUNS=3 diff --git a/evaluation/scripts/run_pm_eval.sh b/evaluation/scripts/run_pm_eval.sh index a46440bfc..39d9e72ca 100755 --- a/evaluation/scripts/run_pm_eval.sh +++ b/evaluation/scripts/run_pm_eval.sh @@ -2,21 +2,11 @@ # Common parameters for all scripts LIB="memos-api" -VERSION="072202" +VERSION="default" WORKERS=10 TOPK=20 -if [ "$LIB" = "mirix" ]; then - echo "Running pm_mirix.py 100 times..." - for i in {1..100}; do - echo "Iteration $i/100" - CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_mirix.py --version $VERSION --workers 1 - if [ $? -ne 0 ]; then - echo "Error running xx.py on iteration $i" - exit 1 - fi - done -elif ["$LIB" = "zep"]; then +if ["$LIB" = "zep"]; then CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_ingestion_zep.py --version $VERSION --workers $WORKERS CUDA_VISIBLE_DEVICES=0 python scripts/personamem/pm_search_zep.py --version $VERSION --top_k $TOPK --workers $WORKERS echo "Running pm_responses.py..." diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh index a79cefcc2..129382ebf 100755 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -6,13 +6,13 @@ # Number of workers for parallel processing. # This variable controls both pref_memos.py (--max-workers) # and pref_eval.py (--concurrency-limit). -WORKERS=10 +WORKERS=20 # Parameters for pref_memos.py -TOP_K=6 -ADD_TURN=0 # Options: 0, 10, or 300 -LIB="memos-api" -VERSION="1022-0" +TOP_K=10 +ADD_TURN=10 # Options: 0, 10, or 300 +LIB="memos-api" # Options: memos-api, memos-api-online, mem0, mem0-graph, memobase, supermemory, memu, zep +VERSION="default" # --- File Paths --- # You may need to adjust these paths based on your project structure. @@ -133,7 +133,8 @@ echo "" echo "Running pref_eval.py..." python scripts/PrefEval/pref_eval.py \ --input $RESPONSE_FILE \ - --concurrency-limit $WORKERS + --concurrency-limit $WORKERS \ + --lib $LIB if [ $? -ne 0 ]; then echo "Error: Evaluation script failed." @@ -142,4 +143,4 @@ fi echo "" echo "--- PrefEval Pipeline completed successfully! ---" -echo "Final results are in $RESPONSE_FILE" +echo "Final results are in $RESPONSE_FILE" \ No newline at end of file diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 4e7cfdbca..3c34c49d0 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -245,7 +245,7 @@ def search(self, query, user_id, top_k): res = json.loads(response.text)["data"]["memory_detail_list"] for i in res: i.update({"memory": i.pop("memory_value")}) - return {"text_mem": [{"memories": res}], "pref_mem": ""} + return {"text_mem": [{"memories": res}], "pref_str": ""} except Exception as e: if attempt < max_retries - 1: time.sleep(2**attempt) From 81c7ad9f2ca323c9e032c17fd9c43f31c82b4c15 Mon Sep 17 00:00:00 2001 From: Wustzdy <67457465+wustzdy@users.noreply.github.com> Date: Thu, 30 Oct 2025 18:00:51 +0800 Subject: [PATCH 31/64] add polardb pool (#420) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * fix:nacos * feat: fix config Exception * feat: format config * feat: format config * fix: * fix: * fix: * fix: * add polardb pool * fix: --------- Co-authored-by: ccl <13282138256@163.com> Co-authored-by: lijicode <34564964+lijicode@users.noreply.github.com> Co-authored-by: liji <532311301@qq.com> --- src/memos/graph_dbs/polardb.py | 424 +++++++++++++++++++++------------ 1 file changed, 278 insertions(+), 146 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 5d50cf68f..f24f1072c 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -15,9 +15,6 @@ logger = get_logger(__name__) -# Graph database configuration -GRAPH_NAME = "test_memos_graph" - def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: node_id = item["id"] @@ -119,6 +116,7 @@ def __init__(self, config: PolarDBGraphDBConfig): but it will be removed automatically before returning to external consumers. """ import psycopg2 + import psycopg2.pool self.config = config @@ -137,12 +135,26 @@ def __init__(self, config: PolarDBGraphDBConfig): port = config.port user = config.user password = config.password - + """ # Create connection self.connection = psycopg2.connect( - host=host, port=port, user=user, password=password, dbname=self.db_name + host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 ) - self.connection.autocommit = True + """ + + # Create connection pool + self.connection_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=5, + maxconn=2000, + host=host, + port=port, + user=user, + password=password, + dbname=self.db_name, + ) + + # Keep a reference to the pool for cleanup + self._pool_closed = False """ # Handle auto_create @@ -167,6 +179,17 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) + def _get_connection(self): + """Get a connection from the pool.""" + if self._pool_closed: + raise RuntimeError("Connection pool has been closed") + return self.connection_pool.getconn() + + def _return_connection(self, connection): + """Return a connection to the pool.""" + if not self._pool_closed and connection: + self.connection_pool.putconn(connection) + def _ensure_database_exists(self): """Create database if it doesn't exist.""" try: @@ -180,8 +203,10 @@ def _ensure_database_exists(self): @timed def _create_graph(self): """Create PostgreSQL schema and table for graph storage.""" + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Create schema if it doesn't exist cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') logger.info(f"Schema '{self.db_name}_graph' ensured.") @@ -229,6 +254,8 @@ def _create_graph(self): except Exception as e: logger.error(f"Failed to create graph schema: {e}") raise e + finally: + self._return_connection(conn) def create_index( self, @@ -241,8 +268,10 @@ def create_index( Create indexes for embedding and other fields. Note: This creates PostgreSQL indexes on the underlying tables. """ + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Create indexes on the underlying PostgreSQL tables # Apache AGE stores data in regular PostgreSQL tables cursor.execute(f""" @@ -262,6 +291,8 @@ def create_index( logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") + finally: + self._return_connection(conn) def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: """Get count of memory nodes by type.""" @@ -274,14 +305,18 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params = [f'"{memory_type}"', f'"{user_name}"'] + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return result[0] if result else 0 except Exception as e: logger.error(f"[get_memory_count] Failed: {e}") return -1 + finally: + self._return_connection(conn) @timed def node_not_exist(self, scope: str, user_name: str | None = None) -> int: @@ -296,14 +331,18 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: query += "\nLIMIT 1" params = [f'"{scope}"', f'"{user_name}"'] + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return 1 if result else 0 except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def remove_oldest_memory( @@ -329,9 +368,9 @@ def remove_oldest_memory( OFFSET %s """ select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest] - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Execute query to get IDs to delete cursor.execute(select_query, select_params) ids_to_delete = [row[0] for row in cursor.fetchall()] @@ -357,6 +396,8 @@ def remove_oldest_memory( except Exception as e: logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: @@ -414,12 +455,16 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def delete_node(self, id: str, user_name: str | None = None) -> None: @@ -440,18 +485,24 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def create_extension(self): extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Ensure in the correct database context cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] @@ -474,11 +525,15 @@ def create_extension(self): except Exception as e: logger.warning(f"Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def create_graph(self): + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(f""" SELECT COUNT(*) FROM ag_catalog.ag_graph WHERE name = '{self.db_name}_graph'; @@ -493,6 +548,8 @@ def create_graph(self): except Exception as e: logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def create_edge(self): @@ -501,9 +558,11 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: + print(f"🪶 Creating elabel: {label_name}") + conn = self._get_connection() logger.info(f"Creating elabel: {label_name}") try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") logger.info(f"Successfully created elabel: {label_name}") except Exception as e: @@ -512,6 +571,8 @@ def create_edge(self): else: logger.warning(f"Failed to create label {label_name}: {e}") logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def add_edge( @@ -543,13 +604,16 @@ def add_edge( ); """ + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") except Exception as e: logger.error(f"Failed to insert edge: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def delete_edge(self, source_id: str, target_id: str, type: str) -> None: @@ -564,10 +628,13 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: DELETE FROM "{self.db_name}_graph"."Edges" WHERE source_id = %s AND target_id = %s AND edge_type = %s """ - - with self.connection.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type)) - logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + finally: + self._return_connection(conn) @timed def edge_exists_old( @@ -622,11 +689,14 @@ def edge_exists_old( WHERE {where_clause} LIMIT 1 """ - - with self.connection.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result is not None + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result is not None + finally: + self._return_connection(conn) @timed def edge_exists( @@ -674,10 +744,14 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - with self.connection.cursor() as cursor: - cursor.execute(query) - result = cursor.fetchone() - return result is not None and result[0] is not None + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None + finally: + self._return_connection(conn) @timed def get_node( @@ -720,16 +794,17 @@ def format_param_value(value: str) -> str: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(format_param_value(user_name)) + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() if result: if include_embedding: - node_id, properties_json, embedding_json = result + properties_json, embedding_json = result else: - node_id, properties_json = result + properties_json = result embedding_json = None # Parse properties from JSONB if it's a string @@ -755,13 +830,19 @@ def format_param_value(value: str) -> str: logger.warning(f"Failed to parse embedding for node {id}") return self._parse_node( - {"id": id, "memory": properties.get("memory", ""), **properties} + { + "id": id, + "memory": json.loads(properties[1]).get("memory", ""), + **json.loads(properties[1]), + } ) return None except Exception as e: logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) return None + finally: + self._return_connection(conn) @timed def get_nodes( @@ -803,43 +884,47 @@ def get_nodes( query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') - with self.connection.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() - nodes = [] - for row in results: - node_id, properties_json, embedding_json = row - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {node_id}") - properties = {} - else: - properties = properties_json if properties_json else {} + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} - # Parse embedding from JSONB if it exists - if embedding_json is not None: - try: - # remove embedding - """ - embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json - # properties["embedding"] = embedding - """ - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - nodes.append( - self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } + # Parse embedding from JSONB if it exists + if embedding_json is not None: + try: + # remove embedding + """ + embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json + # properties["embedding"] = embedding + """ + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) ) - ) - return nodes + return nodes + finally: + self._return_connection(conn) @timed def get_edges_old( @@ -1057,8 +1142,9 @@ def get_children_with_embeddings( WHERE t.cid::graphid = m.id; """ + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1113,6 +1199,8 @@ def get_children_with_embeddings( except Exception as e: logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: """Get the path of nodes from source to target within a limited depth.""" @@ -1174,9 +1262,9 @@ def get_subgraph( r) $$ ) as (centers agtype, neighbors agtype, rels agtype); """ - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() @@ -1250,6 +1338,8 @@ def get_subgraph( except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) return {"core_node": None, "neighbors": [], "edges": []} + finally: + self._return_connection(conn) def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" @@ -1333,24 +1423,28 @@ def search_by_embedding( """ params = [vector] - with self.connection.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - """ - polarId = row[0] # id - properties = row[1] # properties - # embedding = row[3] # embedding - """ - oldid = row[3] # old_id - score = row[4] # scope - id_val = str(oldid) - score_val = float(score) - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) - return output[:top_k] + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + """ + polarId = row[0] # id + properties = row[1] # properties + # embedding = row[3] # embedding + """ + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid) + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + return output[:top_k] + finally: + self._return_connection(conn) @timed def get_by_metadata( @@ -1439,13 +1533,16 @@ def get_by_metadata( """ ids = [] + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() ids = [str(item[0]).strip('"') for item in results] except Exception as e: logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + finally: + self._return_connection(conn) return ids @@ -1596,8 +1693,9 @@ def get_grouped_counts( GROUP BY {", ".join(group_by_fields)} """ + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Handle parameterized query if params and isinstance(params, list): cursor.execute(query, params) @@ -1622,6 +1720,8 @@ def get_grouped_counts( except Exception as e: logger.error(f"Failed to get grouped counts: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def deduplicate_nodes(self) -> None: """Deduplicate redundant or semantically similar nodes.""" @@ -1653,10 +1753,13 @@ def clear(self, user_name: str | None = None) -> None: DETACH DELETE n $$) AS (result agtype) """ - - with self.connection.cursor() as cursor: - cursor.execute(query) - logger.info("Cleared all nodes from database.") + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") + finally: + self._return_connection(conn) except Exception as e: logger.error(f"[ERROR] Failed to clear database: {e}") @@ -1678,7 +1781,7 @@ def export_graph( } """ user_name = user_name if user_name else self._get_config_value("user_name") - + conn = self._get_connection() try: # Export nodes if include_embedding: @@ -1694,16 +1797,16 @@ def export_graph( WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype """ - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(node_query) node_results = cursor.fetchall() nodes = [] for row in node_results: if include_embedding: - node_id, properties_json, embedding_json = row + properties_json, embedding_json = row else: - node_id, properties_json = row + properties_json = row embedding_json = None # Parse properties from JSONB if it's a string @@ -1733,7 +1836,10 @@ def export_graph( except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e + finally: + self._return_connection(conn) + conn = self._get_connection() try: # Export edges using cypher query edge_query = f""" @@ -1744,7 +1850,7 @@ def export_graph( $$) AS (source agtype, target agtype, edge agtype) """ - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(edge_query) edge_results = cursor.fetchall() edges = [] @@ -1806,6 +1912,9 @@ def export_graph( except Exception as e: logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e + finally: + self._return_connection(conn) + return {"nodes": nodes, "edges": edges} @timed @@ -1820,9 +1929,12 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: RETURN count(n) $$) AS (count agtype) """ - - result = self.execute_query(query) - return int(result.one_or_none()["count"].value) + conn = self._get_connection() + try: + result = self.execute_query(query, conn) + return int(result.one_or_none()["count"].value) + finally: + self._return_connection(conn) @timed def get_all_memory_items( @@ -1863,8 +1975,9 @@ def get_all_memory_items( """ nodes = [] node_ids = set() + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -1886,6 +1999,8 @@ def get_all_memory_items( except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) return nodes else: @@ -1899,8 +2014,9 @@ def get_all_memory_items( """ nodes = [] + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -1917,6 +2033,8 @@ def get_all_memory_items( except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) return nodes @@ -2119,8 +2237,9 @@ def get_structure_optimization_candidates( candidates = [] node_ids = set() + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() logger.info(f"Found {len(results)} structure optimization candidates") @@ -2197,6 +2316,8 @@ def get_structure_optimization_candidates( except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) + finally: + self._return_connection(conn) return candidates @@ -2319,44 +2440,48 @@ def add_node( elif len(embedding_vector) == 768: embedding_column = "embedding_768" - with self.connection.cursor() as cursor: - # Delete existing record first (if any) - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(delete_query, (id,)) - # - get_graph_id_query = f""" - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(get_graph_id_query, (id,)) - graph_id = cursor.fetchone()[0] - properties["graph_id"] = str(graph_id) - - # Then insert new record - if embedding_vector: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s, - %s - ) + conn = self._get_connection() + try: + with conn.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) """ - cursor.execute( - insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) - ) - else: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s + cursor.execute(delete_query, (id,)) + # + get_graph_id_query = f""" + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(get_graph_id_query, (id,)) + graph_id = cursor.fetchone()[0] + properties["graph_id"] = str(graph_id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) ) - """ - cursor.execute(insert_query, (id, json.dumps(properties))) - logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") + finally: + self._return_connection(conn) def _build_node_from_agtype(self, node_agtype, embedding=None): """ @@ -2463,8 +2588,9 @@ def get_neighbors_by_tag( logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -2513,6 +2639,8 @@ def get_neighbors_by_tag( except Exception as e: logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def get_neighbors_by_tag_ccl( self, @@ -2758,9 +2886,9 @@ def get_edges( RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -2805,6 +2933,8 @@ def get_edges( except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def _convert_graph_edges(self, core_node: dict) -> dict: import copy @@ -2812,6 +2942,8 @@ def _convert_graph_edges(self, core_node: dict) -> dict: data = copy.deepcopy(core_node) id_map = {} core_node = data.get("core_node", {}) + if not core_node: + return core_node core_meta = core_node.get("metadata", {}) if "graph_id" in core_meta and "id" in core_node: id_map[core_meta["graph_id"]] = core_node["id"] From 25c7642d2331405ea43eb37c0f6448420d21bc5b Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 30 Oct 2025 18:40:30 +0800 Subject: [PATCH 32/64] Feat/pref optimize update (#422) * add hybrid search and fine extractor * add dialog and modify spliter chunk * optmize the update and retriever code * modify pref field * add pref mem update srategy * add pref mem update srategy * fix bug in pre_commit * modify pref filed * fix bug * fix pre_commit --------- Co-authored-by: yuan.wang --- evaluation/scripts/PrefEval/pref_memos.py | 2 +- evaluation/scripts/locomo/locomo_search.py | 4 ++-- evaluation/scripts/longmemeval/lme_search.py | 2 +- evaluation/scripts/personamem/pm_search.py | 4 ++-- evaluation/scripts/utils/client.py | 10 ++++++---- src/memos/api/config.py | 2 +- src/memos/api/product_models.py | 3 ++- src/memos/api/routers/server_router.py | 16 ++++++++++------ src/memos/templates/instruction_completion.py | 16 +++++++++++----- 9 files changed, 36 insertions(+), 23 deletions(-) diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 0ee88e868..4a21e3af0 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -103,7 +103,7 @@ def search_memory_for_line(line_data, mem_client, top_k_value): f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] ) - + f"\n{relevant_memories['pref_string']}" + + f"\n{relevant_memories.get('pref_string', '')}" ) memory_tokens_used = len(tokenizer.encode(memories_str)) diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index 0b610d574..24f6149ec 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -107,11 +107,11 @@ def memos_api_search( speaker_a_context = ( "\n".join([i["memory"] for i in search_a_results["text_mem"][0]["memories"]]) - + f"\n{search_a_results['pref_string']}" + + f"\n{search_a_results.get('pref_string', '')}" ) speaker_b_context = ( "\n".join([i["memory"] for i in search_b_results["text_mem"][0]["memories"]]) - + f"\n{search_b_results['pref_string']}" + + f"\n{search_b_results.get('pref_string', '')}" ) context = TEMPLATE_MEMOS.format( diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index 89c02aaea..8e0e3c5c2 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -46,7 +46,7 @@ def memos_search(client, query, user_id, top_k): results = client.search(query=query, user_id=user_id, top_k=top_k) context = ( "\n".join([i["memory"] for i in results["text_mem"][0]["memories"]]) - + f"\n{results['pref_string']}" + + f"\n{results.get('pref_string', '')}" ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context) duration_ms = (time() - start) * 1000 diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index edec6b008..13ed659d2 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -81,8 +81,8 @@ def memos_search(client, user_id, query, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) search_memories = ( - "\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"]) - + f"\n{results['pref_string']}" + "\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"]) + + f"\n{results.get('pref_string', '')}" ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 3c34c49d0..ea0caa307 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -182,7 +182,8 @@ def search(self, query, user_id, top_k): "conversation_id": "", "top_k": top_k, "mode": os.getenv("SEARCH_MODE", "fast"), - "handle_pref_mem": False, + "include_preference": True, + "pref_top_k": 6, }, ensure_ascii=False, ) @@ -344,9 +345,10 @@ def wait_for_completion(self, task_id): query = "杭州西湖有什么" top_k = 5 - # MEMOBASE - client = MemobaseClient() + # MEMOS-API + client = MemosApiClient() for m in messages: m["created_at"] = iso_date - client.add(messages, user_id) + client.add(messages, user_id, user_id) memories = client.search(query, user_id, top_k) + print(memories) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 03622922d..d9db93c1a 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -354,7 +354,7 @@ def get_preference_memory_config() -> dict[str, Any]: return { "backend": "pref_text", "config": { - "extractor_llm": {"backend": "openai", "config": APIConfig.get_openai_config()}, + "extractor_llm": APIConfig.get_memreader_config(), "vector_db": { "backend": "milvus", "config": APIConfig.get_milvus_config(), diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index dd2fde22b..0412754c3 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -180,7 +180,8 @@ class APISearchRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) - handle_pref_mem: bool = Field(False, description="Whether to handle preference memory") + include_preference: bool = Field(True, description="Whether to handle preference memory") + pref_top_k: int = Field(6, description="Number of preference results to return") class APIADDRequest(BaseRequest): diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 38b9a361e..e255b1a48 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -324,17 +324,18 @@ def _post_process_pref_mem( memories_result: list[dict[str, Any]], pref_formatted_mem: list[dict[str, Any]], mem_cube_id: str, - handle_pref_mem: bool, + include_preference: bool, ): - if handle_pref_mem: + if include_preference: memories_result["pref_mem"].append( { "cube_id": mem_cube_id, "memories": pref_formatted_mem, } ) - pref_instruction: str = instruct_completion(pref_formatted_mem) + pref_instruction, pref_note = instruct_completion(pref_formatted_mem) memories_result["pref_string"] = pref_instruction + memories_result["pref_note"] = pref_note return memories_result @@ -354,7 +355,7 @@ def search_memories(search_req: APISearchRequest): "act_mem": [], "para_mem": [], "pref_mem": [], - "pref_string": "", + "pref_note": "", } search_mode = search_req.mode @@ -382,7 +383,7 @@ def _search_pref(): return [] results = naive_mem_cube.pref_mem.search( query=search_req.query, - top_k=search_req.top_k, + top_k=search_req.pref_top_k, info={ "user_id": search_req.user_id, "session_id": search_req.session_id, @@ -405,7 +406,10 @@ def _search_pref(): ) memories_result = _post_process_pref_mem( - memories_result, pref_formatted_memories, search_req.mem_cube_id, search_req.handle_pref_mem + memories_result, + pref_formatted_memories, + search_req.mem_cube_id, + search_req.include_preference, ) return SearchResponse( diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py index c2a7f58c7..acd110930 100644 --- a/src/memos/templates/instruction_completion.py +++ b/src/memos/templates/instruction_completion.py @@ -6,7 +6,7 @@ def instruct_completion( memories: list[dict[str, Any]] | None = None, -) -> str: +) -> [str, str]: """Create instruction following the preferences.""" explicit_pref = [] implicit_pref = [] @@ -49,10 +49,16 @@ def instruct_completion( lang = detect_lang(explicit_pref_str + implicit_pref_str) if not explicit_pref_str and not implicit_pref_str: - return "" + return "", "" if not explicit_pref_str: - return implicit_pref_str + "\n" + _prompt_map[lang].replace(_remove_exp_map[lang], "") + pref_note = _prompt_map[lang].replace(_remove_exp_map[lang], "") + pref_string = implicit_pref_str + "\n" + pref_note + return pref_string, pref_note if not implicit_pref_str: - return explicit_pref_str + "\n" + _prompt_map[lang].replace(_remove_imp_map[lang], "") + pref_note = _prompt_map[lang].replace(_remove_imp_map[lang], "") + pref_string = explicit_pref_str + "\n" + pref_note + return pref_string, pref_note - return explicit_pref_str + "\n" + implicit_pref_str + "\n" + _prompt_map[lang] + pref_note = _prompt_map[lang] + pref_string = explicit_pref_str + "\n" + implicit_pref_str + "\n" + pref_note + return pref_string, pref_note From 0e7128e9c5ac6e9ca0ca6ca8287516899f1bcbef Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Thu, 30 Oct 2025 19:38:42 +0800 Subject: [PATCH 33/64] fix:tree file change Searcher inputs (#423) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/memories/textual/tree.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 19bd3ba5b..53628d075 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -138,7 +138,6 @@ def get_searcher( self.reranker, internet_retriever=None, moscube=moscube, - search_strategy=self.search_strategy, ) else: searcher = Searcher( @@ -148,7 +147,6 @@ def get_searcher( self.reranker, internet_retriever=self.internet_retriever, moscube=moscube, - search_strategy=self.search_strategy, ) return searcher @@ -197,7 +195,7 @@ def search( bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, - vec_cot=self.vec_cot, + search_strategy=self.search_strategy, ) else: searcher = Searcher( @@ -208,7 +206,7 @@ def search( bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, - vec_cot=self.vec_cot, + search_strategy=self.search_strategy, ) return searcher.search(query, top_k, info, mode, memory_type, search_filter) From aa808632b9fffc38adcfa1e2206f7206e288b4a7 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 30 Oct 2025 20:08:10 +0800 Subject: [PATCH 34/64] Feat/pref optimize update (#425) * add hybrid search and fine extractor * add dialog and modify spliter chunk * optmize the update and retriever code * modify pref field * add pref mem update srategy * add pref mem update srategy * fix bug in pre_commit * modify pref filed * fix bug * fix pre_commit * fix bug in adder --------- Co-authored-by: yuan.wang --- evaluation/README.md | 2 +- evaluation/scripts/PrefEval/pref_mem0.py | 2 +- evaluation/scripts/PrefEval/pref_memobase.py | 3 +- evaluation/scripts/PrefEval/pref_memos.py | 3 +- evaluation/scripts/PrefEval/pref_memu.py | 3 +- .../scripts/PrefEval/pref_supermemory.py | 3 +- evaluation/scripts/PrefEval/pref_zep.py | 3 +- evaluation/scripts/personamem/pm_ingestion.py | 62 +++++++++++++------ evaluation/scripts/personamem/pm_metric.py | 15 ++++- evaluation/scripts/personamem/pm_responses.py | 37 ++++++++--- evaluation/scripts/personamem/pm_search.py | 37 ++++++++--- evaluation/scripts/run_prefeval_eval.sh | 2 +- .../textual/prefer_text_memory/adder.py | 8 ++- 13 files changed, 130 insertions(+), 50 deletions(-) diff --git a/evaluation/README.md b/evaluation/README.md index ba8c7a0cc..8683c60b2 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -84,4 +84,4 @@ get `questions_32k.csv` and `shared_contexts_32k.jsonl` from https://huggingface # Specify the model and memory backend you want to use (e.g., mem0, zep, etc.) # If you want to use MIRIX, edit the the configuration in ./scripts/personamem/config.yaml ./scripts/run_pm_eval.sh -``` \ No newline at end of file +``` diff --git a/evaluation/scripts/PrefEval/pref_mem0.py b/evaluation/scripts/PrefEval/pref_mem0.py index 214068567..300e0ede3 100644 --- a/evaluation/scripts/PrefEval/pref_mem0.py +++ b/evaluation/scripts/PrefEval/pref_mem0.py @@ -56,7 +56,7 @@ def add_memory_for_line( for idx, _ in enumerate(conversation[::2]): msg_idx = idx * 2 - record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}" timestamp_add = int(time.time() * 100) if record_id not in success_records: diff --git a/evaluation/scripts/PrefEval/pref_memobase.py b/evaluation/scripts/PrefEval/pref_memobase.py index e99b10520..776642657 100644 --- a/evaluation/scripts/PrefEval/pref_memobase.py +++ b/evaluation/scripts/PrefEval/pref_memobase.py @@ -12,6 +12,7 @@ from openai import OpenAI from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -68,7 +69,7 @@ def add_memory_for_line( ) for idx, _ in enumerate(conversation[::2]): msg_idx = idx * 2 - record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}" if record_id not in success_records: mem_client.add(messages=conversation[msg_idx : msg_idx + 2], user_id=user_id) diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 4a21e3af0..bbe1788b5 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -12,6 +12,7 @@ from openai import OpenAI from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -49,7 +50,7 @@ def add_memory_for_line( for idx, _ in enumerate(conversation[::2]): msg_idx = idx * 2 - record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}" if record_id not in success_records: mem_client.add( diff --git a/evaluation/scripts/PrefEval/pref_memu.py b/evaluation/scripts/PrefEval/pref_memu.py index 4c37db7b7..00c411eb7 100644 --- a/evaluation/scripts/PrefEval/pref_memu.py +++ b/evaluation/scripts/PrefEval/pref_memu.py @@ -14,6 +14,7 @@ from openai import OpenAI from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -56,7 +57,7 @@ def add_memory_for_line( for idx, _ in enumerate(conversation[::2]): msg_idx = idx * 2 - record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}" if record_id not in success_records: mem_client.add( diff --git a/evaluation/scripts/PrefEval/pref_supermemory.py b/evaluation/scripts/PrefEval/pref_supermemory.py index 68963e2af..7386bc462 100644 --- a/evaluation/scripts/PrefEval/pref_supermemory.py +++ b/evaluation/scripts/PrefEval/pref_supermemory.py @@ -12,6 +12,7 @@ from openai import OpenAI from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -54,7 +55,7 @@ def add_memory_for_line( for idx, _ in enumerate(conversation[::2]): msg_idx = idx * 2 - record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}" if record_id not in success_records: mem_client.add( diff --git a/evaluation/scripts/PrefEval/pref_zep.py b/evaluation/scripts/PrefEval/pref_zep.py index be98c6ba9..8a4d50558 100644 --- a/evaluation/scripts/PrefEval/pref_zep.py +++ b/evaluation/scripts/PrefEval/pref_zep.py @@ -14,6 +14,7 @@ from openai import OpenAI from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -56,7 +57,7 @@ def add_memory_for_line( for idx, _ in enumerate(conversation[::2]): msg_idx = idx * 2 - record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}" + record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}" if record_id not in success_records: mem_client.add( diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index fdbf43528..b960aa157 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -10,6 +10,7 @@ from tqdm import tqdm + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -171,7 +172,9 @@ def ingest_conv(row_data, context, version, conv_idx, frame, success_records, f) client = MemosApiOnlineClient() try: - ingest_session(session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client) + ingest_session( + session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client + ) print(f"✅ Ingestion of conversation {conv_idx} completed") print("=" * 80) @@ -187,10 +190,9 @@ def main(frame, version, num_workers=2, clear=False): os.makedirs(f"results/pm/{frame}-{version}/", exist_ok=True) record_file = f"results/pm/{frame}-{version}/success_records.txt" - if clear: - if os.path.exists(record_file): - os.remove(record_file) - print("🧹 Cleared progress records") + if clear and os.path.exists(record_file): + os.remove(record_file) + print("🧹 Cleared progress records") print("\n" + "=" * 80) print(f"🚀 PERSONAMEM INGESTION - {frame.upper()} v{version}".center(80)) @@ -205,15 +207,20 @@ def main(frame, version, num_workers=2, clear=False): success_records = set() if os.path.exists(record_file): - with open(record_file, "r") as f: - success_records = set(line.strip() for line in f) - print(f"📊 Found {len(success_records)} completed conversations, {total_rows - len(success_records)} remaining") + with open(record_file) as f: + success_records = {line.strip() for line in f} + print( + f"📊 Found {len(success_records)} completed conversations, {total_rows - len(success_records)} remaining" + ) start_time = datetime.now() all_data = list(load_rows_with_context(question_csv_path, context_jsonl_path)) - pending_data = [(idx, row_data, context) for idx, (row_data, context) in enumerate(all_data) - if str(idx) not in success_records] + pending_data = [ + (idx, row_data, context) + for idx, (row_data, context) in enumerate(all_data) + if str(idx) not in success_records + ] if not pending_data: print("✅ All conversations have been processed!") @@ -232,16 +239,16 @@ def main(frame, version, num_workers=2, clear=False): conv_idx=idx, frame=frame, success_records=success_records, - f=f + f=f, ) futures.append(future) completed_count = 0 for future in tqdm( - as_completed(futures), total=len(futures), desc="Processing conversations" + as_completed(futures), total=len(futures), desc="Processing conversations" ): try: - result = future.result() + future.result() completed_count += 1 except Exception as exc: print(f"\n❌ Conversation generated an exception: {exc}") @@ -261,13 +268,28 @@ def main(frame, version, num_workers=2, clear=False): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Ingestion Script") - parser.add_argument("--lib", type=str, - choices=["memos-api-online", "mem0", "mem0_graph", "memos-api", "memobase", "memu", - "supermemory", "zep"], - default='memos-api') - parser.add_argument("--version", type=str, default="default", help="Version of the evaluation framework.") - parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.") + parser.add_argument( + "--lib", + type=str, + choices=[ + "memos-api-online", + "mem0", + "mem0_graph", + "memos-api", + "memobase", + "memu", + "supermemory", + "zep", + ], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="default", help="Version of the evaluation framework." + ) + parser.add_argument( + "--workers", type=int, default=3, help="Number of parallel workers for processing users." + ) parser.add_argument("--clear", action="store_true", help="Clear progress and start fresh") args = parser.parse_args() - main(frame=args.lib, version=args.version, num_workers=args.workers, clear=args.clear) \ No newline at end of file + main(frame=args.lib, version=args.version, num_workers=args.workers, clear=args.clear) diff --git a/evaluation/scripts/personamem/pm_metric.py b/evaluation/scripts/personamem/pm_metric.py index b9d10a576..4c93ec0c6 100644 --- a/evaluation/scripts/personamem/pm_metric.py +++ b/evaluation/scripts/personamem/pm_metric.py @@ -353,12 +353,23 @@ def print_summary(results): parser.add_argument( "--lib", type=str, - choices=["zep", "mem0", "mem0_graph", "memos-api", "memos-api-online", "memobase", "memu", "supermemory"], + choices=[ + "zep", + "mem0", + "mem0_graph", + "memos-api", + "memos-api-online", + "memobase", + "memu", + "supermemory", + ], required=True, help="Memory library to evaluate", default="memos-api", ) - parser.add_argument("--version", type=str, default="default", help="Evaluation framework version") + parser.add_argument( + "--version", type=str, default="default", help="Evaluation framework version" + ) args = parser.parse_args() lib, version = args.lib, args.version diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index 2e41b4140..171b5af1a 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -10,6 +10,7 @@ from openai import OpenAI from tqdm import tqdm + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re @@ -153,9 +154,9 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id[future] = user_id for future in tqdm( - as_completed(future_to_user_id), - total=len(future_to_user_id), - desc="📝 Generating responses", + as_completed(future_to_user_id), + total=len(future_to_user_id), + desc="📝 Generating responses", ): user_id = future_to_user_id[future] try: @@ -184,12 +185,30 @@ def main(frame, version, num_runs=3, num_workers=4): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Response Generation Script") - parser.add_argument("--lib", type=str, - choices=["memos-api-online", "zep", "mem0", "mem0_graph", "memos-api", "memobase", "memu", - "supermemory"], default='memos-api') - parser.add_argument("--version", type=str, default="default", help="Version of the evaluation framework.") - parser.add_argument("--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation.") - parser.add_argument("--workers", type=int, default=10, help="Number of worker threads to use for processing.") + parser.add_argument( + "--lib", + type=str, + choices=[ + "memos-api-online", + "zep", + "mem0", + "mem0_graph", + "memos-api", + "memobase", + "memu", + "supermemory", + ], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="default", help="Version of the evaluation framework." + ) + parser.add_argument( + "--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation." + ) + parser.add_argument( + "--workers", type=int, default=10, help="Number of worker threads to use for processing." + ) args = parser.parse_args() main(frame=args.lib, version=args.version, num_runs=args.num_runs, num_workers=args.workers) diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index 13ed659d2..80a65e09b 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -3,10 +3,12 @@ import json import os import sys + from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from time import time + from tqdm import tqdm @@ -232,6 +234,7 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): context, duration_ms = memobase_search(client, question, user_id, top_k) elif frame == "memos-api-online": from utils.client import MemosApiOnlineClient + client = MemosApiOnlineClient() print("🔌 Using memos-api-online client for search...") context, duration_ms = memos_search(client, question, user_id, top_k) @@ -253,7 +256,7 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): os.makedirs(f"results/pm/{frame}-{version}/tmp", exist_ok=True) with open( - f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{conv_idx}.json", "w" + f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{conv_idx}.json", "w" ) as f: json.dump(search_results, f, indent=4) print(f"💾 Search results for conversation {conv_idx} saved...") @@ -304,7 +307,7 @@ def main(frame, version, top_k=20, num_workers=2): } for future in tqdm( - as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" + as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" ): idx = future_to_idx[future] try: @@ -333,13 +336,29 @@ def main(frame, version, top_k=20, num_workers=2): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Search Script") - parser.add_argument("--lib", type=str, - choices=["memos-api-online", "mem0", "mem0_graph", "memos-api", "memobase", "memu", - "supermemory"], - default='memos-api') - parser.add_argument("--version", type=str, default="default", help="Version of the evaluation framework.") - parser.add_argument("--top_k", type=int, default=20, help="Number of top results to retrieve from the search.") - parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.") + parser.add_argument( + "--lib", + type=str, + choices=[ + "memos-api-online", + "mem0", + "mem0_graph", + "memos-api", + "memobase", + "memu", + "supermemory", + ], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="default", help="Version of the evaluation framework." + ) + parser.add_argument( + "--top_k", type=int, default=20, help="Number of top results to retrieve from the search." + ) + parser.add_argument( + "--workers", type=int, default=3, help="Number of parallel workers for processing users." + ) args = parser.parse_args() diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh index 129382ebf..6f5f3b7b0 100755 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -143,4 +143,4 @@ fi echo "" echo "--- PrefEval Pipeline completed successfully! ---" -echo "Final results are in $RESPONSE_FILE" \ No newline at end of file +echo "Final results are in $RESPONSE_FILE" diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 052ae30c2..8d00ae81d 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -232,8 +232,12 @@ def _update_memory_fine( need_update = ( need_update if isinstance(need_update, bool) else need_update.lower() == "true" ) - update_item = [mem for mem in retrieved_memories if mem.id == rsp["id"]] - if need_update and update_item: + update_item = ( + [mem for mem in retrieved_memories if mem.id == rsp["id"]] + if rsp and "id" in rsp + else [] + ) + if need_update and update_item and rsp: update_vec_db_item = update_item[0] update_vec_db_item.payload[preference_type] = rsp["new_preference"] update_vec_db_item.payload["updated_at"] = vec_db_item.payload["updated_at"] From 8be2e80b61bfb0e70e104d4dcf60407f65c846f4 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Thu, 30 Oct 2025 20:32:16 +0800 Subject: [PATCH 35/64] Fix/query schedule (#424) * feat: change MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS to 10000 * feat: add user_name to schedule server router * feat: roll back to old mem-reader-prompt * feat: add moniter in schedule * feat: set default MEMRADER_MAX_TOKENS to 8000 * feat: add metric in schedule status * fix: bug * fix: base scheduler bug --- src/memos/api/routers/server_router.py | 6 + src/memos/mem_scheduler/base_scheduler.py | 111 +------- .../general_modules/dispatcher.py | 47 ++++ src/memos/mem_scheduler/utils/metrics.py | 250 ++++++++++++++++++ 4 files changed, 310 insertions(+), 104 deletions(-) create mode 100644 src/memos/mem_scheduler/utils/metrics.py diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index e255b1a48..2b481d5c6 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -653,6 +653,11 @@ def scheduler_status(user_name: str | None = None): cube = getattr(task, "mem_cube_id", "unknown") task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1 + try: + metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot() + except Exception: + metrics_snapshot = {} + return { "message": "ok", "data": { @@ -661,6 +666,7 @@ def scheduler_status(user_name: str | None = None): "task_count_per_user": task_count_per_user, "timestamp": time.time(), "instance_id": INSTANCE_ID, + "metrics": metrics_snapshot, }, } except Exception as err: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index c2f606146..b3b457c36 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -49,7 +49,6 @@ from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -from memos.memos_tools.notification_utils import send_online_bot_notification from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE @@ -127,21 +126,6 @@ def __init__(self, config: BaseSchedulerConfig): "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS ) - # queue monitor (optional) - self._queue_monitor_thread: threading.Thread | None = None - self._queue_monitor_running: bool = False - self.queue_monitor_interval_seconds: float = self.config.get( - "queue_monitor_interval_seconds", 60.0 - ) - self.queue_monitor_warn_utilization: float = self.config.get( - "queue_monitor_warn_utilization", 0.7 - ) - self.queue_monitor_crit_utilization: float = self.config.get( - "queue_monitor_crit_utilization", 0.9 - ) - self.enable_queue_monitor: bool = self.config.get("enable_queue_monitor", False) - self._online_bot_callable = None # type: ignore[var-annotated] - # other attributes self._context_lock = threading.Lock() self.current_user_id: UserID | str | None = None @@ -541,6 +525,10 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt logger.error(error_msg) raise TypeError(error_msg) + if getattr(message, "timestamp", None) is None: + with contextlib.suppress(Exception): + message.timestamp = datetime.utcnow() + if self.disable_handlers and message.label in self.disable_handlers: logger.info(f"Skipping disabled handler: {message.label} - {message.content}") continue @@ -555,6 +543,9 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt logger.info( f"Submitted message to local queue: {message.label} - {message.content}" ) + with contextlib.suppress(Exception): + if messages: + self.dispatcher.on_messages_enqueued(messages) def _submit_web_logs( self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem] @@ -706,13 +697,6 @@ def start(self) -> None: self._consumer_thread.start() logger.info("Message consumer thread started") - # optionally start queue monitor if enabled and bot callable present - if self.enable_queue_monitor and self._online_bot_callable is not None: - try: - self.start_queue_monitor(self._online_bot_callable) - except Exception as e: - logger.warning(f"Failed to start queue monitor: {e}") - def stop(self) -> None: """Stop all scheduler components gracefully. @@ -762,9 +746,6 @@ def stop(self) -> None: self._cleanup_queues() logger.info("Memory Scheduler stopped completely") - # Stop queue monitor - self.stop_queue_monitor() - @property def handlers(self) -> dict[str, Callable]: """ @@ -997,16 +978,6 @@ def _fmt_eta(seconds: float | None) -> str: return True - # ---------------- Queue monitor & notifications ---------------- - def set_notification_bots(self, online_bot=None): - """ - Set external notification callables. - - Args: - online_bot: a callable matching dinding_report_bot.online_bot signature - """ - self._online_bot_callable = online_bot - def _gather_queue_stats(self) -> dict: """Collect queue/dispatcher stats for reporting.""" stats: dict[str, int | float | str] = {} @@ -1044,71 +1015,3 @@ def _gather_queue_stats(self) -> dict: except Exception: stats.update({"running": 0, "inflight": 0, "handlers": 0}) return stats - - def _queue_monitor_loop(self, online_bot) -> None: - logger.info(f"Queue monitor started (interval={self.queue_monitor_interval_seconds}s)") - self._queue_monitor_running = True - while self._queue_monitor_running: - time.sleep(self.queue_monitor_interval_seconds) - try: - stats = self._gather_queue_stats() - # decide severity based on utilization if local queue - title_color = "#00956D" - subtitle = "Scheduler" - if not stats.get("use_redis_queue"): - util = float(stats.get("utilization", 0.0)) - if util >= self.queue_monitor_crit_utilization: - title_color = "#C62828" # red - subtitle = "Scheduler (CRITICAL)" - elif util >= self.queue_monitor_warn_utilization: - title_color = "#E65100" # orange - subtitle = "Scheduler (WARNING)" - - other_data1 = { - "use_redis_queue": stats.get("use_redis_queue"), - "handlers": stats.get("handlers"), - "running": stats.get("running"), - "inflight": stats.get("inflight"), - } - if not stats.get("use_redis_queue"): - other_data2 = { - "qsize": stats.get("qsize"), - "unfinished_tasks": stats.get("unfinished_tasks"), - "maxsize": stats.get("maxsize"), - "utilization": f"{float(stats.get('utilization', 0.0)):.2%}", - } - else: - other_data2 = { - "redis_mode": True, - } - - send_online_bot_notification( - online_bot=online_bot, - header_name="Scheduler Queue", - sub_title_name=subtitle, - title_color=title_color, - other_data1=other_data1, - other_data2=other_data2, - emoji={"Runtime": "🧠", "Queue": "📬"}, - ) - except Exception as e: - logger.warning(f"Queue monitor iteration failed: {e}") - logger.info("Queue monitor stopped") - - def start_queue_monitor(self, online_bot) -> None: - if self._queue_monitor_thread and self._queue_monitor_thread.is_alive(): - return - self._online_bot_callable = online_bot - self._queue_monitor_thread = threading.Thread( - target=self._queue_monitor_loop, - args=(online_bot,), - daemon=True, - name="QueueMonitorThread", - ) - self._queue_monitor_thread.start() - - def stop_queue_monitor(self) -> None: - self._queue_monitor_running = False - if self._queue_monitor_thread and self._queue_monitor_thread.is_alive(): - with contextlib.suppress(Exception): - self._queue_monitor_thread.join(timeout=2.0) diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index 997b01302..c2407b9e6 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -1,8 +1,10 @@ import concurrent import threading +import time from collections import defaultdict from collections.abc import Callable +from datetime import timezone from typing import Any from memos.context.context import ContextThreadPoolExecutor @@ -11,6 +13,7 @@ from memos.mem_scheduler.general_modules.task_threads import ThreadManager from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem +from memos.mem_scheduler.utils.metrics import MetricsRegistry logger = get_logger(__name__) @@ -70,6 +73,19 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): self._completed_tasks = [] self.completed_tasks_max_show_size = 10 + self.metrics = MetricsRegistry( + topk_per_label=(self.config or {}).get("metrics_topk_per_label", 50) + ) + + def on_messages_enqueued(self, msgs: list[ScheduleMessageItem]) -> None: + if not msgs: + return + now = time.time() + for m in msgs: + self.metrics.on_enqueue( + label=m.label, mem_cube_id=m.mem_cube_id, inst_rate=1.0, now=now + ) + def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): """ Create a wrapper around the handler to track task execution and capture results. @@ -84,9 +100,37 @@ def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): def wrapped_handler(messages: list[ScheduleMessageItem]): try: + # --- mark start: record queuing time(now - enqueue_ts)--- + now = time.time() + for m in messages: + enq_ts = getattr(m, "timestamp", None) + + # Path 1: epoch seconds (preferred) + if isinstance(enq_ts, int | float): + enq_epoch = float(enq_ts) + + # Path 2: datetime -> normalize to UTC epoch + elif hasattr(enq_ts, "timestamp"): + dt = enq_ts + if dt.tzinfo is None: + # treat naive as UTC to neutralize +8h skew + dt = dt.replace(tzinfo=timezone.utc) + enq_epoch = dt.timestamp() + else: + # fallback: treat as "just now" + enq_epoch = now + + wait_sec = max(0.0, now - enq_epoch) + self.metrics.on_start( + label=m.label, mem_cube_id=m.mem_cube_id, wait_sec=wait_sec, now=now + ) + # Execute the original handler result = handler(messages) + # --- mark done --- + for m in messages: + self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) # Mark task as completed and remove from tracking with self._task_lock: if task_item.item_id in self._running_tasks: @@ -100,6 +144,9 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): except Exception as e: # Mark task as failed and remove from tracking + for m in messages: + self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) + # Mark task as failed and remove from tracking with self._task_lock: if task_item.item_id in self._running_tasks: task_item.mark_failed(str(e)) diff --git a/src/memos/mem_scheduler/utils/metrics.py b/src/memos/mem_scheduler/utils/metrics.py new file mode 100644 index 000000000..5155c98b3 --- /dev/null +++ b/src/memos/mem_scheduler/utils/metrics.py @@ -0,0 +1,250 @@ +# metrics.py +from __future__ import annotations + +import threading +import time + +from dataclasses import dataclass, field + + +# ==== global window config ==== +WINDOW_SEC = 120 # 2 minutes sliding window + + +# ---------- O(1) EWMA ---------- +class Ewma: + """ + Time-decayed EWMA: + """ + + __slots__ = ("alpha", "last_ts", "tau", "value") + + def __init__(self, alpha: float = 0.3, tau: float = WINDOW_SEC): + self.alpha = alpha + self.value = 0.0 + self.last_ts: float = time.time() + self.tau = max(1e-6, float(tau)) + + def _decay_to(self, now: float | None = None): + now = time.time() if now is None else now + dt = max(0.0, now - self.last_ts) + if dt <= 0: + return + from math import exp + + self.value *= exp(-dt / self.tau) + self.last_ts = now + + def update(self, instant: float, now: float | None = None): + self._decay_to(now) + self.value = self.alpha * instant + (1 - self.alpha) * self.value + + def value_at(self, now: float | None = None) -> float: + now = time.time() if now is None else now + dt = max(0.0, now - self.last_ts) + if dt <= 0: + return self.value + from math import exp + + return self.value * exp(-dt / self.tau) + + +# ---------- approximate P95(Reservoir sample) ---------- +class ReservoirP95: + __slots__ = ("_i", "buf", "k", "n", "window") + + def __init__(self, k: int = 512, window: float = WINDOW_SEC): + self.k = k + self.buf: list[tuple[float, float]] = [] # (value, ts) + self.n = 0 + self._i = 0 + self.window = float(window) + + def _gc(self, now: float): + win_start = now - self.window + self.buf = [p for p in self.buf if p[1] >= win_start] + if self.buf: + self._i %= len(self.buf) + else: + self._i = 0 + + def add(self, x: float, now: float | None = None): + now = time.time() if now is None else now + self._gc(now) + self.n += 1 + if len(self.buf) < self.k: + self.buf.append((x, now)) + return + self.buf[self._i] = (x, now) + self._i = (self._i + 1) % self.k + + def p95(self, now: float | None = None) -> float: + now = time.time() if now is None else now + self._gc(now) + if not self.buf: + return 0.0 + arr = sorted(v for v, _ in self.buf) + idx = int(0.95 * (len(arr) - 1)) + return arr[idx] + + +# ---------- Space-Saving Top-K ---------- +class SpaceSaving: + """only topK:add(key) O(1),query topk O(K log K)""" + + def __init__(self, k: int = 100): + self.k = k + self.cnt: dict[str, int] = {} + + def add(self, key: str): + if key in self.cnt: + self.cnt[key] += 1 + return + if len(self.cnt) < self.k: + self.cnt[key] = 1 + return + victim = min(self.cnt, key=self.cnt.get) + self.cnt[key] = self.cnt.pop(victim) + 1 + + def topk(self) -> list[tuple[str, int]]: + return sorted(self.cnt.items(), key=lambda kv: kv[1], reverse=True) + + +@dataclass +class KeyStats: + backlog: int = 0 + lambda_ewma: Ewma = field(default_factory=lambda: Ewma(0.3, WINDOW_SEC)) + mu_ewma: Ewma = field(default_factory=lambda: Ewma(0.3, WINDOW_SEC)) + wait_p95: ReservoirP95 = field(default_factory=lambda: ReservoirP95(512, WINDOW_SEC)) + last_ts: float = field(default_factory=time.time) + # last event timestamps for rate estimation + last_enqueue_ts: float | None = None + last_done_ts: float | None = None + + def snapshot(self, now: float | None = None) -> dict: + now = time.time() if now is None else now + lam = self.lambda_ewma.value_at(now) + mu = self.mu_ewma.value_at(now) + delta = mu - lam + eta = float("inf") if delta <= 1e-9 else self.backlog / delta + return { + "backlog": self.backlog, + "lambda": round(lam, 3), + "mu": round(mu, 3), + "delta": round(delta, 3), + "eta_sec": None if eta == float("inf") else round(eta, 1), + "wait_p95_sec": round(self.wait_p95.p95(now), 3), + } + + +class MetricsRegistry: + """ + metrics: + - 1st phase:label(must) + - 2nd phase:labelXmem_cube_id(only Top-K) + - on_enqueue(label, mem_cube_id) + - on_start(label, mem_cube_id, wait_sec) + - on_done(label, mem_cube_id) + """ + + def __init__(self, topk_per_label: int = 50): + self._lock = threading.RLock() + self._label_stats: dict[str, KeyStats] = {} + self._label_topk: dict[str, SpaceSaving] = {} + self._detail_stats: dict[tuple[str, str], KeyStats] = {} + self._topk_per_label = topk_per_label + + # ---------- helpers ---------- + def _get_label(self, label: str) -> KeyStats: + if label not in self._label_stats: + self._label_stats[label] = KeyStats() + self._label_topk[label] = SpaceSaving(self._topk_per_label) + return self._label_stats[label] + + def _get_detail(self, label: str, mem_cube_id: str) -> KeyStats | None: + # 只有 Top-K 的 mem_cube_id 才建细粒度 key + ss = self._label_topk[label] + if mem_cube_id in ss.cnt or len(ss.cnt) < ss.k: + key = (label, mem_cube_id) + if key not in self._detail_stats: + self._detail_stats[key] = KeyStats() + return self._detail_stats[key] + return None + + # ---------- events ---------- + def on_enqueue( + self, label: str, mem_cube_id: str, inst_rate: float = 1.0, now: float | None = None + ): + with self._lock: + now = time.time() if now is None else now + ls = self._get_label(label) + # derive instantaneous arrival rate from inter-arrival time (events/sec) + prev_ts = ls.last_enqueue_ts + dt = (now - prev_ts) if prev_ts is not None else None + inst_rate = (1.0 / max(1e-3, dt)) if dt is not None else 0.0 # first sample: no spike + ls.last_enqueue_ts = now + ls.backlog += 1 + old_lam = ls.lambda_ewma.value_at(now) + ls.lambda_ewma.update(inst_rate, now) + new_lam = ls.lambda_ewma.value_at(now) + print( + f"[DEBUG enqueue] {label} backlog={ls.backlog} dt={dt if dt is not None else '—'}s inst={inst_rate:.3f} λ {old_lam:.3f}→{new_lam:.3f}" + ) + self._label_topk[label].add(mem_cube_id) + ds = self._get_detail(label, mem_cube_id) + if ds: + prev_ts_d = ds.last_enqueue_ts + dt_d = (now - prev_ts_d) if prev_ts_d is not None else None + inst_rate_d = (1.0 / max(1e-3, dt_d)) if dt_d is not None else 0.0 + ds.last_enqueue_ts = now + ds.backlog += 1 + ds.lambda_ewma.update(inst_rate_d, now) + + def on_start(self, label: str, mem_cube_id: str, wait_sec: float, now: float | None = None): + with self._lock: + now = time.time() if now is None else now + ls = self._get_label(label) + ls.wait_p95.add(wait_sec, now) + ds = self._detail_stats.get((label, mem_cube_id)) + if ds: + ds.wait_p95.add(wait_sec, now) + + def on_done( + self, label: str, mem_cube_id: str, inst_rate: float = 1.0, now: float | None = None + ): + with self._lock: + now = time.time() if now is None else now + ls = self._get_label(label) + # derive instantaneous service rate from inter-completion time (events/sec) + prev_ts = ls.last_done_ts + dt = (now - prev_ts) if prev_ts is not None else None + inst_rate = (1.0 / max(1e-3, dt)) if dt is not None else 0.0 + ls.last_done_ts = now + if ls.backlog > 0: + ls.backlog -= 1 + old_mu = ls.mu_ewma.value_at(now) + ls.mu_ewma.update(inst_rate, now) + new_mu = ls.mu_ewma.value_at(now) + print( + f"[DEBUG done] {label} backlog={ls.backlog} dt={dt if dt is not None else '—'}s inst={inst_rate:.3f} μ {old_mu:.3f}→{new_mu:.3f}" + ) + ds = self._detail_stats.get((label, mem_cube_id)) + if ds: + prev_ts_d = ds.last_done_ts + dt_d = (now - prev_ts_d) if prev_ts_d is not None else None + inst_rate_d = (1.0 / max(1e-3, dt_d)) if dt_d is not None else 0.0 + ds.last_done_ts = now + if ds.backlog > 0: + ds.backlog -= 1 + ds.mu_ewma.update(inst_rate_d, now) + + # ---------- snapshots ---------- + def snapshot(self) -> dict: + with self._lock: + now = time.time() + by_label = {lbl: ks.snapshot(now) for lbl, ks in self._label_stats.items()} + heavy = {lbl: self._label_topk[lbl].topk() for lbl in self._label_topk} + details = {} + for (lbl, cube), ks in self._detail_stats.items(): + details.setdefault(lbl, {})[cube] = ks.snapshot(now) + return {"by_label": by_label, "heavy": heavy, "details": details} From 28cf578b0ba4f85ac99d40a4ede0627b231b5477 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Thu, 30 Oct 2025 20:45:33 +0800 Subject: [PATCH 36/64] fix: message schema bug (#426) * feat: change MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS to 10000 * feat: add user_name to schedule server router * feat: roll back to old mem-reader-prompt * feat: add moniter in schedule * feat: set default MEMRADER_MAX_TOKENS to 8000 * feat: add metric in schedule status * fix: bug * fix: base scheduler bug * fix: message schema bug --- src/memos/graph_dbs/polardb.py | 3 +++ src/memos/mem_scheduler/schemas/message_schemas.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index f24f1072c..1d7dc06fc 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2389,6 +2389,8 @@ def add_node( self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None ) -> None: """Add a memory node to the graph.""" + logger.info(f"In add node polardb: id-{id} memory-{memory}") + # user_name comes from metadata; fallback to config if missing metadata["user_name"] = user_name if user_name else self.config.user_name @@ -2481,6 +2483,7 @@ def add_node( cursor.execute(insert_query, (id, json.dumps(properties))) logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") finally: + logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") self._return_connection(conn) def _build_node_from_agtype(self, node_agtype, embedding=None): diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 9cdb6823d..7f328474f 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -46,6 +46,10 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): default=None, description="user name / display name (optional)", ) + session_id: str | None = Field( + default=None, + description="session_id (optional)", + ) # Pydantic V2 model configuration model_config = ConfigDict( From af895314763a1bc16b39a820e66d4888b4fbda6b Mon Sep 17 00:00:00 2001 From: Wustzdy <67457465+wustzdy@users.noreply.github.com> Date: Thu, 30 Oct 2025 20:50:59 +0800 Subject: [PATCH 37/64] fix commit (#427) Co-authored-by: CaralHsi --- src/memos/graph_dbs/polardb.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 1d7dc06fc..20e02bd0c 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -183,7 +183,10 @@ def _get_connection(self): """Get a connection from the pool.""" if self._pool_closed: raise RuntimeError("Connection pool has been closed") - return self.connection_pool.getconn() + conn = self.connection_pool.getconn() + # Set autocommit for PolarDB compatibility + conn.autocommit = True + return conn def _return_connection(self, connection): """Return a connection to the pool.""" From 9c5d9fba97b4860570a215d3ffb7d59968542030 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 30 Oct 2025 21:35:58 +0800 Subject: [PATCH 38/64] Feat/pref optimize update (#429) * add hybrid search and fine extractor * add dialog and modify spliter chunk * optmize the update and retriever code * modify pref field * add pref mem update srategy * add pref mem update srategy * fix bug in pre_commit * modify pref filed * fix bug * fix pre_commit * fix bug in adder * fast --------- Co-authored-by: yuan.wang --- src/memos/memories/textual/prefer_text_memory/adder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 8d00ae81d..c8eea3cd4 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -326,7 +326,7 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | search_results.sort(key=lambda x: x.score, reverse=True) return self._update_memory( - memory, search_results, collection_name, preference_type, update_mode="fine" + memory, search_results, collection_name, preference_type, update_mode="fast" ) except Exception as e: From c7e9af4779f039cda585b59bdf7ae59dd8fa5317 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:35:53 +0800 Subject: [PATCH 39/64] Feat/pref optimize update (#431) * add hybrid search and fine extractor * add dialog and modify spliter chunk * optmize the update and retriever code * modify pref field * add pref mem update srategy * add pref mem update srategy * fix bug in pre_commit * modify pref filed * fix bug * fix pre_commit * fix bug in adder * fast * modify pref and adder mode * modify code * make pre_commit --------- Co-authored-by: yuan.wang --- .../textual/prefer_text_memory/adder.py | 9 ++++-- src/memos/templates/prefer_complete_prompt.py | 30 ++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index c8eea3cd4..ce0282a23 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -1,4 +1,5 @@ import json +import os from abc import ABC, abstractmethod from concurrent.futures import as_completed @@ -287,7 +288,7 @@ def _update_memory( retrieved_memories: list[MilvusVecDBItem], collection_name: str, preference_type: str, - update_mode: str = "fine", + update_mode: str = "fast", ) -> list[str] | str | None: """Update the memory. Args: @@ -326,7 +327,11 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | search_results.sort(key=lambda x: x.score, reverse=True) return self._update_memory( - memory, search_results, collection_name, preference_type, update_mode="fast" + memory, + search_results, + collection_name, + preference_type, + update_mode=os.getenv("PREFERENCE_ADDER_MODE", "fast"), ) except Exception as e: diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index b98e65d54..ec06af27f 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -62,17 +62,24 @@ NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ You are a preference inference assistant. Please extract **implicit preferences** from the following conversation -(preferences that the user did not explicitly state but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices). +(preferences that the user did not explicitly state but can be reasonably inferred from their underlying motivations, behavioral patterns, decision-making logic, and latent needs). Notes: -- Implicit preferences refer to user inclinations or choices that are not directly expressed, but can be reasonably inferred from factual cues in the conversation. +- Implicit preferences refer to user inclinations or choices that are not directly expressed, but can be deeply inferred by analyzing: + * **Hidden motivations**: What underlying needs or goals might drive the user's behavior? + * **Behavioral patterns**: What recurring patterns or tendencies can be observed? + * **Decision-making logic**: What reasoning or trade-offs might the user be considering? + * **Latent preferences**: What preferences might the user have but haven't yet articulated? + * **Contextual signals**: What do the user's choices, comparisons, exclusions, or scenario selections reveal about their deeper preferences? - Do not treat explicitly stated preferences as implicit preferences; this prompt is only for inferring preferences that are not directly mentioned. +- Go beyond surface-level facts to understand the user's hidden possibilities and underlying logic. Requirements: 1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. 2. Inferred implicit preferences must not conflict with explicit preferences. 3. For implicit_preference: only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. Put all reasoning and explanation in the reasoning field. -4. If no implicit preference can be reasonably inferred, leave the implicit_preference field empty (do not output anything else). +4. In the reasoning field, explicitly explain the underlying logic and hidden motivations you identified. +5. If no implicit preference can be reasonably inferred, leave the implicit_preference field empty (do not output anything else). Conversation: {qa_pair} @@ -82,7 +89,7 @@ { "implicit_preference": "A concise natural language statement of the implicit preferences reasonably inferred from the conversation, or an empty string", "context_summary": "The corresponding context summary, which is a summary of the corresponding conversation, do not lack any scenario information", - "reasoning": "Briefly explain the reasoning process for the implicit preference" + "reasoning": "Explain the underlying logic, hidden motivations, and behavioral patterns that led to this inference" } ``` Don't output anything except the JSON. @@ -91,17 +98,24 @@ NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH = """ 你是一个偏好推理助手。请从以下对话中提取**隐式偏好** -(用户没有明确表述,但可以从上下文、行为、频率、比较、排除或场景选择中合理推断出的偏好)。 +(用户没有明确表述,但可以通过分析其潜在动机、行为模式、决策逻辑和隐藏需求深度推断出的偏好)。 注意事项: -- 隐式偏好是指用户未直接表达,但可以从对话中的事实线索合理推断出的倾向或选择。 +- 隐式偏好是指用户未直接表达,但可以通过深入分析以下方面推断出的倾向或选择: + * **隐藏动机**:什么样的潜在需求或目标可能驱动用户的行为? + * **行为模式**:可以观察到什么样的重复模式或倾向? + * **决策逻辑**:用户可能在考虑什么样的推理或权衡? + * **潜在偏好**:用户可能有但尚未明确表达的偏好是什么? + * **情境信号**:用户的选择、比较、排除或场景选择揭示了什么样的深层偏好? - 不要将明确陈述的偏好视为隐式偏好;此提示仅用于推断未直接提及的偏好。 +- 超越表面事实,理解用户的隐藏可能性和背后的逻辑。 要求: 1. 仅在对话中有充分证据时进行推断;避免无根据或牵强的猜测。 2. 推断的隐式偏好不得与显式偏好冲突。 3. 对于 implicit_preference:仅输出偏好陈述本身;不要包含任何额外的解释、推理或置信度信息。将所有推理和解释放在 reasoning 字段中。 -4. 如果无法合理推断出隐式偏好,则将 implicit_preference 字段留空(不要输出其他任何内容)。 +4. 在 reasoning 字段中,明确解释你识别出的底层逻辑和隐藏动机。 +5. 如果无法合理推断出隐式偏好,则将 implicit_preference 字段留空(不要输出其他任何内容)。 对话: {qa_pair} @@ -111,7 +125,7 @@ { "implicit_preference": "从对话中合理推断出的隐式偏好的简洁自然语言陈述,或空字符串", "context_summary": "对应的上下文摘要,即对应对话的摘要,不要遗漏任何场景信息", - "reasoning": "简要解释隐式偏好的推理过程" + "reasoning": "解释推断出该偏好的底层逻辑、隐藏动机和行为模式" } ``` 除JSON外不要输出任何其他内容。 From 387fe8afda78e0e1f49d34ca867cb9aa112c9cde Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Mon, 3 Nov 2025 11:13:29 +0800 Subject: [PATCH 40/64] Feat/pref optimize update (#432) * add hybrid search and fine extractor * add dialog and modify spliter chunk * optmize the update and retriever code * modify pref field * add pref mem update srategy * add pref mem update srategy * fix bug in pre_commit * modify pref filed * fix bug * fix pre_commit * fix bug in adder * fast * modify pref and adder mode * modify code * make pre_commit * fix pref_string for memos online api * modify code * modify code * modify code * pre comimt * modify code --------- Co-authored-by: yuan.wang --- evaluation/scripts/utils/client.py | 36 +++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index ea0caa307..9aa527903 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -234,6 +234,8 @@ def search(self, query, user_id, top_k): "user_id": user_id, "memory_limit_number": top_k, "mode": os.getenv("SEARCH_MODE", "fast"), + "include_preference": True, + "pref_top_k": 6, } ) @@ -243,10 +245,38 @@ def search(self, query, user_id, top_k): response = requests.request("POST", url, data=payload, headers=self.headers) assert response.status_code == 200, response.text assert json.loads(response.text)["message"] == "ok", response.text - res = json.loads(response.text)["data"]["memory_detail_list"] - for i in res: + text_mem_res = json.loads(response.text)["data"]["memory_detail_list"] + pref_mem_res = json.loads(response.text)["data"]["preference_detail_list"] + preference_note = json.loads(response.text)["data"]["preference_note"] + for i in text_mem_res: i.update({"memory": i.pop("memory_value")}) - return {"text_mem": [{"memories": res}], "pref_str": ""} + + explicit_prefs = [ + p["preference"] + for p in pref_mem_res + if p.get("preference_type", "") == "explicit_preference" + ] + implicit_prefs = [ + p["preference"] + for p in pref_mem_res + if p.get("preference_type", "") == "implicit_preference" + ] + + pref_parts = [] + if explicit_prefs: + pref_parts.append( + "Explicit Preference:\n" + + "\n".join(f"{i + 1}. {p}" for i, p in enumerate(explicit_prefs)) + ) + if implicit_prefs: + pref_parts.append( + "Implicit Preference:\n" + + "\n".join(f"{i + 1}. {p}" for i, p in enumerate(implicit_prefs)) + ) + + pref_string = "\n".join(pref_parts) + preference_note + + return {"text_mem": [{"memories": text_mem_res}], "pref_string": pref_string} except Exception as e: if attempt < max_retries - 1: time.sleep(2**attempt) From b3ec17a9c05965bc64de6c445a292551e588901b Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Mon, 3 Nov 2025 16:00:39 +0800 Subject: [PATCH 41/64] Feat/add request log (#439) * feat: update log context * feat: update log context * feat: update mcp * feat: update mcp * feat: add error log * feat: add error log * feat: add error log * feat: update log * feat: add chat_time * feat: add chat_time * feat: add chat_time * feat: update log * feat: update log * feat: update log * feat: update log * feat: update log * feat: add arms * fix: format * fix: format * feat: add dockerfile * feat: add dockerfile * feat: add arms config * feat: update log * feat: add sleep time * feat: add sleep time * feat: update log * feat: delete dockerfile * feat: delete dockerfile * feat: update dockerfile * fix: conflict * feat: replace ThreadPool to context * feat: add timed log * feat: add request log * feat: add request log --------- Co-authored-by: harvey_xiang --- src/memos/api/middleware/request_context.py | 151 +++++++++++++++++++- src/memos/api/routers/server_router.py | 5 + 2 files changed, 150 insertions(+), 6 deletions(-) diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 2922ab3eb..74865b66f 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -2,6 +2,8 @@ Request context middleware for automatic trace_id injection. """ +import json +import os import time from collections.abc import Callable @@ -17,6 +19,9 @@ logger = memos.log.get_logger(__name__) +# Maximum body size to read for logging (in bytes) - bodies larger than this will be skipped +MAX_BODY_LOG_SIZE = os.getenv("MAX_BODY_LOG_SIZE", 10 * 1024) + def extract_trace_id_from_headers(request: Request) -> str | None: """Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id.""" @@ -26,6 +31,127 @@ def extract_trace_id_from_headers(request: Request) -> str | None: return None +def _is_json_request(request: Request) -> tuple[bool, str]: + """ + Check if request is a JSON request. + + Args: + request: The request object + + Returns: + Tuple of (is_json, content_type) + """ + if request.method not in ("POST", "PUT", "PATCH", "DELETE"): + return False, "" + + content_type = request.headers.get("content-type", "") + if not content_type: + return False, "" + + is_json = "application/json" in content_type.lower() + return is_json, content_type + + +def _should_read_body(content_length: str | None) -> tuple[bool, int | None]: + """ + Check if body should be read based on content-length header. + + Args: + content_length: Content-Length header value + + Returns: + Tuple of (should_read, body_size). body_size is None if header is invalid. + """ + if not content_length: + return True, None + + try: + body_size = int(content_length) + return body_size <= MAX_BODY_LOG_SIZE, body_size + except ValueError: + return True, None + + +def _create_body_info(content_type: str, body_size: int) -> dict: + """Create body_info dict for large bodies that are skipped.""" + return { + "content_type": content_type, + "content_length": body_size, + "note": f"body too large ({body_size} bytes), skipping read", + } + + +def _parse_json_body(body_bytes: bytes) -> dict | str: + """ + Parse JSON body bytes. + + Args: + body_bytes: Raw body bytes + + Returns: + Parsed JSON dict, or error message string if parsing fails + """ + try: + return json.loads(body_bytes) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + return f"" + + +async def get_request_params(request: Request) -> tuple[dict, bytes | None]: + """ + Extract request parameters (query params and body) for logging. + + Only reads body for application/json requests that are within size limits. + + This function is wrapped with exception handling to ensure logging failures + don't affect the actual request processing. + + Args: + request: The incoming request object + + Returns: + Tuple of (params_dict, body_bytes). body_bytes is None if body was not read. + Returns empty dict and None on any error. + """ + try: + params_log = {} + + # Check if this is a JSON request + is_json, content_type = _is_json_request(request) + if not is_json: + return params_log, None + + # Pre-check body size using content-length header + content_length = request.headers.get("content-length") + should_read, body_size = _should_read_body(content_length) + + if not should_read and body_size is not None: + params_log["body_info"] = _create_body_info(content_type, body_size) + return params_log, None + + # Read body + body_bytes = await request.body() + + if not body_bytes: + return params_log, None + + # Post-check: verify actual size (content-length might be missing or wrong) + actual_size = len(body_bytes) + if actual_size > MAX_BODY_LOG_SIZE: + params_log["body_info"] = _create_body_info(content_type, actual_size) + return params_log, None + + # Parse JSON body + params_log["body"] = _parse_json_body(body_bytes) + return params_log, body_bytes + + except Exception as e: + # Catch-all for any unexpected errors + logger.error(f"Unexpected error in get_request_params: {e}", exc_info=True) + # Return empty dict to ensure request can continue + return {}, None + + class RequestContextMiddleware(BaseHTTPMiddleware): """ Middleware to automatically inject request context for every HTTP request. @@ -55,14 +181,27 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: ) set_request_context(context) - # Log request start with parameters - params_log = {} + # Get request parameters for logging + # Wrap in try-catch to ensure logging failures don't break the request + params_log, body_bytes = await get_request_params(request) + + # Re-create the request receive function if body was read + # This ensures downstream handlers can still read the body + if body_bytes is not None: + try: - # Get query parameters - if request.query_params: - params_log["query_params"] = dict(request.query_params) + async def receive(): + return {"type": "http.request", "body": body_bytes, "more_body": False} - logger.info(f"Request started, params: {params_log}, headers: {request.headers}") + request._receive = receive + except Exception as e: + logger.error(f"Failed to recreate request receive function: {e}") + # Continue without restoring body, downstream handlers will handle it + + logger.info( + f"Request started, method: {request.method}, path: {request.url.path}, " + f"request params: {params_log}, headers: {request.headers}" + ) # Process the request try: diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 2b481d5c6..684e02a0c 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -412,6 +412,8 @@ def _search_pref(): search_req.include_preference, ) + logger.info(f"Search memories result: {memories_result}") + return SearchResponse( message="Search completed successfully", data=memories_result, @@ -618,6 +620,9 @@ def _process_pref_mem() -> list[dict[str, str]]: text_response_data = text_future.result() pref_response_data = pref_future.result() + logger.info(f"add_memories Text response data: {text_response_data}") + logger.info(f"add_memories Pref response data: {pref_response_data}") + return MemoryResponse( message="Memory added successfully", data=text_response_data + pref_response_data, From b3b0baa733017019eef6a7c4e9824a3aa8032ae8 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:29:18 +0800 Subject: [PATCH 42/64] Feat/standardized preference field (#440) * standardized preference field * fix pre_commit --------- Co-authored-by: yuan.wang Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/memories/textual/item.py | 3 +- .../textual/prefer_text_memory/adder.py | 59 +++++++++---------- .../textual/prefer_text_memory/extractor.py | 3 + .../textual/prefer_text_memory/retrievers.py | 4 +- .../textual/prefer_text_memory/utils.py | 6 +- src/memos/templates/instruction_completion.py | 11 ++-- 6 files changed, 41 insertions(+), 45 deletions(-) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 9b9059d26..2c23ae193 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -196,9 +196,8 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): dialog_id: str | None = Field(default=None, description="ID of the dialog.") original_text: str | None = Field(default=None, description="String of the dialog.") embedding: list[float] | None = Field(default=None, description="Vector of the dialog.") - explicit_preference: str | None = Field(default=None, description="Explicit preference.") + preference: str | None = Field(default=None, description="Preference.") created_at: str | None = Field(default=None, description="Timestamp of the dialog.") - implicit_preference: str | None = Field(default=None, description="Implicit preference.") class TextualMemoryItem(BaseModel): diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index ce0282a23..a78601e86 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -103,7 +103,6 @@ def _update_memory_op_trace( new_memories: list[TextualMemoryItem], retrieved_memories: list[MilvusVecDBItem], collection_name: str, - preference_type: str, ) -> list[str] | str: # create new vec db items new_vec_db_items: list[MilvusVecDBItem] = [] @@ -124,17 +123,19 @@ def _update_memory_op_trace( { "id": new_memory.id, "context_summary": new_memory.memory, - "preference": new_memory.payload[preference_type], + "preference": new_memory.payload["preference"], } for new_memory in new_vec_db_items + if new_memory.payload.get("preference", None) ] retrieved_mem_inputs = [ { "id": mem.id, "context_summary": mem.memory, - "preference": mem.payload[preference_type], + "preference": mem.payload["preference"], } for mem in retrieved_memories + if mem.payload.get("preference", None) ] rsp = self._judge_update_or_add_trace_op( @@ -168,7 +169,7 @@ def execute_op( elif op_type == "update": if op["target_id"] in retrieved_mem_db_item_map: update_mem_db_item = retrieved_mem_db_item_map[op["target_id"]] - update_mem_db_item.payload[preference_type] = op["new_preference"] + update_mem_db_item.payload["preference"] = op["new_preference"] update_mem_db_item.payload["updated_at"] = datetime.now().isoformat() update_mem_db_item.memory = op["new_context_summary"] update_mem_db_item.original_text = op["new_context_summary"] @@ -198,7 +199,6 @@ def _update_memory_fine( new_memory: TextualMemoryItem, retrieved_memories: list[MilvusVecDBItem], collection_name: str, - preference_type: str, ) -> str: payload = new_memory.to_dict()["metadata"] fields_to_remove = {"dialog_id", "original_text", "embedding"} @@ -211,19 +211,15 @@ def _update_memory_fine( payload=payload, ) - new_mem_input = { - "memory": new_memory.memory, - "preference": new_memory.metadata.explicit_preference - if preference_type == "explicit_preference" - else new_memory.metadata.implicit_preference, - } + new_mem_input = {"memory": new_memory.memory, "preference": new_memory.metadata.preference} retrieved_mem_inputs = [ { "id": mem.id, "memory": mem.memory, - "preference": mem.payload[preference_type], + "preference": mem.payload["preference"], } for mem in retrieved_memories + if mem.payload.get("preference", None) ] rsp = self._judge_update_or_add_fine( new_mem=json.dumps(new_mem_input), @@ -240,7 +236,7 @@ def _update_memory_fine( ) if need_update and update_item and rsp: update_vec_db_item = update_item[0] - update_vec_db_item.payload[preference_type] = rsp["new_preference"] + update_vec_db_item.payload["preference"] = rsp["new_preference"] update_vec_db_item.payload["updated_at"] = vec_db_item.payload["updated_at"] update_vec_db_item.memory = rsp["new_memory"] update_vec_db_item.original_text = vec_db_item.original_text @@ -287,7 +283,6 @@ def _update_memory( new_memory: TextualMemoryItem, retrieved_memories: list[MilvusVecDBItem], collection_name: str, - preference_type: str, update_mode: str = "fast", ) -> list[str] | str | None: """Update the memory. @@ -295,15 +290,12 @@ def _update_memory( new_memory: TextualMemoryItem retrieved_memories: list[MilvusVecDBItem] collection_name: str - preference_type: str update_mode: str, "fast" or "fine" """ if update_mode == "fast": return self._update_memory_fast(new_memory, retrieved_memories, collection_name) elif update_mode == "fine": - return self._update_memory_fine( - new_memory, retrieved_memories, collection_name, preference_type - ) + return self._update_memory_fine(new_memory, retrieved_memories, collection_name) else: raise ValueError(f"Invalid update mode: {update_mode}") @@ -330,7 +322,6 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | memory, search_results, collection_name, - preference_type, update_mode=os.getenv("PREFERENCE_ADDER_MODE", "fast"), ) @@ -369,18 +360,24 @@ def process_memory_batch(self, memories: list[TextualMemoryItem], *args, **kwarg explicit_recalls = list({recall.id: recall for recall in explicit_recalls}.values()) implicit_recalls = list({recall.id: recall for recall in implicit_recalls}.values()) - explicit_added_ids = self._update_memory_op_trace( - explicit_new_mems, - explicit_recalls, - pref_type_collection_map["explicit_preference"], - "explicit_preference", - ) - implicit_added_ids = self._update_memory_op_trace( - implicit_new_mems, - implicit_recalls, - pref_type_collection_map["implicit_preference"], - "implicit_preference", - ) + # 使用线程池并行处理显式和隐式偏好 + with ContextThreadPoolExecutor(max_workers=2) as executor: + explicit_future = executor.submit( + self._update_memory_op_trace, + explicit_new_mems, + explicit_recalls, + pref_type_collection_map["explicit_preference"], + ) + implicit_future = executor.submit( + self._update_memory_op_trace, + implicit_new_mems, + implicit_recalls, + pref_type_collection_map["implicit_preference"], + ) + + explicit_added_ids = explicit_future.result() + implicit_added_ids = implicit_future.result() + return explicit_added_ids + implicit_added_ids def process_memory_single( diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 61629b38a..d5eab2aec 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -67,6 +67,8 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) + for d in result: + d["preference"] = d.pop("explicit_preference") return result except Exception as e: logger.error(f"Error extracting explicit preference: {e}, return None") @@ -88,6 +90,7 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) + result["preference"] = result.pop("implicit_preference") return result except Exception as e: logger.error(f"Error extracting implicit preferences: {e}, return None") diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index f09d646b1..0074c3f1c 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -106,7 +106,7 @@ def retrieve( metadata=PreferenceTextualMemoryMetadata(**pref.payload), ) for pref in explicit_prefs - if pref.payload["explicit_preference"] + if pref.payload.get("preference", None) ] implicit_prefs_mem = [ @@ -116,7 +116,7 @@ def retrieve( metadata=PreferenceTextualMemoryMetadata(**pref.payload), ) for pref in implicit_prefs - if pref.payload["implicit_preference"] + if pref.payload.get("preference", None) ] reranker_map = { diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py index 85adc9304..76d4b4211 100644 --- a/src/memos/memories/textual/prefer_text_memory/utils.py +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -46,10 +46,8 @@ def deduplicate_preferences( for i, pref in enumerate(prefs): # Extract preference text - if hasattr(pref.metadata, "implicit_preference") and pref.metadata.implicit_preference: - text = pref.metadata.implicit_preference - elif hasattr(pref.metadata, "explicit_preference") and pref.metadata.explicit_preference: - text = pref.metadata.explicit_preference + if hasattr(pref.metadata, "preference") and pref.metadata.preference: + text = pref.metadata.preference else: text = pref.memory diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py index acd110930..03ae52c77 100644 --- a/src/memos/templates/instruction_completion.py +++ b/src/memos/templates/instruction_completion.py @@ -12,14 +12,13 @@ def instruct_completion( implicit_pref = [] for memory in memories: pref_type = memory.get("metadata", {}).get("preference_type") + pref = memory.get("metadata", {}).get("preference", None) + if not pref: + continue if pref_type == "explicit_preference": - pref = memory.get("metadata", {}).get("explicit_preference", None) - if pref: - explicit_pref.append(pref) + explicit_pref.append(pref) elif pref_type == "implicit_preference": - pref = memory.get("metadata", {}).get("implicit_preference", None) - if pref: - implicit_pref.append(pref) + implicit_pref.append(pref) explicit_pref_str = ( "Explicit Preference:\n" From cef93698c2a570662594e68f88fa941ed95189cb Mon Sep 17 00:00:00 2001 From: Wustzdy <67457465+wustzdy@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:50:25 +0800 Subject: [PATCH 43/64] update polardb pool timeout (#441) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix:parse_node * fix: * fix:conn --------- Co-authored-by: ccl <13282138256@163.com> Co-authored-by: CaralHsi --- src/memos/graph_dbs/polardb.py | 61 +++++++++++++++++++++++++++++- src/memos/memories/textual/tree.py | 6 +-- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 20e02bd0c..de05185d2 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -151,6 +151,10 @@ def __init__(self, config: PolarDBGraphDBConfig): user=user, password=password, dbname=self.db_name, + connect_timeout=60, # Connection timeout in seconds + keepalives_idle=40, # Seconds of inactivity before sending keepalive (should be < server idle timeout) + keepalives_interval=15, # Seconds between keepalive retries + keepalives_count=5, # Number of keepalive retries before considering connection dead ) # Keep a reference to the pool for cleanup @@ -179,7 +183,7 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) - def _get_connection(self): + def _get_connection_old(self): """Get a connection from the pool.""" if self._pool_closed: raise RuntimeError("Connection pool has been closed") @@ -188,7 +192,60 @@ def _get_connection(self): conn.autocommit = True return conn + def _get_connection(self): + """Get a connection from the pool.""" + if self._pool_closed: + raise RuntimeError("Connection pool has been closed") + + max_retries = 3 + for attempt in range(max_retries): + try: + conn = self.connection_pool.getconn() + + # Check if connection is closed + if conn.closed != 0: + # Connection is closed, close it explicitly and try again + try: + conn.close() + except Exception as e: + logger.warning(f"Failed to close connection: {e}") + if attempt < max_retries - 1: + continue + else: + raise RuntimeError("Pool returned a closed connection") + + # Set autocommit for PolarDB compatibility + conn.autocommit = True + return conn + except Exception as e: + if attempt >= max_retries - 1: + raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e + continue + def _return_connection(self, connection): + """Return a connection to the pool.""" + if not self._pool_closed and connection: + try: + # Check if connection is closed + if hasattr(connection, "closed") and connection.closed != 0: + # Connection is closed, just close it and don't return to pool + try: + connection.close() + except Exception as e: + logger.warning(f"Failed to close connection: {e}") + return + + # Connection is valid, return to pool + self.connection_pool.putconn(connection) + except Exception as e: + # If putconn fails, close the connection + logger.warning(f"Failed to return connection to pool: {e}") + try: + connection.close() + except Exception as e: + logger.warning(f"Failed to close connection: {e}") + + def _return_connection_old(self, connection): """Return a connection to the pool.""" if not self._pool_closed and connection: self.connection_pool.putconn(connection) @@ -1834,7 +1891,7 @@ def export_graph( if include_embedding and embedding_json is not None: properties["embedding"] = embedding_json - nodes.append(self._parse_node(properties)) + nodes.append(self._parse_node(json.loads(properties[1]))) except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 53628d075..dea3cc1ab 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -260,7 +260,7 @@ def get_relevant_subgraph( center_id=core_id, depth=depth, center_status=center_status ) - if not subgraph["core_node"]: + if subgraph is None or not subgraph["core_node"]: logger.info(f"Skipping node {core_id} (inactive or not found).") continue @@ -281,9 +281,9 @@ def get_relevant_subgraph( {"id": core_id, "score": score, "core_node": core_node, "neighbors": neighbors} ) - top_core = cores[0] + top_core = cores[0] if cores else None return { - "core_id": top_core["id"], + "core_id": top_core["id"] if top_core else None, "nodes": list(all_nodes.values()), "edges": [{"source": f, "target": t, "type": ty} for (f, t, ty) in all_edges], } From fd56f64b67734cd633a402fb6b98498d616db28b Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:09:30 +0800 Subject: [PATCH 44/64] feat: fix self-input prompt error (#443) --- src/memos/mem_os/product_server.py | 44 ++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/src/memos/mem_os/product_server.py b/src/memos/mem_os/product_server.py index b94b26f65..758f2794d 100644 --- a/src/memos/mem_os/product_server.py +++ b/src/memos/mem_os/product_server.py @@ -71,11 +71,7 @@ def chat( m.metadata.embedding = [] new_memories_list.append(m) memories_list = new_memories_list - system_prompt = self._build_base_system_prompt(base_prompt, mode="base") - - memory_context = self._build_memory_context(memories_list, mode="base") - - user_content = memory_context + query if memory_context else query + system_prompt = self._build_system_prompt(memories_list, base_prompt) history_info = [] if history: @@ -83,7 +79,7 @@ def chat( current_messages = [ {"role": "system", "content": system_prompt}, *history_info, - {"role": "user", "content": user_content}, + {"role": "user", "content": query}, ] response = self.chat_llm.generate(current_messages) time_end = time.time() @@ -187,6 +183,42 @@ def _build_base_system_prompt( prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" return prefix + sys_body + def _build_system_prompt( + self, + memories: list[TextualMemoryItem] | list[str] | None = None, + base_prompt: str | None = None, + **kwargs, + ) -> str: + """Build system prompt with optional memories context.""" + if base_prompt is None: + base_prompt = ( + "You are a knowledgeable and helpful AI assistant. " + "You have access to conversation memories that help you provide more personalized responses. " + "Use the memories to understand the user's context, preferences, and past interactions. " + "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." + ) + + memory_context = "" + if memories: + memory_list = [] + for i, memory in enumerate(memories, 1): + if isinstance(memory, TextualMemoryItem): + text_memory = memory.memory + else: + if not isinstance(memory, str): + logger.error("Unexpected memory type.") + text_memory = memory + memory_list.append(f"{i}. {text_memory}") + memory_context = "\n".join(memory_list) + + if "{memories}" in base_prompt: + return base_prompt.format(memories=memory_context) + elif base_prompt and memories: + # For backward compatibility, append memories if no placeholder is found + memory_context_with_header = "\n\n## Memories:\n" + memory_context + return base_prompt + memory_context_with_header + return base_prompt + def _build_memory_context( self, memories_all: list[TextualMemoryItem], From e79a9abe417833586fa75391787e6d249c13d2bb Mon Sep 17 00:00:00 2001 From: lijicode <34564964+lijicode@users.noreply.github.com> Date: Mon, 3 Nov 2025 19:06:35 +0800 Subject: [PATCH 45/64] feat: fix polardb value (#445) --- src/memos/graph_dbs/polardb.py | 62 ++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index de05185d2..552b30241 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -363,7 +363,7 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype """ query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params = [f'"{memory_type}"', f'"{user_name}"'] + params = [self.format_param_value(memory_type), self.format_param_value(user_name)] # Get a connection from the pool conn = self._get_connection() @@ -389,7 +389,7 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: """ query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" query += "\nLIMIT 1" - params = [f'"{scope}"', f'"{user_name}"'] + params = [self.format_param_value(scope), self.format_param_value(user_name)] # Get a connection from the pool conn = self._get_connection() @@ -427,7 +427,11 @@ def remove_oldest_memory( ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC OFFSET %s """ - select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest] + select_params = [ + self.format_param_value(memory_type), + self.format_param_value(user_name), + keep_latest, + ] conn = self._get_connection() try: with conn.cursor() as cursor: @@ -501,19 +505,23 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N SET properties = %s, embedding = %s WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ - params = [json.dumps(properties), json.dumps(embedding_vector), f'"{id}"'] + params = [ + json.dumps(properties), + json.dumps(embedding_vector), + self.format_param_value(id), + ] else: query = f""" UPDATE "{self.db_name}_graph"."Memory" SET properties = %s WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ - params = [json.dumps(properties), f'"{id}"'] + params = [json.dumps(properties), self.format_param_value(id)] # Only add user filter when user_name is provided if user_name is not None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(f'"{user_name}"') + params.append(self.format_param_value(user_name)) # Get a connection from the pool conn = self._get_connection() @@ -538,12 +546,12 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: DELETE FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ - params = [f'"{id}"'] + params = [self.format_param_value(id)] # Only add user filter when user_name is provided if user_name is not None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(f'"{user_name}"') + params.append(self.format_param_value(user_name)) # Get a connection from the pool conn = self._get_connection() @@ -831,28 +839,17 @@ def get_node( select_fields = "id, properties, embedding" if include_embedding else "id, properties" - # Helper function to format parameter value - def format_param_value(value: str) -> str: - """Format parameter value to handle both quoted and unquoted formats""" - # Remove outer quotes if they exist - if value.startswith('"') and value.endswith('"'): - # Already has double quotes, return as is - return value - else: - # Add double quotes - return f'"{value}"' - query = f""" SELECT {select_fields} FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ - params = [format_param_value(id)] + params = [self.format_param_value(id)] # Only add user filter when user_name is provided if user_name is not None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(format_param_value(user_name)) + params.append(self.format_param_value(user_name)) conn = self._get_connection() try: @@ -930,7 +927,7 @@ def get_nodes( where_conditions.append( "ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = %s::agtype" ) - params.append(f"{id_val}") + params.append(self.format_param_value(id_val)) where_clause = " OR ".join(where_conditions) @@ -942,7 +939,7 @@ def get_nodes( user_name = user_name if user_name else self.config.user_name query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(f'"{user_name}"') + params.append(self.format_param_value(user_name)) conn = self._get_connection() try: @@ -2616,7 +2613,7 @@ def get_neighbors_by_tag( exclude_conditions.append( "ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) != %s::agtype" ) - params.append(f'"{exclude_id}"') + params.append(self.format_param_value(exclude_id)) where_clauses.append(f"({' AND '.join(exclude_conditions)})") # Status filter - keep only 'activated' @@ -2633,7 +2630,7 @@ def get_neighbors_by_tag( where_clauses.append( "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" ) - params.append(f'"{user_name}"') + params.append(self.format_param_value(user_name)) # Testing showed no data; annotate. where_clauses.append( @@ -3022,3 +3019,18 @@ def _convert_graph_edges(self, core_node: dict) -> dict: if tgt in id_map: edge["target"] = id_map[tgt] return data + + def format_param_value(self, value: str | None) -> str: + """Format parameter value to handle both quoted and unquoted formats""" + # Handle None value + if value is None: + logger.warning(f"format_param_value: value is None") + return "null" + + # Remove outer quotes if they exist + if value.startswith('"') and value.endswith('"'): + # Already has double quotes, return as is + return value + else: + # Add double quotes + return f'"{value}"' From 9ea42e44598ef0f8b3144ff6d65c68a58ca29ddf Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Mon, 3 Nov 2025 19:29:51 +0800 Subject: [PATCH 46/64] Feat/add request log (#442) * feat: update log context * feat: update log context * feat: update mcp * feat: update mcp * feat: add error log * feat: add error log * feat: add error log * feat: update log * feat: add chat_time * feat: add chat_time * feat: add chat_time * feat: update log * feat: update log * feat: update log * feat: update log * feat: update log * feat: add arms * fix: format * fix: format * feat: add dockerfile * feat: add dockerfile * feat: add arms config * feat: update log * feat: add sleep time * feat: add sleep time * feat: update log * feat: delete dockerfile * feat: delete dockerfile * feat: update dockerfile * fix: conflict * feat: replace ThreadPool to context * feat: add timed log * feat: add request log * feat: add request log * feat: add source in request * feat: source --------- Co-authored-by: harvey_xiang Co-authored-by: CaralHsi --- src/memos/api/middleware/request_context.py | 20 ++++++++++++++++---- src/memos/api/product_api.py | 2 +- src/memos/api/server_api.py | 2 +- src/memos/context/context.py | 15 +++++++++++++++ 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 74865b66f..488f59625 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -162,6 +162,17 @@ class RequestContextMiddleware(BaseHTTPMiddleware): 3. Ensures the context is available throughout the request lifecycle """ + def __init__(self, app, source: str | None = None): + """ + Initialize the middleware. + + Args: + app: The ASGI application + source: Source identifier (e.g., 'product' or 'server') to distinguish request origin + """ + super().__init__(app) + self.source = source or "api" + async def dispatch(self, request: Request, call_next: Callable) -> Response: # Extract or generate trace_id trace_id = extract_trace_id_from_headers(request) or generate_trace_id() @@ -178,6 +189,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: env=env, user_type=user_type, user_name=user_name, + source=self.source, ) set_request_context(context) @@ -199,7 +211,7 @@ async def receive(): # Continue without restoring body, downstream handlers will handle it logger.info( - f"Request started, method: {request.method}, path: {request.url.path}, " + f"Request started, source: {self.source}, method: {request.method}, path: {request.url.path}, " f"request params: {params_log}, headers: {request.headers}" ) @@ -209,16 +221,16 @@ async def receive(): end_time = time.time() if response.status_code == 200: logger.info( - f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + f"Request completed: source: {self.source}, path: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" ) else: logger.error( - f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + f"Request Failed: source: {self.source}, path: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" ) except Exception as e: end_time = time.time() logger.error( - f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms" + f"Request Exception Error: source: {self.source}, path: {request.url.path}, error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms" ) raise e diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py index 709ad74fb..ec5cccae1 100644 --- a/src/memos/api/product_api.py +++ b/src/memos/api/product_api.py @@ -17,7 +17,7 @@ version="1.0.1", ) -app.add_middleware(RequestContextMiddleware) +app.add_middleware(RequestContextMiddleware, source="product_api") # Include routers app.include_router(product_router) diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 24c67de48..0dfef99d9 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -18,7 +18,7 @@ version="1.0.1", ) -app.add_middleware(RequestContextMiddleware) +app.add_middleware(RequestContextMiddleware, source="server_api") # Include routers app.include_router(server_router) diff --git a/src/memos/context/context.py b/src/memos/context/context.py index d6a0f3bf1..b5d4c24fe 100644 --- a/src/memos/context/context.py +++ b/src/memos/context/context.py @@ -36,12 +36,14 @@ def __init__( env: str | None = None, user_type: str | None = None, user_name: str | None = None, + source: str | None = None, ): self.trace_id = trace_id or "trace-id" self.api_path = api_path self.env = env self.user_type = user_type self.user_name = user_name + self.source = source self._data: dict[str, Any] = {} def set(self, key: str, value: Any) -> None: @@ -59,6 +61,7 @@ def __setattr__(self, name: str, value: Any) -> None: "env", "user_type", "user_name", + "source", ): super().__setattr__(name, value) else: @@ -80,6 +83,7 @@ def to_dict(self) -> dict[str, Any]: "env": self.env, "user_type": self.user_type, "user_name": self.user_name, + "source": self.source, "data": self._data.copy(), } @@ -146,6 +150,16 @@ def get_current_user_name() -> str | None: return "memos" +def get_current_source() -> str | None: + """ + Get the current request's source (e.g., 'product_api' or 'server_api'). + """ + context = _request_context.get() + if context: + return context.get("source") + return None + + def get_current_context() -> RequestContext | None: """ Get the current request context. @@ -161,6 +175,7 @@ def get_current_context() -> RequestContext | None: env=context_dict.get("env"), user_type=context_dict.get("user_type"), user_name=context_dict.get("user_name"), + source=context_dict.get("source"), ) ctx._data = context_dict.get("data", {}).copy() return ctx From 06351272986f890219702ea1c1f9b141e7a99eca Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Tue, 4 Nov 2025 10:27:25 +0800 Subject: [PATCH 47/64] Feat/revert request context (#446) * feat: update log context * feat: update log context * feat: update mcp * feat: update mcp * feat: add error log * feat: add error log * feat: add error log * feat: update log * feat: add chat_time * feat: add chat_time * feat: add chat_time * feat: update log * feat: update log * feat: update log * feat: update log * feat: update log * feat: add arms * fix: format * fix: format * feat: add dockerfile * feat: add dockerfile * feat: add arms config * feat: update log * feat: add sleep time * feat: add sleep time * feat: update log * feat: delete dockerfile * feat: delete dockerfile * feat: update dockerfile * fix: conflict * feat: replace ThreadPool to context * feat: add timed log * feat: add request log * feat: add request log * feat: add source in request * feat: source * feat: revert context --------- Co-authored-by: harvey_xiang --- src/memos/api/middleware/request_context.py | 145 +------------------- 1 file changed, 1 insertion(+), 144 deletions(-) diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 488f59625..443aa1f3d 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -2,8 +2,6 @@ Request context middleware for automatic trace_id injection. """ -import json -import os import time from collections.abc import Callable @@ -19,9 +17,6 @@ logger = memos.log.get_logger(__name__) -# Maximum body size to read for logging (in bytes) - bodies larger than this will be skipped -MAX_BODY_LOG_SIZE = os.getenv("MAX_BODY_LOG_SIZE", 10 * 1024) - def extract_trace_id_from_headers(request: Request) -> str | None: """Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id.""" @@ -31,127 +26,6 @@ def extract_trace_id_from_headers(request: Request) -> str | None: return None -def _is_json_request(request: Request) -> tuple[bool, str]: - """ - Check if request is a JSON request. - - Args: - request: The request object - - Returns: - Tuple of (is_json, content_type) - """ - if request.method not in ("POST", "PUT", "PATCH", "DELETE"): - return False, "" - - content_type = request.headers.get("content-type", "") - if not content_type: - return False, "" - - is_json = "application/json" in content_type.lower() - return is_json, content_type - - -def _should_read_body(content_length: str | None) -> tuple[bool, int | None]: - """ - Check if body should be read based on content-length header. - - Args: - content_length: Content-Length header value - - Returns: - Tuple of (should_read, body_size). body_size is None if header is invalid. - """ - if not content_length: - return True, None - - try: - body_size = int(content_length) - return body_size <= MAX_BODY_LOG_SIZE, body_size - except ValueError: - return True, None - - -def _create_body_info(content_type: str, body_size: int) -> dict: - """Create body_info dict for large bodies that are skipped.""" - return { - "content_type": content_type, - "content_length": body_size, - "note": f"body too large ({body_size} bytes), skipping read", - } - - -def _parse_json_body(body_bytes: bytes) -> dict | str: - """ - Parse JSON body bytes. - - Args: - body_bytes: Raw body bytes - - Returns: - Parsed JSON dict, or error message string if parsing fails - """ - try: - return json.loads(body_bytes) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - return f"" - - -async def get_request_params(request: Request) -> tuple[dict, bytes | None]: - """ - Extract request parameters (query params and body) for logging. - - Only reads body for application/json requests that are within size limits. - - This function is wrapped with exception handling to ensure logging failures - don't affect the actual request processing. - - Args: - request: The incoming request object - - Returns: - Tuple of (params_dict, body_bytes). body_bytes is None if body was not read. - Returns empty dict and None on any error. - """ - try: - params_log = {} - - # Check if this is a JSON request - is_json, content_type = _is_json_request(request) - if not is_json: - return params_log, None - - # Pre-check body size using content-length header - content_length = request.headers.get("content-length") - should_read, body_size = _should_read_body(content_length) - - if not should_read and body_size is not None: - params_log["body_info"] = _create_body_info(content_type, body_size) - return params_log, None - - # Read body - body_bytes = await request.body() - - if not body_bytes: - return params_log, None - - # Post-check: verify actual size (content-length might be missing or wrong) - actual_size = len(body_bytes) - if actual_size > MAX_BODY_LOG_SIZE: - params_log["body_info"] = _create_body_info(content_type, actual_size) - return params_log, None - - # Parse JSON body - params_log["body"] = _parse_json_body(body_bytes) - return params_log, body_bytes - - except Exception as e: - # Catch-all for any unexpected errors - logger.error(f"Unexpected error in get_request_params: {e}", exc_info=True) - # Return empty dict to ensure request can continue - return {}, None - - class RequestContextMiddleware(BaseHTTPMiddleware): """ Middleware to automatically inject request context for every HTTP request. @@ -193,26 +67,9 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: ) set_request_context(context) - # Get request parameters for logging - # Wrap in try-catch to ensure logging failures don't break the request - params_log, body_bytes = await get_request_params(request) - - # Re-create the request receive function if body was read - # This ensures downstream handlers can still read the body - if body_bytes is not None: - try: - - async def receive(): - return {"type": "http.request", "body": body_bytes, "more_body": False} - - request._receive = receive - except Exception as e: - logger.error(f"Failed to recreate request receive function: {e}") - # Continue without restoring body, downstream handlers will handle it - logger.info( f"Request started, source: {self.source}, method: {request.method}, path: {request.url.path}, " - f"request params: {params_log}, headers: {request.headers}" + f"headers: {request.headers}" ) # Process the request From 4c8f89a2dbf0da5c7def95342d32de76f1898146 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Tue, 4 Nov 2025 15:25:36 +0800 Subject: [PATCH 48/64] fix prompt error (#447) Co-authored-by: yuan.wang --- src/memos/graph_dbs/polardb.py | 2 +- src/memos/templates/prefer_complete_prompt.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 552b30241..ac49228e2 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3024,7 +3024,7 @@ def format_param_value(self, value: str | None) -> str: """Format parameter value to handle both quoted and unquoted formats""" # Handle None value if value is None: - logger.warning(f"format_param_value: value is None") + logger.warning("format_param_value: value is None") return "null" # Remove outer quotes if they exist diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index ec06af27f..8f3a62abb 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -599,15 +599,15 @@ PREF_INSTRUCTIONS = """ # Note: -Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. +Fact memory are summaries of facts, while preference memory are summaries of user preferences. Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. -When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +When encountering preference conflicts, the priority is: explicit preference > implicit preference > fact memory. """ PREF_INSTRUCTIONS_ZH = """ # 注意: -明文记忆是事实的摘要,而偏好记忆是用户偏好的摘要。 +事实记忆是事实的摘要,而偏好记忆是用户偏好的摘要。 你的回复不得违反用户的任何偏好,无论是显式偏好还是隐式偏好,并简要解释你为什么这样回答以避免冲突。 -当遇到偏好冲突时,优先级为:显式偏好 > 隐式偏好 > 明文记忆。 +当遇到偏好冲突时,优先级为:显式偏好 > 隐式偏好 > 事实记忆。 """ From 88699f9e8f240c857880b00e4a1e7446d75313ab Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Tue, 4 Nov 2025 15:55:18 +0800 Subject: [PATCH 49/64] fix: fix search config input bug; patch retrieve_utils path set; adjust reader strategy template. (#448) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/config.py | 2 +- src/memos/configs/memory.py | 2 +- src/memos/mem_reader/strategy_struct.py | 15 ++- src/memos/memories/textual/simple_tree.py | 4 +- .../retrieve/retrieval_mid_structs.py | 1 + .../retrieve/retrieve_utils.py | 2 +- .../tree_text_memory/retrieve/searcher.py | 17 ++- .../retrieve/task_goal_parser.py | 13 ++- .../templates/mem_reader_strategy_prompts.py | 109 ++++++------------ 9 files changed, 70 insertions(+), 95 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d9db93c1a..7f61d54ac 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -427,7 +427,7 @@ def get_reader_config() -> dict[str, Any]: "config": { "chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"), "chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)), - "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)), + "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 10)), "chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)), }, } diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 49320fbf5..34967849a 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -184,7 +184,7 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) - search_strategy: dict[str, bool] | None = Field( + search_strategy: dict[str, Any] | None = Field( default=None, description=( 'Set search strategy for this memory configuration.{"bm25": true, "cot": false}' diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py index 2cac1652a..1fc21461e 100644 --- a/src/memos/mem_reader/strategy_struct.py +++ b/src/memos/mem_reader/strategy_struct.py @@ -43,7 +43,7 @@ def _get_llm_response(self, mem_str: str) -> dict: template = STRATEGY_PROMPT_DICT["chat"][lang] examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", mem_str) - if self.config.remove_prompt_example: + if self.config.remove_prompt_example: # TODO unused prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] try: @@ -112,6 +112,19 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: results.append([overlap_item, item]) current_length = overlap_length + content_length + else: + cut_size, cut_overlap = ( + self.chat_chunker["chunk_session"], + self.chat_chunker["chunk_overlap"], + ) + for items in scene_data: + step = cut_size - cut_overlap + end = len(items) - cut_overlap + if end <= 0: + results.extend([items[:]]) + else: + results.extend([items[i : i + cut_size] for i in range(0, end, step)]) + elif type == "doc": parser_config = ParserConfigFactory.model_validate( { diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 992b7bfab..313989cd2 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -66,7 +66,9 @@ def __init__( time_start_bm = time.time() self.search_strategy = config.search_strategy self.bm25_retriever = ( - EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None + EnhancedBM25() + if self.search_strategy and self.search_strategy.get("bm25", False) + else None ) logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py index 6accc4a16..7aefaa1a3 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py @@ -13,3 +13,4 @@ class ParsedTaskGoal: rephrased_query: str | None = None internet_search: bool = False goal_type: str | None = None # e.g., 'default', 'explanation', etc. + context: str = "" diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index eec827c86..3f2b41a47 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -17,7 +17,7 @@ def find_project_root(marker=".git"): if (current / marker).exists(): return current current = current.parent - logger.warn(f"The project root directory tag file was not found: {marker}") + return Path(".") PROJECT_ROOT = find_project_root() diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 0974d67f2..2f6ef6afa 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -30,8 +30,8 @@ logger = get_logger(__name__) COT_DICT = { - "fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, - "fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, + "fine": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, + "fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, } @@ -59,12 +59,8 @@ def __init__( # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube - self.vec_cot = ( - search_strategy.get("vec_cot", "false") == "true" if search_strategy else False - ) - self.use_fast_graph = ( - search_strategy.get("fast_graph", "false") == "true" if search_strategy else False - ) + self.vec_cot = search_strategy.get("cot", False) if search_strategy else False + self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -287,6 +283,7 @@ def _retrieve_paths( search_filter, user_name, id_filter, + mode=mode, ) ) tasks.append( @@ -369,6 +366,7 @@ def _retrieve_from_long_term_and_user( search_filter: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, + mode: str = "fast", ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] @@ -377,7 +375,7 @@ def _retrieve_from_long_term_and_user( # chain of thinking cot_embeddings = [] if self.vec_cot: - queries = self._cot_query(query) + queries = self._cot_query(query, mode=mode, context=parsed_goal.context) if len(queries) > 1: cot_embeddings = self.embedder.embed(queries) cot_embeddings.extend(query_embedding) @@ -566,7 +564,6 @@ def _cot_query( prompt = template.replace("${original_query}", query).replace( "${split_num_threshold}", str(split_num) ) - logger.info("COT process") messages = [{"role": "user", "content": prompt}] try: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 5d706559c..55e33494c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -39,7 +39,7 @@ def parse( - mode == 'fine': use LLM to parse structured topic/keys/tags """ if mode == "fast": - return self._parse_fast(task_description, **kwargs) + return self._parse_fast(task_description, context=context, **kwargs) elif mode == "fine": if not self.llm: raise ValueError("LLM not provided for slow mode.") @@ -51,6 +51,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: """ Fast mode: simple jieba word split. """ + context = kwargs.get("context", "") use_fast_graph = kwargs.get("use_fast_graph", False) if use_fast_graph: desc_tokenized = self.tokenizer.tokenize_mixed(task_description) @@ -61,6 +62,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: goal_type="default", rephrased_query=task_description, internet_search=False, + context=context, ) else: return ParsedTaskGoal( @@ -70,6 +72,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: goal_type="default", rephrased_query=task_description, internet_search=False, + context=context, ) def _parse_fine( @@ -91,16 +94,17 @@ def _parse_fine( logger.info(f"Parsing Goal... LLM input is {prompt}") response = self.llm.generate(messages=[{"role": "user", "content": prompt}]) logger.info(f"Parsing Goal... LLM Response is {response}") - return self._parse_response(response) + return self._parse_response(response, context=context) except Exception: logger.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}") - return self._parse_fast(query) + return self._parse_fast(query, context=context) - def _parse_response(self, response: str) -> ParsedTaskGoal: + def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: """ Parse LLM JSON output safely. """ try: + context = kwargs.get("context", "") response = response.replace("```", "").replace("json", "").strip() response_json = eval(response) return ParsedTaskGoal( @@ -110,6 +114,7 @@ def _parse_response(self, response: str) -> ParsedTaskGoal: rephrased_query=response_json.get("rephrased_instruction", None), internet_search=response_json.get("internet_search", False), goal_type=response_json.get("goal_type", "default"), + context=context, ) except Exception as e: raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py index fca4d717b..ba4a00d0a 100644 --- a/src/memos/templates/mem_reader_strategy_prompts.py +++ b/src/memos/templates/mem_reader_strategy_prompts.py @@ -16,8 +16,13 @@ - Always set "model_type" to "UserMemory" for this output. 3. Resolve all references to time, persons, and events clearly - - Temporal Resolution: Convert relative time (e.g., 'yesterday') to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. + - Temporal Resolution: Convert relative time (e.g., "yesterday") to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. + > Where feasible, use the message timestamp to convert relative time expressions into absolute dates (e.g., "yesterday" in a message dated January 15, 2023, can be converted to "January 14, 2023," and "last week" can be described as "the week preceding January 15, 2023"). + > Explicitly differentiate between the time when the event occurred and the time the message was sent. + > Clearly indicate any uncertainty (e.g., "approximately June 2025", "exact date unknown"). - Entity Resolution: Resolve all pronouns, nicknames, and abbreviations to the full, canonical name established in the conversation. + > For example, "Melanie" uses the abbreviated name "Mel" in the paragraph; when extracting her name in the "value" field, it should be restored to "Melanie". + - Location resolution: If specific locations are mentioned, include them explicitly. 4. Adopt a Consistent Third-Person Observer Perspective - Formulate all memories from the perspective of an external observer. Use "The user" or their specific name as the subject. @@ -38,14 +43,14 @@ 7. Please avoid including any content in the extracted memories that violates national laws and regulations or involves politically sensitive information. -Return a valid JSON object with the following structure: +Return a valid JSON object with the following structure: { "memory list": [ { "key": , "memory_type": , - "value": , + "value": , "tags": }, ... @@ -54,11 +59,11 @@ } Language rules: -- The `key`, `value`, `tags`, and `summary` fields must match the primary language of the input conversation. **If the input is Chinese, output in Chinese.** -- Keep `memory_type` in English. +- The `key`, `value`, `tags`, `summary` and `memory_type` fields must be in English. + Example: -Conversation: +Conversations: user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. assistant: Oh Tom! Do you think the team can finish by December 15? user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until December 10, so testing will be tight. @@ -71,65 +76,19 @@ { "key": "Initial project meeting", "memory_type": "LongTermMemory", - "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom met with the team to discuss a new project. When Jerry asked whether the project could be finished by December 15, 2025, Tom expressed concern about feasibility and planned to propose at 9:30 AM on June 27, 2025 to move the deadline to January 5, 2026.", + "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom held a meeting with their team to discuss a new project. The conversation covered the timeline and raised concerns about the feasibility of the December 15, 2025 deadline.", "tags": ["Tom", "project", "timeline", "meeting", "deadline"] }, { - "key": "Jerry’s suggestion about the deadline", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry viewpoint] Jerry questioned the December 15 deadline and suggested considering an extension.", - "tags": ["Jerry", "deadline change", "suggestion"] + "key": "Planned scope adjustment", + "memory_type": "UserMemory", + "value": "Tom planned to suggest in a meeting on June 27, 2025 at 9:30 AM that the team should prioritize features and propose shifting the project deadline to January 5, 2026.", + "tags": ["Tom", "planning", "deadline change", "feature prioritization"] } ], - "summary": "Tom is currently working on a tight-schedule project. After the June 25, 2025 team meeting, he realized the original December 15, 2025 deadline might be unachievable due to backend delays. Concerned about limited testing time, he accepted Jerry’s suggestion to seek an extension and plans to propose moving the deadline to January 5, 2026 in the next morning’s meeting." + "summary": "Tom is currently focused on managing a new project with a tight schedule. After a team meeting on June 25, 2025, he realized the original deadline of December 15 might not be feasible due to backend delays. Concerned about insufficient testing time, he welcomed Jerry’s suggestion of proposing an extension. Tom plans to raise the idea of shifting the deadline to January 5, 2026 in the next morning’s meeting. His actions reflect both stress about timelines and a proactive, team-oriented problem-solving approach." } -Another Example in Chinese (Note: when the user's language is Chinese, you must also output in Chinese): - -对话(节选): -user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 -assistant|19:32 -:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? -user|19:35:不喜欢亮色。国贸方便。 -assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 -user|19:40:165cm,S码;最好有口袋。 -assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 -user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 -assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 -user|19:52:行,周六(7/19)去国贸试,合适就买。 -assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 - -{ - "memory list": [ - { - "key": "参加婚礼购买裙子", - "memory_type": "UserMemory", - "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", - "tags": ["婚礼", "预算", "国贸", "计划"] - }, - { - "key": "审美与版型偏好", - "memory_type": "UserMemory", - "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", - "tags": ["偏好", "颜色", "版型"] - }, - { - "key": "体型尺码", - "memory_type": "UserMemory", - "value": "[user观点]用户身高约165cm、常穿S码", - "tags": ["体型", "尺码"] - }, - { - "key": "关于用户选购裙子的建议", - "memory_type": "LongTermMemory", - "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", - "tags": ["婚礼穿着", "门店", "选购路线"] - } - ], - "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" -} - -Always respond in the same language as the conversation. Conversation: ${conversation} @@ -155,7 +114,11 @@ 3. 明确解析所有指代关系 - 时间解析:根据消息时间戳将相对时间(如“昨天”)转换为绝对日期。区分事件时间与消息时间,对不确定项进行标注 + # 条件允许则使用消息时间戳将相对时间表达转换为绝对日期(如:2023年1月15日的“昨天”则转换为2023年1月14日);“上周”则转换为2023年1月15日前一周)。 + # 明确区分事件时间和消息时间。 + # 如果存在不确定性,需明确说明(例如,“约2025年6月”,“具体日期不详”)。 - 实体解析:将所有代词、昵称和缩写解析为对话中确立的完整规范名称 + - 地点解析:若提及具体地点,请包含在内。 4. 采用统一的第三人称观察视角 - 所有记忆表述均需从外部观察者视角构建,使用“用户”或其具体姓名作为主语 @@ -183,7 +146,7 @@ { "key": <字符串,唯一且简洁的记忆标题>, "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, - "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, + "value": <详细、独立且无歧义的记忆陈述>, "tags": <一个包含相关人名、事件和特征关键词的列表(例如,["丽丽","截止日期", "团队", "计划"])> }, ... @@ -192,10 +155,10 @@ } 语言规则: -- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** -- `memory_type` 保持英文。 +- `key`、`value`、`tags`、`summary` 、`memory_type` 字段必须输出中文 -示例: + +示例1: 对话: user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 assistant: 哦Tom!你觉得团队能在12月15日前完成吗? @@ -209,25 +172,20 @@ { "key": "项目初期会议", "memory_type": "LongTermMemory", - "value": "[user-Tom观点]2025年6月25日下午3:00,Tom与团队开会讨论新项目。当Jerry - 询问该项目能否在2025年12月15日前完成时,Tom对此日期前完成的可行性表达担忧,并计划在2025年6月27日上午9:30 - 提议将截止日期推迟至2026年1月5日。", - "tags": ["Tom", "项目", "时间表", "会议", "截止日期"] + "value": "2025年6月25日下午3:00,Tom与团队开会讨论新项目。会议涉及时间表,并提出了对2025年12月15日截止日期可行性的担忧。", + "tags": ["项目", "时间表", "会议", "截止日期"] }, { - "key": "Jerry对新项目截止日期的建议", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry观点]Jerry对Tom的新项目截止日期提出疑问、并提议Tom考虑延期。", - "tags": ["Jerry", "截止日期变更", "建议"] + "key": "计划调整范围", + "memory_type": "UserMemory", + "value": "Tom计划在2025年6月27日上午9:30的会议上建议团队优先处理功能,并提议将项目截止日期推迟至2026年1月5日。", + "tags": ["计划", "截止日期变更", "功能优先级"] } ], - "summary": "Tom目前正在做一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15 - 日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议,计划在次日早上的会议上提出将截止日期推迟至2026 - 年1月5日。" + "summary": "Tom目前正专注于管理一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议。Tom计划在次日早上的会议上提出将截止日期推迟至2026年1月5日。他的行为反映出对时间线的担忧,以及积极、以团队为导向的问题解决方式。" } -另一个中文示例(注意:当用户语言为中文时,您也需输出中文): - +示例2: 对话(节选): user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 assistant|19:32 @@ -271,7 +229,6 @@ "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" } -请始终使用与对话相同的语言进行回复。 对话: ${conversation} From fafc7475f9d4e3527fac51e9d91e70629eaf9993 Mon Sep 17 00:00:00 2001 From: Hao <42795704+Nyakult@users.noreply.github.com> Date: Tue, 4 Nov 2025 15:56:08 +0800 Subject: [PATCH 50/64] eval result (#428) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add breakpoint in eval scripts * feat(locomo): 支持断点续传 * format code * doc: Update readme eval result --- README.md | 30 ++++++++--------- evaluation/scripts/locomo/locomo_ingestion.py | 33 ++++++++++++++----- 2 files changed, 38 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 6873ba2b1..50621b584 100644 --- a/README.md +++ b/README.md @@ -54,22 +54,20 @@ ## 📈 Performance Benchmark -MemOS demonstrates significant improvements over baseline memory solutions in multiple reasoning tasks. - -| Model | Avg. Score | Multi-Hop | Open Domain | Single-Hop | Temporal Reasoning | -|-------------|------------|-----------|-------------|------------|---------------------| -| **OpenAI** | 0.5275 | 0.6028 | 0.3299 | 0.6183 | 0.2825 | -| **MemOS** | **0.7331** | **0.6430** | **0.5521** | **0.7844** | **0.7321** | -| **Improvement** | **+38.98%** | **+6.67%** | **+67.35%** | **+26.86%** | **+159.15%** | - -> 💡 **Temporal reasoning accuracy improved by 159% compared to the OpenAI baseline.** - -### Details of End-to-End Evaluation on LOCOMO - -> [!NOTE] -> Comparison of LLM Judge Scores across five major tasks in the LOCOMO benchmark. Each bar shows the mean evaluation score judged by LLMs for a given method-task pair, with standard deviation as error bars. MemOS-0630 consistently outperforms baseline methods (LangMem, Zep, OpenAI, Mem0) across all task types, especially in multi-hop and temporal reasoning scenarios. - -END2END SCORE +MemOS demonstrates significant improvements over baseline memory solutions in multiple memory tasks, +showcasing its capabilities in **information extraction**, **temporal and cross-session reasoning**, and **personalized preference responses**. + +| Model | LOCOMO | LongMemEval | PrefEval-10 | PersonaMem | +|-----------------|-------------|-------------|-------------|-------------| +| **GPT-4o-mini** | 52.75 | 55.4 | 2.8 | 43.46 | +| **MemOS** | **75.80** | **77.80** | **71.90** | **61.17** | +| **Improvement** | **+43.70%** | **+40.43%** | **+2568%** | **+40.75%** | + +### Detailed Evaluation Results +- We use gpt-4o-mini as the processing and judging LLM and bge-m3 as embedding model in MemOS evaluation. +- The evaluation was conducted under conditions that align various settings as closely as possible. Reproduce the results with our scripts at [`evaluation`](./evaluation). +- Check the full search and response details at huggingface https://huggingface.co/datasets/MemTensor/MemOS_eval_result. +> 💡 **MemOS outperforms all other methods (Mem0, Zep, Memobase, SuperMemory et al.) across all benchmarks!** ## ✨ Key Features diff --git a/evaluation/scripts/locomo/locomo_ingestion.py b/evaluation/scripts/locomo/locomo_ingestion.py index 518d90c4c..a9e4d5f02 100644 --- a/evaluation/scripts/locomo/locomo_ingestion.py +++ b/evaluation/scripts/locomo/locomo_ingestion.py @@ -88,7 +88,7 @@ def ingest_session(client, session, frame, version, metadata): return elapsed_time -def process_user(conv_idx, frame, locomo_df, version): +def process_user(conv_idx, frame, locomo_df, version, success_records, f): conversation = locomo_df["conversation"].iloc[conv_idx] max_session_count = 35 start_time = time.time() @@ -149,11 +149,15 @@ def process_user(conv_idx, frame, locomo_df, version): print(f"Processing {valid_sessions} sessions for user {conv_idx}") - for session, metadata in sessions_to_process: - session_time = ingest_session(client, session, frame, version, metadata) - total_session_time += session_time - print(f"User {conv_idx}, {metadata['session_key']} processed in {session_time} seconds") - + for session_idx, (session, metadata) in enumerate(sessions_to_process): + if f"{conv_idx}_{session_idx}" not in success_records: + session_time = ingest_session(client, session, frame, version, metadata) + total_session_time += session_time + print(f"User {conv_idx}, {metadata['session_key']} processed in {session_time} seconds") + f.write(f"{conv_idx}_{session_idx}\n") + f.flush() + else: + print(f"Session {conv_idx}_{session_idx} already ingested") end_time = time.time() elapsed_time = round(end_time - start_time, 2) print(f"User {conv_idx} processed successfully in {elapsed_time} seconds") @@ -170,9 +174,20 @@ def main(frame, version="default", num_workers=4): print( f"Starting processing for {num_users} users in serial mode, each user using {num_workers} workers for sessions..." ) - with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + os.makedirs(f"results/locomo/{frame}-{version}/", exist_ok=True) + success_records = [] + record_file = f"results/locomo/{frame}-{version}/success_records.txt" + if os.path.exists(record_file): + with open(record_file) as f: + for i in f.readlines(): + success_records.append(i.strip()) + + with ( + concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor, + open(record_file, "a+") as f, + ): futures = [ - executor.submit(process_user, user_id, frame, locomo_df, version) + executor.submit(process_user, user_id, frame, locomo_df, version, success_records, f) for user_id in range(num_users) ] for future in concurrent.futures.as_completed(futures): @@ -216,7 +231,7 @@ def main(frame, version="default", num_workers=4): help="Version identifier for saving results (e.g., 1010)", ) parser.add_argument( - "--workers", type=int, default=3, help="Number of parallel workers to process users" + "--workers", type=int, default=10, help="Number of parallel workers to process users" ) args = parser.parse_args() lib = args.lib From e62f72dab772d239c93c0a2ff84e731cfeae0f3a Mon Sep 17 00:00:00 2001 From: Wustzdy <67457465+wustzdy@users.noreply.github.com> Date: Tue, 4 Nov 2025 19:09:34 +0800 Subject: [PATCH 51/64] Useless quotes (#450) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: * fix: * format * fix:id --------- Co-authored-by: ccl <13282138256@163.com> --- src/memos/graph_dbs/polardb.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index ac49228e2..a7245a625 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -859,9 +859,9 @@ def get_node( if result: if include_embedding: - properties_json, embedding_json = result + _, properties_json, embedding_json = result else: - properties_json = result + _, properties_json = result embedding_json = None # Parse properties from JSONB if it's a string @@ -885,12 +885,14 @@ def get_node( properties["embedding"] = embedding except (json.JSONDecodeError, TypeError): logger.warning(f"Failed to parse embedding for node {id}") - + properties.pop("id") + properties.pop("memory") + properties.pop("user_name", None) return self._parse_node( { "id": id, - "memory": json.loads(properties[1]).get("memory", ""), - **json.loads(properties[1]), + "memory": properties.get("memory", ""), + **properties, } ) return None @@ -1290,6 +1292,8 @@ def get_subgraph( user_name = user_name if user_name else self._get_config_value("user_name") + if center_id.startswith('"') and center_id.endswith('"'): + center_id = center_id[1:-1] # Use a simplified query to get the subgraph (temporarily only direct neighbors) """ SELECT * FROM cypher('{self.db_name}_graph', $$ From f3e73386aaed49c8f0414b99ef72fad3169da93f Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Tue, 4 Nov 2025 21:52:12 +0800 Subject: [PATCH 52/64] Revert "fix prompt error" (#452) Revert "fix prompt error (#447)" This reverts commit 4c8f89a2dbf0da5c7def95342d32de76f1898146. --- src/memos/graph_dbs/polardb.py | 2 +- src/memos/templates/prefer_complete_prompt.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index a7245a625..77f3c76e0 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3028,7 +3028,7 @@ def format_param_value(self, value: str | None) -> str: """Format parameter value to handle both quoted and unquoted formats""" # Handle None value if value is None: - logger.warning("format_param_value: value is None") + logger.warning(f"format_param_value: value is None") return "null" # Remove outer quotes if they exist diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 8f3a62abb..ec06af27f 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -599,15 +599,15 @@ PREF_INSTRUCTIONS = """ # Note: -Fact memory are summaries of facts, while preference memory are summaries of user preferences. +Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. -When encountering preference conflicts, the priority is: explicit preference > implicit preference > fact memory. +When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. """ PREF_INSTRUCTIONS_ZH = """ # 注意: -事实记忆是事实的摘要,而偏好记忆是用户偏好的摘要。 +明文记忆是事实的摘要,而偏好记忆是用户偏好的摘要。 你的回复不得违反用户的任何偏好,无论是显式偏好还是隐式偏好,并简要解释你为什么这样回答以避免冲突。 -当遇到偏好冲突时,优先级为:显式偏好 > 隐式偏好 > 事实记忆。 +当遇到偏好冲突时,优先级为:显式偏好 > 隐式偏好 > 明文记忆。 """ From b8cd27b6a8df907f2cfe9e19cfe40b822f8ccb5f Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Tue, 4 Nov 2025 21:55:09 +0800 Subject: [PATCH 53/64] Revert "fix: fix search config input bug; patch retrieve_utils path set; adjust reader strategy template." (#453) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Revert "fix: fix search config input bug; patch retrieve_utils path set; adju…" This reverts commit 88699f9e8f240c857880b00e4a1e7446d75313ab. --- src/memos/api/config.py | 2 +- src/memos/configs/memory.py | 2 +- src/memos/mem_reader/strategy_struct.py | 15 +-- src/memos/memories/textual/simple_tree.py | 4 +- .../retrieve/retrieval_mid_structs.py | 1 - .../retrieve/retrieve_utils.py | 2 +- .../tree_text_memory/retrieve/searcher.py | 17 +-- .../retrieve/task_goal_parser.py | 13 +-- .../templates/mem_reader_strategy_prompts.py | 109 ++++++++++++------ 9 files changed, 95 insertions(+), 70 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 7f61d54ac..d9db93c1a 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -427,7 +427,7 @@ def get_reader_config() -> dict[str, Any]: "config": { "chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"), "chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)), - "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 10)), + "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)), "chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)), }, } diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 34967849a..49320fbf5 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -184,7 +184,7 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) - search_strategy: dict[str, Any] | None = Field( + search_strategy: dict[str, bool] | None = Field( default=None, description=( 'Set search strategy for this memory configuration.{"bm25": true, "cot": false}' diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py index 1fc21461e..2cac1652a 100644 --- a/src/memos/mem_reader/strategy_struct.py +++ b/src/memos/mem_reader/strategy_struct.py @@ -43,7 +43,7 @@ def _get_llm_response(self, mem_str: str) -> dict: template = STRATEGY_PROMPT_DICT["chat"][lang] examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", mem_str) - if self.config.remove_prompt_example: # TODO unused + if self.config.remove_prompt_example: prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] try: @@ -112,19 +112,6 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: results.append([overlap_item, item]) current_length = overlap_length + content_length - else: - cut_size, cut_overlap = ( - self.chat_chunker["chunk_session"], - self.chat_chunker["chunk_overlap"], - ) - for items in scene_data: - step = cut_size - cut_overlap - end = len(items) - cut_overlap - if end <= 0: - results.extend([items[:]]) - else: - results.extend([items[i : i + cut_size] for i in range(0, end, step)]) - elif type == "doc": parser_config = ParserConfigFactory.model_validate( { diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 313989cd2..992b7bfab 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -66,9 +66,7 @@ def __init__( time_start_bm = time.time() self.search_strategy = config.search_strategy self.bm25_retriever = ( - EnhancedBM25() - if self.search_strategy and self.search_strategy.get("bm25", False) - else None + EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None ) logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py index 7aefaa1a3..6accc4a16 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py @@ -13,4 +13,3 @@ class ParsedTaskGoal: rephrased_query: str | None = None internet_search: bool = False goal_type: str | None = None # e.g., 'default', 'explanation', etc. - context: str = "" diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 3f2b41a47..eec827c86 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -17,7 +17,7 @@ def find_project_root(marker=".git"): if (current / marker).exists(): return current current = current.parent - return Path(".") + logger.warn(f"The project root directory tag file was not found: {marker}") PROJECT_ROOT = find_project_root() diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 2f6ef6afa..0974d67f2 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -30,8 +30,8 @@ logger = get_logger(__name__) COT_DICT = { - "fine": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, - "fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, + "fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, + "fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, } @@ -59,8 +59,12 @@ def __init__( # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube - self.vec_cot = search_strategy.get("cot", False) if search_strategy else False - self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False + self.vec_cot = ( + search_strategy.get("vec_cot", "false") == "true" if search_strategy else False + ) + self.use_fast_graph = ( + search_strategy.get("fast_graph", "false") == "true" if search_strategy else False + ) self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -283,7 +287,6 @@ def _retrieve_paths( search_filter, user_name, id_filter, - mode=mode, ) ) tasks.append( @@ -366,7 +369,6 @@ def _retrieve_from_long_term_and_user( search_filter: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, - mode: str = "fast", ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] @@ -375,7 +377,7 @@ def _retrieve_from_long_term_and_user( # chain of thinking cot_embeddings = [] if self.vec_cot: - queries = self._cot_query(query, mode=mode, context=parsed_goal.context) + queries = self._cot_query(query) if len(queries) > 1: cot_embeddings = self.embedder.embed(queries) cot_embeddings.extend(query_embedding) @@ -564,6 +566,7 @@ def _cot_query( prompt = template.replace("${original_query}", query).replace( "${split_num_threshold}", str(split_num) ) + logger.info("COT process") messages = [{"role": "user", "content": prompt}] try: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 55e33494c..5d706559c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -39,7 +39,7 @@ def parse( - mode == 'fine': use LLM to parse structured topic/keys/tags """ if mode == "fast": - return self._parse_fast(task_description, context=context, **kwargs) + return self._parse_fast(task_description, **kwargs) elif mode == "fine": if not self.llm: raise ValueError("LLM not provided for slow mode.") @@ -51,7 +51,6 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: """ Fast mode: simple jieba word split. """ - context = kwargs.get("context", "") use_fast_graph = kwargs.get("use_fast_graph", False) if use_fast_graph: desc_tokenized = self.tokenizer.tokenize_mixed(task_description) @@ -62,7 +61,6 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: goal_type="default", rephrased_query=task_description, internet_search=False, - context=context, ) else: return ParsedTaskGoal( @@ -72,7 +70,6 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: goal_type="default", rephrased_query=task_description, internet_search=False, - context=context, ) def _parse_fine( @@ -94,17 +91,16 @@ def _parse_fine( logger.info(f"Parsing Goal... LLM input is {prompt}") response = self.llm.generate(messages=[{"role": "user", "content": prompt}]) logger.info(f"Parsing Goal... LLM Response is {response}") - return self._parse_response(response, context=context) + return self._parse_response(response) except Exception: logger.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}") - return self._parse_fast(query, context=context) + return self._parse_fast(query) - def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: + def _parse_response(self, response: str) -> ParsedTaskGoal: """ Parse LLM JSON output safely. """ try: - context = kwargs.get("context", "") response = response.replace("```", "").replace("json", "").strip() response_json = eval(response) return ParsedTaskGoal( @@ -114,7 +110,6 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: rephrased_query=response_json.get("rephrased_instruction", None), internet_search=response_json.get("internet_search", False), goal_type=response_json.get("goal_type", "default"), - context=context, ) except Exception as e: raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py index ba4a00d0a..fca4d717b 100644 --- a/src/memos/templates/mem_reader_strategy_prompts.py +++ b/src/memos/templates/mem_reader_strategy_prompts.py @@ -16,13 +16,8 @@ - Always set "model_type" to "UserMemory" for this output. 3. Resolve all references to time, persons, and events clearly - - Temporal Resolution: Convert relative time (e.g., "yesterday") to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. - > Where feasible, use the message timestamp to convert relative time expressions into absolute dates (e.g., "yesterday" in a message dated January 15, 2023, can be converted to "January 14, 2023," and "last week" can be described as "the week preceding January 15, 2023"). - > Explicitly differentiate between the time when the event occurred and the time the message was sent. - > Clearly indicate any uncertainty (e.g., "approximately June 2025", "exact date unknown"). + - Temporal Resolution: Convert relative time (e.g., 'yesterday') to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. - Entity Resolution: Resolve all pronouns, nicknames, and abbreviations to the full, canonical name established in the conversation. - > For example, "Melanie" uses the abbreviated name "Mel" in the paragraph; when extracting her name in the "value" field, it should be restored to "Melanie". - - Location resolution: If specific locations are mentioned, include them explicitly. 4. Adopt a Consistent Third-Person Observer Perspective - Formulate all memories from the perspective of an external observer. Use "The user" or their specific name as the subject. @@ -43,14 +38,14 @@ 7. Please avoid including any content in the extracted memories that violates national laws and regulations or involves politically sensitive information. - Return a valid JSON object with the following structure: + { "memory list": [ { "key": , "memory_type": , - "value": , + "value": , "tags": }, ... @@ -59,11 +54,11 @@ } Language rules: -- The `key`, `value`, `tags`, `summary` and `memory_type` fields must be in English. - +- The `key`, `value`, `tags`, and `summary` fields must match the primary language of the input conversation. **If the input is Chinese, output in Chinese.** +- Keep `memory_type` in English. Example: -Conversations: +Conversation: user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. assistant: Oh Tom! Do you think the team can finish by December 15? user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until December 10, so testing will be tight. @@ -76,19 +71,65 @@ { "key": "Initial project meeting", "memory_type": "LongTermMemory", - "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom held a meeting with their team to discuss a new project. The conversation covered the timeline and raised concerns about the feasibility of the December 15, 2025 deadline.", + "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom met with the team to discuss a new project. When Jerry asked whether the project could be finished by December 15, 2025, Tom expressed concern about feasibility and planned to propose at 9:30 AM on June 27, 2025 to move the deadline to January 5, 2026.", "tags": ["Tom", "project", "timeline", "meeting", "deadline"] }, { - "key": "Planned scope adjustment", - "memory_type": "UserMemory", - "value": "Tom planned to suggest in a meeting on June 27, 2025 at 9:30 AM that the team should prioritize features and propose shifting the project deadline to January 5, 2026.", - "tags": ["Tom", "planning", "deadline change", "feature prioritization"] + "key": "Jerry’s suggestion about the deadline", + "memory_type": "LongTermMemory", + "value": "[assistant-Jerry viewpoint] Jerry questioned the December 15 deadline and suggested considering an extension.", + "tags": ["Jerry", "deadline change", "suggestion"] } ], - "summary": "Tom is currently focused on managing a new project with a tight schedule. After a team meeting on June 25, 2025, he realized the original deadline of December 15 might not be feasible due to backend delays. Concerned about insufficient testing time, he welcomed Jerry’s suggestion of proposing an extension. Tom plans to raise the idea of shifting the deadline to January 5, 2026 in the next morning’s meeting. His actions reflect both stress about timelines and a proactive, team-oriented problem-solving approach." + "summary": "Tom is currently working on a tight-schedule project. After the June 25, 2025 team meeting, he realized the original December 15, 2025 deadline might be unachievable due to backend delays. Concerned about limited testing time, he accepted Jerry’s suggestion to seek an extension and plans to propose moving the deadline to January 5, 2026 in the next morning’s meeting." } +Another Example in Chinese (Note: when the user's language is Chinese, you must also output in Chinese): + +对话(节选): +user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 +assistant|19:32 +:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? +user|19:35:不喜欢亮色。国贸方便。 +assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 +user|19:40:165cm,S码;最好有口袋。 +assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 +user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 +assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 +user|19:52:行,周六(7/19)去国贸试,合适就买。 +assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 + +{ + "memory list": [ + { + "key": "参加婚礼购买裙子", + "memory_type": "UserMemory", + "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", + "tags": ["婚礼", "预算", "国贸", "计划"] + }, + { + "key": "审美与版型偏好", + "memory_type": "UserMemory", + "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", + "tags": ["偏好", "颜色", "版型"] + }, + { + "key": "体型尺码", + "memory_type": "UserMemory", + "value": "[user观点]用户身高约165cm、常穿S码", + "tags": ["体型", "尺码"] + }, + { + "key": "关于用户选购裙子的建议", + "memory_type": "LongTermMemory", + "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", + "tags": ["婚礼穿着", "门店", "选购路线"] + } + ], + "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" +} + +Always respond in the same language as the conversation. Conversation: ${conversation} @@ -114,11 +155,7 @@ 3. 明确解析所有指代关系 - 时间解析:根据消息时间戳将相对时间(如“昨天”)转换为绝对日期。区分事件时间与消息时间,对不确定项进行标注 - # 条件允许则使用消息时间戳将相对时间表达转换为绝对日期(如:2023年1月15日的“昨天”则转换为2023年1月14日);“上周”则转换为2023年1月15日前一周)。 - # 明确区分事件时间和消息时间。 - # 如果存在不确定性,需明确说明(例如,“约2025年6月”,“具体日期不详”)。 - 实体解析:将所有代词、昵称和缩写解析为对话中确立的完整规范名称 - - 地点解析:若提及具体地点,请包含在内。 4. 采用统一的第三人称观察视角 - 所有记忆表述均需从外部观察者视角构建,使用“用户”或其具体姓名作为主语 @@ -146,7 +183,7 @@ { "key": <字符串,唯一且简洁的记忆标题>, "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, - "value": <详细、独立且无歧义的记忆陈述>, + "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, "tags": <一个包含相关人名、事件和特征关键词的列表(例如,["丽丽","截止日期", "团队", "计划"])> }, ... @@ -155,10 +192,10 @@ } 语言规则: -- `key`、`value`、`tags`、`summary` 、`memory_type` 字段必须输出中文 +- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 - -示例1: +示例: 对话: user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 assistant: 哦Tom!你觉得团队能在12月15日前完成吗? @@ -172,20 +209,25 @@ { "key": "项目初期会议", "memory_type": "LongTermMemory", - "value": "2025年6月25日下午3:00,Tom与团队开会讨论新项目。会议涉及时间表,并提出了对2025年12月15日截止日期可行性的担忧。", - "tags": ["项目", "时间表", "会议", "截止日期"] + "value": "[user-Tom观点]2025年6月25日下午3:00,Tom与团队开会讨论新项目。当Jerry + 询问该项目能否在2025年12月15日前完成时,Tom对此日期前完成的可行性表达担忧,并计划在2025年6月27日上午9:30 + 提议将截止日期推迟至2026年1月5日。", + "tags": ["Tom", "项目", "时间表", "会议", "截止日期"] }, { - "key": "计划调整范围", - "memory_type": "UserMemory", - "value": "Tom计划在2025年6月27日上午9:30的会议上建议团队优先处理功能,并提议将项目截止日期推迟至2026年1月5日。", - "tags": ["计划", "截止日期变更", "功能优先级"] + "key": "Jerry对新项目截止日期的建议", + "memory_type": "LongTermMemory", + "value": "[assistant-Jerry观点]Jerry对Tom的新项目截止日期提出疑问、并提议Tom考虑延期。", + "tags": ["Jerry", "截止日期变更", "建议"] } ], - "summary": "Tom目前正专注于管理一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议。Tom计划在次日早上的会议上提出将截止日期推迟至2026年1月5日。他的行为反映出对时间线的担忧,以及积极、以团队为导向的问题解决方式。" + "summary": "Tom目前正在做一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15 + 日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议,计划在次日早上的会议上提出将截止日期推迟至2026 + 年1月5日。" } -示例2: +另一个中文示例(注意:当用户语言为中文时,您也需输出中文): + 对话(节选): user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 assistant|19:32 @@ -229,6 +271,7 @@ "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" } +请始终使用与对话相同的语言进行回复。 对话: ${conversation} From e07a1b43624d1b2061c2f2bf2593607afd6b1475 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Tue, 4 Nov 2025 22:01:21 +0800 Subject: [PATCH 54/64] Revert "Useless quotes" (#454) Revert "Useless quotes (#450)" This reverts commit e62f72dab772d239c93c0a2ff84e731cfeae0f3a. --- src/memos/graph_dbs/polardb.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 77f3c76e0..552b30241 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -859,9 +859,9 @@ def get_node( if result: if include_embedding: - _, properties_json, embedding_json = result + properties_json, embedding_json = result else: - _, properties_json = result + properties_json = result embedding_json = None # Parse properties from JSONB if it's a string @@ -885,14 +885,12 @@ def get_node( properties["embedding"] = embedding except (json.JSONDecodeError, TypeError): logger.warning(f"Failed to parse embedding for node {id}") - properties.pop("id") - properties.pop("memory") - properties.pop("user_name", None) + return self._parse_node( { "id": id, - "memory": properties.get("memory", ""), - **properties, + "memory": json.loads(properties[1]).get("memory", ""), + **json.loads(properties[1]), } ) return None @@ -1292,8 +1290,6 @@ def get_subgraph( user_name = user_name if user_name else self._get_config_value("user_name") - if center_id.startswith('"') and center_id.endswith('"'): - center_id = center_id[1:-1] # Use a simplified query to get the subgraph (temporarily only direct neighbors) """ SELECT * FROM cypher('{self.db_name}_graph', $$ From f67ca36408418a8f2c8200f5130f101c938d9304 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:41:27 +0800 Subject: [PATCH 55/64] fix prompt error (#455) Co-authored-by: yuan.wang Co-authored-by: CaralHsi --- src/memos/graph_dbs/polardb.py | 2 +- src/memos/templates/prefer_complete_prompt.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 552b30241..ac49228e2 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3024,7 +3024,7 @@ def format_param_value(self, value: str | None) -> str: """Format parameter value to handle both quoted and unquoted formats""" # Handle None value if value is None: - logger.warning(f"format_param_value: value is None") + logger.warning("format_param_value: value is None") return "null" # Remove outer quotes if they exist diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index ec06af27f..8f3a62abb 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -599,15 +599,15 @@ PREF_INSTRUCTIONS = """ # Note: -Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. +Fact memory are summaries of facts, while preference memory are summaries of user preferences. Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. -When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +When encountering preference conflicts, the priority is: explicit preference > implicit preference > fact memory. """ PREF_INSTRUCTIONS_ZH = """ # 注意: -明文记忆是事实的摘要,而偏好记忆是用户偏好的摘要。 +事实记忆是事实的摘要,而偏好记忆是用户偏好的摘要。 你的回复不得违反用户的任何偏好,无论是显式偏好还是隐式偏好,并简要解释你为什么这样回答以避免冲突。 -当遇到偏好冲突时,优先级为:显式偏好 > 隐式偏好 > 明文记忆。 +当遇到偏好冲突时,优先级为:显式偏好 > 隐式偏好 > 事实记忆。 """ From 7c4a74ca3e3cfc4606619608ea75c8fac4b45cf7 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:45:29 +0800 Subject: [PATCH 56/64] Feat/remove pref rank prompt (#456) modify prompt Co-authored-by: yuan.wang --- src/memos/templates/prefer_complete_prompt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 8f3a62abb..9e0274cba 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -601,7 +601,6 @@ # Note: Fact memory are summaries of facts, while preference memory are summaries of user preferences. Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. -When encountering preference conflicts, the priority is: explicit preference > implicit preference > fact memory. """ @@ -609,5 +608,4 @@ # 注意: 事实记忆是事实的摘要,而偏好记忆是用户偏好的摘要。 你的回复不得违反用户的任何偏好,无论是显式偏好还是隐式偏好,并简要解释你为什么这样回答以避免冲突。 -当遇到偏好冲突时,优先级为:显式偏好 > 隐式偏好 > 事实记忆。 """ From a0f3a0019581814620218c6288d465fff0a4abe6 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:46:21 +0800 Subject: [PATCH 57/64] fix: fix strategy reader input; code reformat (#457) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- README.md | 2 +- src/memos/api/config.py | 2 +- src/memos/configs/memory.py | 2 +- src/memos/mem_reader/strategy_struct.py | 15 ++- src/memos/memories/textual/simple_tree.py | 4 +- .../retrieve/retrieval_mid_structs.py | 1 + .../retrieve/retrieve_utils.py | 2 +- .../tree_text_memory/retrieve/searcher.py | 17 ++- .../retrieve/task_goal_parser.py | 13 ++- .../templates/mem_reader_strategy_prompts.py | 109 ++++++------------ 10 files changed, 71 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index 50621b584..a08177676 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ MemOS demonstrates significant improvements over baseline memory solutions in multiple memory tasks, showcasing its capabilities in **information extraction**, **temporal and cross-session reasoning**, and **personalized preference responses**. -| Model | LOCOMO | LongMemEval | PrefEval-10 | PersonaMem | +| Model | LOCOMO | LongMemEval | PrefEval-10 | PersonaMem | |-----------------|-------------|-------------|-------------|-------------| | **GPT-4o-mini** | 52.75 | 55.4 | 2.8 | 43.46 | | **MemOS** | **75.80** | **77.80** | **71.90** | **61.17** | diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d9db93c1a..7f61d54ac 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -427,7 +427,7 @@ def get_reader_config() -> dict[str, Any]: "config": { "chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"), "chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)), - "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)), + "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 10)), "chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)), }, } diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 49320fbf5..34967849a 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -184,7 +184,7 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) - search_strategy: dict[str, bool] | None = Field( + search_strategy: dict[str, Any] | None = Field( default=None, description=( 'Set search strategy for this memory configuration.{"bm25": true, "cot": false}' diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py index 2cac1652a..1fc21461e 100644 --- a/src/memos/mem_reader/strategy_struct.py +++ b/src/memos/mem_reader/strategy_struct.py @@ -43,7 +43,7 @@ def _get_llm_response(self, mem_str: str) -> dict: template = STRATEGY_PROMPT_DICT["chat"][lang] examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", mem_str) - if self.config.remove_prompt_example: + if self.config.remove_prompt_example: # TODO unused prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] try: @@ -112,6 +112,19 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: results.append([overlap_item, item]) current_length = overlap_length + content_length + else: + cut_size, cut_overlap = ( + self.chat_chunker["chunk_session"], + self.chat_chunker["chunk_overlap"], + ) + for items in scene_data: + step = cut_size - cut_overlap + end = len(items) - cut_overlap + if end <= 0: + results.extend([items[:]]) + else: + results.extend([items[i : i + cut_size] for i in range(0, end, step)]) + elif type == "doc": parser_config = ParserConfigFactory.model_validate( { diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 992b7bfab..313989cd2 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -66,7 +66,9 @@ def __init__( time_start_bm = time.time() self.search_strategy = config.search_strategy self.bm25_retriever = ( - EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None + EnhancedBM25() + if self.search_strategy and self.search_strategy.get("bm25", False) + else None ) logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py index 6accc4a16..7aefaa1a3 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py @@ -13,3 +13,4 @@ class ParsedTaskGoal: rephrased_query: str | None = None internet_search: bool = False goal_type: str | None = None # e.g., 'default', 'explanation', etc. + context: str = "" diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index eec827c86..3f2b41a47 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -17,7 +17,7 @@ def find_project_root(marker=".git"): if (current / marker).exists(): return current current = current.parent - logger.warn(f"The project root directory tag file was not found: {marker}") + return Path(".") PROJECT_ROOT = find_project_root() diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 0974d67f2..2f6ef6afa 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -30,8 +30,8 @@ logger = get_logger(__name__) COT_DICT = { - "fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, - "fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, + "fine": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, + "fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, } @@ -59,12 +59,8 @@ def __init__( # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube - self.vec_cot = ( - search_strategy.get("vec_cot", "false") == "true" if search_strategy else False - ) - self.use_fast_graph = ( - search_strategy.get("fast_graph", "false") == "true" if search_strategy else False - ) + self.vec_cot = search_strategy.get("cot", False) if search_strategy else False + self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -287,6 +283,7 @@ def _retrieve_paths( search_filter, user_name, id_filter, + mode=mode, ) ) tasks.append( @@ -369,6 +366,7 @@ def _retrieve_from_long_term_and_user( search_filter: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, + mode: str = "fast", ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] @@ -377,7 +375,7 @@ def _retrieve_from_long_term_and_user( # chain of thinking cot_embeddings = [] if self.vec_cot: - queries = self._cot_query(query) + queries = self._cot_query(query, mode=mode, context=parsed_goal.context) if len(queries) > 1: cot_embeddings = self.embedder.embed(queries) cot_embeddings.extend(query_embedding) @@ -566,7 +564,6 @@ def _cot_query( prompt = template.replace("${original_query}", query).replace( "${split_num_threshold}", str(split_num) ) - logger.info("COT process") messages = [{"role": "user", "content": prompt}] try: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 5d706559c..55e33494c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -39,7 +39,7 @@ def parse( - mode == 'fine': use LLM to parse structured topic/keys/tags """ if mode == "fast": - return self._parse_fast(task_description, **kwargs) + return self._parse_fast(task_description, context=context, **kwargs) elif mode == "fine": if not self.llm: raise ValueError("LLM not provided for slow mode.") @@ -51,6 +51,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: """ Fast mode: simple jieba word split. """ + context = kwargs.get("context", "") use_fast_graph = kwargs.get("use_fast_graph", False) if use_fast_graph: desc_tokenized = self.tokenizer.tokenize_mixed(task_description) @@ -61,6 +62,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: goal_type="default", rephrased_query=task_description, internet_search=False, + context=context, ) else: return ParsedTaskGoal( @@ -70,6 +72,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: goal_type="default", rephrased_query=task_description, internet_search=False, + context=context, ) def _parse_fine( @@ -91,16 +94,17 @@ def _parse_fine( logger.info(f"Parsing Goal... LLM input is {prompt}") response = self.llm.generate(messages=[{"role": "user", "content": prompt}]) logger.info(f"Parsing Goal... LLM Response is {response}") - return self._parse_response(response) + return self._parse_response(response, context=context) except Exception: logger.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}") - return self._parse_fast(query) + return self._parse_fast(query, context=context) - def _parse_response(self, response: str) -> ParsedTaskGoal: + def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: """ Parse LLM JSON output safely. """ try: + context = kwargs.get("context", "") response = response.replace("```", "").replace("json", "").strip() response_json = eval(response) return ParsedTaskGoal( @@ -110,6 +114,7 @@ def _parse_response(self, response: str) -> ParsedTaskGoal: rephrased_query=response_json.get("rephrased_instruction", None), internet_search=response_json.get("internet_search", False), goal_type=response_json.get("goal_type", "default"), + context=context, ) except Exception as e: raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py index fca4d717b..ba4a00d0a 100644 --- a/src/memos/templates/mem_reader_strategy_prompts.py +++ b/src/memos/templates/mem_reader_strategy_prompts.py @@ -16,8 +16,13 @@ - Always set "model_type" to "UserMemory" for this output. 3. Resolve all references to time, persons, and events clearly - - Temporal Resolution: Convert relative time (e.g., 'yesterday') to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. + - Temporal Resolution: Convert relative time (e.g., "yesterday") to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. + > Where feasible, use the message timestamp to convert relative time expressions into absolute dates (e.g., "yesterday" in a message dated January 15, 2023, can be converted to "January 14, 2023," and "last week" can be described as "the week preceding January 15, 2023"). + > Explicitly differentiate between the time when the event occurred and the time the message was sent. + > Clearly indicate any uncertainty (e.g., "approximately June 2025", "exact date unknown"). - Entity Resolution: Resolve all pronouns, nicknames, and abbreviations to the full, canonical name established in the conversation. + > For example, "Melanie" uses the abbreviated name "Mel" in the paragraph; when extracting her name in the "value" field, it should be restored to "Melanie". + - Location resolution: If specific locations are mentioned, include them explicitly. 4. Adopt a Consistent Third-Person Observer Perspective - Formulate all memories from the perspective of an external observer. Use "The user" or their specific name as the subject. @@ -38,14 +43,14 @@ 7. Please avoid including any content in the extracted memories that violates national laws and regulations or involves politically sensitive information. -Return a valid JSON object with the following structure: +Return a valid JSON object with the following structure: { "memory list": [ { "key": , "memory_type": , - "value": , + "value": , "tags": }, ... @@ -54,11 +59,11 @@ } Language rules: -- The `key`, `value`, `tags`, and `summary` fields must match the primary language of the input conversation. **If the input is Chinese, output in Chinese.** -- Keep `memory_type` in English. +- The `key`, `value`, `tags`, `summary` and `memory_type` fields must be in English. + Example: -Conversation: +Conversations: user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. assistant: Oh Tom! Do you think the team can finish by December 15? user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until December 10, so testing will be tight. @@ -71,65 +76,19 @@ { "key": "Initial project meeting", "memory_type": "LongTermMemory", - "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom met with the team to discuss a new project. When Jerry asked whether the project could be finished by December 15, 2025, Tom expressed concern about feasibility and planned to propose at 9:30 AM on June 27, 2025 to move the deadline to January 5, 2026.", + "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom held a meeting with their team to discuss a new project. The conversation covered the timeline and raised concerns about the feasibility of the December 15, 2025 deadline.", "tags": ["Tom", "project", "timeline", "meeting", "deadline"] }, { - "key": "Jerry’s suggestion about the deadline", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry viewpoint] Jerry questioned the December 15 deadline and suggested considering an extension.", - "tags": ["Jerry", "deadline change", "suggestion"] + "key": "Planned scope adjustment", + "memory_type": "UserMemory", + "value": "Tom planned to suggest in a meeting on June 27, 2025 at 9:30 AM that the team should prioritize features and propose shifting the project deadline to January 5, 2026.", + "tags": ["Tom", "planning", "deadline change", "feature prioritization"] } ], - "summary": "Tom is currently working on a tight-schedule project. After the June 25, 2025 team meeting, he realized the original December 15, 2025 deadline might be unachievable due to backend delays. Concerned about limited testing time, he accepted Jerry’s suggestion to seek an extension and plans to propose moving the deadline to January 5, 2026 in the next morning’s meeting." + "summary": "Tom is currently focused on managing a new project with a tight schedule. After a team meeting on June 25, 2025, he realized the original deadline of December 15 might not be feasible due to backend delays. Concerned about insufficient testing time, he welcomed Jerry’s suggestion of proposing an extension. Tom plans to raise the idea of shifting the deadline to January 5, 2026 in the next morning’s meeting. His actions reflect both stress about timelines and a proactive, team-oriented problem-solving approach." } -Another Example in Chinese (Note: when the user's language is Chinese, you must also output in Chinese): - -对话(节选): -user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 -assistant|19:32 -:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? -user|19:35:不喜欢亮色。国贸方便。 -assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 -user|19:40:165cm,S码;最好有口袋。 -assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 -user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 -assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 -user|19:52:行,周六(7/19)去国贸试,合适就买。 -assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 - -{ - "memory list": [ - { - "key": "参加婚礼购买裙子", - "memory_type": "UserMemory", - "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", - "tags": ["婚礼", "预算", "国贸", "计划"] - }, - { - "key": "审美与版型偏好", - "memory_type": "UserMemory", - "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", - "tags": ["偏好", "颜色", "版型"] - }, - { - "key": "体型尺码", - "memory_type": "UserMemory", - "value": "[user观点]用户身高约165cm、常穿S码", - "tags": ["体型", "尺码"] - }, - { - "key": "关于用户选购裙子的建议", - "memory_type": "LongTermMemory", - "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", - "tags": ["婚礼穿着", "门店", "选购路线"] - } - ], - "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" -} - -Always respond in the same language as the conversation. Conversation: ${conversation} @@ -155,7 +114,11 @@ 3. 明确解析所有指代关系 - 时间解析:根据消息时间戳将相对时间(如“昨天”)转换为绝对日期。区分事件时间与消息时间,对不确定项进行标注 + # 条件允许则使用消息时间戳将相对时间表达转换为绝对日期(如:2023年1月15日的“昨天”则转换为2023年1月14日);“上周”则转换为2023年1月15日前一周)。 + # 明确区分事件时间和消息时间。 + # 如果存在不确定性,需明确说明(例如,“约2025年6月”,“具体日期不详”)。 - 实体解析:将所有代词、昵称和缩写解析为对话中确立的完整规范名称 + - 地点解析:若提及具体地点,请包含在内。 4. 采用统一的第三人称观察视角 - 所有记忆表述均需从外部观察者视角构建,使用“用户”或其具体姓名作为主语 @@ -183,7 +146,7 @@ { "key": <字符串,唯一且简洁的记忆标题>, "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, - "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, + "value": <详细、独立且无歧义的记忆陈述>, "tags": <一个包含相关人名、事件和特征关键词的列表(例如,["丽丽","截止日期", "团队", "计划"])> }, ... @@ -192,10 +155,10 @@ } 语言规则: -- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** -- `memory_type` 保持英文。 +- `key`、`value`、`tags`、`summary` 、`memory_type` 字段必须输出中文 -示例: + +示例1: 对话: user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 assistant: 哦Tom!你觉得团队能在12月15日前完成吗? @@ -209,25 +172,20 @@ { "key": "项目初期会议", "memory_type": "LongTermMemory", - "value": "[user-Tom观点]2025年6月25日下午3:00,Tom与团队开会讨论新项目。当Jerry - 询问该项目能否在2025年12月15日前完成时,Tom对此日期前完成的可行性表达担忧,并计划在2025年6月27日上午9:30 - 提议将截止日期推迟至2026年1月5日。", - "tags": ["Tom", "项目", "时间表", "会议", "截止日期"] + "value": "2025年6月25日下午3:00,Tom与团队开会讨论新项目。会议涉及时间表,并提出了对2025年12月15日截止日期可行性的担忧。", + "tags": ["项目", "时间表", "会议", "截止日期"] }, { - "key": "Jerry对新项目截止日期的建议", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry观点]Jerry对Tom的新项目截止日期提出疑问、并提议Tom考虑延期。", - "tags": ["Jerry", "截止日期变更", "建议"] + "key": "计划调整范围", + "memory_type": "UserMemory", + "value": "Tom计划在2025年6月27日上午9:30的会议上建议团队优先处理功能,并提议将项目截止日期推迟至2026年1月5日。", + "tags": ["计划", "截止日期变更", "功能优先级"] } ], - "summary": "Tom目前正在做一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15 - 日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议,计划在次日早上的会议上提出将截止日期推迟至2026 - 年1月5日。" + "summary": "Tom目前正专注于管理一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议。Tom计划在次日早上的会议上提出将截止日期推迟至2026年1月5日。他的行为反映出对时间线的担忧,以及积极、以团队为导向的问题解决方式。" } -另一个中文示例(注意:当用户语言为中文时,您也需输出中文): - +示例2: 对话(节选): user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 assistant|19:32 @@ -271,7 +229,6 @@ "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" } -请始终使用与对话相同的语言进行回复。 对话: ${conversation} From 65a2daf06de5893ec93759b590d9476009d5d0c4 Mon Sep 17 00:00:00 2001 From: lijicode <34564964+lijicode@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:52:01 +0800 Subject: [PATCH 58/64] Dev ccl1103 (#449) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix:parse_node * fix: * fix:conn * fix:getnode * fix: --------- Co-authored-by: ccl <13282138256@163.com> Co-authored-by: CaralHsi --- src/memos/graph_dbs/polardb.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index ac49228e2..5d34ff03f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -859,9 +859,9 @@ def get_node( if result: if include_embedding: - properties_json, embedding_json = result + _, properties_json, embedding_json = result else: - properties_json = result + _, properties_json = result embedding_json = None # Parse properties from JSONB if it's a string @@ -889,8 +889,8 @@ def get_node( return self._parse_node( { "id": id, - "memory": json.loads(properties[1]).get("memory", ""), - **json.loads(properties[1]), + "memory": properties.get("memory", ""), + **properties, } ) return None @@ -1290,6 +1290,8 @@ def get_subgraph( user_name = user_name if user_name else self._get_config_value("user_name") + if center_id.startswith('"') and center_id.endswith('"'): + center_id = center_id[1:-1] # Use a simplified query to get the subgraph (temporarily only direct neighbors) """ SELECT * FROM cypher('{self.db_name}_graph', $$ From b37939c7e90f82abfa387acda4642fcd66d141d0 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:38:05 +0800 Subject: [PATCH 59/64] Feat/fix bug 1031 (#459) * modify bug * modify bug * remove print --------- Co-authored-by: yuan.wang --- src/memos/mem_reader/simple_struct.py | 23 +++++++++++-------- src/memos/templates/instruction_completion.py | 5 +++- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 0f74adead..13515c038 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -67,9 +67,18 @@ def detect_lang(text): try: if not text or not isinstance(text, str): return "en" + cleaned_text = text + # remove role and timestamp + cleaned_text = re.sub( + r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE + ) + cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) + + # extract chinese characters chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" - chinese_chars = re.findall(chinese_pattern, text) - if len(chinese_chars) / len(re.sub(r"[\s\d\W]", "", text)) > 0.3: + chinese_chars = re.findall(chinese_pattern, cleaned_text) + text_without_special = re.sub(r"[\s\d\W]", "", cleaned_text) + if text_without_special and len(chinese_chars) / len(text_without_special) > 0.3: return "zh" return "en" except Exception: @@ -466,15 +475,11 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: if type == "chat": for items in scene_data: result = [] - for item in items: - # Convert dictionary to string - if "chat_time" in item: - result.append(item) - else: - result.append(item) + for i, item in enumerate(items): + result.append(item) if len(result) >= 10: results.append(result) - context = copy.deepcopy(result[-2:]) + context = copy.deepcopy(result[-2:]) if i + 1 < len(items) else [] result = context if result: results.append(result) diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py index 03ae52c77..b88ff474c 100644 --- a/src/memos/templates/instruction_completion.py +++ b/src/memos/templates/instruction_completion.py @@ -45,7 +45,10 @@ def instruct_completion( "zh": "隐式偏好 > ", "en": "implicit preference > ", } - lang = detect_lang(explicit_pref_str + implicit_pref_str) + lang = detect_lang( + explicit_pref_str.replace("Explicit Preference:\n", "") + + implicit_pref_str.replace("Implicit Preference:\n", "") + ) if not explicit_pref_str and not implicit_pref_str: return "", "" From ccbffae86c87c42bdd0ea99038c795cf493aeab1 Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Thu, 6 Nov 2025 11:08:49 +0800 Subject: [PATCH 60/64] Fix/no response (#463) * fix: response error * fix: response error * fix: response error --------- Co-authored-by: harvey_xiang --- src/memos/api/middleware/request_context.py | 13 ++++++++++--- src/memos/api/routers/server_router.py | 5 ++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 443aa1f3d..025a0f9eb 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -72,10 +72,18 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: f"headers: {request.headers}" ) + response = await call_next(request) + end_time = time.time() + # Process the request try: - response = await call_next(request) - end_time = time.time() + if not response: + logger.error( + f"Request Failed No Response, path: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + + return response + if response.status_code == 200: logger.info( f"Request completed: source: {self.source}, path: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" @@ -89,6 +97,5 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: logger.error( f"Request Exception Error: source: {self.source}, path: {request.url.path}, error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms" ) - raise e return response diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 684e02a0c..8df383bfb 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -349,7 +349,7 @@ def search_memories(search_req: APISearchRequest): mem_cube_id=search_req.mem_cube_id, session_id=search_req.session_id or "default_session", ) - logger.info(f"Search user_id is: {user_context.mem_cube_id}") + logger.info(f"Search Req is: {search_req}") memories_result: MOSSearchResult = { "text_mem": [], "act_mem": [], @@ -502,6 +502,9 @@ def add_memories(add_req: APIADDRequest): mem_cube_id=add_req.mem_cube_id, session_id=add_req.session_id or "default_session", ) + + logger.info(f"Add Req is: {add_req}") + target_session_id = add_req.session_id if not target_session_id: target_session_id = "default_session" From 4e500a9c0dafe9e54fd9fe4e5a5ce7551584c3cf Mon Sep 17 00:00:00 2001 From: Hao <42795704+Nyakult@users.noreply.github.com> Date: Thu, 6 Nov 2025 12:03:07 +0800 Subject: [PATCH 61/64] doc: Update readme (#458) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * doc: Update readme * feat(env): update .env.example * refactor(client): pref mem * fix pref type * chore(env): 更新环境变量配置文件 --- README.md | 23 ++++++- docker/.env.example | 69 ++++++++++++++----- evaluation/.env-example | 23 ------- evaluation/README.md | 5 +- .../configs-example/mem_cube_config.json | 51 -------------- .../configs-example/mos_memos_config.json | 51 -------------- evaluation/scripts/utils/client.py | 42 +++++------ src/memos/log.py | 2 +- 8 files changed, 90 insertions(+), 176 deletions(-) delete mode 100644 evaluation/configs-example/mem_cube_config.json delete mode 100644 evaluation/configs-example/mos_memos_config.json diff --git a/README.md b/README.md index a08177676..30f01f49c 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,27 @@ showcasing its capabilities in **information extraction**, **temporal and cross- ## 🚀 Getting Started +### ⭐️ MemOS online API +The easiest way to use MemOS. Equip your agent with memory **in minutes**! + +Sign up and get started on[`MemOS dashboard`](https://memos-dashboard.openmem.net/cn/quickstart/?source=landing). + + +### Self-Hosted Server +1. Get the repository. +```bash +git clone https://github.com/MemTensor/MemOS.git +cd MemOS +pip install -r ./docker/requirements.txt +``` + +2. Configure `docker/.env.example` and copy to `MemOS/.env` +3. Start the service. +```bash +uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8001 --workers 8 +``` + +### Local SDK Here's a quick example of how to create a **`MemCube`**, load it from a directory, access its memories, and save it. ```python @@ -102,7 +123,7 @@ for item in mem_cube.act_mem.get_all(): mem_cube.dump("tmp/mem_cube") ``` -What about **`MOS`** (Memory Operating System)? It's a higher-level orchestration layer that manages multiple MemCubes and provides a unified API for memory operations. Here's a quick example of how to use MOS: +**`MOS`** (Memory Operating System) is a higher-level orchestration layer that manages multiple MemCubes and provides a unified API for memory operations. Here's a quick example of how to use MOS: ```python from memos.configs.mem_os import MOSConfig diff --git a/docker/.env.example b/docker/.env.example index 33f7ae853..0f4fcb65d 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1,29 +1,60 @@ # MemOS Environment Variables Configuration +TZ=Asia/Shanghai -# Path to memory storage (e.g. /tmp/data_test) -MOS_CUBE_PATH= +MOS_CUBE_PATH="/tmp/data_test" # Path to memory storage (e.g. /tmp/data_test) +MOS_ENABLE_DEFAULT_CUBE_CONFIG="true" # Enable default cube config (true/false) # OpenAI Configuration -OPENAI_API_KEY= # Your OpenAI API key -OPENAI_API_BASE= # OpenAI API base URL (default: https://api.openai.com/v1) +OPENAI_API_KEY="sk-xxx" # Your OpenAI API key +OPENAI_API_BASE="http://xxx" # OpenAI API base URL (default: https://api.openai.com/v1) -# MemOS Feature Toggles -MOS_ENABLE_DEFAULT_CUBE_CONFIG= # Enable default cube config (true/false) -MOS_ENABLE_SCHEDULER= # Enable background scheduler (true/false) +# MemOS Chat Model Configuration +MOS_CHAT_MODEL=gpt-4o-mini +MOS_CHAT_TEMPERATURE=0.8 +MOS_MAX_TOKENS=8000 +MOS_TOP_P=0.9 +MOS_TOP_K=50 +MOS_CHAT_MODEL_PROVIDER=openai -# Neo4j Configuration -NEO4J_URI= # Neo4j connection URI (e.g. bolt://localhost:7687) -NEO4J_USER= # Neo4j username -NEO4J_PASSWORD= # Neo4j password -MOS_NEO4J_SHARED_DB= # Shared Neo4j database name (if using multi-db) +# graph db +# neo4j +NEO4J_BACKEND=xxx +NEO4J_URI=bolt://xxx +NEO4J_USER=xxx +NEO4J_PASSWORD=xxx +MOS_NEO4J_SHARED_DB=xxx +NEO4J_DB_NAME=xxx + +# tetxmem reog +MOS_ENABLE_REORGANIZE=false # MemOS User Configuration -MOS_USER_ID= # Unique user ID -MOS_SESSION_ID= # Session ID for current chat -MOS_MAX_TURNS_WINDOW= # Max number of turns to keep in memory +MOS_USER_ID=root +MOS_SESSION_ID=default_session +MOS_MAX_TURNS_WINDOW=20 + +# MemRader Configuration +MEMRADER_MODEL=gpt-4o-mini +MEMRADER_API_KEY=sk-xxx +MEMRADER_API_BASE=http://xxx:3000/v1 +MEMRADER_MAX_TOKENS=5000 + +#embedding & rerank +EMBEDDING_DIMENSION=1024 +MOS_EMBEDDER_BACKEND=universal_api +MOS_EMBEDDER_MODEL=bge-m3 +MOS_EMBEDDER_API_BASE=http://xxx +MOS_EMBEDDER_API_KEY=EMPTY +MOS_RERANKER_BACKEND=http_bge +MOS_RERANKER_URL=http://xxx +# Ollama Configuration (for embeddings) +#OLLAMA_API_BASE=http://xxx -# Ollama Configuration (for local embedding models) -OLLAMA_API_BASE= # Ollama API base URL (e.g. http://localhost:11434) +# milvus for pref mem +MILVUS_URI=http://xxx +MILVUS_USER_NAME=xxx +MILVUS_PASSWORD=xxx -# Embedding Configuration -MOS_EMBEDDER_BACKEND= # Embedding backend: openai, ollama, etc. +# pref mem +ENABLE_PREFERENCE_MEMORY=true +RETURN_ORIGINAL_PREF_MEM=true diff --git a/evaluation/.env-example b/evaluation/.env-example index 0e94e9caa..5381532c2 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -22,26 +22,3 @@ SUPERMEMORY_API_KEY="sm_xxx" MEMOBASE_API_KEY="xxx" MEMOBASE_PROJECT_URL="http://***.***.***.***:8019" -# eval settings -PRE_SPLIT_CHUNK=false - -# Configuration Only For Scheduler -# RabbitMQ Configuration -MEMSCHEDULER_RABBITMQ_HOST_NAME=rabbitmq-cn-***.cn-***.amqp-32.net.mq.amqp.aliyuncs.com -MEMSCHEDULER_RABBITMQ_USER_NAME=*** -MEMSCHEDULER_RABBITMQ_PASSWORD=*** -MEMSCHEDULER_RABBITMQ_VIRTUAL_HOST=memos -MEMSCHEDULER_RABBITMQ_ERASE_ON_CONNECT=true -MEMSCHEDULER_RABBITMQ_PORT=5672 - -# OpenAI Configuration -MEMSCHEDULER_OPENAI_API_KEY=sk-*** -MEMSCHEDULER_OPENAI_BASE_URL=http://***.***.***.***:3000/v1 -MEMSCHEDULER_OPENAI_DEFAULT_MODEL=gpt-4o-mini - -# Graph DB Configuration -MEMSCHEDULER_GRAPHDBAUTH_URI=bolt://localhost:7687 -MEMSCHEDULER_GRAPHDBAUTH_USER=neo4j -MEMSCHEDULER_GRAPHDBAUTH_PASSWORD=*** -MEMSCHEDULER_GRAPHDBAUTH_DB_NAME=neo4j -MEMSCHEDULER_GRAPHDBAUTH_AUTO_CREATE=true diff --git a/evaluation/README.md b/evaluation/README.md index 8683c60b2..a5a4f32ca 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -16,10 +16,7 @@ This repository provides tools and scripts for evaluating the `LoCoMo`, `LongMem ``` ## Configuration - -1. Copy the `.env-example` file to `.env`, and fill in the required environment variables according to your environment and API keys. - -2. Copy the `configs-example/` directory to a new directory named `configs/`, and modify the configuration files inside it as needed. This directory contains model and API-specific settings. +Copy the `.env-example` file to `.env`, and fill in the required environment variables according to your environment and API keys. ## Setup MemOS ### local server diff --git a/evaluation/configs-example/mem_cube_config.json b/evaluation/configs-example/mem_cube_config.json deleted file mode 100644 index d609d27b0..000000000 --- a/evaluation/configs-example/mem_cube_config.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "user_id": "__USER_ID__", - "cube_id": "__USER_ID__", - "text_mem": { - "backend": "tree_text", - "config": { - "extractor_llm": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-4o-mini", - "temperature": 0.8, - "max_tokens": 1024, - "top_p": 0.9, - "top_k": 50, - "api_key": "sk-***REDACTED***", - "api_base": "http://***.***.***.***:3000/v1" - } - }, - "dispatcher_llm": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-4o-mini", - "temperature": 0.8, - "max_tokens": 1024, - "top_p": 0.9, - "top_k": 50, - "api_key": "sk-***REDACTED***", - "api_base": "http://***.***.***.***:3000/v1" - } - }, - "graph_db": { - "backend": "neo4j", - "config": { - "uri": "bolt://***.***.***.***:7687", - "user": "***REDACTED***", - "password": "***REDACTED***", - "db_name": "__DB_NAME__", - "auto_create": true - } - }, - "embedder": { - "backend": "ollama", - "config": { - "model_name_or_path": "nomic-embed-text:latest" - } - } - } - }, - "act_mem": {}, - "para_mem": {} -} diff --git a/evaluation/configs-example/mos_memos_config.json b/evaluation/configs-example/mos_memos_config.json deleted file mode 100644 index b7f2767b7..000000000 --- a/evaluation/configs-example/mos_memos_config.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "user_id": "root", - "chat_model": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-4o-mini", - "api_key": "sk-***REDACTED***", - "api_base": "http://***.***.***.***:3000/v1", - "temperature": 0.1, - "remove_think_prefix": true, - "max_tokens": 4096 - } - }, - "mem_reader": { - "backend": "simple_struct", - "config": { - "llm": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-4o-mini", - "temperature": 0.8, - "max_tokens": 1024, - "top_p": 0.9, - "top_k": 50, - "api_key": "sk-***REDACTED***", - "api_base": "http://***.***.***.***:3000/v1" - } - }, - "embedder": { - "backend": "ollama", - "config": { - "model_name_or_path": "nomic-embed-text:latest" - } - }, - "chunker": { - "backend": "sentence", - "config": { - "tokenizer_or_token_counter": "gpt2", - "chunk_size": 512, - "chunk_overlap": 128, - "min_sentences_per_chunk": 1 - } - } - } - }, - "max_turns_window": 30, - "top_k": "__TOP_K__", - "enable_textual_memory": true, - "enable_activation_memory": false, - "enable_parametric_memory": false -} diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 9aa527903..157c3f8ea 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -250,33 +250,23 @@ def search(self, query, user_id, top_k): preference_note = json.loads(response.text)["data"]["preference_note"] for i in text_mem_res: i.update({"memory": i.pop("memory_value")}) + explicit_pref_string = "Explicit Preference:" + implicit_pref_string = "\n\nImplicit Preference:" + explicit_idx = 0 + implicit_idx = 0 + for pref in pref_mem_res: + if pref["preference_type"] == "explicit_preference": + explicit_pref_string += f"\n{explicit_idx + 1}. {pref['preference']}" + explicit_idx += 1 + if pref["preference_type"] == "implicit_preference": + implicit_pref_string += f"\n{implicit_idx + 1}. {pref['preference']}" + implicit_idx += 1 + + return { + "text_mem": [{"memories": text_mem_res}], + "pref_string": explicit_pref_string + implicit_pref_string + preference_note, + } - explicit_prefs = [ - p["preference"] - for p in pref_mem_res - if p.get("preference_type", "") == "explicit_preference" - ] - implicit_prefs = [ - p["preference"] - for p in pref_mem_res - if p.get("preference_type", "") == "implicit_preference" - ] - - pref_parts = [] - if explicit_prefs: - pref_parts.append( - "Explicit Preference:\n" - + "\n".join(f"{i + 1}. {p}" for i, p in enumerate(explicit_prefs)) - ) - if implicit_prefs: - pref_parts.append( - "Implicit Preference:\n" - + "\n".join(f"{i + 1}. {p}" for i, p in enumerate(implicit_prefs)) - ) - - pref_string = "\n".join(pref_parts) + preference_note - - return {"text_mem": [{"memories": text_mem_res}], "pref_string": pref_string} except Exception as e: if attempt < max_retries - 1: time.sleep(2**attempt) diff --git a/src/memos/log.py b/src/memos/log.py index 2a538fdde..faa808414 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -187,7 +187,7 @@ def close(self): }, "handlers": { "console": { - "level": "DEBUG", + "level": selected_log_level, "class": "logging.StreamHandler", "stream": stdout, "formatter": "no_datetime", From 5e4f695deb6ba1697b2d488f14a2dbca3ca7bb28 Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Thu, 6 Nov 2025 13:06:34 +0800 Subject: [PATCH 62/64] Fix/no response (#464) * fix: response error * fix: response error * fix: response error * feat: replace context thread --------- Co-authored-by: harvey_xiang Co-authored-by: CaralHsi --- src/memos/api/config.py | 4 ++-- src/memos/mem_scheduler/base_scheduler.py | 3 ++- src/memos/mem_scheduler/general_modules/task_threads.py | 5 +++-- src/memos/mem_scheduler/monitors/dispatcher_monitor.py | 4 ++-- .../mem_scheduler/webservice_modules/rabbitmq_service.py | 3 ++- .../mem_scheduler/webservice_modules/redis_service.py | 6 +++--- .../textual/tree_text_memory/organize/reorganizer.py | 7 +++---- 7 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 7f61d54ac..1a3c328f1 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -15,6 +15,7 @@ from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig +from memos.context.context import ContextThread from memos.mem_cube.general import GeneralMemCube @@ -178,7 +179,6 @@ def start_watch_if_enabled(cls) -> None: if not enable: return interval = int(os.getenv("NACOS_WATCH_INTERVAL", "60")) - import threading def _loop() -> None: while True: @@ -188,7 +188,7 @@ def _loop() -> None: logger.error(f"❌ Nacos watch loop error: {e}") time.sleep(interval) - threading.Thread(target=_loop, daemon=True).start() + ContextThread(target=_loop, daemon=True).start() logger.info(f"Nacos watch thread started (interval={interval}s).") @classmethod diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index b3b457c36..028fe8e3f 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -12,6 +12,7 @@ from sqlalchemy.engine import Engine from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig +from memos.context.context import ContextThread from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube @@ -689,7 +690,7 @@ def start(self) -> None: logger.info("Message consumer process started") else: # Default to thread mode - self._consumer_thread = threading.Thread( + self._consumer_thread = ContextThread( target=self._message_consumer, daemon=True, name="MessageConsumerThread", diff --git a/src/memos/mem_scheduler/general_modules/task_threads.py b/src/memos/mem_scheduler/general_modules/task_threads.py index 551e8b726..73b570a8b 100644 --- a/src/memos/mem_scheduler/general_modules/task_threads.py +++ b/src/memos/mem_scheduler/general_modules/task_threads.py @@ -5,6 +5,7 @@ from concurrent.futures import as_completed from typing import Any, TypeVar +from memos.context.context import ContextThread from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule @@ -138,7 +139,7 @@ def worker(task_name: str, func: Callable, args: tuple): # Start all threads for task_name, (func, args) in tasks.items(): - thread = threading.Thread( + thread = ContextThread( target=worker, args=(task_name, func, args), name=f"task-{task_name}" ) threads[task_name] = thread @@ -283,7 +284,7 @@ def run_race( # Create and start threads for each task for task_name, task_func in tasks.items(): - thread = threading.Thread( + thread = ContextThread( target=self.worker, args=(task_func, task_name), name=f"race-{task_name}" ) self.threads[task_name] = thread diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 0ebb7da4f..46c4e2d49 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -4,7 +4,7 @@ from time import perf_counter from memos.configs.mem_scheduler import BaseSchedulerConfig -from memos.context.context import ContextThreadPoolExecutor +from memos.context.context import ContextThread, ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher @@ -340,7 +340,7 @@ def start(self) -> bool: return False self._running = True - self._monitor_thread = threading.Thread( + self._monitor_thread = ContextThread( target=self._monitor_loop, name="threadpool_monitor", daemon=True ) self._monitor_thread.start() diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index b240f4369..3c0dff907 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -6,6 +6,7 @@ from pathlib import Path from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig +from memos.context.context import ContextThread from memos.dependency import require_python_package from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule @@ -96,7 +97,7 @@ def initialize_rabbitmq( ) # Start IOLoop in dedicated thread - self._io_loop_thread = threading.Thread( + self._io_loop_thread = ContextThread( target=self.rabbitmq_connection.ioloop.start, daemon=True ) self._io_loop_thread.start() diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index d86911e82..5439af9c6 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -1,12 +1,12 @@ import asyncio import os import subprocess -import threading import time from collections.abc import Callable from typing import Any +from memos.context.context import ContextThread from memos.dependency import require_python_package from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule @@ -41,7 +41,7 @@ def __init__(self): self.query_list_capacity = 1000 self._redis_listener_running = False - self._redis_listener_thread: threading.Thread | None = None + self._redis_listener_thread: ContextThread | None = None self._redis_listener_loop: asyncio.AbstractEventLoop | None = None @property @@ -336,7 +336,7 @@ def redis_start_listening(self, handler: Callable | None = None): if handler is None: handler = self.redis_consume_message_stream - self._redis_listener_thread = threading.Thread( + self._redis_listener_thread = ContextThread( target=self._redis_run_listener_async, args=(handler,), daemon=True, diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index 0337225d1..ea06a7c60 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -1,5 +1,4 @@ import json -import threading import time import traceback @@ -10,7 +9,7 @@ import numpy as np -from memos.context.context import ContextThreadPoolExecutor +from memos.context.context import ContextThread, ContextThreadPoolExecutor from memos.dependency import require_python_package from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.item import GraphDBEdge, GraphDBNode @@ -94,12 +93,12 @@ def __init__( self._reorganize_needed = True if self.is_reorganize: # ____ 1. For queue message driven thread ___________ - self.thread = threading.Thread(target=self._run_message_consumer_loop) + self.thread = ContextThread(target=self._run_message_consumer_loop) self.thread.start() # ____ 2. For periodic structure optimization _______ self._stop_scheduler = False self._is_optimizing = {"LongTermMemory": False, "UserMemory": False} - self.structure_optimizer_thread = threading.Thread( + self.structure_optimizer_thread = ContextThread( target=self._run_structure_organizer_loop ) self.structure_optimizer_thread.start() From f6408550445d7e6ac782c5e0f7fdd6ede93bfe78 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Thu, 6 Nov 2025 14:48:52 +0800 Subject: [PATCH 63/64] Fix: fix pg query for group error string (#465) * fix:get_grouped_counts * fix: code format --- src/memos/api/config.py | 12 ++++++------ src/memos/graph_dbs/polardb.py | 9 +++++---- .../textual/tree_text_memory/organize/manager.py | 4 +++- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 1a3c328f1..f02edaad6 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -861,9 +861,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() == "true", "memory_size": { - "WorkingMemory": os.getenv("NEBULAR_WORKING_MEMORY", 20), - "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), - "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), + "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)), + "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)), + "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)), }, "search_strategy": { "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), @@ -933,9 +933,9 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: == "true", "internet_retriever": internet_config, "memory_size": { - "WorkingMemory": os.getenv("NEBULAR_WORKING_MEMORY", 20), - "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), - "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), + "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)), + "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)), + "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)), }, "search_strategy": { "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 5d34ff03f..60902420f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1740,9 +1740,11 @@ def get_grouped_counts( for field in group_fields: alias = field.replace(".", "_") return_fields.append( - f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype) AS {alias}" + f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype)::text AS {alias}" + ) + group_by_fields.append( + f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype)::text" ) - group_by_fields.append(alias) # Full SQL query construction query = f""" @@ -1751,7 +1753,6 @@ def get_grouped_counts( {where_clause} GROUP BY {", ".join(group_by_fields)} """ - conn = self._get_connection() try: with conn.cursor() as cursor: @@ -1772,7 +1773,7 @@ def get_grouped_counts( else: group_values[field] = str(value) count_value = row[-1] # Last column is count - output.append({**group_values, "count": count_value}) + output.append({**group_values, "count": int(count_value)}) return output diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 01ccc382b..0c41717ea 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -156,7 +156,9 @@ def _refresh_memory_size(self, user_name: str | None = None) -> None: results = self.graph_store.get_grouped_counts( group_fields=["memory_type"], user_name=user_name ) - self.current_memory_size = {record["memory_type"]: record["count"] for record in results} + self.current_memory_size = { + record["memory_type"]: int(record["count"]) for record in results + } logger.info(f"[MemoryManager] Refreshed memory sizes: {self.current_memory_size}") def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None): From 119bbe250f1e1620534ab8325b76fdb8d23ad273 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Thu, 6 Nov 2025 16:22:08 +0800 Subject: [PATCH 64/64] feat: freeze usage update in Searcher (#466) * feat: tmp-close usage update, maybe reopen it later * test: withdraw update usage test --- .../memories/textual/tree_text_memory/retrieve/searcher.py | 6 ++---- tests/memories/textual/test_tree_searcher.py | 7 ------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 2f6ef6afa..f408755fd 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -1,8 +1,5 @@ -import json import traceback -from datetime import datetime - from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.factory import Neo4jGraphDB @@ -508,7 +505,7 @@ def _sort_and_trim(self, results, top_k): @timed def _update_usage_history(self, items, info, user_name: str | None = None): - """Update usage history in graph DB""" + """Update usage history in graph DB now_time = datetime.now().isoformat() info_copy = dict(info or {}) info_copy.pop("chat_history", None) @@ -532,6 +529,7 @@ def _update_usage_history(self, items, info, user_name: str | None = None): self._usage_executor.submit( self._update_usage_history_worker, payload, usage_record, user_name ) + """ def _update_usage_history_worker( self, payload, usage_record: str, user_name: str | None = None diff --git a/tests/memories/textual/test_tree_searcher.py b/tests/memories/textual/test_tree_searcher.py index d99664817..2a5536cf8 100644 --- a/tests/memories/textual/test_tree_searcher.py +++ b/tests/memories/textual/test_tree_searcher.py @@ -69,13 +69,6 @@ def test_searcher_fast_path(mock_searcher): assert len(result) <= 2 assert all(isinstance(item, TextualMemoryItem) for item in result) - # Should update usage and call update_node - for item in result: - assert len(item.metadata.usage) > 0 - mock_searcher.graph_store.update_node.assert_any_call( - item.id, {"usage": item.metadata.usage}, user_name=None - ) - def test_searcher_fine_mode_triggers_reasoner(mock_searcher): parsed_goal = MagicMock()