From 9f1c37bc6cecef841c73229d491bc03f31f915ab Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Fri, 6 Feb 2026 16:55:48 -0800 Subject: [PATCH 01/20] fix: Add Unicode sanitization for cloud embedders MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _sanitize_unicode() function to remove surrogates - Apply sanitization before all embedding API calls - Add comprehensive tests for Unicode handling Fixes production crashes with VoyageAI/OpenAI when texts contain emoji or Unicode surrogates (U+D800-U+DFFF). Tested with: - Emoji: '👋 🔥' - Surrogates: '\ud800' - International text: 中文, العربية, Тест Co-Authored-By: Claude Opus 4.6 --- src/memos/embedders/universal_api.py | 17 ++++ .../embedders/test_unicode_sanitization.py | 88 +++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 tests/unit/embedders/test_unicode_sanitization.py diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 538d913ea..318f0ac80 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -14,6 +14,21 @@ logger = get_logger(__name__) +def _sanitize_unicode(text: str) -> str: + """ + Remove Unicode surrogates and other problematic characters. + Surrogates (U+D800-U+DFFF) cause UnicodeEncodeError with some APIs. + """ + try: + # Encode with 'surrogatepass' then decode, replacing invalid chars + cleaned = text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="replace") + # Replace replacement char with empty string for cleaner output + return cleaned.replace("\ufffd", "") + except Exception: + # Fallback: remove all non-BMP characters + return "".join(c for c in text if ord(c) < 0x10000) + + class UniversalAPIEmbedder(BaseEmbedder): def __init__(self, config: UniversalAPIEmbedderConfig): self.provider = config.provider @@ -54,6 +69,8 @@ def __init__(self, config: UniversalAPIEmbedderConfig): def embed(self, texts: list[str]) -> list[list[float]]: if isinstance(texts, str): texts = [texts] + # Sanitize Unicode to prevent encoding errors with emoji/surrogates + texts = [_sanitize_unicode(t) for t in texts] # Truncate texts if max_tokens is configured texts = self._truncate_texts(texts) logger.info(f"Embeddings request with input: {texts}") diff --git a/tests/unit/embedders/test_unicode_sanitization.py b/tests/unit/embedders/test_unicode_sanitization.py new file mode 100644 index 000000000..015f0a361 --- /dev/null +++ b/tests/unit/embedders/test_unicode_sanitization.py @@ -0,0 +1,88 @@ +""" +Tests for Unicode sanitization in embedders. +""" + +import pytest + + +def _sanitize_unicode(text: str) -> str: + """ + Remove Unicode surrogates and other problematic characters. + Surrogates (U+D800-U+DFFF) cause UnicodeEncodeError with some APIs. + """ + try: + # Encode with 'surrogatepass' then decode, replacing invalid chars + cleaned = text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="replace") + # Replace replacement char with empty string for cleaner output + return cleaned.replace("\ufffd", "") + except Exception: + # Fallback: remove all non-BMP characters + return "".join(c for c in text if ord(c) < 0x10000) + + +class TestUnicodeSanitization: + """Test Unicode sanitization function.""" + + def test_emoji_handling(self): + """Test that emoji are preserved.""" + text = "Hello 👋 world 🌍" + result = _sanitize_unicode(text) + assert "Hello" in result + assert "world" in result + # Emoji should be present (though they might be sanitized differently) + + def test_surrogate_removal(self): + """Test that surrogates are removed.""" + text = "Hello\ud800world" # Surrogate in the middle + result = _sanitize_unicode(text) + assert "Hello" in result + assert "world" in result + # Surrogate should be removed + assert "\ud800" not in result + + def test_mixed_unicode(self): + """Test mixed Unicode characters.""" + text = "Test 中文 العربية Тест" + result = _sanitize_unicode(text) + assert "Test" in result + # International characters should be preserved + + def test_empty_string(self): + """Test empty string handling.""" + assert _sanitize_unicode("") == "" + + def test_ascii_only(self): + """Test that ASCII text is unchanged.""" + text = "Hello World 123" + assert _sanitize_unicode(text) == text + + def test_multiple_surrogates(self): + """Test multiple surrogates are handled.""" + text = "\ud800\udc00test\ud83d\ude00" + result = _sanitize_unicode(text) + assert "test" in result + # Should not raise UnicodeEncodeError + + def test_list_of_texts(self): + """Test sanitizing a list of texts.""" + texts = ["Normal text", "Emoji 👋", "Surrogate\ud800test", "Mixed 中文 🔥"] + results = [_sanitize_unicode(t) for t in texts] + assert len(results) == 4 + assert all(isinstance(r, str) for r in results) + + def test_encoding_to_utf8(self): + """Test that result can be encoded to UTF-8.""" + problematic_texts = [ + "Hello\ud800world", + "Test\ud83dEmoji", + "\ud800\udc00\ud83d\ude00", + ] + for text in problematic_texts: + result = _sanitize_unicode(text) + # Should not raise UnicodeEncodeError + encoded = result.encode("utf-8") + assert isinstance(encoded, bytes) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From e26098237d3292b40ecd293c6444ff39066a0458 Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Fri, 6 Feb 2026 22:17:51 -0800 Subject: [PATCH 02/20] fix: activation memory config crashes get_default() with OpenAI backend Two bugs in `get_default_config()` / `get_default_cube_config()`: 1. `get_default_config()` injects `act_mem` dict into MOSConfig when `enable_activation_memory=True`, but MOSConfig has no `act_mem` field and inherits `extra="forbid"` from BaseConfig. This causes a `ValidationError: Extra inputs are not permitted` for any user calling `get_default()` with activation memory enabled. 2. `get_default_cube_config()` hardcodes `extractor_llm` backend to `"openai"` for KV cache activation memory, but `KVCacheMemoryConfig` validator requires `huggingface`/`huggingface_singleton`/`vllm` (KV cache needs local model access for attention tensor extraction). This causes `ConfigurationError` even if bug #1 is fixed. Fix: Remove `act_mem` from MOSConfig dict (the `enable_activation_memory` bool flag is sufficient). In MemCube config, require explicit `activation_memory_backend` kwarg instead of hardcoding `"openai"`. Co-Authored-By: Claude Opus 4.6 --- src/memos/mem_os/utils/default_config.py | 56 +++++++++++++----------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/src/memos/mem_os/utils/default_config.py b/src/memos/mem_os/utils/default_config.py index edb7875d4..9898cbe8c 100644 --- a/src/memos/mem_os/utils/default_config.py +++ b/src/memos/mem_os/utils/default_config.py @@ -3,12 +3,15 @@ Provides simplified configuration generation for users. """ +import logging from typing import Literal from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig from memos.mem_cube.general import GeneralMemCube +logger = logging.getLogger(__name__) + def get_default_config( openai_api_key: str, @@ -116,20 +119,9 @@ def get_default_config( }, } - # Add activation memory if enabled - if config_dict.get("enable_activation_memory", False): - config_dict["act_mem"] = { - "backend": "kv_cache", - "config": { - "memory_filename": kwargs.get( - "activation_memory_filename", "activation_memory.pickle" - ), - "extractor_llm": { - "backend": "openai", - "config": openai_config, - }, - }, - } + # Note: act_mem configuration belongs in MemCube config (get_default_cube_config), + # not in MOSConfig which doesn't have an act_mem field (extra="forbid"). + # The enable_activation_memory flag above is sufficient for MOSConfig. return MOSConfig(**config_dict) @@ -237,21 +229,33 @@ def get_default_cube_config( }, } - # Configure activation memory if enabled + # Configure activation memory if enabled. + # KV cache activation memory requires a local HuggingFace/vLLM model (it + # extracts internal attention KV tensors via build_kv_cache), so it cannot + # work with remote API backends like OpenAI. + # Only create act_mem when activation_memory_backend is explicitly provided. act_mem_config = {} if kwargs.get("enable_activation_memory", False): - act_mem_config = { - "backend": "kv_cache", - "config": { - "memory_filename": kwargs.get( - "activation_memory_filename", "activation_memory.pickle" - ), - "extractor_llm": { - "backend": "openai", - "config": openai_config, + extractor_backend = kwargs.get("activation_memory_backend") + if extractor_backend in ("huggingface", "huggingface_singleton", "vllm"): + act_mem_config = { + "backend": "kv_cache", + "config": { + "memory_filename": kwargs.get( + "activation_memory_filename", "activation_memory.pickle" + ), + "extractor_llm": { + "backend": extractor_backend, + "config": kwargs.get("activation_memory_llm_config", {}), + }, }, - }, - } + } + else: + logger.info( + "Activation memory (kv_cache) requires a local model backend " + "(huggingface/vllm) via activation_memory_backend kwarg. " + "Skipping act_mem in MemCube config." + ) # Create MemCube configuration cube_config_dict = { From 975dece061f94f2b415bb644bb76637f347d0e70 Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Fri, 6 Feb 2026 22:18:57 -0800 Subject: [PATCH 03/20] fix: downgrade AuthConfig partial-init log from WARNING to INFO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AuthConfig.validate_partial_initialization() logs a WARNING on every startup when some components (openai, graph_db) are not configured. This is noisy because partial configuration is a valid and common state — the individual from_local_env() methods already log appropriate warnings for actual initialization failures. Change the partial-init log to INFO level; keep WARNING only for the edge case where ALL components are None. Co-Authored-By: Claude Opus 4.6 --- src/memos/configs/mem_scheduler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index a28f3bdce..9807f42c3 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -250,8 +250,12 @@ def validate_partial_initialization(self) -> "AuthConfig": "All configuration components are None. This may indicate missing environment variables or configuration files." ) elif failed_components: - logger.warning( - f"Failed to initialize components: {', '.join(failed_components)}. Successfully initialized: {', '.join(initialized_components)}" + # Use info level: individual from_local_env() methods already log + # warnings for actual initialization failures. Components that are + # simply not configured (no env vars) are not errors. + logger.info( + f"Components not configured: {', '.join(failed_components)}. " + f"Successfully initialized: {', '.join(initialized_components)}" ) return self From 9d2f9be9b89b01cf416e5fa883359dbf314a3499 Mon Sep 17 00:00:00 2001 From: Wenqiang Wei <46308778+endxxxx@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:31:19 +0800 Subject: [PATCH 04/20] test: remove unrelated unittest --- .../embedders/test_unicode_sanitization.py | 88 ------------------- 1 file changed, 88 deletions(-) delete mode 100644 tests/unit/embedders/test_unicode_sanitization.py diff --git a/tests/unit/embedders/test_unicode_sanitization.py b/tests/unit/embedders/test_unicode_sanitization.py deleted file mode 100644 index 015f0a361..000000000 --- a/tests/unit/embedders/test_unicode_sanitization.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Tests for Unicode sanitization in embedders. -""" - -import pytest - - -def _sanitize_unicode(text: str) -> str: - """ - Remove Unicode surrogates and other problematic characters. - Surrogates (U+D800-U+DFFF) cause UnicodeEncodeError with some APIs. - """ - try: - # Encode with 'surrogatepass' then decode, replacing invalid chars - cleaned = text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="replace") - # Replace replacement char with empty string for cleaner output - return cleaned.replace("\ufffd", "") - except Exception: - # Fallback: remove all non-BMP characters - return "".join(c for c in text if ord(c) < 0x10000) - - -class TestUnicodeSanitization: - """Test Unicode sanitization function.""" - - def test_emoji_handling(self): - """Test that emoji are preserved.""" - text = "Hello 👋 world 🌍" - result = _sanitize_unicode(text) - assert "Hello" in result - assert "world" in result - # Emoji should be present (though they might be sanitized differently) - - def test_surrogate_removal(self): - """Test that surrogates are removed.""" - text = "Hello\ud800world" # Surrogate in the middle - result = _sanitize_unicode(text) - assert "Hello" in result - assert "world" in result - # Surrogate should be removed - assert "\ud800" not in result - - def test_mixed_unicode(self): - """Test mixed Unicode characters.""" - text = "Test 中文 العربية Тест" - result = _sanitize_unicode(text) - assert "Test" in result - # International characters should be preserved - - def test_empty_string(self): - """Test empty string handling.""" - assert _sanitize_unicode("") == "" - - def test_ascii_only(self): - """Test that ASCII text is unchanged.""" - text = "Hello World 123" - assert _sanitize_unicode(text) == text - - def test_multiple_surrogates(self): - """Test multiple surrogates are handled.""" - text = "\ud800\udc00test\ud83d\ude00" - result = _sanitize_unicode(text) - assert "test" in result - # Should not raise UnicodeEncodeError - - def test_list_of_texts(self): - """Test sanitizing a list of texts.""" - texts = ["Normal text", "Emoji 👋", "Surrogate\ud800test", "Mixed 中文 🔥"] - results = [_sanitize_unicode(t) for t in texts] - assert len(results) == 4 - assert all(isinstance(r, str) for r in results) - - def test_encoding_to_utf8(self): - """Test that result can be encoded to UTF-8.""" - problematic_texts = [ - "Hello\ud800world", - "Test\ud83dEmoji", - "\ud800\udc00\ud83d\ude00", - ] - for text in problematic_texts: - result = _sanitize_unicode(text) - # Should not raise UnicodeEncodeError - encoded = result.encode("utf-8") - assert isinstance(encoded, bytes) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From a40cacf6d61035b42e07b5c0822ccd37976f1055 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Mon, 2 Mar 2026 14:46:34 +0800 Subject: [PATCH 05/20] chore: sync from main (#1147) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add return_fields parameter to search methods (#955) Add optional return_fields parameter to search_by_embedding, search_by_keywords_like, search_by_keywords_tfidf, and search_by_fulltext methods across all graph DB backends (neo4j, neo4j_community, polardb). When return_fields is specified (e.g., ['memory', 'status', 'tags']), the requested fields are included in each result dict alongside 'id' and 'score', eliminating the need for N+1 get_node() calls. Default is None, preserving full backward compatibility. Changes: - base.py: Updated docstring for search_by_embedding - neo4j.py: Added return_fields to search_by_embedding, modified Cypher RETURN clause and record construction - neo4j_community.py: Added return_fields to search_by_embedding, added _fetch_return_fields helper for direct vec_db path - polardb.py: Added return_fields to all 4 search methods, added _extract_fields_from_properties helper for JSON property extraction Closes #955 fix: add field name validation to prevent query injection in return_fields - Add _validate_return_fields() to BaseGraphDB base class with regex validation - Apply validation in neo4j.py, neo4j_community.py, polardb.py before field name concatenation - Add return_fields parameter to base class abstract method signature - Revert unrelated .get(node_id) change back to .get(node_id, None) - Add TestFieldNameValidation and TestNeo4jCommunitySearchReturnFields test classes (7 new tests) fix: resolve ruff lint and format issues for CI compliance * feat: add optimize user_name && key words && log (#1111) * feat:optimize user_name && key_words * feat:optimize user_name && key_words * hotfix: An error occurred when adding the edge of the graph. (#1109) * skip edge * skip edges --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> * fix: fix add/update log judgment — scope by user_name an… (#1116) fix(add_handler): fix add/update log judgment — scope by user_name and exclude self-match - Add user_name=msg.mem_cube_id to get_by_metadata() to prevent cross-user key matching - Exclude mem_item.id from candidates to prevent the just-persisted node from always triggering the UPDATE path (self-match bug) - Wrap get_by_metadata() in try/except so Cypher parse errors (e.g. from special chars in key) are logged as warnings instead of silently swallowing the item as missing - Add null-safety check on original_mem_item before accessing .memory Co-authored-by: glin1993@outlook.com <> * fix: set default relativity (#1119) * fix: add full text search for neo4j db (#1095) * feat: add full_text_search for neo4j * test: 改回去 --------- Co-authored-by: CaralHsi * fix: Add toggle for fulltext retrieval path (#1096) feat: Add toggle for fulltext retrieval path (FULLTEXT_CALL), default off * test: 将 relativity 设置为 0.5 * test: 将 relativity 设置为 0.4 * test: 将 relativity 设置为 0.45 * test: 将 relativity 设置为 0.475 * fix: set relativity 0.45 * fix: translate comments to english --------- Co-authored-by: CaralHsi Co-authored-by: jiang * fix: remove unused edges & optimize logs (#1123) * feat: add return_fields for search_by_embedding * feat: remove unused edges * feat: optimize log * feat: optimize log * feat: optimize log * feat: optimize log * fix: File memory parsing to output a list-type result (#1125) Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> * fix: use merged_from to correctly identify add/update memory logs (#1118) * fix(handlers): replace get_by_metadata with merged_from for add/update log detection add_handler: - Remove get_by_metadata graph DB query (eliminated self-match, cross-user matching, and Cypher escaping bugs) - Use metadata.info['merged_from'] to determine ADD vs UPDATE — set upstream by mem_reader during fine extraction when LLM merges memories - Remove unused key/transform_name_to_key computation in log_add_messages mem_read_handler: - Cloud env: mark operation as 'UPDATE' when merged_from is set, 'ADD' otherwise - Local env: split items into addMemory / updateMemory events based on merged_from, emitting separate scheduler log events for each * style(handlers): satisfy ruff formatting and isinstance union types --------- Co-authored-by: glin1993@outlook.com <> Co-authored-by: CaralHsi * fix: fix memory_type validation error (#1130) * fix: user_name (#1131) * fix: user_name * fix: user_name * feat: delete some logs * fix: test_history_manager * fix: get embedding of pref mem from db instead of recompute (#1133) * fix: add embedding for pref memory * fix: recompute embedding for missing memory instead of all memory * reformat * fix: get embedding from pref instead of pref.payload * fix: add include embedding params to pref mem * fix: decrease top k --------- Co-authored-by: jiang * fix: replace bare except with except Exception in thread_safe_dict_segment (#1112) The bare except in acquire_write() catches KeyboardInterrupt and SystemExit, which could leave the lock in an inconsistent state. Using except Exception ensures system-level exceptions propagate. Co-authored-by: haosenwang1018 Co-authored-by: CaralHsi * docs: update README.md (#898) update contribution guidelines link Co-authored-by: CaralHsi * feat: optimize get_edges (#1138) * feat: optimize get_edges * feat: optimize get_edges * chore: change version number to 2.0.7 (#1140) --------- Co-authored-by: damaozi <1811866786@qq.com> Co-authored-by: Hustzdy <67457465+wustzdy@users.noreply.github.com> Co-authored-by: Dubberman <48425266+whipser030@users.noreply.github.com> Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: Zehao Lin Co-authored-by: Jiang <33757498+hijzy@users.noreply.github.com> Co-authored-by: jiang Co-authored-by: Qi Weng Co-authored-by: Sense_wang <167664334+haosenwang1018@users.noreply.github.com> Co-authored-by: haosenwang1018 Co-authored-by: Ikko Eltociear Ashimine --- README.md | 4 +- pyproject.toml | 2 +- src/memos/__init__.py | 2 +- src/memos/api/handlers/search_handler.py | 72 ++- src/memos/api/middleware/__init__.py | 9 +- src/memos/api/product_models.py | 12 +- src/memos/api/utils/api_keys.py | 2 +- src/memos/embedders/universal_api.py | 1 - src/memos/graph_dbs/base.py | 32 +- src/memos/graph_dbs/neo4j.py | 31 +- src/memos/graph_dbs/neo4j_community.py | 91 +++- src/memos/graph_dbs/polardb.py | 474 ++++++------------ src/memos/mem_reader/multi_modal_struct.py | 1 + .../read_multi_modal/file_content_parser.py | 117 +++-- .../read_skill_memory/process_skill_memory.py | 4 +- .../handlers/add_handler.py | 40 +- .../handlers/mem_read_handler.py | 67 ++- .../textual/prefer_text_memory/retrievers.py | 52 +- src/memos/memories/textual/tree.py | 14 +- .../tree_text_memory/organize/handler.py | 64 ++- .../organize/history_manager.py | 2 + .../tree_text_memory/organize/manager.py | 41 +- .../tree_text_memory/organize/reorganizer.py | 77 ++- .../tree_text_memory/retrieve/searcher.py | 2 +- .../memos_tools/thread_safe_dict_segment.py | 2 +- src/memos/multi_mem_cube/single_cube.py | 5 +- src/memos/templates/mem_reader_prompts.py | 34 +- tests/graph_dbs/test_search_return_fields.py | 306 +++++++++++ .../memories/textual/test_history_manager.py | 8 +- 29 files changed, 1036 insertions(+), 532 deletions(-) create mode 100644 tests/graph_dbs/test_search_return_fields.py diff --git a/README.md b/README.md index 70562fec7..45a372ed1 100644 --- a/README.md +++ b/README.md @@ -345,10 +345,10 @@ url = {https://global-sci.com/article/91443/memory3-language-modeling-with-expli ## 🙌 Contributing -We welcome contributions from the community! Please read our [contribution guidelines](https://memos-docs.openmem.net/contribution/overview) to get started. +We welcome contributions from the community! Please read our [contribution guidelines](https://memos-docs.openmem.net/open_source/contribution/overview/) to get started.
## 📄 License -MemOS is licensed under the [Apache 2.0 License](./LICENSE). \ No newline at end of file +MemOS is licensed under the [Apache 2.0 License](./LICENSE). diff --git a/pyproject.toml b/pyproject.toml index b4b01e0e1..4a9ea8852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ############################################################################## name = "MemoryOS" -version = "2.0.6" +version = "2.0.7" description = "Intelligence Begins with Memory" license = {text = "Apache-2.0"} readme = "README.md" diff --git a/src/memos/__init__.py b/src/memos/__init__.py index b568ae0c2..fefa3b2ab 100644 --- a/src/memos/__init__.py +++ b/src/memos/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.6" +__version__ = "2.0.7" from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 267d1bb28..8e7785ad5 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -64,7 +64,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Expand top_k for deduplication (5x to ensure enough candidates) if search_req_local.dedup in ("sim", "mmr"): - search_req_local.top_k = search_req_local.top_k * 5 + search_req_local.top_k = search_req_local.top_k * 3 # Search and deduplicate cube_view = self._build_cube_view(search_req_local) @@ -152,9 +152,6 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di return results embeddings = self._extract_embeddings([mem for _, mem, _ in flat]) - if embeddings is None: - documents = [mem.get("memory", "") for _, mem, _ in flat] - embeddings = self.searcher.embedder.embed(documents) similarity_matrix = cosine_similarity_matrix(embeddings) @@ -235,12 +232,39 @@ def _mmr_dedup_text_memories( if len(flat) <= 1: return results + total_by_type: dict[str, int] = {"text": 0, "preference": 0} + existing_by_type: dict[str, int] = {"text": 0, "preference": 0} + missing_by_type: dict[str, int] = {"text": 0, "preference": 0} + missing_indices: list[int] = [] + for idx, (mem_type, _, mem, _) in enumerate(flat): + if mem_type not in total_by_type: + total_by_type[mem_type] = 0 + existing_by_type[mem_type] = 0 + missing_by_type[mem_type] = 0 + total_by_type[mem_type] += 1 + + embedding = mem.get("metadata", {}).get("embedding") + if embedding: + existing_by_type[mem_type] += 1 + else: + missing_by_type[mem_type] += 1 + missing_indices.append(idx) + + self.logger.info( + "[SearchHandler] MMR embedding metadata scan: total=%s total_by_type=%s existing_by_type=%s missing_by_type=%s", + len(flat), + total_by_type, + existing_by_type, + missing_by_type, + ) + if missing_indices: + self.logger.warning( + "[SearchHandler] MMR embedding metadata missing; will compute missing embeddings: missing_total=%s", + len(missing_indices), + ) + # Get or compute embeddings embeddings = self._extract_embeddings([mem for _, _, mem, _ in flat]) - if embeddings is None: - self.logger.warning("[SearchHandler] Embedding is missing; recomputing embeddings") - documents = [mem.get("memory", "") for _, _, mem, _ in flat] - embeddings = self.searcher.embedder.embed(documents) # Compute similarity matrix using NumPy-optimized method # Returns numpy array but compatible with list[i][j] indexing @@ -404,14 +428,32 @@ def _max_similarity( return 0.0 return max(similarity_matrix[index][j] for j in selected_indices) - @staticmethod - def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None: + def _extract_embeddings(self, memories: list[dict[str, Any]]) -> list[list[float]]: embeddings: list[list[float]] = [] - for mem in memories: - embedding = mem.get("metadata", {}).get("embedding") - if not embedding: - return None - embeddings.append(embedding) + missing_indices: list[int] = [] + missing_documents: list[str] = [] + + for idx, mem in enumerate(memories): + metadata = mem.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + mem["metadata"] = metadata + + embedding = metadata.get("embedding") + if embedding: + embeddings.append(embedding) + continue + + embeddings.append([]) + missing_indices.append(idx) + missing_documents.append(mem.get("memory", "")) + + if missing_indices: + computed = self.searcher.embedder.embed(missing_documents) + for idx, embedding in zip(missing_indices, computed, strict=False): + embeddings[idx] = embedding + memories[idx]["metadata"]["embedding"] = embedding + return embeddings @staticmethod diff --git a/src/memos/api/middleware/__init__.py b/src/memos/api/middleware/__init__.py index 64cbc5c60..fd39252f5 100644 --- a/src/memos/api/middleware/__init__.py +++ b/src/memos/api/middleware/__init__.py @@ -1,13 +1,14 @@ """Krolik middleware extensions for MemOS.""" -from .auth import verify_api_key, require_scope, require_admin, require_read, require_write +from .auth import require_admin, require_read, require_scope, require_write, verify_api_key from .rate_limit import RateLimitMiddleware + __all__ = [ - "verify_api_key", - "require_scope", + "RateLimitMiddleware", "require_admin", "require_read", + "require_scope", "require_write", - "RateLimitMiddleware", + "verify_api_key", ] diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 6fc03e735..5bf27e985 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -99,12 +99,12 @@ class ChatRequest(BaseRequest): manager_user_id: str | None = Field(None, description="Manager User ID") project_id: str | None = Field(None, description="Project ID") relativity: float = Field( - 0.0, + 0.45, ge=0, description=( "Relevance threshold for recalled memories. " "Only memories with metadata.relativity >= relativity will be returned. " - "Use 0 to disable threshold filtering. Default: 0.3." + "Use 0 to disable threshold filtering. Default: 0.45." ), ) @@ -339,12 +339,12 @@ class APISearchRequest(BaseRequest): ) relativity: float = Field( - 0.0, + 0.45, ge=0, description=( "Relevance threshold for recalled memories. " "Only memories with metadata.relativity >= relativity will be returned. " - "Use 0 to disable threshold filtering. Default: 0.3." + "Use 0 to disable threshold filtering. Default: 0.45." ), ) @@ -785,12 +785,12 @@ class APIChatCompleteRequest(BaseRequest): manager_user_id: str | None = Field(None, description="Manager User ID") project_id: str | None = Field(None, description="Project ID") relativity: float = Field( - 0.0, + 0.45, ge=0, description=( "Relevance threshold for recalled memories. " "Only memories with metadata.relativity >= relativity will be returned. " - "Use 0 to disable threshold filtering. Default: 0.3." + "Use 0 to disable threshold filtering. Default: 0.45." ), ) diff --git a/src/memos/api/utils/api_keys.py b/src/memos/api/utils/api_keys.py index 559ddd355..29b493fd0 100644 --- a/src/memos/api/utils/api_keys.py +++ b/src/memos/api/utils/api_keys.py @@ -5,8 +5,8 @@ """ import hashlib -import os import secrets + from dataclasses import dataclass from datetime import datetime, timedelta diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 318f0ac80..c71ed6b5a 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -90,7 +90,6 @@ async def _create_embeddings(): ) ) logger.info(f"Embeddings request succeeded with {time.time() - init_time} seconds") - logger.info(f"Embeddings request response: {response}") return [r.embedding for r in response.data] except Exception as e: if self.use_backup_client: diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index 130b66a3d..0bc4a54f8 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -1,12 +1,35 @@ +import re + from abc import ABC, abstractmethod from typing import Any, Literal +# Pattern for valid field names: alphanumeric and underscores, must start with letter or underscore +_VALID_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + class BaseGraphDB(ABC): """ Abstract base class for a graph database interface used in a memory-augmented RAG system. """ + @staticmethod + def _validate_return_fields(return_fields: list[str] | None) -> list[str]: + """Validate and sanitize return_fields to prevent query injection. + + Only allows alphanumeric characters and underscores in field names. + Silently drops invalid field names. + + Args: + return_fields: List of field names to validate. + + Returns: + List of valid field names. + """ + if not return_fields: + return [] + return [f for f in return_fields if _VALID_FIELD_NAME_RE.match(f)] + # Node (Memory) Management @abstractmethod def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: @@ -144,16 +167,23 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: # Search / recall operations @abstractmethod - def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) -> list[dict]: + def search_by_embedding( + self, vector: list[float], top_k: int = 5, return_fields: list[str] | None = None, **kwargs + ) -> list[dict]: """ Retrieve node IDs based on vector similarity. Args: vector (list[float]): The embedding vector representing query semantics. top_k (int): Number of top similar nodes to retrieve. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result dict will + contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method may internally call a VecDB (e.g., Qdrant) or store embeddings in the graph DB itself. diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 746051187..33eb39692 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -818,6 +818,7 @@ def search_by_embedding( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -832,9 +833,14 @@ def search_by_embedding( threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. Keys should match node properties, values are the expected values. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result + dict will contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method uses Neo4j native vector indexing to search for similar nodes. @@ -886,11 +892,20 @@ def search_by_embedding( if where_clauses: where_clause = "WHERE " + " AND ".join(where_clauses) + return_clause = "RETURN node.id AS id, score" + if return_fields: + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"node.{field} AS {field}" for field in validated_fields if field != "id" + ) + if extra_fields: + return_clause = f"RETURN node.id AS id, score, {extra_fields}" + query = f""" CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding) YIELD node, score {where_clause} - RETURN node.id AS id, score + {return_clause} """ parameters = {"embedding": vector, "k": top_k} @@ -920,7 +935,15 @@ def search_by_embedding( print(f"[search_by_embedding] query: {query},parameters: {parameters}") with self.driver.session(database=self.db_name) as session: result = session.run(query, parameters) - records = [{"id": record["id"], "score": record["score"]} for record in result] + records = [] + for record in result: + item = {"id": record["id"], "score": record["score"]} + if return_fields: + record_keys = record.keys() + for field in return_fields: + if field != "id" and field in record_keys: + item[field] = record[field] + records.append(item) # Threshold filtering after retrieval if threshold is not None: @@ -943,8 +966,8 @@ def search_by_fulltext( **kwargs, ) -> list[dict]: """ - TODO: 实现 Neo4j 的关键词检索, 以兼容 TreeTextMemory 的 keyword/fulltext 召回路径. - 目前先返回空列表, 避免切换到 Neo4j 后因缺失方法导致运行时报错. + TODO: Implement fulltext search for Neo4j to be compatible with TreeTextMemory's keyword/fulltext recall path. + Currently, return an empty list to avoid runtime errors due to missing methods when switching to Neo4j. """ return [] diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index cae7d6ca5..09ad46c42 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -246,6 +246,39 @@ def get_children_with_embeddings( return child_nodes + def _fetch_return_fields( + self, + ids: list[str], + score_map: dict[str, float], + return_fields: list[str], + ) -> list[dict]: + """Fetch additional fields from Neo4j for given node IDs.""" + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"n.{field} AS {field}" for field in validated_fields if field != "id" + ) + return_clause = "RETURN n.id AS id" + if extra_fields: + return_clause = f"RETURN n.id AS id, {extra_fields}" + + query = f""" + MATCH (n:Memory) + WHERE n.id IN $ids + {return_clause} + """ + with self.driver.session(database=self.db_name) as session: + neo4j_results = session.run(query, {"ids": ids}) + results = [] + for record in neo4j_results: + node_id = record["id"] + item = {"id": node_id, "score": score_map.get(node_id)} + record_keys = record.keys() + for field in return_fields: + if field != "id" and field in record_keys: + item[field] = record[field] + results.append(item) + return results + # Search / recall operations def search_by_embedding( self, @@ -258,6 +291,7 @@ def search_by_embedding( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -273,9 +307,14 @@ def search_by_embedding( filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]} knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result dict will + contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method uses an external vector database (not Neo4j) to perform the search. @@ -320,7 +359,14 @@ def search_by_embedding( # If no filter or knowledgebase_ids provided, return vector search results directly if not filter and not knowledgebase_ids: - return [{"id": r.id, "score": r.score} for r in vec_results] + if not return_fields: + return [{"id": r.id, "score": r.score} for r in vec_results] + # Need to fetch additional fields from Neo4j + vec_ids = [r.id for r in vec_results] + if not vec_ids: + return [] + score_map = {r.id: r.score for r in vec_results} + return self._fetch_return_fields(vec_ids, score_map, return_fields) # Extract IDs from vector search results vec_ids = [r.id for r in vec_results] @@ -363,22 +409,49 @@ def search_by_embedding( if filter_params: params.update(filter_params) + # Build RETURN clause with optional extra fields + return_clause = "RETURN n.id AS id" + if return_fields: + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"n.{field} AS {field}" for field in validated_fields if field != "id" + ) + if extra_fields: + return_clause = f"RETURN n.id AS id, {extra_fields}" + # Query Neo4j to filter results query = f""" MATCH (n:Memory) {where_clause} - RETURN n.id AS id + {return_clause} """ logger.info(f"[search_by_embedding] query: {query}, params: {params}") with self.driver.session(database=self.db_name) as session: neo4j_results = session.run(query, params) - filtered_ids = {record["id"] for record in neo4j_results} + if return_fields: + # Build a map of id -> extra fields from Neo4j results + neo4j_data = {} + for record in neo4j_results: + node_id = record["id"] + record_keys = record.keys() + neo4j_data[node_id] = { + field: record[field] + for field in return_fields + if field != "id" and field in record_keys + } + filtered_ids = set(neo4j_data.keys()) + else: + filtered_ids = {record["id"] for record in neo4j_results} # Filter vector results by Neo4j filtered IDs and return with scores - filtered_results = [ - {"id": r.id, "score": r.score} for r in vec_results if r.id in filtered_ids - ] + filtered_results = [] + for r in vec_results: + if r.id in filtered_ids: + item = {"id": r.id, "score": r.score} + if return_fields and r.id in neo4j_data: + item.update(neo4j_data[r.id]) + filtered_results.append(item) return filtered_results @@ -397,8 +470,8 @@ def search_by_fulltext( **kwargs, ) -> list[dict]: """ - TODO: 实现 Neo4j Community 的关键词检索, 以兼容 TreeTextMemory 的 keyword/fulltext 召回路径. - 目前先返回空列表, 避免切换到 Neo4j 后因缺失方法导致运行时报错. + TODO: Implement fulltext search for Neo4j to be compatible with TreeTextMemory's keyword/fulltext recall path. + Currently, return an empty list to avoid runtime errors due to missing methods when switching to Neo4j. """ return [] @@ -1122,7 +1195,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]] # Merge embeddings into parsed nodes for parsed_node in parsed_nodes: node_id = parsed_node["id"] - parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None) + parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id) return parsed_nodes diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index f0a23e39b..592f45a7f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -204,21 +204,6 @@ def _get_connection_old(self): return conn def _get_connection(self): - """ - Get a connection from the pool. - - This function: - 1. Gets a connection from ThreadedConnectionPool - 2. Checks if connection is closed or unhealthy - 3. Returns healthy connection or retries (max 3 times) - 4. Handles connection pool exhaustion gracefully - - Returns: - psycopg2 connection object - - Raises: - RuntimeError: If connection pool is closed or exhausted after retries - """ logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") if self._pool_closed: raise RuntimeError("Connection pool has been closed") @@ -229,13 +214,9 @@ def _get_connection(self): for attempt in range(max_retries): conn = None try: - # Try to get connection from pool - # This may raise PoolError if pool is exhausted conn = self.connection_pool.getconn() - # Check if connection is closed if conn.closed != 0: - # Connection is closed, return it to pool with close flag and try again logger.warning( f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" ) @@ -295,19 +276,17 @@ def _get_connection(self): return conn except psycopg2.pool.PoolError as pool_error: - # Pool exhausted or other pool-related error - # Don't retry immediately for pool exhaustion - it's unlikely to resolve quickly error_msg = str(pool_error).lower() if "exhausted" in error_msg or "pool" in error_msg: # Log pool status for debugging try: # Try to get pool stats if available pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" - logger.error( - f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" + logger.info( + f" polardb get_connection Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" ) except Exception: - logger.error( + logger.warning( f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" ) @@ -323,7 +302,6 @@ def _get_connection(self): raise RuntimeError( f"Connection pool exhausted after {max_retries} attempts. " f"This usually means connections are not being returned to the pool. " - f"Check for connection leaks in your code." ) from pool_error else: # Other pool errors - retry with normal backoff @@ -337,12 +315,8 @@ def _get_connection(self): ) from pool_error except Exception as e: - # Other exceptions (not pool-related) - # Only try to return connection if we actually got one - # If getconn() failed (e.g., pool exhausted), conn will be None if conn is not None: try: - # Return connection to pool if it's valid self.connection_pool.putconn(conn, close=True) except Exception as putconn_error: logger.warning( @@ -363,20 +337,7 @@ def _get_connection(self): raise RuntimeError("Failed to get connection after all retries") def _return_connection(self, connection): - """ - Return a connection to the pool. - - This function safely returns a connection to the pool, handling: - - Closed connections (close them instead of returning) - - Pool closed state (close connection directly) - - None connections (no-op) - - putconn() failures (close connection as fallback) - - Args: - connection: psycopg2 connection object or None - """ if self._pool_closed: - # Pool is closed, just close the connection if it exists if connection: try: connection.close() @@ -388,13 +349,10 @@ def _return_connection(self, connection): return if not connection: - # No connection to return - this is normal if _get_connection() failed return try: - # Check if connection is closed if hasattr(connection, "closed") and connection.closed != 0: - # Connection is closed, just close it explicitly and don't return to pool logger.debug( "[_return_connection] Connection is closed, closing it instead of returning to pool" ) @@ -404,12 +362,9 @@ def _return_connection(self, connection): logger.warning(f"[_return_connection] Failed to close closed connection: {e}") return - # Connection is valid, return to pool self.connection_pool.putconn(connection) logger.debug("[_return_connection] Successfully returned connection to pool") except Exception as e: - # If putconn fails, try to close the connection - # This prevents connection leaks if putconn() fails logger.error( f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True ) @@ -841,8 +796,8 @@ def add_edge( start_time = time.time() if not source_id or not target_id: - logger.warning(f"Edge '{source_id}' and '{target_id}' are both None") - raise ValueError("[add_edge] source_id and target_id must be provided") + logger.error(f"Edge '{source_id}' and '{target_id}' are both None") + return source_exists = self.get_node(source_id) is not None target_exists = self.get_node(target_id) is not None @@ -851,7 +806,7 @@ def add_edge( logger.warning( "[add_edge] Source %s or target %s does not exist.", source_exists, target_exists ) - raise ValueError("[add_edge] source_id and target_id must be provided") + return properties = {} if user_name is not None: @@ -1116,9 +1071,7 @@ def get_node( self._return_connection(conn) @timed - def get_nodes( - self, ids: list[str], user_name: str | None = None, **kwargs - ) -> list[dict[str, Any]]: + def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, Any]]: """ Retrieve the metadata and memory of a list of nodes. Args: @@ -1690,6 +1643,36 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError + def _extract_fields_from_properties( + self, properties: Any, return_fields: list[str] + ) -> dict[str, Any]: + """Extract requested fields from a PolarDB properties agtype/JSON value. + + Args: + properties: The raw properties value from a PolarDB row (agtype or JSON string). + return_fields: List of field names to extract. + + Returns: + dict with field_name -> value for each requested field found in properties. + """ + result = {} + return_fields = self._validate_return_fields(return_fields) + if not properties or not return_fields: + return result + try: + if isinstance(properties, str): + props = json.loads(properties) + elif isinstance(properties, dict): + props = properties + else: + props = json.loads(str(properties)) + except (json.JSONDecodeError, TypeError, ValueError): + return result + for field in return_fields: + if field != "id" and field in props: + result[field] = props[field] + return result + @timed def search_by_keywords_like( self, @@ -1700,6 +1683,7 @@ def search_by_keywords_like( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: where_clauses = [] @@ -1751,10 +1735,14 @@ def search_by_keywords_like( where_clauses.append("""(properties -> '"memory"')::text LIKE %s""") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - query = f""" - SELECT + select_clause = """SELECT ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text + agtype_object_field_text(properties, 'memory') as memory_text""" + if return_fields: + select_clause += ", properties" + + query = f""" + {select_clause} FROM "{self.db_name}_graph"."Memory" {where_clause} """ @@ -1775,7 +1763,11 @@ def search_by_keywords_like( id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): id_val = id_val[1:-1] - output.append({"id": id_val}) + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) logger.info( f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" ) @@ -1795,6 +1787,7 @@ def search_by_keywords_tfidf( knowledgebase_ids: list[str] | None = None, tsvector_field: str = "properties_tsvector_zh", tsquery_config: str = "jiebaqry", + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: where_clauses = [] @@ -1850,10 +1843,14 @@ def search_by_keywords_tfidf( where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" # Build fulltext search query - query = f""" - SELECT + select_clause = """SELECT ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text + agtype_object_field_text(properties, 'memory') as memory_text""" + if return_fields: + select_clause += ", properties" + + query = f""" + {select_clause} FROM "{self.db_name}_graph"."Memory" {where_clause} """ @@ -1874,7 +1871,11 @@ def search_by_keywords_tfidf( id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): id_val = id_val[1:-1] - output.append({"id": id_val}) + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) logger.info( f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" @@ -1897,6 +1898,7 @@ def search_by_fulltext( knowledgebase_ids: list[str] | None = None, tsvector_field: str = "properties_tsvector_zh", tsquery_config: str = "jiebacfg", + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -1914,15 +1916,16 @@ def search_by_fulltext( filter: filter conditions with 'and' or 'or' logic for search results. tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + return_fields: additional node fields to include in results **kwargs: other parameters (e.g. cube_name) Returns: - list[dict]: result list containing id and score + list[dict]: result list containing id and score. + If return_fields is specified, each dict also includes the requested fields. """ logger.info( f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}" ) - # Build WHERE clause dynamically, same as search_by_embedding start_time = time.time() where_clauses = [] @@ -1966,13 +1969,10 @@ def search_by_fulltext( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) - # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) - # Add fulltext search condition - # Convert query_text to OR query format: "word1 | word2 | word3" tsquery_string = " | ".join(query_words) where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") @@ -1981,19 +1981,31 @@ def search_by_fulltext( logger.info(f"[search_by_fulltext] where_clause: {where_clause}") - # Build fulltext search query + select_cols = f"""ag_catalog.agtype_access_operator(m.properties, '"id"'::agtype) AS old_id, + ts_rank(m.{tsvector_field}, q.fq) AS rank""" + if return_fields: + select_cols += ", m.properties" + where_with_q = [] + for w in where_clauses: + if f"{tsvector_field} @@ to_tsquery(" in w: + where_with_q.append(f"m.{tsvector_field} @@ q.fq") + else: + where_with_q.append( + w.replace("(properties,", "(m.properties,") + .replace("(properties)", "(m.properties)") + .replace("ARRAY[properties,", "ARRAY[m.properties,") + ) + where_clause_cte = f"WHERE {' AND '.join(where_with_q)}" if where_with_q else "" query = f""" - SELECT - ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text, - ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank - FROM "{self.db_name}_graph"."Memory" - {where_clause} + WITH q AS (SELECT to_tsquery('{tsquery_config}', %s) AS fq) + SELECT {select_cols} + FROM "{self.db_name}_graph"."Memory" m + CROSS JOIN q + {where_clause_cte} ORDER BY rank DESC LIMIT {top_k}; """ - - params = [tsquery_string, tsquery_string] + params = [tsquery_string] logger.info(f"[search_by_fulltext] query: {query}, params: {params}") conn = None try: @@ -2004,7 +2016,7 @@ def search_by_fulltext( output = [] for row in results: oldid = row[0] # old_id - rank = row[2] # rank score + rank = row[1] # rank score (no memory_text column) id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): @@ -2013,10 +2025,16 @@ def search_by_fulltext( # Apply threshold filter if specified if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[2] # properties column + item.update( + self._extract_fields_from_properties(properties, return_fields) + ) + output.append(item) elapsed_time = time.time() - start_time logger.info( - f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s" + f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s" ) return output[:top_k] finally: @@ -2026,23 +2044,21 @@ def search_by_fulltext( def search_by_embedding( self, vector: list[float], + user_name: str, top_k: int = 5, scope: str | None = None, status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, - user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: - """ - Retrieve node IDs based on vector similarity using PostgreSQL vector operations. - """ - # Build WHERE clause dynamically like nebular.py logger.info( - f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + f"search_by_embedding user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},scope:{scope},status:{status},search_filter:{search_filter},filter:{filter},knowledgebase_ids:{knowledgebase_ids},return_fields:{return_fields}" ) + start_time = time.time() where_clauses = [] if scope: where_clauses.append( @@ -2057,31 +2073,18 @@ def search_by_embedding( "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" ) where_clauses.append("embedding is not null") - # Add user_name filter like nebular.py - - """ - # user_name = self._get_config_value("user_name") - # if not self.config.use_multi_db and user_name: - # if kwargs.get("cube_name"): - # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{kwargs['cube_name']}\"'::agtype") - # else: - # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype") - """ - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, default_user_name=self.config.user_name, ) - # Add OR condition if we have any user_name conditions if user_name_conditions: if len(user_name_conditions) == 1: where_clauses.append(user_name_conditions[0]) else: where_clauses.append(f"({' OR '.join(user_name_conditions)})") - # Add search_filter conditions like nebular.py if search_filter: for key, value in search_filter.items(): if isinstance(value, str): @@ -2093,14 +2096,12 @@ def search_by_embedding( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) - # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - # Keep original simple query structure but add dynamic WHERE clause query = f""" WITH t AS ( SELECT id, @@ -2117,19 +2118,12 @@ def search_by_embedding( FROM t WHERE scope > 0.1; """ - # Convert vector to string format for PostgreSQL vector type - # PostgreSQL vector type expects a string format like '[1,2,3]' vector_str = convert_to_vector(vector) - # Use string format directly in query instead of parameterized query - # Replace %s with the vector string, but need to quote it properly - # PostgreSQL vector type needs the string to be quoted query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)") params = [] - # Split query by lines and wrap long lines to prevent terminal truncation query_lines = query.strip().split("\n") for line in query_lines: - # Wrap lines longer than 200 characters to prevent terminal truncation if len(line) > 200: wrapped_lines = textwrap.wrap( line, width=200, break_long_words=False, break_on_hyphens=False @@ -2145,28 +2139,13 @@ def search_by_embedding( try: conn = self._get_connection() with conn.cursor() as cursor: - try: - # If params is empty, execute query directly without parameters - if params: - cursor.execute(query, params) - else: - cursor.execute(query) - except Exception as e: - logger.error(f"[search_by_embedding] Error executing query: {e}") - logger.error(f"[search_by_embedding] Query length: {len(query)}") - logger.error( - f"[search_by_embedding] Params type: {type(params)}, length: {len(params)}" - ) - logger.error(f"[search_by_embedding] Query contains %s: {'%s' in query}") - raise + if params: + cursor.execute(query, params) + else: + cursor.execute(query) results = cursor.fetchall() output = [] for row in results: - """ - polarId = row[0] # id - properties = row[1] # properties - # embedding = row[3] # embedding - """ if len(row) < 5: logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") continue @@ -2178,7 +2157,17 @@ def search_by_embedding( score_val = float(score) score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[1] # properties column + item.update( + self._extract_fields_from_properties(properties, return_fields) + ) + output.append(item) + elapsed_time = time.time() - start_time + logger.info( + f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" + ) return output[:top_k] finally: self._return_connection(conn) @@ -2187,7 +2176,7 @@ def search_by_embedding( def get_by_metadata( self, filters: list[dict[str, Any]], - user_name: str | None = None, + user_name: str, filter: dict | None = None, knowledgebase_ids: list | None = None, user_name_flag: bool = True, @@ -2209,7 +2198,9 @@ def get_by_metadata( Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). """ - logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") + logger.info( + f" get_by_metadata user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},filters:{filters}" + ) user_name = user_name if user_name else self._get_config_value("user_name") @@ -2264,9 +2255,6 @@ def get_by_metadata( else: raise ValueError(f"Unsupported operator: {op}") - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( user_name=user_name, knowledgebase_ids=knowledgebase_ids, @@ -2306,7 +2294,7 @@ def get_by_metadata( results = cursor.fetchall() ids = [str(item[0]).strip('"') for item in results] except Exception as e: - logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}") finally: self._return_connection(conn) @@ -2536,8 +2524,8 @@ def clear(self, user_name: str | None = None) -> None: @timed def export_graph( self, + user_name: str, include_embedding: bool = False, - user_name: str | None = None, user_id: str | None = None, page: int | None = None, page_size: int | None = None, @@ -2576,7 +2564,7 @@ def export_graph( } """ logger.info( - f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}" + f" export_graph include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}" ) user_id = user_id if user_id else self._get_config_value("user_id") @@ -2724,159 +2712,7 @@ def export_graph( finally: self._return_connection(conn) - conn = None - try: - conn = self._get_connection() - # Build Cypher WHERE conditions for edges - cypher_where_conditions = [] - if user_name: - cypher_where_conditions.append(f"a.user_name = '{user_name}'") - cypher_where_conditions.append(f"b.user_name = '{user_name}'") - if user_id: - cypher_where_conditions.append(f"a.user_id = '{user_id}'") - cypher_where_conditions.append(f"b.user_id = '{user_id}'") - - # Add memory_type filter condition for edges (apply to both source and target nodes) - if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: - # Escape single quotes in memory_type values for Cypher - escaped_memory_types = [mt.replace("'", "\\'") for mt in memory_type] - memory_type_list_str = ", ".join([f"'{mt}'" for mt in escaped_memory_types]) - # Cypher IN syntax: a.memory_type IN ['LongTermMemory', 'WorkingMemory'] - cypher_where_conditions.append(f"a.memory_type IN [{memory_type_list_str}]") - cypher_where_conditions.append(f"b.memory_type IN [{memory_type_list_str}]") - - # Add status filter for edges: if not passed, exclude deleted; otherwise filter by IN list - if status is None: - # Default behavior: exclude deleted entries - cypher_where_conditions.append("a.status <> 'deleted' AND b.status <> 'deleted'") - elif isinstance(status, list) and len(status) > 0: - escaped_statuses = [st.replace("'", "\\'") for st in status] - status_list_str = ", ".join([f"'{st}'" for st in escaped_statuses]) - cypher_where_conditions.append(f"a.status IN [{status_list_str}]") - cypher_where_conditions.append(f"b.status IN [{status_list_str}]") - - # Build filter conditions for edges (apply to both source and target nodes) - filter_where_clause = self._build_filter_conditions_cypher(filter) - logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}") - if filter_where_clause: - # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists - # Remove the leading " AND " and replace n. with a. for source node and b. for target node - filter_clause = filter_where_clause.strip() - if filter_clause.startswith("AND "): - filter_clause = filter_clause[4:].strip() - # Replace n. with a. for source node and create a copy for target node - source_filter = filter_clause.replace("n.", "a.") - target_filter = filter_clause.replace("n.", "b.") - # Combine source and target filters with AND - combined_filter = f"({source_filter}) AND ({target_filter})" - cypher_where_conditions.append(combined_filter) - - cypher_where_clause = "" - if cypher_where_conditions: - cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" - - # Get total count of edges before pagination - count_edge_query = f""" - SELECT COUNT(*) - FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - {cypher_where_clause} - RETURN a.id AS source, b.id AS target, type(r) as edge - $$) AS (source agtype, target agtype, edge agtype) - ) AS edges - """ - logger.info(f"[export_graph edges count] Query: {count_edge_query}") - with conn.cursor() as cursor: - cursor.execute(count_edge_query) - total_edges = cursor.fetchone()[0] - - # Export edges using cypher query - # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery - # Build pagination clause if needed - edge_pagination_clause = "" - if use_pagination: - edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - - edge_query = f""" - SELECT source, target, edge FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - {cypher_where_clause} - RETURN a.id AS source, b.id AS target, type(r) as edge - ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC, - COALESCE(b.created_at, '1970-01-01T00:00:00') DESC, - a.id DESC, b.id DESC - $$) AS (source agtype, target agtype, edge agtype) - ) AS edges - {edge_pagination_clause} - """ - logger.info(f"[export_graph edges] Query: {edge_query}") - with conn.cursor() as cursor: - cursor.execute(edge_query) - edge_results = cursor.fetchall() - edges = [] - - for row in edge_results: - source_agtype, target_agtype, edge_agtype = row - - # Extract and clean source - source_raw = ( - source_agtype.value - if hasattr(source_agtype, "value") - else str(source_agtype) - ) - if ( - isinstance(source_raw, str) - and source_raw.startswith('"') - and source_raw.endswith('"') - ): - source = source_raw[1:-1] - else: - source = str(source_raw) - - # Extract and clean target - target_raw = ( - target_agtype.value - if hasattr(target_agtype, "value") - else str(target_agtype) - ) - if ( - isinstance(target_raw, str) - and target_raw.startswith('"') - and target_raw.endswith('"') - ): - target = target_raw[1:-1] - else: - target = str(target_raw) - - # Extract and clean edge type - type_raw = ( - edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype) - ) - if ( - isinstance(type_raw, str) - and type_raw.startswith('"') - and type_raw.endswith('"') - ): - edge_type = type_raw[1:-1] - else: - edge_type = str(type_raw) - - edges.append( - { - "source": source, - "target": target, - "type": edge_type, - } - ) - - except Exception as e: - logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) - raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e - finally: - self._return_connection(conn) - + edges = [] return { "nodes": nodes, "edges": edges, @@ -2908,8 +2744,8 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: def get_all_memory_items( self, scope: str, + user_name: str, include_embedding: bool = False, - user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list | None = None, status: str | None = None, @@ -2930,14 +2766,13 @@ def get_all_memory_items( list[dict]: Full list of memory items under this scope. """ logger.info( - f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}" + f"[get_all_memory_items] user_name: {user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status},scope:{scope}" ) user_name = user_name if user_name else self._get_config_value("user_name") if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( user_name=user_name, knowledgebase_ids=knowledgebase_ids, @@ -3015,7 +2850,7 @@ def get_all_memory_items( node_ids.add(node_id) except Exception as e: - logger.error(f"Failed to get memories: {e}", exc_info=True) + logger.warning(f"Failed to get memories: {e}", exc_info=True) finally: self._return_connection(conn) @@ -4199,34 +4034,47 @@ def get_edges( ... ] """ + start_time = time.time() + logger.info(f" get_edges id:{id},type:{type},direction:{direction},user_name:{user_name}") user_name = user_name if user_name else self._get_config_value("user_name") - - if direction == "OUTGOING": - pattern = "(a:Memory)-[r]->(b:Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "INCOMING": - pattern = "(a:Memory)<-[r]-(b:Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "ANY": - pattern = "(a:Memory)-[r]-(b:Memory)" - where_clause = f"a.id = '{id}' OR b.id = '{id}'" - else: + if direction not in ("OUTGOING", "INCOMING", "ANY"): raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - # Add type filter - if type != "ANY": - where_clause += f" AND type(r) = '{type}'" - - # Add user filter - where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + # Escape single quotes for safe embedding in Cypher string + id_esc = (id or "").replace("'", "''") + user_esc = (user_name or "").replace("'", "''") + type_esc = (type or "").replace("'", "''") + type_filter = f" AND type(r) = '{type_esc}'" if type != "ANY" else "" + logger.info(f"type_filter:{type_filter}") + if direction == "OUTGOING": + cypher_body = f""" + MATCH (a:Memory)-[r:{type}]->(b:Memory) + WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}' + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + """ + elif direction == "INCOMING": + cypher_body = f""" + MATCH (b:Memory)<-[r:{type}]-(a:Memory) + WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}' + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + """ + else: # ANY: union of OUTGOING and INCOMING + cypher_body = f""" + MATCH (a:Memory)-[r]->(b:Memory) + WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter} + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + UNION ALL + MATCH (b:Memory)<-[r]-(a:Memory) + WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter} + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + """ query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH {pattern} - WHERE {where_clause} - RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + {cypher_body.strip()} $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ + logger.info(f"get_edges query:{query}") conn = None try: conn = self._get_connection() @@ -4270,6 +4118,8 @@ def get_edges( edge_type = str(edge_type_raw) edges.append({"from": from_id, "to": to_id, "type": edge_type}) + elapsed_time = time.time() - start_time + logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s") return edges except Exception as e: diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index a27d64758..e3d2bece9 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -721,6 +721,7 @@ def _process_one_item( m_maybe_merged.get("memory_type", "LongTermMemory") .replace("长期记忆", "LongTermMemory") .replace("用户记忆", "UserMemory") + .replace("pref", "UserMemory") ) node = self._make_memory_item( value=m_maybe_merged.get("value", ""), diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 2b49d63ba..1b4add398 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -50,7 +50,9 @@ class FileContentParser(BaseMessageParser): """Parser for file content parts.""" - def _get_doc_llm_response(self, chunk_text: str, custom_tags: list[str] | None = None) -> dict: + def _get_doc_llm_response( + self, chunk_text: str, custom_tags: list[str] | None = None + ) -> dict | list: """ Call LLM to extract memory from document chunk. Uses doc prompts from DOC_PROMPT_DICT. @@ -60,7 +62,7 @@ def _get_doc_llm_response(self, chunk_text: str, custom_tags: list[str] | None = custom_tags: Optional list of custom tags for LLM extraction Returns: - Parsed JSON response from LLM or empty dict if failed + Parsed JSON response from LLM (dict or list) or empty dict if failed """ if not self.llm: logger.warning("[FileContentParser] LLM not available for fine mode") @@ -777,35 +779,49 @@ def _make_fallback( return [_make_fallback(idx, text, "no_llm") for idx, text in valid_chunks] # Process single chunk with LLM extraction (worker function) - def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: - """Process chunk with LLM, fallback to raw on failure.""" + def _process_chunk(chunk_idx: int, chunk_text: str) -> list[TextualMemoryItem]: + """Process chunk with LLM, fallback to raw on failure. Returns list of memory items.""" try: response_json = self._get_doc_llm_response(chunk_text, custom_tags) if response_json: - value = response_json.get("value", "").strip() - if value: - tags = response_json.get("tags", []) - tags = tags if isinstance(tags, list) else [] - tags.extend(["mode:fine", "multimodal:file"]) - - llm_mem_type = response_json.get("memory_type", memory_type) - if llm_mem_type not in ["LongTermMemory", "UserMemory"]: - llm_mem_type = memory_type - - return _make_memory_item( - value=value, - mem_type=llm_mem_type, - tags=tags, - key=response_json.get("key"), - chunk_idx=chunk_idx, - chunk_content=chunk_text, - ) + # Handle list format response + response_list = response_json.get("memory list", []) + memory_items = [] + for item_data in response_list: + if not isinstance(item_data, dict): + continue + + value = item_data.get("value", "").strip() + if value: + tags = item_data.get("tags", []) + tags = tags if isinstance(tags, list) else [] + tags.extend(["mode:fine", "multimodal:file"]) + key_str = item_data.get("key", "") + + llm_mem_type = item_data.get("memory_type", memory_type) + if llm_mem_type not in ["LongTermMemory", "UserMemory"]: + llm_mem_type = memory_type + + memory_item = _make_memory_item( + value=value, + mem_type=llm_mem_type, + tags=tags, + key=key_str, + chunk_idx=chunk_idx, + chunk_content=chunk_text, + ) + memory_items.append(memory_item) + + if memory_items: + return memory_items + else: + return [_make_fallback(chunk_idx, chunk_text)] except Exception as e: logger.error(f"[FileContentParser] LLM error for chunk {chunk_idx}: {e}") # Fallback to raw chunk logger.warning(f"[FileContentParser] Fallback to raw for chunk {chunk_idx}") - return _make_fallback(chunk_idx, chunk_text) + return [_make_fallback(chunk_idx, chunk_text)] def _relate_chunks(items: list[TextualMemoryItem]) -> None: """ @@ -853,30 +869,37 @@ def get_chunk_idx(item: TextualMemoryItem) -> int: ): chunk_idx = futures[future] try: - node = future.result() - memory_items.append(node) - - # Check if this node is a fallback by checking tags - is_fallback = any(tag.startswith("fallback:") for tag in node.metadata.tags) - if is_fallback: - fallback_count += 1 - - # save raw file - node_id = node.id - if node.memory != node.metadata.sources[0].content: - chunk_node = _make_memory_item( - value=node.metadata.sources[0].content, - mem_type="RawFileMemory", - tags=[ - "mode:fine", - "multimodal:file", - f"chunk:{chunk_idx + 1}/{total_chunks}", - ], - chunk_idx=chunk_idx, - chunk_content="", - ) - chunk_node.metadata.summary_ids = [node_id] - memory_items.append(chunk_node) + nodes = future.result() + memory_items.extend(nodes) + + # Check if any node is a fallback by checking tags + has_fallback = False + for node in nodes: + is_fallback = any(tag.startswith("fallback:") for tag in node.metadata.tags) + if is_fallback: + fallback_count += 1 + has_fallback = True + + # save raw file only if no fallback (all nodes are LLM-extracted) + if not has_fallback and nodes: + # Use first node's source info for raw file + first_node = nodes[0] + if first_node.metadata.sources and len(first_node.metadata.sources) > 0: + # Collect all node IDs for summary_ids + node_ids = [node.id for node in nodes] + chunk_node = _make_memory_item( + value=first_node.metadata.sources[0].content, + mem_type="RawFileMemory", + tags=[ + "mode:fine", + "multimodal:file", + f"chunk:{chunk_idx + 1}/{total_chunks}", + ], + chunk_idx=chunk_idx, + chunk_content="", + ) + chunk_node.metadata.summary_ids = node_ids + memory_items.append(chunk_node) except Exception as e: tqdm.write(f"[ERROR] Chunk {chunk_idx} failed: {e}") diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py index d39955ac2..a9a727b08 100644 --- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py +++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py @@ -1019,7 +1019,9 @@ def process_skill_memory_fine( **kwargs, ) -> list[TextualMemoryItem]: skills_repo_backend = _get_skill_file_storage_location() - oss_client, missing_keys, flag = _skill_init(skills_repo_backend, oss_config, skills_dir_config) + oss_client, _missing_keys, flag = _skill_init( + skills_repo_backend, oss_config, skills_dir_config + ) if not flag: return [] diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py index 63718fd92..e4a88a635 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py @@ -68,33 +68,35 @@ def log_add_messages(self, msg: ScheduleMessageItem): mem_item = mem_cube.text_mem.get(memory_id=memory_id, user_name=msg.mem_cube_id) if mem_item is None: raise ValueError(f"Memory {memory_id} not found after retries") - key = getattr(mem_item.metadata, "key", None) or transform_name_to_key( - name=mem_item.memory - ) - exists = False original_content = None original_item_id = None - if key and hasattr(mem_cube.text_mem, "graph_store"): - candidates = mem_cube.text_mem.graph_store.get_by_metadata( - [ - {"field": "key", "op": "=", "value": key}, - { - "field": "memory_type", - "op": "=", - "value": mem_item.metadata.memory_type, - }, - ] + # Determine add vs update from the merged_from field set by the upstream + # mem_reader during fine extraction. When the LLM merges a new memory with + # existing ones it writes their IDs into metadata.info["merged_from"]. + # This avoids an extra graph DB query and the self-match / cross-user + # matching bugs that came with the old get_by_metadata approach. + merged_from = (getattr(mem_item.metadata, "info", None) or {}).get("merged_from") + if merged_from: + merged_ids = ( + merged_from + if isinstance(merged_from, list | tuple | set) + else [merged_from] ) - if candidates: - exists = True - original_item_id = candidates[0] + original_item_id = merged_ids[0] + try: original_mem_item = mem_cube.text_mem.get( memory_id=original_item_id, user_name=msg.mem_cube_id ) - original_content = original_mem_item.memory + original_content = original_mem_item.memory if original_mem_item else None + except Exception as e: + logger.warning( + "Failed to fetch original memory %s for update log: %s", + original_item_id, + e, + ) - if exists: + if merged_from: prepared_update_items_with_original.append( { "new_item": mem_item, diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 5d86c5589..20dbb63b2 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -259,13 +259,19 @@ def _process_memories_with_reader( source_doc_id = ( file_ids[0] if isinstance(file_ids, list) and file_ids else None ) + # Use merged_from to determine ADD vs UPDATE. + # The upstream mem_reader sets this during fine extraction when + # the new memory was merged with an existing one. + item_merged_from = (getattr(item.metadata, "info", None) or {}).get( + "merged_from" + ) kb_log_content.append( { "log_source": "KNOWLEDGE_BASE_LOG", "trigger_source": info.get("trigger_source", "Messages") if info else "Messages", - "operation": "ADD", + "operation": "UPDATE" if item_merged_from else "ADD", "memory_id": item.id, "content": item.memory, "original_content": None, @@ -302,29 +308,39 @@ def _process_memories_with_reader( else: add_content_legacy: list[dict] = [] add_meta_legacy: list[dict] = [] + update_content_legacy: list[dict] = [] + update_meta_legacy: list[dict] = [] for item_id, item in zip( enhanced_mem_ids, flattened_memories, strict=False ): key = getattr(item.metadata, "key", None) or transform_name_to_key( name=item.memory ) - add_content_legacy.append( - {"content": f"{key}: {item.memory}", "ref_id": item_id} - ) - add_meta_legacy.append( - { - "ref_id": item_id, - "id": item_id, - "key": item.metadata.key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } + item_merged_from = (getattr(item.metadata, "info", None) or {}).get( + "merged_from" ) + meta_entry = { + "ref_id": item_id, + "id": item_id, + "key": item.metadata.key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + if item_merged_from: + update_content_legacy.append( + {"content": f"{key}: {item.memory}", "ref_id": item_id} + ) + update_meta_legacy.append(meta_entry) + else: + add_content_legacy.append( + {"content": f"{key}: {item.memory}", "ref_id": item_id} + ) + add_meta_legacy.append(meta_entry) if add_content_legacy: event = self.scheduler_context.services.create_event_log( label="addMemory", @@ -342,6 +358,23 @@ def _process_memories_with_reader( ) event.task_id = task_id self.scheduler_context.services.submit_web_logs([event]) + if update_content_legacy: + event = self.scheduler_context.services.create_event_log( + label="updateMemory", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.scheduler_context.get_mem_cube(), + memcube_log_content=update_content_legacy, + metadata=update_meta_legacy, + memory_len=len(update_content_legacy), + memcube_name=self.scheduler_context.services.map_memcube_name( + mem_cube_id + ), + ) + event.task_id = task_id + self.scheduler_context.services.submit_web_logs([event]) else: logger.info("No enhanced memories generated by mem_reader") else: diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 6352d5840..8483a5151 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -124,25 +124,45 @@ def retrieve( explicit_prefs.sort(key=lambda x: x.score, reverse=True) implicit_prefs.sort(key=lambda x: x.score, reverse=True) - explicit_prefs_mem = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), + explicit_prefs_mem = [] + for pref in explicit_prefs: + if not pref.payload.get("preference", None): + continue + if "embedding" in pref.payload: + payload = pref.payload + else: + pref_vector = getattr(pref, "vector", None) + if pref_vector is None: + payload = pref.payload + else: + payload = {**pref.payload, "embedding": pref_vector} + explicit_prefs_mem.append( + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**payload), + ) ) - for pref in explicit_prefs - if pref.payload.get("preference", None) - ] - implicit_prefs_mem = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), + implicit_prefs_mem = [] + for pref in implicit_prefs: + if not pref.payload.get("preference", None): + continue + if "embedding" in pref.payload: + payload = pref.payload + else: + pref_vector = getattr(pref, "vector", None) + if pref_vector is None: + payload = pref.payload + else: + payload = {**pref.payload, "embedding": pref_vector} + implicit_prefs_mem.append( + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**payload), + ) ) - for pref in implicit_prefs - if pref.payload.get("preference", None) - ] reranker_map = { "naive": self._naive_reranker, diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 5faf8aa09..5b210ba61 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -404,10 +404,10 @@ def delete_by_memory_ids(self, memory_ids: list[str]) -> None: except Exception as e: logger.error(f"An error occurred while deleting memories by memory_ids: {e}") - def delete_all(self) -> None: + def delete_all(self, user_name: str | None = None) -> None: """Delete all memories and their relationships from the graph store.""" try: - self.graph_store.clear() + self.graph_store.clear(user_name=user_name) logger.info("All memories and edges have been deleted from the graph.") except Exception as e: logger.error(f"An error occurred while deleting all memories: {e}") @@ -424,7 +424,7 @@ def delete_by_filter( writable_cube_ids=writable_cube_ids, file_ids=file_ids, filter=filter ) - def load(self, dir: str) -> None: + def load(self, dir: str, user_name: str | None = None) -> None: try: memory_file = os.path.join(dir, self.config.memory_filename) @@ -435,7 +435,7 @@ def load(self, dir: str) -> None: with open(memory_file, encoding="utf-8") as f: memories = json.load(f) - self.graph_store.import_graph(memories) + self.graph_store.import_graph(memories, user_name=user_name) logger.info(f"Loaded {len(memories)} memories from {memory_file}") except FileNotFoundError: @@ -445,10 +445,12 @@ def load(self, dir: str) -> None: except Exception as e: logger.error(f"An error occurred while loading memories: {e}") - def dump(self, dir: str, include_embedding: bool = False) -> None: + def dump(self, dir: str, include_embedding: bool = False, user_name: str | None = None) -> None: """Dump memories to os.path.join(dir, self.config.memory_filename)""" try: - json_memories = self.graph_store.export_graph(include_embedding=include_embedding) + json_memories = self.graph_store.export_graph( + include_embedding=include_embedding, user_name=user_name + ) os.makedirs(dir, exist_ok=True) memory_file = os.path.join(dir, self.config.memory_filename) diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py index 595cf099c..2d776912b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/handler.py +++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py @@ -27,18 +27,24 @@ def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: BaseEmbedd self.llm = llm self.embedder = embedder - def detect(self, memory, top_k: int = 5, scope=None): + def detect(self, memory, top_k: int = 5, scope=None, user_name: str | None = None): # 1. Search for similar memories based on embedding embedding = memory.metadata.embedding embedding_candidates_info = self.graph_store.search_by_embedding( - embedding, top_k=top_k, scope=scope, threshold=self.EMBEDDING_THRESHOLD + embedding, + top_k=top_k, + scope=scope, + threshold=self.EMBEDDING_THRESHOLD, + user_name=user_name, ) # 2. Filter based on similarity threshold embedding_candidates_ids = [ info["id"] for info in embedding_candidates_info if info["id"] != memory.id ] # 3. Judge conflicts using LLM - embedding_candidates = self.graph_store.get_nodes(embedding_candidates_ids) + embedding_candidates = self.graph_store.get_nodes( + embedding_candidates_ids, user_name=user_name + ) detected_relationships = [] for embedding_candidate in embedding_candidates: embedding_candidate = TextualMemoryItem.from_dict(embedding_candidate) @@ -67,13 +73,20 @@ def detect(self, memory, top_k: int = 5, scope=None): pass return detected_relationships - def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, relation) -> None: + def resolve( + self, + memory_a: TextualMemoryItem, + memory_b: TextualMemoryItem, + relation, + user_name: str | None = None, + ) -> None: """ Resolve detected conflicts between two memory items using LLM fusion. Args: memory_a: The first conflicting memory item. memory_b: The second conflicting memory item. relation: relation + user_name: Optional user name for multi-tenant isolation. Returns: A fused TextualMemoryItem representing the resolved memory. """ @@ -105,17 +118,22 @@ def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, rela logger.warning( f"{relation} between {memory_a.id} and {memory_b.id} could not be resolved. " ) - self._hard_update(memory_a, memory_b) + self._hard_update(memory_a, memory_b, user_name=user_name) # —————— 2.2 Conflict resolved, update metadata and memory ———— else: fixed_metadata = self._merge_metadata(answer, memory_a.metadata, memory_b.metadata) merged_memory = TextualMemoryItem(memory=answer, metadata=fixed_metadata) logger.info(f"Resolved result: {merged_memory}") - self._resolve_in_graph(memory_a, memory_b, merged_memory) + self._resolve_in_graph(memory_a, memory_b, merged_memory, user_name=user_name) except json.decoder.JSONDecodeError: logger.error(f"Failed to parse LLM response: {response}") - def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem): + def _hard_update( + self, + memory_a: TextualMemoryItem, + memory_b: TextualMemoryItem, + user_name: str | None = None, + ): """ Hard update: compare updated_at, keep the newer one, overwrite the older one's metadata. """ @@ -125,7 +143,7 @@ def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem) newer_mem = memory_a if time_a >= time_b else memory_b older_mem = memory_b if time_a >= time_b else memory_a - self.graph_store.delete_node(older_mem.id) + self.graph_store.delete_node(older_mem.id, user_name=user_name) logger.warning( f"Delete older memory {older_mem.id}: <{older_mem.memory}> due to conflict with {newer_mem.id}: <{newer_mem.memory}>" ) @@ -135,13 +153,21 @@ def _resolve_in_graph( conflict_a: TextualMemoryItem, conflict_b: TextualMemoryItem, merged: TextualMemoryItem, + user_name: str | None = None, ): - edges_a = self.graph_store.get_edges(conflict_a.id, type="ANY", direction="ANY") - edges_b = self.graph_store.get_edges(conflict_b.id, type="ANY", direction="ANY") + edges_a = self.graph_store.get_edges( + conflict_a.id, type="ANY", direction="ANY", user_name=user_name + ) + edges_b = self.graph_store.get_edges( + conflict_b.id, type="ANY", direction="ANY", user_name=user_name + ) all_edges = edges_a + edges_b self.graph_store.add_node( - merged.id, merged.memory, merged.metadata.model_dump(exclude_none=True) + merged.id, + merged.memory, + merged.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for edge in all_edges: @@ -150,13 +176,15 @@ def _resolve_in_graph( if new_from == new_to: continue # Check if the edge already exists before adding - if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"): - self.graph_store.add_edge(new_from, new_to, edge["type"]) - - self.graph_store.update_node(conflict_a.id, {"status": "archived"}) - self.graph_store.update_node(conflict_b.id, {"status": "archived"}) - self.graph_store.add_edge(conflict_a.id, merged.id, type="MERGED_TO") - self.graph_store.add_edge(conflict_b.id, merged.id, type="MERGED_TO") + if not self.graph_store.edge_exists( + new_from, new_to, edge["type"], direction="ANY", user_name=user_name + ): + self.graph_store.add_edge(new_from, new_to, edge["type"], user_name=user_name) + + self.graph_store.update_node(conflict_a.id, {"status": "archived"}, user_name=user_name) + self.graph_store.update_node(conflict_b.id, {"status": "archived"}, user_name=user_name) + self.graph_store.add_edge(conflict_a.id, merged.id, type="MERGED_TO", user_name=user_name) + self.graph_store.add_edge(conflict_b.id, merged.id, type="MERGED_TO", user_name=user_name) logger.debug( f"Archive {conflict_a.id} and {conflict_b.id}, and inherit their edges to {merged.id}." ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 1afdc9281..132582a0d 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -141,6 +141,7 @@ def mark_memory_status( self, memory_items: list[TextualMemoryItem], status: Literal["activated", "resolving", "archived", "deleted"], + user_name: str | None = None, ) -> None: """ Support status marking operations during history management. Common usages are: @@ -157,6 +158,7 @@ def mark_memory_status( self.graph_db.update_node, id=mem.id, fields={"status": status}, + user_name=user_name, ) ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index cbc349d67..4ca30c7b8 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -238,7 +238,9 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None: _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: - self.reorganizer.add_message(QueueMessage(op="add", after_node=graph_node_ids)) + self.reorganizer.add_message( + QueueMessage(op="add", after_node=graph_node_ids, user_name=user_name) + ) return added_ids @@ -411,16 +413,19 @@ def _add_to_graph_memory( QueueMessage( op="add", after_node=[node_id], + user_name=user_name, ) ) return node_id - def _inherit_edges(self, from_id: str, to_id: str) -> None: + def _inherit_edges(self, from_id: str, to_id: str, user_name: str | None = None) -> None: """ Migrate all non-lineage edges from `from_id` to `to_id`, and remove them from `from_id` after copying. """ - edges = self.graph_store.get_edges(from_id, type="ANY", direction="ANY") + edges = self.graph_store.get_edges( + from_id, type="ANY", direction="ANY", user_name=user_name + ) for edge in edges: if edge["type"] == "MERGED_TO": @@ -433,20 +438,29 @@ def _inherit_edges(self, from_id: str, to_id: str) -> None: continue # Add edge to merged node if it doesn't already exist - if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"): - self.graph_store.add_edge(new_from, new_to, edge["type"]) + if not self.graph_store.edge_exists( + new_from, new_to, edge["type"], direction="ANY", user_name=user_name + ): + self.graph_store.add_edge(new_from, new_to, edge["type"], user_name=user_name) # Remove original edge if it involved the archived node - self.graph_store.delete_edge(edge["from"], edge["to"], edge["type"]) + self.graph_store.delete_edge( + edge["from"], edge["to"], edge["type"], user_name=user_name + ) def _ensure_structure_path( - self, memory_type: str, metadata: TreeNodeTextualMemoryMetadata + self, + memory_type: str, + metadata: TreeNodeTextualMemoryMetadata, + user_name: str | None = None, ) -> str: """ Ensure structural path exists (ROOT → ... → final node), return last node ID. Args: - path: like ["hobby", "photography"] + memory_type: Memory type for the structure node. + metadata: Metadata containing key and other fields. + user_name: Optional user name for multi-tenant isolation. Returns: Final node ID of the structure path. @@ -456,7 +470,8 @@ def _ensure_structure_path( [ {"field": "memory", "op": "=", "value": metadata.key}, {"field": "memory_type", "op": "=", "value": memory_type}, - ] + ], + user_name=user_name, ) if existing: node_id = existing[0] # Use the first match @@ -479,14 +494,16 @@ def _ensure_structure_path( ), ) self.graph_store.add_node( - id=new_node.id, - memory=new_node.memory, - metadata=new_node.metadata.model_dump(exclude_none=True), + new_node.id, + new_node.memory, + new_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) self.reorganizer.add_message( QueueMessage( op="add", after_node=[new_node.id], + user_name=user_name, ) ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index ea06a7c60..b7fb6b1a0 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -52,12 +52,14 @@ def __init__( before_edge: list[str] | list[GraphDBEdge] | None = None, after_node: list[str] | list[GraphDBNode] | None = None, after_edge: list[str] | list[GraphDBEdge] | None = None, + user_name: str | None = None, ): self.op = op self.before_node = before_node self.before_edge = before_edge self.after_node = after_node self.after_edge = after_edge + self.user_name = user_name def __str__(self) -> str: return f"QueueMessage(op={self.op}, before_node={self.before_node if self.before_node is None else len(self.before_node)}, after_node={self.after_node if self.after_node is None else len(self.after_node)})" @@ -191,11 +193,15 @@ def handle_add(self, message: QueueMessage): logger.debug(f"Handling add operation: {str(message)[:500]}") added_node = message.after_node[0] detected_relationships = self.resolver.detect( - added_node, scope=added_node.metadata.memory_type + added_node, + scope=added_node.metadata.memory_type, + user_name=message.user_name, ) if detected_relationships: for added_node, existing_node, relation in detected_relationships: - self.resolver.resolve(added_node, existing_node, relation) + self.resolver.resolve( + added_node, existing_node, relation, user_name=message.user_name + ) self._reorganize_needed = True @@ -209,6 +215,7 @@ def optimize_structure( min_cluster_size: int = 4, min_group_size: int = 20, max_duration_sec: int = 600, + user_name: str | None = None, ): """ Periodically reorganize the graph: @@ -232,7 +239,7 @@ def _check_deadline(where: str): logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.") return - if self.graph_store.node_not_exist(scope): + if self.graph_store.node_not_exist(scope, user_name=user_name): logger.debug(f"[GraphStructureReorganize] No nodes for scope={scope}. Skip.") return @@ -244,12 +251,14 @@ def _check_deadline(where: str): logger.debug( f"[GraphStructureReorganize] Num of scope in self.graph_store is" - f" {self.graph_store.get_memory_count(scope)}" + f" {self.graph_store.get_memory_count(scope, user_name=user_name)}" ) # Load candidate nodes if _check_deadline("[GraphStructureReorganize] Before loading candidates"): return - raw_nodes = self.graph_store.get_structure_optimization_candidates(scope) + raw_nodes = self.graph_store.get_structure_optimization_candidates( + scope, user_name=user_name + ) nodes = [GraphDBNode(**n) for n in raw_nodes] if not nodes: @@ -281,6 +290,7 @@ def _check_deadline(where: str): scope, local_tree_threshold, min_cluster_size, + user_name, ) ) @@ -307,6 +317,7 @@ def _process_cluster_and_write( scope: str, local_tree_threshold: int, min_cluster_size: int, + user_name: str | None = None, ): if len(cluster_nodes) <= min_cluster_size: return @@ -319,15 +330,17 @@ def _process_cluster_and_write( if len(sub_nodes) < min_cluster_size: continue # Skip tiny noise sub_parent_node = self._summarize_cluster(sub_nodes, scope) - self._create_parent_node(sub_parent_node) - self._link_cluster_nodes(sub_parent_node, sub_nodes) + self._create_parent_node(sub_parent_node, user_name=user_name) + self._link_cluster_nodes(sub_parent_node, sub_nodes, user_name=user_name) sub_parents.append(sub_parent_node) if sub_parents and len(sub_parents) >= min_cluster_size: cluster_parent_node = self._summarize_cluster(cluster_nodes, scope) - self._create_parent_node(cluster_parent_node) + self._create_parent_node(cluster_parent_node, user_name=user_name) for sub_parent in sub_parents: - self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT") + self.graph_store.add_edge( + cluster_parent_node.id, sub_parent.id, "PARENT", user_name=user_name + ) logger.info("Adding relations/reasons") nodes_to_check = cluster_nodes @@ -351,10 +364,16 @@ def _process_cluster_and_write( # 1) Add pairwise relations for rel in results["relations"]: if not self.graph_store.edge_exists( - rel["source_id"], rel["target_id"], rel["relation_type"] + rel["source_id"], + rel["target_id"], + rel["relation_type"], + user_name=user_name, ): self.graph_store.add_edge( - rel["source_id"], rel["target_id"], rel["relation_type"] + rel["source_id"], + rel["target_id"], + rel["relation_type"], + user_name=user_name, ) # 2) Add inferred nodes and link to sources @@ -363,14 +382,21 @@ def _process_cluster_and_write( inf_node.id, inf_node.memory, inf_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for src_id in inf_node.metadata.sources: - self.graph_store.add_edge(src_id, inf_node.id, "INFERS") + self.graph_store.add_edge( + src_id, inf_node.id, "INFERS", user_name=user_name + ) # 3) Add sequence links for seq in results["sequence_links"]: - if not self.graph_store.edge_exists(seq["from_id"], seq["to_id"], "FOLLOWS"): - self.graph_store.add_edge(seq["from_id"], seq["to_id"], "FOLLOWS") + if not self.graph_store.edge_exists( + seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name + ): + self.graph_store.add_edge( + seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name + ) # 4) Add aggregate concept nodes for agg_node in results["aggregate_nodes"]: @@ -378,9 +404,12 @@ def _process_cluster_and_write( agg_node.id, agg_node.memory, agg_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for child_id in agg_node.metadata.sources: - self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO") + self.graph_store.add_edge( + agg_node.id, child_id, "AGGREGATE_TO", user_name=user_name + ) logger.info("[Reorganizer] Cluster relation/reasoning done.") @@ -577,7 +606,7 @@ def _parse_json_result(self, response_text): ) return {} - def _create_parent_node(self, parent_node: GraphDBNode) -> None: + def _create_parent_node(self, parent_node: GraphDBNode, user_name: str | None = None) -> None: """ Create a new parent node for the cluster. """ @@ -585,17 +614,23 @@ def _create_parent_node(self, parent_node: GraphDBNode) -> None: parent_node.id, parent_node.memory, parent_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) - def _link_cluster_nodes(self, parent_node: GraphDBNode, child_nodes: list[GraphDBNode]): + def _link_cluster_nodes( + self, + parent_node: GraphDBNode, + child_nodes: list[GraphDBNode], + user_name: str | None = None, + ): """ Add PARENT edges from the parent node to all nodes in the cluster. """ for child in child_nodes: if not self.graph_store.edge_exists( - parent_node.id, child.id, "PARENT", direction="OUTGOING" + parent_node.id, child.id, "PARENT", direction="OUTGOING", user_name=user_name ): - self.graph_store.add_edge(parent_node.id, child.id, "PARENT") + self.graph_store.add_edge(parent_node.id, child.id, "PARENT", user_name=user_name) def _preprocess_message(self, message: QueueMessage) -> bool: message = self._convert_id_to_node(message) @@ -613,7 +648,9 @@ def _convert_id_to_node(self, message: QueueMessage) -> QueueMessage: for i, node in enumerate(message.after_node or []): if not isinstance(node, str): continue - raw_node = self.graph_store.get_node(node, include_embedding=True) + raw_node = self.graph_store.get_node( + node, include_embedding=True, user_name=message.user_name + ) if raw_node is None: logger.debug(f"Node with ID {node} not found in the graph store.") message.after_node[i] = None diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9dcbe8c56..cc269e8c4 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -524,7 +524,7 @@ def _retrieve_from_keyword( user_name=user_name, tsquery_config="jiebaqry", ) - except Exception as e: + except Exception: logger.warning( f"[PATH-KEYWORD] search_by_fulltext failed, scope={scope}, user_name={user_name}" ) diff --git a/src/memos/memos_tools/thread_safe_dict_segment.py b/src/memos/memos_tools/thread_safe_dict_segment.py index c1c10e3e1..bf918889f 100644 --- a/src/memos/memos_tools/thread_safe_dict_segment.py +++ b/src/memos/memos_tools/thread_safe_dict_segment.py @@ -71,7 +71,7 @@ def acquire_write(self) -> bool: self._waiting_writers -= 1 self._last_write_time = time.time() return True - except: + except Exception: self._waiting_writers -= 1 raise diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 1678d9d15..d890c77bf 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -443,7 +443,10 @@ def _search_pref( }, search_filter=search_req.filter, ) - formatted_results = self._postformat_memories(results, user_context.mem_cube_id) + include_embedding = os.getenv("INCLUDE_EMBEDDING", "false") == "true" + formatted_results = self._postformat_memories( + results, user_context.mem_cube_id, include_embedding=include_embedding + ) # For each returned item, tackle with metadata.info project_id / # operation / manager_user_id diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index e4f1ca334..f431bd041 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -244,12 +244,17 @@ Return a single valid JSON object with the following structure: -Return valid JSON: { - "key": , - "memory_type": "LongTermMemory", - "value": , - "tags": + "memory list": [ + { + "key": , + "memory_type": "LongTermMemory", + "value": , + "tags": + } + ... + ], + "summary": } Language rules: @@ -264,7 +269,7 @@ Your Output:""" SIMPLE_STRUCT_DOC_READER_PROMPT_ZH = """您是搜索与检索系统的文本分析专家。 -您的任务是处理文档片段,并生成一个结构化的 JSON 对象。 +您的任务是处理文档片段,并生成一个结构化的 JSON 列表对象。 请执行以下操作: 1. 识别反映文档中事实内容、见解、决策或含义的关键信息——包括任何显著的主题、结论或数据点,使读者无需阅读原文即可充分理解该片段的核心内容。 @@ -281,14 +286,19 @@ - 优先考虑完整性和保真度,而非简洁性。 - 不要泛化或跳过可能具有上下文意义的细节。 -返回一个有效的 JSON 对象,结构如下: +返回有效的 JSON 对象: -返回有效的 JSON: { - "key": <字符串,`value` 字段的简洁标题>, - "memory_type": "LongTermMemory", - "value": <一段清晰准确的段落,全面总结文档片段中的主要观点、论据和信息——若输入摘要为英文,则用英文;若为中文,则用中文>, - "tags": <相关主题关键词列表(例如,["截止日期", "团队", "计划"])> + "memory list": [ + { + "key": <字符串,`value` 字段的简洁标题>, + "memory_type": "LongTermMemory", + "value": <一段清晰准确的段落,全面总结文档片段中的主要观点、论据和信息——若输入摘要为英文,则用英文;若为中文,则用中文>, + "tags": <相关主题关键词列表(例如,["截止日期", "团队", "计划"])> + } + ... + ], + "summary": <简洁总结原文内容,与输入语言一致> } 语言规则: diff --git a/tests/graph_dbs/test_search_return_fields.py b/tests/graph_dbs/test_search_return_fields.py new file mode 100644 index 000000000..82a50308b --- /dev/null +++ b/tests/graph_dbs/test_search_return_fields.py @@ -0,0 +1,306 @@ +""" +Regression tests for issue #955: search methods support specifying return fields. + +Tests that search_by_embedding (and other search methods) accept a `return_fields` +parameter and include the requested fields in the result dicts, eliminating the +need for N+1 get_node() calls. +""" + +import uuid + +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.graph_db import Neo4jGraphDBConfig + + +@pytest.fixture +def neo4j_config(): + return Neo4jGraphDBConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test", + db_name="test_memory_db", + auto_create=False, + embedding_dimension=3, + ) + + +@pytest.fixture +def neo4j_db(neo4j_config): + with patch("neo4j.GraphDatabase") as mock_gd: + mock_driver = MagicMock() + mock_gd.driver.return_value = mock_driver + from memos.graph_dbs.neo4j import Neo4jGraphDB + + db = Neo4jGraphDB(neo4j_config) + db.driver = mock_driver + yield db + + +class TestNeo4jSearchReturnFields: + """Tests for Neo4jGraphDB.search_by_embedding with return_fields.""" + + def test_return_fields_included_in_results(self, neo4j_db): + """return_fields values are present in each result dict.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + node_id = str(uuid.uuid4()) + session_mock.run.return_value = [ + {"id": node_id, "score": 0.95, "memory": "hello", "status": "activated"}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory", "status"], + ) + + assert len(results) == 1 + assert results[0]["id"] == node_id + assert results[0]["score"] == 0.95 + assert results[0]["memory"] == "hello" + assert results[0]["status"] == "activated" + + def test_backward_compatible_without_return_fields(self, neo4j_db): + """Without return_fields, only id and score are returned (old behavior).""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": str(uuid.uuid4()), "score": 0.9}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + ) + + assert len(results) == 1 + assert set(results[0].keys()) == {"id", "score"} + + def test_cypher_return_clause_includes_fields(self, neo4j_db): + """Cypher RETURN clause contains the requested fields.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory", "tags"], + ) + + query = session_mock.run.call_args[0][0] + assert "node.memory AS memory" in query + assert "node.tags AS tags" in query + + def test_cypher_return_clause_default(self, neo4j_db): + """Without return_fields, RETURN clause only has id and score.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + ) + + query = session_mock.run.call_args[0][0] + assert "RETURN node.id AS id, score" in query + assert "node.memory" not in query + + def test_return_fields_skips_id_field(self, neo4j_db): + """Passing 'id' in return_fields does not duplicate it in RETURN clause.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["id", "memory"], + ) + + query = session_mock.run.call_args[0][0] + # 'id' should appear only once (as node.id AS id), not duplicated + assert query.count("node.id AS id") == 1 + assert "node.memory AS memory" in query + + def test_threshold_filtering_still_works_with_return_fields(self, neo4j_db): + """Threshold filtering works correctly when return_fields is specified.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": str(uuid.uuid4()), "score": 0.9, "memory": "high score"}, + {"id": str(uuid.uuid4()), "score": 0.3, "memory": "low score"}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + threshold=0.5, + return_fields=["memory"], + ) + + assert len(results) == 1 + assert results[0]["memory"] == "high score" + + +class TestPolarDBExtractFieldsFromProperties: + """Tests for PolarDBGraphDB._extract_fields_from_properties helper.""" + + @pytest.fixture + def polardb_instance(self): + """Create a minimal PolarDB instance for testing the helper method.""" + with patch("memos.graph_dbs.polardb.PolarDBGraphDB.__init__", return_value=None): + from memos.graph_dbs.polardb import PolarDBGraphDB + + db = PolarDBGraphDB.__new__(PolarDBGraphDB) + yield db + + def test_extract_from_json_string(self, polardb_instance): + """Extract fields from a JSON string properties value.""" + props = '{"id": "abc", "memory": "hello", "status": "activated", "tags": ["a"]}' + result = polardb_instance._extract_fields_from_properties( + props, ["memory", "status", "tags"] + ) + assert result == {"memory": "hello", "status": "activated", "tags": ["a"]} + + def test_extract_from_dict(self, polardb_instance): + """Extract fields from a dict properties value.""" + props = {"id": "abc", "memory": "hello", "status": "activated"} + result = polardb_instance._extract_fields_from_properties(props, ["memory", "status"]) + assert result == {"memory": "hello", "status": "activated"} + + def test_extract_skips_id(self, polardb_instance): + """'id' field is skipped even if requested.""" + props = '{"id": "abc", "memory": "hello"}' + result = polardb_instance._extract_fields_from_properties(props, ["id", "memory"]) + assert result == {"memory": "hello"} + + def test_extract_missing_fields(self, polardb_instance): + """Missing fields are silently skipped.""" + props = '{"id": "abc", "memory": "hello"}' + result = polardb_instance._extract_fields_from_properties(props, ["memory", "nonexistent"]) + assert result == {"memory": "hello"} + + def test_extract_empty_properties(self, polardb_instance): + """Empty/None properties return empty dict.""" + assert polardb_instance._extract_fields_from_properties(None, ["memory"]) == {} + assert polardb_instance._extract_fields_from_properties("", ["memory"]) == {} + + def test_extract_invalid_json(self, polardb_instance): + """Invalid JSON returns empty dict without raising.""" + result = polardb_instance._extract_fields_from_properties("not-json", ["memory"]) + assert result == {} + + +class TestFieldNameValidation: + """Tests for _validate_return_fields injection prevention.""" + + def test_valid_field_names_pass(self): + from memos.graph_dbs.base import BaseGraphDB + + result = BaseGraphDB._validate_return_fields(["memory", "status", "tags", "user_name"]) + assert result == ["memory", "status", "tags", "user_name"] + + def test_invalid_field_names_rejected(self): + from memos.graph_dbs.base import BaseGraphDB + + # Cypher injection attempts + result = BaseGraphDB._validate_return_fields( + [ + "memory} RETURN n //", + "status; DROP", + "valid_field", + "a.b", + "field name", + "", + ] + ) + assert result == ["valid_field"] + + def test_none_returns_empty(self): + from memos.graph_dbs.base import BaseGraphDB + + assert BaseGraphDB._validate_return_fields(None) == [] + + def test_empty_list_returns_empty(self): + from memos.graph_dbs.base import BaseGraphDB + + assert BaseGraphDB._validate_return_fields([]) == [] + + def test_injection_in_cypher_query_prevented(self, neo4j_db): + """Malicious field names should not appear in the Cypher query.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory} RETURN n //", "valid_field"], + ) + + query = session_mock.run.call_args[0][0] + # Injection attempt should NOT appear in query + assert "memory}" not in query + assert "RETURN n //" not in query + # Valid field should appear + assert "node.valid_field AS valid_field" in query + + +class TestNeo4jCommunitySearchReturnFields: + """Tests for Neo4jCommunityGraphDB._fetch_return_fields with return_fields.""" + + @pytest.fixture + def neo4j_community_db(self): + """Create a minimal Neo4jCommunityGraphDB instance by patching __init__.""" + with patch( + "memos.graph_dbs.neo4j_community.Neo4jCommunityGraphDB.__init__", return_value=None + ): + from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB + + db = Neo4jCommunityGraphDB.__new__(Neo4jCommunityGraphDB) + db.driver = MagicMock() + db.db_name = "test_memory_db" + yield db + + def test_fetch_return_fields_queries_neo4j(self, neo4j_community_db): + """_fetch_return_fields builds correct Cypher and returns fields.""" + session_mock = neo4j_community_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": "node-1", "memory": "hello", "status": "activated"}, + ] + + results = neo4j_community_db._fetch_return_fields( + ids=["node-1"], + score_map={"node-1": 0.95}, + return_fields=["memory", "status"], + ) + + assert len(results) == 1 + assert results[0]["id"] == "node-1" + assert results[0]["score"] == 0.95 + assert results[0]["memory"] == "hello" + assert results[0]["status"] == "activated" + + query = session_mock.run.call_args[0][0] + assert "n.memory AS memory" in query + assert "n.status AS status" in query + + def test_fetch_return_fields_validates_names(self, neo4j_community_db): + """_fetch_return_fields rejects invalid field names.""" + session_mock = neo4j_community_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_community_db._fetch_return_fields( + ids=["node-1"], + score_map={"node-1": 0.95}, + return_fields=["memory} RETURN n //", "valid_field"], + ) + + query = session_mock.run.call_args[0][0] + assert "memory}" not in query + assert "n.valid_field AS valid_field" in query diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index 46cf3a1f6..a6ac186b7 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -131,7 +131,7 @@ def test_mark_memory_status(history_manager, mock_graph_db): # Assert assert mock_graph_db.update_node.call_count == 3 - # Verify we called it correctly - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}) - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}) - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}) + # Verify we called it correctly (user_name=None is passed by mark_memory_status) + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name=None) + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name=None) + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name=None) From 6a91bcd4d25314ff1815f589ce237f777c7b9fc3 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Mon, 2 Mar 2026 15:06:07 +0800 Subject: [PATCH 06/20] feat: Optimize the chunk logic of the knowledge base (#1135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: fix three feature issues * fix:add uncommitted changes for the previous fix * fix: optimize chunk strategy * Optimize chunk strategy * add some comments * fix: update is_markdown * chunker fix * fix: add context during document processing * reformat * fix tests info --------- Co-authored-by: mozuyun Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> --- .gitignore | 2 + src/memos/chunkers/base.py | 41 ++++ src/memos/chunkers/charactertext_chunker.py | 4 +- src/memos/chunkers/markdown_chunker.py | 101 ++++++++- src/memos/chunkers/sentence_chunker.py | 4 +- src/memos/chunkers/simple_chunker.py | 15 +- .../read_multi_modal/file_content_parser.py | 193 ++++++++++++++++-- .../mem_reader/read_multi_modal/utils.py | 3 +- .../organize/history_manager.py | 2 +- src/memos/templates/mem_reader_prompts.py | 8 + tests/chunkers/test_sentence_chunker.py | 17 +- tests/utils.py | 20 +- 12 files changed, 378 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index ac31eb41a..97af509ea 100644 --- a/.gitignore +++ b/.gitignore @@ -216,3 +216,5 @@ cython_debug/ outputs evaluation/data/temporal_locomo +test_add_pipeline.py +test_file_pipeline.py diff --git a/src/memos/chunkers/base.py b/src/memos/chunkers/base.py index c2a783baa..e858132e1 100644 --- a/src/memos/chunkers/base.py +++ b/src/memos/chunkers/base.py @@ -1,3 +1,5 @@ +import re + from abc import ABC, abstractmethod from memos.configs.chunker import BaseChunkerConfig @@ -22,3 +24,42 @@ def __init__(self, config: BaseChunkerConfig): @abstractmethod def chunk(self, text: str) -> list[Chunk]: """Chunk the given text into smaller chunks.""" + + def protect_urls(self, text: str) -> tuple[str, dict[str, str]]: + """ + Protect URLs in text from being split during chunking. + + Args: + text: Text to process + + Returns: + tuple: (Text with URLs replaced by placeholders, URL mapping dictionary) + """ + url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+' + url_map = {} + + def replace_url(match): + url = match.group(0) + placeholder = f"__URL_{len(url_map)}__" + url_map[placeholder] = url + return placeholder + + protected_text = re.sub(url_pattern, replace_url, text) + return protected_text, url_map + + def restore_urls(self, text: str, url_map: dict[str, str]) -> str: + """ + Restore protected URLs in text back to their original form. + + Args: + text: Text with URL placeholders + url_map: URL mapping dictionary from protect_urls + + Returns: + str: Text with URLs restored + """ + restored_text = text + for placeholder, url in url_map.items(): + restored_text = restored_text.replace(placeholder, url) + + return restored_text diff --git a/src/memos/chunkers/charactertext_chunker.py b/src/memos/chunkers/charactertext_chunker.py index 15c0958ba..25739d96f 100644 --- a/src/memos/chunkers/charactertext_chunker.py +++ b/src/memos/chunkers/charactertext_chunker.py @@ -36,6 +36,8 @@ def __init__( def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chunks = self.chunker.split_text(text) + protected_text, url_map = self.protect_urls(text) + chunks = self.chunker.split_text(protected_text) + chunks = [self.restore_urls(chunk, url_map) for chunk in chunks] logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks diff --git a/src/memos/chunkers/markdown_chunker.py b/src/memos/chunkers/markdown_chunker.py index b7771ac35..a37370200 100644 --- a/src/memos/chunkers/markdown_chunker.py +++ b/src/memos/chunkers/markdown_chunker.py @@ -1,3 +1,5 @@ +import re + from memos.configs.chunker import MarkdownChunkerConfig from memos.dependency import require_python_package from memos.log import get_logger @@ -22,6 +24,7 @@ def __init__( chunk_size: int = 1000, chunk_overlap: int = 200, recursive: bool = False, + auto_fix_headers: bool = True, ): from langchain_text_splitters import ( MarkdownHeaderTextSplitter, @@ -29,6 +32,7 @@ def __init__( ) self.config = config + self.auto_fix_headers = auto_fix_headers self.chunker = MarkdownHeaderTextSplitter( headers_to_split_on=config.headers_to_split_on if config @@ -46,17 +50,110 @@ def __init__( def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - md_header_splits = self.chunker.split_text(text) + # Protect URLs first + protected_text, url_map = self.protect_urls(text) + # Auto-detect and fix malformed header hierarchy if enabled + if self.auto_fix_headers and self._detect_malformed_headers(protected_text): + logger.info("[Chunker:] detected malformed header hierarchy, attempting to fix...") + protected_text = self._fix_header_hierarchy(protected_text) + logger.info("[Chunker:] Header hierarchy fix completed") + + md_header_splits = self.chunker.split_text(protected_text) chunks = [] if self.chunker_recursive: md_header_splits = self.chunker_recursive.split_documents(md_header_splits) for doc in md_header_splits: try: chunk = " ".join(list(doc.metadata.values())) + "\n" + doc.page_content + chunk = self.restore_urls(chunk, url_map) chunks.append(chunk) except Exception as e: logger.warning(f"warning chunking document: {e}") - chunks.append(doc.page_content) + restored_chunk = self.restore_urls(doc.page_content, url_map) + chunks.append(restored_chunk) logger.info(f"Generated chunks: {chunks[:5]}") logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks + + def _detect_malformed_headers(self, text: str) -> bool: + """Detect if markdown has improper header hierarchy usage.""" + # Extract all valid markdown header lines + header_levels = [] + pattern = re.compile(r"^#{1,6}\s+.+") + for line in text.split("\n"): + stripped_line = line.strip() + if pattern.match(stripped_line): + hash_match = re.match(r"^(#+)", stripped_line) + if hash_match: + level = len(hash_match.group(1)) + header_levels.append(level) + + total_headers = len(header_levels) + if total_headers == 0: + logger.debug("No valid headers detected, skipping check") + return False + + # Calculate level-1 header ratio + level1_count = sum(1 for level in header_levels if level == 1) + + # Determine if malformed: >90% are level-1 when total > 5 + # OR all headers are level-1 when total ≤ 5 + if total_headers > 5: + level1_ratio = level1_count / total_headers + if level1_ratio > 0.9: + logger.warning( + f"Detected header hierarchy issue: {level1_count}/{total_headers} " + f"({level1_ratio:.1%}) of headers are level 1" + ) + return True + elif total_headers <= 5 and level1_count == total_headers: + logger.warning( + f"Detected header hierarchy issue: all {total_headers} headers are level 1" + ) + return True + return False + + def _fix_header_hierarchy(self, text: str) -> str: + """ + Fix markdown header hierarchy by adjusting levels. + + Strategy: + 1. Keep the first header unchanged as level-1 parent + 2. Increment all subsequent headers by 1 level (max level 6) + """ + header_pattern = re.compile(r"^(#{1,6})\s+(.+)$") + lines = text.split("\n") + fixed_lines = [] + first_valid_header = False + + for line in lines: + stripped_line = line.strip() + # Match valid header lines (invalid # lines kept as-is) + header_match = header_pattern.match(stripped_line) + if header_match: + current_hashes, title_content = header_match.groups() + current_level = len(current_hashes) + + if not first_valid_header: + # First valid header: keep original level unchanged + fixed_line = f"{current_hashes} {title_content}" + first_valid_header = True + logger.debug( + f"Keep first header at level {current_level}: {title_content[:50]}..." + ) + else: + # Subsequent headers: increment by 1, cap at level 6 + new_level = min(current_level + 1, 6) + new_hashes = "#" * new_level + fixed_line = f"{new_hashes} {title_content}" + logger.debug( + f"Adjust header level: {current_level} -> {new_level}: {title_content[:50]}..." + ) + fixed_lines.append(fixed_line) + else: + fixed_lines.append(line) + + # Join with newlines to preserve original formatting + fixed_text = "\n".join(fixed_lines) + logger.info(f"[Chunker:] Header hierarchy fix completed: {fixed_text[:50]}...") + return fixed_text diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index f39dfb8e2..e695d0d9a 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -43,11 +43,13 @@ def __init__(self, config: SentenceChunkerConfig): def chunk(self, text: str) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chonkie_chunks = self.chunker.chunk(text) + protected_text, url_map = self.protect_urls(text) + chonkie_chunks = self.chunker.chunk(protected_text) chunks = [] for c in chonkie_chunks: chunk = Chunk(text=c.text, token_count=c.token_count, sentences=c.sentences) + chunk = self.restore_urls(chunk.text, url_map) chunks.append(chunk) logger.debug(f"Generated {len(chunks)} chunks from input text") diff --git a/src/memos/chunkers/simple_chunker.py b/src/memos/chunkers/simple_chunker.py index cc0dc40d0..58e12e2f1 100644 --- a/src/memos/chunkers/simple_chunker.py +++ b/src/memos/chunkers/simple_chunker.py @@ -20,12 +20,15 @@ def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> Returns: List of text chunks """ - if not text or len(text) <= chunk_size: - return [text] if text.strip() else [] + protected_text, url_map = self.protect_urls(text) + + if not protected_text or len(protected_text) <= chunk_size: + chunks = [protected_text] if protected_text.strip() else [] + return [self.restore_urls(chunk, url_map) for chunk in chunks] chunks = [] start = 0 - text_len = len(text) + text_len = len(protected_text) while start < text_len: # Calculate end position @@ -35,16 +38,16 @@ def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> if end < text_len: # Try to break at newline, sentence end, or space for separator in ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " "]: - last_sep = text.rfind(separator, start, end) + last_sep = protected_text.rfind(separator, start, end) if last_sep != -1: end = last_sep + len(separator) break - chunk = text[start:end].strip() + chunk = protected_text[start:end].strip() if chunk: chunks.append(chunk) # Move start position with overlap start = max(start + 1, end - chunk_overlap) - return chunks + return [self.restore_urls(chunk, url_map) for chunk in chunks] diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 1b4add398..00e02abda 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -51,8 +51,11 @@ class FileContentParser(BaseMessageParser): """Parser for file content parts.""" def _get_doc_llm_response( - self, chunk_text: str, custom_tags: list[str] | None = None - ) -> dict | list: + self, + chunk_text: str, + custom_tags: list[str] | None = None, + message_text_context: str | None = None, + ) -> dict: """ Call LLM to extract memory from document chunk. Uses doc prompts from DOC_PROMPT_DICT. @@ -60,6 +63,8 @@ def _get_doc_llm_response( Args: chunk_text: Text chunk to extract memory from custom_tags: Optional list of custom tags for LLM extraction + message_text_context: Optional text from the same message that + provides user intent / context for understanding this document Returns: Parsed JSON response from LLM (dict or list) or empty dict if failed @@ -79,6 +84,10 @@ def _get_doc_llm_response( ) prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt) + # Inject sibling text context into prompt placeholder + context_text = message_text_context.strip() if message_text_context else "" + prompt = prompt.replace("{context}", context_text) + messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) @@ -109,14 +118,25 @@ def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None, boo return response.text, None, True file_ext = os.path.splitext(filename)[1].lower() - if file_ext in [".md", ".markdown", ".txt"]: + if file_ext in [".md", ".markdown", ".txt"] or self._is_oss_md(url_str): return response.text, None, True with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_ext) as temp_file: temp_file.write(response.content) return "", temp_file.name, False except Exception as e: logger.error(f"[FileContentParser] URL processing error: {e}") - return f"[File URL download failed: {url_str}]", None + return f"[File URL download failed: {url_str}]", None, False + + def _is_oss_md(self, url: str) -> bool: + """Check if URL is an OSS markdown file based on pattern.""" + loose_pattern = re.compile(r"^https?://[^/]*\.aliyuncs\.com/.*/([^/?#]+)") + match = loose_pattern.search(url) + if not match: + return False + + file_name = match.group(1) + lower_name = file_name.lower() + return lower_name.endswith((".md", ".markdown", ".txt")) def _is_base64(self, data: str) -> bool: """Quick heuristic to check base64-like string.""" @@ -139,7 +159,12 @@ def _handle_local(self, data: str) -> str: return "" def _process_single_image( - self, image_url: str, original_ref: str, info: dict[str, Any], **kwargs + self, + image_url: str, + original_ref: str, + info: dict[str, Any], + header_context: list[str] | None = None, + **kwargs, ) -> tuple[str, str]: """ Process a single image and return (original_ref, replacement_text). @@ -148,6 +173,7 @@ def _process_single_image( image_url: URL of the image to process original_ref: Original markdown image reference to replace info: Dictionary containing user_id and session_id + header_context: Optional list of header titles providing context for the image **kwargs: Additional parameters for ImageParser Returns: @@ -173,20 +199,31 @@ def _process_single_image( if hasattr(item, "memory") and item.memory: extracted_texts.append(str(item.memory)) + # Prepare header context string if available + header_context_str = "" + if header_context: + # Join headers with " > " to show hierarchy + header_hierarchy = " > ".join(header_context) + header_context_str = f"[Section: {header_hierarchy}]\n\n" + if extracted_texts: # Combine all extracted texts extracted_content = "\n".join(extracted_texts) + # build final replacement text + replacement_text = ( + f"{header_context_str}[Image Content from {image_url}]:\n{extracted_content}\n" + ) # Replace image with extracted content return ( original_ref, - f"\n[Image Content from {image_url}]:\n{extracted_content}\n", + replacement_text, ) else: # If no content extracted, keep original with a note logger.warning(f"[FileContentParser] No content extracted from image: {image_url}") return ( original_ref, - f"\n[Image: {image_url} - No content extracted]\n", + f"{header_context_str}[Image: {image_url} - No content extracted]\n", ) except Exception as e: @@ -194,7 +231,9 @@ def _process_single_image( # On error, keep original image reference return (original_ref, original_ref) - def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) -> str: + def _extract_and_process_images( + self, text: str, info: dict[str, Any], headers: dict[int, dict] | None = None, **kwargs + ) -> str: """ Extract all images from markdown text and process them using ImageParser in parallel. Replaces image references with extracted text content. @@ -202,6 +241,7 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) Args: text: Markdown text containing image references info: Dictionary containing user_id and session_id + headers: Optional dictionary mapping line numbers to header info **kwargs: Additional parameters for ImageParser Returns: @@ -225,7 +265,13 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) for match in image_matches: image_url = match.group(2) original_ref = match.group(0) - tasks.append((image_url, original_ref)) + image_position = match.start() + + header_context = None + if headers: + header_context = self._get_header_context(text, image_position, headers) + + tasks.append((image_url, original_ref, header_context)) # Process images in parallel replacements = {} @@ -234,9 +280,14 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) with ContextThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit( - self._process_single_image, image_url, original_ref, info, **kwargs + self._process_single_image, + image_url, + original_ref, + info, + header_context, + **kwargs, ): (image_url, original_ref) - for image_url, original_ref in tasks + for image_url, original_ref, header_context in tasks } # Collect results with progress tracking @@ -603,6 +654,18 @@ def parse_fine( # Extract custom_tags from kwargs (for LLM extraction) custom_tags = kwargs.get("custom_tags") + # Extract sibling text context . + message_text_context = None + context_items = kwargs.get("context_items") + if context_items: + sibling_texts = [] + for ctx_item in context_items: + for src in getattr(ctx_item.metadata, "sources", None) or []: + if src.type == "chat" and src.content: + sibling_texts.append(src.content.strip()) + if sibling_texts: + message_text_context = "\n".join(sibling_texts) + # Use parser from utils parser = self.parser or get_parser() if not parser: @@ -663,9 +726,20 @@ def parse_fine( ) if not parsed_text: return [] + + # Extract markdown headers if applicable + headers = {} + if is_markdown: + headers = self._extract_markdown_headers(parsed_text) + logger.info( + f"[Chunker: FileContentParser] Extracted {len(headers)} headers from markdown" + ) + # Extract and process images from parsed_text if is_markdown and parsed_text and self.image_parser: - parsed_text = self._extract_and_process_images(parsed_text, info, **kwargs) + parsed_text = self._extract_and_process_images( + parsed_text, info, headers=headers if headers else None, **kwargs + ) # Extract info fields if not info: @@ -782,7 +856,9 @@ def _make_fallback( def _process_chunk(chunk_idx: int, chunk_text: str) -> list[TextualMemoryItem]: """Process chunk with LLM, fallback to raw on failure. Returns list of memory items.""" try: - response_json = self._get_doc_llm_response(chunk_text, custom_tags) + response_json = self._get_doc_llm_response( + chunk_text, custom_tags, message_text_context=message_text_context + ) if response_json: # Handle list format response response_list = response_json.get("memory list", []) @@ -932,3 +1008,94 @@ def get_chunk_idx(item: TextualMemoryItem) -> int: chunk_idx=None, ) ] + + def _extract_markdown_headers(self, text: str) -> dict[int, dict]: + """ + Extract markdown headers and their positions. + + Args: + text: Markdown text to parse + """ + if not text: + return {} + + headers = {} + # Pattern to match markdown headers: # Title, ## Title, etc. + header_pattern = r"^(#{1,6})\s+(.+)$" + + lines = text.split("\n") + char_position = 0 + + for line_num, line in enumerate(lines): + # Match header pattern (must be at start of line) + match = re.match(header_pattern, line.strip()) + if match: + level = len(match.group(1)) # Number of # symbols (1-6) + title = match.group(2).strip() # Extract title text + + # Store header info with its position + headers[line_num] = {"level": level, "title": title, "position": char_position} + + logger.debug(f"[FileContentParser] Found H{level} at line {line_num}: {title}") + + # Update character position for next line (+1 for newline character) + char_position += len(line) + 1 + + logger.info(f"[Chunker: FileContentParser] Extracted {len(headers)} headers from markdown") + return headers + + def _get_header_context( + self, text: str, image_position: int, headers: dict[int, dict] + ) -> list[str]: + """ + Get all header levels above an image position in hierarchical order. + + Finds the image's line number, then identifies all preceding headers + and constructs the hierarchical path to the image location. + + Args: + text: Full markdown text + image_position: Character position of the image in text + headers: Dict of headers from _extract_markdown_headers + """ + if not headers: + return [] + + # Find the line number corresponding to the image position + lines = text.split("\n") + char_count = 0 + image_line = 0 + + for i, line in enumerate(lines): + if char_count >= image_position: + image_line = i + break + char_count += len(line) + 1 # +1 for newline + + # Filter headers that appear before the image + preceding_headers = { + line_num: info for line_num, info in headers.items() if line_num < image_line + } + + if not preceding_headers: + return [] + + # Build hierarchical header stack + header_stack = [] + + for line_num in sorted(preceding_headers.keys()): + header = preceding_headers[line_num] + level = header["level"] + title = header["title"] + + # Pop headers of same or lower level + while header_stack and header_stack[-1]["level"] >= level: + removed = header_stack.pop() + logger.debug(f"[FileContentParser] Popped H{removed['level']}: {removed['title']}") + + # Push current header onto stack + header_stack.append({"level": level, "title": title}) + + # Return titles in order + result = [h["title"] for h in header_stack] + return result diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index a6d910e54..be82587bf 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -346,7 +346,8 @@ def detect_lang(text): r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE ) cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) - + # remove URLs to prevent the dilution of Chinese characters + cleaned_text = re.sub(r'https?://[^\s<>"{}|\\^`\[\]]+', "", cleaned_text) # extract chinese characters chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" chinese_chars = re.findall(chinese_pattern, cleaned_text) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 132582a0d..98094877c 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -128,7 +128,7 @@ def resolve_history_via_nli( ) new_item.metadata.history.append(archived) logger.info( - f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" + f"[Chunker: MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" ) # 3. Concat duplicate/conflict memories to new_item.memory diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index f431bd041..63e4c1538 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -263,6 +263,10 @@ {custom_tags_prompt} +If given context, use it as a supplement to the document information extraction; if no context is given, directly process the document information. +Reference context: +{context} + Document chunk: {chunk_text} @@ -307,6 +311,10 @@ {custom_tags_prompt} +如果给定了上下文,就结合上下文信息作为文档信息提取的补充,如果没有给定上下文,请直接处理文档信息。 +参考的上下文: +{context} + 示例: 输入的文本片段: 在Kalamang语中,亲属名词在所有格构式中的行为并不一致。名词 esa“父亲”和 ema“母亲”只能在技术称谓(teknonym)中与第三人称所有格后缀共现,而在非技术称谓用法中,带有所有格后缀是不合语法的。相比之下,大多数其他亲属名词并不允许所有格构式,只有极少数例外。 diff --git a/tests/chunkers/test_sentence_chunker.py b/tests/chunkers/test_sentence_chunker.py index 28aaeabb9..7ff6b2ccd 100644 --- a/tests/chunkers/test_sentence_chunker.py +++ b/tests/chunkers/test_sentence_chunker.py @@ -47,6 +47,17 @@ def test_sentence_chunker(self): self.assertEqual(len(chunks), 2) # Validate the properties of the first chunk mock_chunker.chunk.assert_called_once_with(text) - self.assertEqual(chunks[0].text, "This is the first sentence.") - self.assertEqual(chunks[0].token_count, 6) - self.assertEqual(chunks[0].sentences, ["This is the first sentence."]) + + # Handle both return types: list[str] | list[Chunk] + if isinstance(chunks[0], str): + # If returns list[str], check the string value + self.assertEqual(chunks[0], "This is the first sentence.") + self.assertEqual(chunks[1], "This is the second sentence.") + else: + # If returns list[Chunk], check the Chunk properties + from memos.chunkers.base import Chunk + + self.assertIsInstance(chunks[0], Chunk) + self.assertEqual(chunks[0].text, "This is the first sentence.") + self.assertEqual(chunks[0].token_count, 6) + self.assertEqual(chunks[0].sentences, ["This is the first sentence."]) diff --git a/tests/utils.py b/tests/utils.py index 132cd7138..ec8a32799 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,8 @@ def check_module_base_class(cls: Any) -> None: General function to test the correctness of an abstract base class. - It should inherit from ABC. - It should define at least one method. - - All methods should be marked as @abstractmethod. + - It should have at least one abstract method. + - Abstract methods (those in __abstractmethods__) should be marked as @abstractmethod. - It should not be instantiable. - All methods should have docstrings. @@ -31,14 +32,25 @@ def check_module_base_class(cls: Any) -> None: assert all_class_methods, f"{cls.__name__} should define at least one method" # Check 3: Verify abstract methods + # Get the set of abstract methods from the class + abstract_methods = getattr(cls, "__abstractmethods__", set()) + + # Ensure there is at least one abstract method + assert len(abstract_methods) > 0, f"{cls.__name__} should have at least one abstract method" + + # Verify that all methods in __abstractmethods__ are actually marked as abstract for method_name in all_class_methods: method = getattr(cls, method_name) # Skip private methods (starting with _) as they are typically helper methods if method_name.startswith("_") and method_name != "__init__": continue - assert getattr(method, "__isabstractmethod__", False), ( - f"The method '{method_name}' in {cls.__name__} should be marked as @abstractmethod" - ) + + # If the method is in __abstractmethods__, it must be marked as abstract + if method_name in abstract_methods: + assert getattr(method, "__isabstractmethod__", False), ( + f"The method '{method_name}' in {cls.__name__} is in __abstractmethods__ " + f"but should be marked as @abstractmethod" + ) # Check 4: Test that the class cannot be instantiated directly with pytest.raises(TypeError) as excinfo: From a7eb7dfa4ed0c717c859a653dcaec0e8f8ade628 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:05:51 +0800 Subject: [PATCH 07/20] feat: transfer pref to polar db (#1145) * feat: transfer pref * feat: search debug * feat: modify init components in scheduler * feat: modify pref code in feedback * feat: remove redundant comment * feat: modify some code --------- Co-authored-by: yuan.wang --- examples/mem_feedback/example_feedback.py | 2 +- src/memos/api/handlers/__init__.py | 2 - src/memos/api/handlers/add_handler.py | 4 +- src/memos/api/handlers/component_init.py | 85 +---- src/memos/api/handlers/formatters_handler.py | 57 ++-- src/memos/api/handlers/memory_handler.py | 225 ++++--------- src/memos/api/handlers/search_handler.py | 2 +- src/memos/api/product_models.py | 2 +- src/memos/api/routers/server_router.py | 11 +- src/memos/mem_cube/navie.py | 39 +-- src/memos/mem_feedback/feedback.py | 99 ++---- src/memos/mem_feedback/simple_feedback.py | 3 - src/memos/mem_reader/multi_modal_struct.py | 22 +- .../process_preference_memory.py | 296 ++++++++++++++++++ .../init_components_for_scheduler.py | 107 +------ src/memos/memories/textual/item.py | 3 +- src/memos/memories/textual/preference.py | 2 +- .../memories/textual/simple_preference.py | 126 -------- src/memos/memories/textual/tree.py | 4 + .../tree_text_memory/organize/manager.py | 2 + .../tree_text_memory/retrieve/recall.py | 1 + .../tree_text_memory/retrieve/searcher.py | 109 ++++++- src/memos/multi_mem_cube/single_cube.py | 205 +----------- src/memos/search/search_service.py | 2 + 24 files changed, 560 insertions(+), 850 deletions(-) create mode 100644 src/memos/mem_reader/read_pref_memory/process_preference_memory.py diff --git a/examples/mem_feedback/example_feedback.py b/examples/mem_feedback/example_feedback.py index 8f4446863..794ddf111 100644 --- a/examples/mem_feedback/example_feedback.py +++ b/examples/mem_feedback/example_feedback.py @@ -144,7 +144,7 @@ def init_components(): mem_reader=mem_reader, searcher=searcher, reranker=mem_reranker, - pref_mem=None, + pref_feedback=True, ) return feedback_server, memory_manager, embedder diff --git a/src/memos/api/handlers/__init__.py b/src/memos/api/handlers/__init__.py index 90347768c..bd4c9f4b0 100644 --- a/src/memos/api/handlers/__init__.py +++ b/src/memos/api/handlers/__init__.py @@ -32,7 +32,6 @@ ) from memos.api.handlers.formatters_handler import ( format_memory_item, - post_process_pref_mem, to_iter, ) @@ -54,7 +53,6 @@ "formatters_handler", "init_server", "memory_handler", - "post_process_pref_mem", "scheduler_handler", "search_handler", "suggestion_handler", diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 3cdbedabf..e9ed4f955 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -22,7 +22,7 @@ class AddHandler(BaseHandler): """ Handler for memory addition operations. - Handles both text and preference memory additions with sync/async support. + Handles text memory additions with sync/async support. """ def __init__(self, dependencies: HandlerDependencies): @@ -41,7 +41,7 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. - Orchestrates the addition of both text and preference memories, + Orchestrates the addition of text memories, supporting concurrent processing. Args: diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index ba527d602..aa2525878 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -19,11 +19,7 @@ build_llm_config, build_mem_reader_config, build_nli_client_config, - build_pref_adder_config, - build_pref_extractor_config, - build_pref_retriever_config, build_reranker_config, - build_vec_db_config, ) from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.embedders.factory import EmbedderFactory @@ -36,12 +32,6 @@ from memos.mem_reader.factory import MemReaderFactory from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.memories.textual.prefer_text_memory.factory import ( - AdderFactory, - ExtractorFactory, - RetrieverFactory, -) -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager @@ -56,7 +46,6 @@ InternetRetrieverFactory, ) from memos.reranker.factory import RerankerFactory -from memos.vec_dbs.factory import VecDBFactory if TYPE_CHECKING: @@ -125,7 +114,7 @@ def init_server() -> dict[str, Any]: required by the MemOS server, including: - Database connections (graph DB, vector DB) - Language models and embedders - - Memory systems (text, preference) + - Memory systems (text) - Scheduler and related modules Returns: @@ -169,20 +158,11 @@ def init_server() -> dict[str, Any]: reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() internet_retriever_config = build_internet_retriever_config() - vector_db_config = build_vec_db_config() - pref_extractor_config = build_pref_extractor_config() - pref_adder_config = build_pref_adder_config() - pref_retriever_config = build_pref_retriever_config() logger.debug("Component configurations built successfully") # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) - vector_db = ( - VecDBFactory.from_config(vector_db_config) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) llm = LLMFactory.from_config(llm_config) chat_llms = ( _init_chat_llms(chat_llm_config) @@ -231,61 +211,6 @@ def init_server() -> dict[str, Any]: logger.debug("Text memory initialized") - # Initialize preference memory components - pref_extractor = ( - ExtractorFactory.from_config( - config_factory=pref_extractor_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - pref_adder = ( - AdderFactory.from_config( - config_factory=pref_adder_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - text_mem=text_mem, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - pref_retriever = ( - RetrieverFactory.from_config( - config_factory=pref_retriever_config, - llm_provider=llm, - embedder=embedder, - reranker=feedback_reranker, - vector_db=vector_db, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - logger.debug("Preference memory components initialized") - - # Initialize preference memory - pref_mem = ( - SimplePreferenceTextMemory( - extractor_llm=llm, - vector_db=vector_db, - embedder=embedder, - reranker=feedback_reranker, - extractor=pref_extractor, - adder=pref_adder, - retriever=pref_retriever, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - logger.debug("Preference memory initialized") - # Initialize MOS Server mos_server = MOSServer( mem_reader=mem_reader, @@ -298,7 +223,6 @@ def init_server() -> dict[str, Any]: # Create MemCube with pre-initialized memory instances naive_mem_cube = NaiveMemCube( text_mem=text_mem, - pref_mem=pref_mem, act_mem=None, para_mem=None, ) @@ -325,7 +249,7 @@ def init_server() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, - pref_mem=pref_mem, + pref_feedback=True, ) # Initialize Scheduler @@ -384,12 +308,7 @@ def init_server() -> dict[str, Any]: "naive_mem_cube": naive_mem_cube, "searcher": searcher, "api_module": api_module, - "vector_db": vector_db, - "pref_extractor": pref_extractor, - "pref_adder": pref_adder, - "pref_retriever": pref_retriever, "text_mem": text_mem, - "pref_mem": pref_mem, "online_bot": online_bot, "feedback_server": feedback_server, "redis_client": redis_client, diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 06c4fd223..ee88ae639 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -65,49 +65,14 @@ def format_memory_item( return memory -def post_process_pref_mem( - memories_result: dict[str, Any], - pref_formatted_mem: list[dict[str, Any]], - mem_cube_id: str, - include_preference: bool, -) -> dict[str, Any]: - """ - Post-process preference memory results. - - Adds formatted preference memories to the result dictionary and generates - instruction completion strings if preferences are included. - - Args: - memories_result: Result dictionary to update - pref_formatted_mem: List of formatted preference memories - mem_cube_id: Memory cube ID - include_preference: Whether to include preferences in result - - Returns: - Updated memories_result dictionary - """ - if include_preference: - memories_result["pref_mem"].append( - { - "cube_id": mem_cube_id, - "memories": pref_formatted_mem, - "total_nodes": len(pref_formatted_mem), - } - ) - pref_instruction, pref_note = instruct_completion(pref_formatted_mem) - memories_result["pref_string"] = pref_instruction - memories_result["pref_note"] = pref_note - - return memories_result - - def post_process_textual_mem( memories_result: dict[str, Any], text_formatted_mem: list[dict[str, Any]], mem_cube_id: str, ) -> dict[str, Any]: """ - Post-process text and tool memory results. + Post-process text, tool, skill and preference memory results. + Now automatically handles preference memories. """ fact_mem = [ mem @@ -124,6 +89,11 @@ def post_process_textual_mem( mem for mem in text_formatted_mem if mem["metadata"]["memory_type"] == "SkillMemory" ] + # Extract preference memories + pref_mem = [ + mem for mem in text_formatted_mem if mem["metadata"]["memory_type"] == "PreferenceMemory" + ] + memories_result["text_mem"].append( { "cube_id": mem_cube_id, @@ -145,6 +115,19 @@ def post_process_textual_mem( "total_nodes": len(skill_mem), } ) + + memories_result["pref_mem"].append( + { + "cube_id": mem_cube_id, + "memories": pref_mem, + "total_nodes": len(pref_mem), + } + ) + if pref_mem: + pref_instruction, pref_note = instruct_completion(pref_mem) + memories_result["pref_string"] = pref_instruction + memories_result["pref_note"] = pref_note + return memories_result diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index ef56c7489..2ab8f50c7 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -4,12 +4,8 @@ This module handles retrieving all memories or specific subgraphs based on queries. """ -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal -from memos.api.handlers.formatters_handler import ( - format_memory_item, - post_process_pref_mem, -) from memos.api.product_models import ( DeleteMemoryRequest, DeleteMemoryResponse, @@ -29,10 +25,6 @@ ) -if TYPE_CHECKING: - from memos.memories.textual.preference import TextualMemoryItem - - logger = get_logger(__name__) @@ -171,8 +163,7 @@ def handle_get_subgraph( def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemoryResponse: """ Handler for getting a single memory by its ID. - - Tries to retrieve from text memory first, then preference memory if not found. + Now unified to retrieve from text_mem only (includes preferences). Args: memory_id: The ID of the memory to retrieve @@ -184,37 +175,12 @@ def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemory try: memory = naive_mem_cube.text_mem.get(memory_id) - except Exception: + except Exception as e: + logger.error(f"Failed to get memory {memory_id}: {e}") memory = None - # If not found in text memory, try preference memory - pref = None - if memory is None and naive_mem_cube.pref_mem is not None: - collection_names = ["explicit_preference", "implicit_preference"] - for collection_name in collection_names: - try: - pref = naive_mem_cube.pref_mem.get_with_collection_name(collection_name, memory_id) - if pref is not None: - break - except Exception: - continue - - # Get the data from whichever memory source succeeded - data = (memory or pref).model_dump() if (memory or pref) else None - - if data is not None: - # For each returned item, tackle with metadata.info project_id / - # operation / manager_user_id - metadata = data.get("metadata", None) - if metadata is not None and isinstance(metadata, dict): - info = metadata.get("info", None) - if info is not None and isinstance(info, dict): - for key in ("project_id", "operation", "manager_user_id"): - if key not in info: - continue - value = info.pop(key) - if key not in metadata: - metadata[key] = value + # Get the data + data = memory.model_dump() if memory else None return GetMemoryResponse( message="Memory retrieved successfully" @@ -230,50 +196,20 @@ def handle_get_memory_by_ids( ) -> GetMemoryResponse: """ Handler for getting multiple memories by their IDs. + Now unified to retrieve from text_mem only (includes preferences). Retrieves multiple memories and formats them as a list of dictionaries. """ try: memories = naive_mem_cube.text_mem.get_by_ids(memory_ids=memory_ids) - except Exception: + except Exception as e: + logger.error(f"Failed to get memories: {e}") memories = [] # Ensure memories is not None if memories is None: memories = [] - if naive_mem_cube.pref_mem is not None: - collection_names = ["explicit_preference", "implicit_preference"] - for collection_name in collection_names: - try: - result = naive_mem_cube.pref_mem.get_by_ids_with_collection_name( - collection_name, memory_ids - ) - if result is not None: - result = [format_memory_item(item, save_sources=False) for item in result] - memories.extend(result) - except Exception: - continue - - # For each returned item, tackle with metadata.info project_id / - # operation / manager_user_id - for item in memories: - if not isinstance(item, dict): - continue - metadata = item.get("metadata") - if not isinstance(metadata, dict): - continue - info = metadata.get("info") - if not isinstance(info, dict): - continue - - for key in ("project_id", "operation", "manager_user_id"): - if key not in info: - continue - value = info.pop(key) - if key not in metadata: - metadata[key] = value - return GetMemoryResponse( message="Memories retrieved successfully", code=200, data={"memories": memories} ) @@ -343,67 +279,31 @@ def handle_get_memories( "total_nodes": total_skill_nodes, } ] - preferences: list[TextualMemoryItem] = [] - total_preference_nodes = 0 - format_preferences = [] - if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: - filter_params: dict[str, Any] = {} - if get_mem_req.user_id is not None: - filter_params["user_id"] = get_mem_req.user_id - if get_mem_req.mem_cube_id is not None: - filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - if get_mem_req.filter is not None: - # Check and remove user_id/mem_cube_id from filter if present - filter_copy = get_mem_req.filter.copy() - removed_fields = [] - - if "user_id" in filter_copy: - filter_copy.pop("user_id") - removed_fields.append("user_id") - if "mem_cube_id" in filter_copy: - filter_copy.pop("mem_cube_id") - removed_fields.append("mem_cube_id") - - if removed_fields: - logger.warning( - f"Fields {removed_fields} found in filter will be ignored. " - f"Use request-level user_id/mem_cube_id parameters instead." - ) - - filter_params.update(filter_copy) - - preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter( - filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size + # Get preference memories (same pattern as other memory types) + if get_mem_req.include_preference: + pref_memories_info = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + memory_type=["PreferenceMemory"], ) - format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] - - # For each returned item, tackle with metadata.info project_id / - # operation / manager_user_id - for item in format_preferences: - if not isinstance(item, dict): - continue - metadata = item.get("metadata") - if not isinstance(metadata, dict): - continue - info = metadata.get("info") - if not isinstance(info, dict): - continue - - for key in ("project_id", "operation", "manager_user_id"): - if key not in info: - continue - value = info.pop(key) - if key not in metadata: - metadata[key] = value - - results = post_process_pref_mem( - results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference - ) - if total_preference_nodes > 0 and results.get("pref_mem", []): - results["pref_mem"][0]["total_nodes"] = total_preference_nodes + pref_memories, total_pref_nodes = ( + pref_memories_info["nodes"], + pref_memories_info["total_nodes"], + ) + + results["pref_mem"] = [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": pref_memories, + "total_nodes": total_pref_nodes, + } + ] - # Filter to only keep text_mem, pref_mem, tool_mem + # Filter to only keep text_mem, pref_mem, tool_mem, skill_mem filtered_results = { "text_mem": results.get("text_mem", []), "pref_mem": results.get("pref_mem", []), @@ -415,6 +315,10 @@ def handle_get_memories( def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube): + """ + Handler for deleting memories. + Now unified to delete from text_mem only (includes preferences). + """ logger.info( f"[Delete memory request] writable_cube_ids: {delete_mem_req.writable_cube_ids}, memory_ids: {delete_mem_req.memory_ids}" ) @@ -432,17 +336,14 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: try: if delete_mem_req.memory_ids is not None: + # Unified deletion from text_mem (includes preferences) naive_mem_cube.text_mem.delete_by_memory_ids(delete_mem_req.memory_ids) - if naive_mem_cube.pref_mem is not None: - naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) elif delete_mem_req.file_ids is not None: naive_mem_cube.text_mem.delete_by_filter( writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids ) elif delete_mem_req.filter is not None: naive_mem_cube.text_mem.delete_by_filter(filter=delete_mem_req.filter) - if naive_mem_cube.pref_mem is not None: - naive_mem_cube.pref_mem.delete_by_filter(filter=delete_mem_req.filter) except Exception as e: logger.error(f"Failed to delete memories: {e}", exc_info=True) return DeleteMemoryResponse( @@ -572,49 +473,29 @@ def handle_get_memories_dashboard( for cube_id, memories in skill_mem_by_cube.items() ] - preferences: list[TextualMemoryItem] = [] - - format_preferences = [] - if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: - filter_params: dict[str, Any] = {} - if get_mem_req.user_id is not None: - filter_params["user_id"] = get_mem_req.user_id - if get_mem_req.mem_cube_id is not None: - filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - if get_mem_req.filter is not None: - # Check and remove user_id/mem_cube_id from filter if present - filter_copy = get_mem_req.filter.copy() - removed_fields = [] - - if "user_id" in filter_copy: - filter_copy.pop("user_id") - removed_fields.append("user_id") - if "mem_cube_id" in filter_copy: - filter_copy.pop("mem_cube_id") - removed_fields.append("mem_cube_id") - - if removed_fields: - logger.warning( - f"Fields {removed_fields} found in filter will be ignored. " - f"Use request-level user_id/mem_cube_id parameters instead." - ) - - filter_params.update(filter_copy) - - preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter( - filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size + if get_mem_req.include_preference: + pref_memories_info = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + memory_type=["PreferenceMemory"], + ) + pref_memories, total_preference_nodes = ( + pref_memories_info["nodes"], + pref_memories_info["total_nodes"], ) - format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] - # Group preferences by cube_id from metadata.mem_cube_id + # Group preference memories by cube_id from metadata.user_name pref_mem_by_cube: dict[str, list] = {} - for pref in format_preferences: - cube_id = pref.get("metadata", {}).get("mem_cube_id", get_mem_req.mem_cube_id) + for memory in pref_memories: + cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id) if cube_id not in pref_mem_by_cube: pref_mem_by_cube[cube_id] = [] - pref_mem_by_cube[cube_id].append(pref) + pref_mem_by_cube[cube_id].append(memory) - # If no preferences found, create a default entry with the requested cube_id + # If no memories found, create a default entry with the requested cube_id if not pref_mem_by_cube and get_mem_req.mem_cube_id: pref_mem_by_cube[get_mem_req.mem_cube_id] = [] diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 8e7785ad5..58121776e 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -49,7 +49,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse Main handler for search memories endpoint. Orchestrates the search process based on the requested search mode, - supporting both text and preference memory searches. + supporting text memory searches. Args: search_req: Search request containing query and parameters diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 5bf27e985..6f112b9a7 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -434,7 +434,7 @@ class APISearchRequest(BaseRequest): # Internal field for search memory type search_memory_type: str = Field( "All", - description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, RawFileMemory, AllSummaryMemory, SkillMemory", + description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, RawFileMemory, AllSummaryMemory, SkillMemory, PreferenceMemory", ) # ==== Context ==== diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index af6ae4fe5..fa8a0b396 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -94,7 +94,6 @@ redis_client = components["redis_client"] status_tracker = TaskStatusTracker(redis_client=redis_client) graph_db = components["graph_db"] -vector_db = components["vector_db"] # ============================================================================= @@ -369,15 +368,9 @@ def feedback_memories(feedback_req: APIFeedbackRequest): response_model=GetUserNamesByMemoryIdsResponse, ) def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest): - """Get user names by memory ids.""" + """Get user names by memory ids. Now unified to query from graph_db only.""" result = graph_db.get_user_names_by_memory_ids(memory_ids=request.memory_ids) - if vector_db: - prefs = [] - for collection_name in ["explicit_preference", "implicit_preference"]: - prefs.extend( - vector_db.get_by_ids(collection_name=collection_name, ids=request.memory_ids) - ) - result.update({pref.id: pref.payload.get("mem_cube_id", None) for pref in prefs}) + return GetUserNamesByMemoryIdsResponse( code=200, message="Successfully", diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py index 3afa78bab..b9395ea0d 100644 --- a/src/memos/mem_cube/navie.py +++ b/src/memos/mem_cube/navie.py @@ -20,7 +20,6 @@ class NaiveMemCube(BaseMemCube): def __init__( self, text_mem: BaseTextMemory | None = None, - pref_mem: BaseTextMemory | None = None, act_mem: BaseActMemory | None = None, para_mem: BaseParaMemory | None = None, ): @@ -28,19 +27,20 @@ def __init__( self._text_mem: BaseTextMemory = text_mem self._act_mem: BaseActMemory | None = act_mem self._para_mem: BaseParaMemory | None = para_mem - self._pref_mem: BaseTextMemory | None = pref_mem + # pref_mem removed - now handled by text_mem def load( self, dir: str, - memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, ) -> None: """Load memories. Args: dir (str): The directory containing the memory files. memory_types (list[str], optional): List of memory types to load. If None, loads all available memory types. - Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] + Options: ["text_mem", "act_mem", "para_mem"] + Note: pref_mem is now integrated into text_mem """ loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) if loaded_schema != self.config.model_schema: @@ -51,7 +51,7 @@ def load( # If no specific memory types specified, load all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] + memory_types = ["text_mem", "act_mem", "para_mem"] # Load specified memory types if "text_mem" in memory_types and self.text_mem: @@ -66,23 +66,20 @@ def load( self.para_mem.load(dir) logger.info(f"Loaded para_mem from {dir}") - if "pref_mem" in memory_types and self.pref_mem: - self.pref_mem.load(dir) - logger.info(f"Loaded pref_mem from {dir}") - logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( self, dir: str, - memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, ) -> None: """Dump memories. Args: dir (str): The directory where the memory files will be saved. memory_types (list[str], optional): List of memory types to dump. If None, dumps all available memory types. - Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] + Options: ["text_mem", "act_mem", "para_mem"] + Note: pref_mem is now integrated into text_mem """ if os.path.exists(dir) and os.listdir(dir): raise MemCubeError( @@ -94,7 +91,7 @@ def dump( # If no specific memory types specified, dump all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] + memory_types = ["text_mem", "act_mem", "para_mem"] # Dump specified memory types if "text_mem" in memory_types and self.text_mem: @@ -109,10 +106,6 @@ def dump( self.para_mem.dump(dir) logger.info(f"Dumped para_mem to {dir}") - if "pref_mem" in memory_types and self.pref_mem: - self.pref_mem.dump(dir) - logger.info(f"Dumped pref_mem to {dir}") - logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") @property @@ -157,16 +150,4 @@ def para_mem(self, value: BaseParaMemory) -> None: raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") self._para_mem = value - @property - def pref_mem(self) -> "BaseTextMemory | None": - """Get the preference memory.""" - if self._pref_mem is None: - logger.warning("Preference memory is not initialized. Returning None.") - return self._pref_mem - - @pref_mem.setter - def pref_mem(self, value: BaseTextMemory) -> None: - """Set the preference memory.""" - if not isinstance(value, BaseTextMemory): - raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") - self._pref_mem = value + # pref_mem property removed - preferences now handled by text_mem diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 6c6d1821f..18045af2c 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -2,7 +2,6 @@ import difflib import json import re -import uuid from datetime import datetime from typing import TYPE_CHECKING, Any, Literal @@ -36,7 +35,6 @@ if TYPE_CHECKING: - from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_feedback_prompts import ( FEEDBACK_ANSWER_PROMPT, @@ -95,7 +93,6 @@ def __init__(self, config: MemFeedbackConfig): self.stopword_manager = StopwordManager self.searcher: Searcher = None self.reranker = None - self.pref_mem: SimplePreferenceTextMemory = None self.pref_feedback: bool = False self.DB_IDX_READY = False @@ -239,6 +236,9 @@ def _single_add_operation( else: to_add_memory = new_memory_item.model_copy(deep=True) + if to_add_memory.metadata.memory_type == "PreferenceMemory": + to_add_memory.metadata.preference = new_memory_item.memory + to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( datetime.now().isoformat() ) @@ -274,13 +274,6 @@ def _single_update_operation( """ Individual update operations """ - if "preference" in old_memory_item.metadata.__dict__: - logger.info( - f"[0107 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}" - ) - return self._single_update_pref( - old_memory_item, new_memory_item, user_id, user_name, operation - ) memory_type = old_memory_item.metadata.memory_type source_doc_id = ( @@ -329,68 +322,6 @@ def _single_update_operation( "origin_memory": old_memory_item.memory, } - def _single_update_pref( - self, - old_memory_item: TextualMemoryItem, - new_memory_item: TextualMemoryItem, - user_id: str, - user_name: str, - operation: dict, - ): - """update preference memory""" - - feedback_context = new_memory_item.memory - if operation and "text" in operation and operation["text"]: - new_memory_item.memory = operation["text"] - new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] - - to_add_memory = old_memory_item.model_copy(deep=True) - to_add_memory.metadata.key = new_memory_item.metadata.key - to_add_memory.metadata.tags = new_memory_item.metadata.tags - to_add_memory.memory = new_memory_item.memory - to_add_memory.metadata.preference = new_memory_item.memory - to_add_memory.metadata.embedding = new_memory_item.metadata.embedding - - to_add_memory.metadata.user_id = new_memory_item.metadata.user_id - to_add_memory.metadata.original_text = old_memory_item.memory - to_add_memory.metadata.covered_history = old_memory_item.id - - to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( - datetime.now().isoformat() - ) - to_add_memory.metadata.context_summary = ( - old_memory_item.metadata.context_summary + " \n" + feedback_context - ) - - # add new memory - to_add_memory.id = str(uuid.uuid4()) - added_ids = self._retry_db_operation(lambda: self.pref_mem.add([to_add_memory])) - # delete - deleted_id = old_memory_item.id - collection_name = old_memory_item.metadata.preference_type - self._retry_db_operation( - lambda: self.pref_mem.delete_with_collection_name(collection_name, [deleted_id]) - ) - # add archived - old_memory_item.metadata.status = "archived" - old_memory_item.metadata.original_text = "archived" - old_memory_item.metadata.embedding = [0.0] * 1024 - - archived_ids = self._retry_db_operation(lambda: self.pref_mem.add([old_memory_item])) - - logger.info( - f"[Memory Feedback UPDATE Pref] New Add:{added_ids!s} | Set archived:{archived_ids!s}" - ) - - return { - "id": to_add_memory.id, - "text": new_memory_item.memory, - "source_doc_id": "", - "archived_id": old_memory_item.id, - "origin_memory": old_memory_item.memory, - "type": "preference", - } - def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]: """Delete working memory bindings""" bindings_to_delete = extract_working_binding_ids(mem_items) @@ -460,7 +391,7 @@ def semantics_feedback( for chunk in memory_chunks: chunk_list = [] for item in chunk: - if "preference" in item.metadata.__dict__: + if item.metadata.memory_type == "PreferenceMemory": chunk_list.append(f"{item.id}: {item.metadata.preference}") else: chunk_list.append(f"{item.id}: {item.memory}") @@ -638,6 +569,19 @@ def check_has_edges(mem_item: TextualMemoryItem) -> tuple[TextualMemoryItem, boo ) text_mems = [item[0] for item in text_mems if float(item[1]) > 0.01] + if self.pref_feedback: + pref_mems = self.searcher.search( + query, + info=info, + memory_type="PreferenceMemory", + user_name=user_name, + top_k=top_k, + include_preference_memory=True, + full_recall=True, + ) + pref_mems = [item[0] for item in pref_mems if float(item[1]) > 0.01] + text_mems.extend(pref_mems) + # Memory with edges is not modified by feedback retrieved_mems = [] with ContextThreadPoolExecutor(max_workers=10) as executor: @@ -656,14 +600,7 @@ def check_has_edges(mem_item: TextualMemoryItem) -> tuple[TextualMemoryItem, boo f"text memories are not modified by feedback due to edges." ) - if self.pref_feedback: - pref_info = {} - if "user_id" in info: - pref_info = {"user_id": info["user_id"]} - retrieved_prefs = self.pref_mem.search(query, top_k, pref_info) - return retrieved_mems + retrieved_prefs - else: - return retrieved_mems + return retrieved_mems def _vec_query(self, new_memories_embedding: list[float], user_name=None): """Vector retrieval query""" diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index 2ac0a0a39..dfc9b9fdf 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -4,7 +4,6 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.mem_feedback.feedback import MemFeedback from memos.mem_reader.simple_struct import SimpleStructMemReader -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -24,7 +23,6 @@ def __init__( mem_reader: SimpleStructMemReader, searcher: Searcher, reranker: BaseReranker, - pref_mem: SimplePreferenceTextMemory, pref_feedback: bool = False, ): self.llm = llm @@ -34,7 +32,6 @@ def __init__( self.mem_reader = mem_reader self.searcher = searcher self.stopword_manager = StopwordManager - self.pref_mem = pref_mem self.reranker = reranker self.DB_IDX_READY = False self.pref_feedback = pref_feedback diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index e3d2bece9..62e8f2d75 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -10,6 +10,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang from memos.mem_reader.read_multi_modal.base import _derive_key +from memos.mem_reader.read_pref_memory.process_preference_memory import process_preference_fine from memos.mem_reader.read_skill_memory.process_skill_memory import process_skill_memory_fine from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader from memos.mem_reader.utils import parse_json_result @@ -993,7 +994,7 @@ def _process_multi_modal_data( # Part A: call llm in parallel using thread pool fine_memory_items = [] - with ContextThreadPoolExecutor(max_workers=3) as executor: + with ContextThreadPoolExecutor(max_workers=4) as executor: future_string = executor.submit( self._process_string_fine, fast_memory_items, info, custom_tags, **kwargs ) @@ -1012,15 +1013,25 @@ def _process_multi_modal_data( skills_dir_config=self.skills_dir_config, **kwargs, ) + future_pref = executor.submit( + process_preference_fine, + fast_memory_items, + info, + self.llm, + self.embedder, + **kwargs, + ) # Collect results fine_memory_items_string_parser = future_string.result() fine_memory_items_tool_trajectory_parser = future_tool.result() fine_memory_items_skill_memory_parser = future_skill.result() + fine_memory_items_pref_parser = future_pref.result() fine_memory_items.extend(fine_memory_items_string_parser) fine_memory_items.extend(fine_memory_items_tool_trajectory_parser) fine_memory_items.extend(fine_memory_items_skill_memory_parser) + fine_memory_items.extend(fine_memory_items_pref_parser) # Part B: get fine multimodal items for fast_item in fast_memory_items: @@ -1060,7 +1071,7 @@ def _process_transfer_multi_modal_data( fine_memory_items = [] # Part A: call llm in parallel using thread pool - with ContextThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=4) as executor: future_string = executor.submit( self._process_string_fine, raw_nodes, info, custom_tags, **kwargs ) @@ -1079,14 +1090,21 @@ def _process_transfer_multi_modal_data( skills_dir_config=self.skills_dir_config, **kwargs, ) + # Add preference memory extraction + future_pref = executor.submit( + process_preference_fine, raw_nodes, info, self.llm, self.embedder, **kwargs + ) # Collect results fine_memory_items_string_parser = future_string.result() fine_memory_items_tool_trajectory_parser = future_tool.result() fine_memory_items_skill_memory_parser = future_skill.result() + fine_memory_items_pref_parser = future_pref.result() + fine_memory_items.extend(fine_memory_items_string_parser) fine_memory_items.extend(fine_memory_items_tool_trajectory_parser) fine_memory_items.extend(fine_memory_items_skill_memory_parser) + fine_memory_items.extend(fine_memory_items_pref_parser) # Part B: get fine multimodal items for raw_node in raw_nodes: diff --git a/src/memos/mem_reader/read_pref_memory/process_preference_memory.py b/src/memos/mem_reader/read_pref_memory/process_preference_memory.py new file mode 100644 index 000000000..1ff1fba52 --- /dev/null +++ b/src/memos/mem_reader/read_pref_memory/process_preference_memory.py @@ -0,0 +1,296 @@ +"""Preference memory extractor.""" + +import json +import os +import uuid + +from concurrent.futures import as_completed +from typing import TYPE_CHECKING, Any + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_reader.read_multi_modal import detect_lang +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.templates.prefer_complete_prompt import ( + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, +) + + +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + +logger = get_logger(__name__) + + +def _extract_explicit_preference(qa_pair_str: str, llm) -> list[dict[str, Any]] | None: + """Extract explicit preference from a QA pair string.""" + lang = detect_lang(qa_pair_str) + _map = { + "zh": NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, + "en": NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + } + prompt = _map[lang].replace("{qa_pair}", qa_pair_str) + + try: + response = llm.generate([{"role": "user", "content": prompt}]) + if not response: + logger.info( + f"[prefer_extractor]: (Error) LLM response content is {response} when extracting explicit preference" + ) + return None + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + for d in result: + d["preference"] = d.pop("explicit_preference") + return result + except Exception as e: + logger.info(f"Error extracting explicit preference: {e}, return None") + return None + + +def _extract_implicit_preference(qa_pair_str: str, llm) -> list[dict[str, Any]] | None: + """Extract implicit preferences from a QA pair string.""" + if not qa_pair_str: + return None + + lang = detect_lang(qa_pair_str) + _map = { + "zh": NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, + "en": NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + } + prompt = _map[lang].replace("{qa_pair}", qa_pair_str) + + try: + response = llm.generate([{"role": "user", "content": prompt}]) + if not response: + logger.info( + f"[prefer_extractor]: (Error) LLM response content is {response} when extracting implicit preference" + ) + return None + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + for d in result: + d["preference"] = d.pop("implicit_preference") + return result + except Exception as e: + logger.info(f"Error extracting implicit preferences: {e}, return None") + return None + + +def _create_preference_memory_item( + preference_data: dict[str, Any], + preference_type: str, + fast_item: TextualMemoryItem | None, + info: dict[str, Any], + embedder, + **kwargs, +) -> TextualMemoryItem: + """ + Create a preference memory item with proper metadata. + + Args: + preference_data: Dictionary containing preference, context_summary, reasoning, topic + preference_type: "explicit_preference" or "implicit_preference" + fast_item: Original fast memory item (for extracting sources and other metadata) + info: Dictionary containing user_id, session_id, etc. + embedder: Embedder instance + kwargs: Additional parameters including user_context + + Returns: + TextualMemoryItem with TreeNodeTextualMemoryMetadata + """ + # Make a copy of info to avoid modifying the original + info_ = info.copy() + + # Extract fields that should be at metadata level + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + + # Extract manager_user_id, project_id, and operation from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + + # Generate embedding for context_summary + context_summary = preference_data.get("context_summary", "") + embedding = embedder.embed([context_summary])[0] if embedder and context_summary else None + + # Extract sources from fast_item + sources = getattr(fast_item.metadata, "sources", []) if fast_item else [] + + # Create metadata + metadata = TreeNodeTextualMemoryMetadata( + memory_type="PreferenceMemory", + embedding=embedding, + user_id=user_id, + session_id=session_id, + status="activated", + tags=[], + type="chat", + info=info_, + sources=sources, + usage=[], + background="", + # Preference-specific fields + preference_type=preference_type, + preference=preference_data.get("preference", ""), + reasoning=preference_data.get("reasoning", ""), + topic=preference_data.get("topic", ""), + # User-specific fields + manager_user_id=manager_user_id, + project_id=project_id, + ) + + # Create and return memory item + return TextualMemoryItem(id=str(uuid.uuid4()), memory=context_summary, metadata=metadata) + + +def _process_single_chunk_explicit( + original_text: str, + fast_item: TextualMemoryItem | None, + info: dict[str, Any], + llm, + embedder, + **kwargs, +) -> list[TextualMemoryItem]: + """Process a single chunk for explicit preferences.""" + if not original_text.strip(): + return [] + + explicit_pref = _extract_explicit_preference(original_text, llm) + if not explicit_pref: + return [] + + memories = [] + for pref in explicit_pref: + memory = _create_preference_memory_item( + preference_data=pref, + preference_type="explicit_preference", + fast_item=fast_item, + info=info, + embedder=embedder, + **kwargs, + ) + memories.append(memory) + + return memories + + +def _process_single_chunk_implicit( + original_text: str, + fast_item: TextualMemoryItem | None, + info: dict[str, Any], + llm, + embedder, + **kwargs, +) -> list[TextualMemoryItem]: + """Process a single chunk for implicit preferences.""" + if not original_text.strip(): + return [] + + implicit_pref = _extract_implicit_preference(original_text, llm) + if not implicit_pref: + return [] + + memories = [] + for pref in implicit_pref: + memory = _create_preference_memory_item( + preference_data=pref, + preference_type="implicit_preference", + fast_item=fast_item, + info=info, + embedder=embedder, + **kwargs, + ) + memories.append(memory) + + return memories + + +def process_preference_fine( + fast_memory_items: list[TextualMemoryItem], + info: dict[str, Any], + llm=None, + embedder=None, + **kwargs, +) -> list[TextualMemoryItem]: + """ + Extract preference memories from fast_memory_items (for fine mode processing). + + Args: + fast_memory_items: List of TextualMemoryItem from fast parsing + info: Dictionary containing user_id and session_id + llm: LLM instance + embedder: Embedder instance + kwargs: Additional parameters (including user_context) + + Returns: + List of preference memory items + """ + + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + if not fast_memory_items or not llm: + return [] + + try: + # Convert fast_memory_items to messages format + chunks = [] + for fast_item in fast_memory_items: + mem_str = fast_item.memory or "" + if not mem_str.strip(): + continue + chunks.append((mem_str, fast_item)) + + if not chunks: + return [] + + # Process chunks in parallel + memories = [] + with ContextThreadPoolExecutor(max_workers=min(10, len(chunks))) as executor: + futures = {} + + # Submit explicit extraction tasks + for chunk, fast_item in chunks: + future = executor.submit( + _process_single_chunk_explicit, chunk, fast_item, info, llm, embedder, **kwargs + ) + futures[future] = ("explicit_preference", chunk) + + # Submit implicit extraction tasks + for chunk, fast_item in chunks: + future = executor.submit( + _process_single_chunk_implicit, chunk, fast_item, info, llm, embedder, **kwargs + ) + futures[future] = ("implicit_preference", chunk) + + # Collect results + for future in as_completed(futures): + try: + memory = future.result() + if memory: + if isinstance(memory, list): + memories.extend(memory) + else: + memories.append(memory) + except Exception as e: + task_type, chunk = futures[future] + logger.warning( + f"[process_preference_fine] Error processing {task_type} chunk, original text: {chunk}: {e}" + ) + continue + + if memories: + logger.info(f"[process_preference_fine] Extracted {len(memories)} preference memories") + + return memories + except Exception as e: + logger.warning( + f"[process_preference_fine] Failed to extract preferences: {e}", exc_info=True + ) + return [] diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index b103acf3a..8777b9f2e 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -18,17 +18,6 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_feedback.simple_feedback import SimpleMemFeedback from memos.mem_reader.factory import MemReaderFactory -from memos.memories.textual.prefer_text_memory.config import ( - AdderConfigFactory, - ExtractorConfigFactory, - RetrieverConfigFactory, -) -from memos.memories.textual.prefer_text_memory.factory import ( - AdderFactory, - ExtractorFactory, - RetrieverFactory, -) -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( @@ -40,7 +29,6 @@ if TYPE_CHECKING: from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.factory import RerankerFactory -from memos.vec_dbs.factory import VecDBFactory logger = get_logger(__name__) @@ -182,36 +170,6 @@ def build_internet_retriever_config() -> dict[str, Any]: return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) -def build_pref_extractor_config() -> dict[str, Any]: - """ - Build preference memory extractor configuration. - - Returns: - Validated extractor configuration dictionary - """ - return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def build_pref_adder_config() -> dict[str, Any]: - """ - Build preference memory adder configuration. - - Returns: - Validated adder configuration dictionary - """ - return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def build_pref_retriever_config() -> dict[str, Any]: - """ - Build preference memory retriever configuration. - - Returns: - Validated retriever configuration dictionary - """ - return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) - - def _get_default_memory_size(cube_config: Any) -> dict[str, int]: """ Get default memory size configuration. @@ -291,20 +249,11 @@ def init_components() -> dict[str, Any]: reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() internet_retriever_config = build_internet_retriever_config() - vector_db_config = build_vec_db_config() - pref_extractor_config = build_pref_extractor_config() - pref_adder_config = build_pref_adder_config() - pref_retriever_config = build_pref_retriever_config() logger.debug("Component configurations built successfully") # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) - vector_db = ( - VecDBFactory.from_config(vector_db_config) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) @@ -345,63 +294,9 @@ def init_components() -> dict[str, Any]: logger.debug("Text memory initialized") - # Initialize preference memory components - pref_extractor = ( - ExtractorFactory.from_config( - config_factory=pref_extractor_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - pref_adder = ( - AdderFactory.from_config( - config_factory=pref_adder_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - text_mem=text_mem, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - pref_retriever = ( - RetrieverFactory.from_config( - config_factory=pref_retriever_config, - llm_provider=llm, - embedder=embedder, - reranker=feedback_reranker, - vector_db=vector_db, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - logger.debug("Preference memory components initialized") - - # Initialize preference memory - pref_mem = ( - SimplePreferenceTextMemory( - extractor_llm=llm, - vector_db=vector_db, - embedder=embedder, - reranker=feedback_reranker, - extractor=pref_extractor, - adder=pref_adder, - retriever=pref_retriever, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - # Create MemCube with pre-initialized memory instances naive_mem_cube = NaiveMemCube( text_mem=text_mem, - pref_mem=pref_mem, act_mem=None, para_mem=None, ) @@ -421,7 +316,7 @@ def init_components() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, - pref_mem=pref_mem, + pref_feedback=True, ) # Return all components as a dictionary for easy access and extension return {"naive_mem_cube": naive_mem_cube, "feedback_server": feedback_server} diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 7e40f1d50..60af67830 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -171,6 +171,7 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata): "ToolTrajectoryMemory", "RawFileMemory", "SkillMemory", + "PreferenceMemory", ] = Field(default="WorkingMemory", description="Memory lifecycle type.") sources: list[SourceMessage] | None = Field( default=None, description="Multiple origins of the memory (e.g., URLs, notes)." @@ -337,8 +338,6 @@ def _coerce_metadata(cls, v: Any): if v.get("relativity") is not None: return SearchedTreeNodeTextualMemoryMetadata(**v) - if v.get("preference_type") is not None: - return PreferenceTextualMemoryMetadata(**v) if any(k in v for k in ("sources", "memory_type", "embedding", "background", "usage")): return TreeNodeTextualMemoryMetadata(**v) return TextualMemoryMetadata(**v) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index dba321f55..0cc6d1930 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -74,6 +74,7 @@ def get_memory( messages (list[MessageList]): The messages to get memory from. type (str): The type of memory to get. info (dict[str, Any]): The info to get memory. + **kwargs: Additional keyword arguments to pass to the extractor. """ return self.extractor.extract(messages, type, info, **kwargs) @@ -91,7 +92,6 @@ def search( if not isinstance(search_filter, dict): search_filter = {} search_filter.update({"status": "activated"}) - logger.info(f"search_filter for preference memory: {search_filter}") return self.retriever.retrieve(query, top_k, info, search_filter) def load(self, dir: str) -> None: diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index db7101744..51523d364 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -1,5 +1,3 @@ -from typing import Any - from memos.embedders.factory import ( ArkEmbedder, OllamaEmbedder, @@ -8,9 +6,7 @@ ) from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger -from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.preference import PreferenceTextMemory -from memos.types import MessageList from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB @@ -38,125 +34,3 @@ def __init__( self.extractor = extractor self.adder = adder self.retriever = retriever - - def get_memory( - self, messages: list[MessageList], type: str, info: dict[str, Any], **kwargs - ) -> list[TextualMemoryItem]: - """Get memory based on the messages. - Args: - messages (MessageList): The messages to get memory from. - type (str): The type of memory to get. - info (dict[str, Any]): The info to get memory. - **kwargs: Additional keyword arguments to pass to the extractor. - """ - return self.extractor.extract(messages, type, info, **kwargs) - - def search( - self, query: str, top_k: int, info=None, search_filter=None, **kwargs - ) -> list[TextualMemoryItem]: - """Search for memories based on a query. - Args: - query (str): The query to search for. - top_k (int): The number of top results to return. - info (dict): Leave a record of memory consumption. - Returns: - list[TextualMemoryItem]: List of matching memories. - """ - if not isinstance(search_filter, dict): - search_filter = {} - search_filter.update({"status": "activated"}) - return self.retriever.retrieve(query, top_k, info, search_filter) - - def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: - """Add memories. - - Args: - memories: List of TextualMemoryItem objects or dictionaries to add. - """ - return self.adder.add(memories) - - def get_with_collection_name( - self, collection_name: str, memory_id: str - ) -> TextualMemoryItem | None: - """Get a memory by its ID and collection name. - Args: - memory_id (str): The ID of the memory to retrieve. - collection_name (str): The name of the collection to retrieve the memory from. - Returns: - TextualMemoryItem: The memory with the given ID and collection name. - """ - try: - res = self.vector_db.get_by_id(collection_name, memory_id) - if res is None: - return None - return TextualMemoryItem( - id=res.id, - memory=res.memory, - metadata=PreferenceTextualMemoryMetadata(**res.payload), - ) - except Exception as e: - # Convert any other exception to ValueError for consistent error handling - raise ValueError( - f"Memory with ID {memory_id} not found in collection {collection_name}: {e}" - ) from e - - def get_by_ids_with_collection_name( - self, collection_name: str, memory_ids: list[str] - ) -> list[TextualMemoryItem]: - """Get memories by their IDs and collection name. - Args: - collection_name (str): The name of the collection to retrieve the memory from. - memory_ids (list[str]): List of memory IDs to retrieve. - Returns: - list[TextualMemoryItem]: List of memories with the specified IDs and collection name. - """ - try: - res = self.vector_db.get_by_ids(collection_name, memory_ids) - if not res: - return [] - return [ - TextualMemoryItem( - id=memo.id, - memory=memo.memory, - metadata=PreferenceTextualMemoryMetadata(**memo.payload), - ) - for memo in res - ] - except Exception as e: - # Convert any other exception to ValueError for consistent error handling - raise ValueError( - f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}" - ) from e - - def get_all(self) -> list[TextualMemoryItem]: - """Get all memories. - Returns: - list[TextualMemoryItem]: List of all memories. - """ - all_collections = ["explicit_preference", "implicit_preference"] - all_memories = {} - for collection_name in all_collections: - items = self.vector_db.get_all(collection_name) - all_memories[collection_name] = [ - TextualMemoryItem( - id=memo.id, - memory=memo.memory, - metadata=PreferenceTextualMemoryMetadata(**memo.payload), - ) - for memo in items - ] - return all_memories - - def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: - """Delete memories by their IDs and collection name. - Args: - collection_name (str): The name of the collection to delete the memory from. - memory_ids (list[str]): List of memory IDs to delete. - """ - self.vector_db.delete(collection_name, memory_ids) - - def delete_all(self) -> None: - """Delete all memories.""" - for collection_name in self.vector_db.config.collection_name: - self.vector_db.delete_collection(collection_name) - self.vector_db.create_collection() diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 5b210ba61..8c896f538 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -169,6 +169,8 @@ def search( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, dedup: str | None = None, include_embedding: bool | None = None, **kwargs, @@ -222,6 +224,8 @@ def search( tool_mem_top_k=tool_mem_top_k, include_skill_memory=include_skill_memory, skill_mem_top_k=skill_mem_top_k, + include_preference_memory=include_preference_memory, + pref_mem_top_k=pref_mem_top_k, dedup=dedup, **kwargs, ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 4ca30c7b8..df419f0c1 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -185,6 +185,7 @@ def _add_memories_batch( "ToolTrajectoryMemory", "RawFileMemory", "SkillMemory", + "PreferenceMemory", ): graph_node_id = ( memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) @@ -341,6 +342,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non "ToolTrajectoryMemory", "RawFileMemory", "SkillMemory", + "PreferenceMemory", ): f_graph = ex.submit( self._add_to_graph_memory, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index e5e96dd58..dd90b8932 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -69,6 +69,7 @@ def retrieve( "ToolTrajectoryMemory", "RawFileMemory", "SkillMemory", + "PreferenceMemory", ]: raise ValueError(f"Unsupported memory scope: {memory_scope}") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index cc269e8c4..b4994671f 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -87,6 +87,8 @@ def retrieve( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: logger.info( @@ -116,6 +118,8 @@ def retrieve( tool_mem_top_k, include_skill_memory, skill_mem_top_k, + include_preference_memory, + pref_mem_top_k, ) return results @@ -129,6 +133,8 @@ def post_retrieve( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, dedup: str | None = None, plugin=False, ): @@ -144,6 +150,8 @@ def post_retrieve( tool_mem_top_k, include_skill_memory, skill_mem_top_k, + include_preference_memory, + pref_mem_top_k, ) self._update_usage_history(final_results, info, user_name) return final_results @@ -163,6 +171,8 @@ def search( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: @@ -212,6 +222,8 @@ def search( tool_mem_top_k=tool_mem_top_k, include_skill_memory=include_skill_memory, skill_mem_top_k=skill_mem_top_k, + include_preference_memory=include_preference_memory, + pref_mem_top_k=pref_mem_top_k, **kwargs, ) @@ -229,6 +241,8 @@ def search( tool_mem_top_k=tool_mem_top_k, include_skill_memory=include_skill_memory, skill_mem_top_k=skill_mem_top_k, + include_preference_memory=include_preference_memory, + pref_mem_top_k=pref_mem_top_k, dedup=dedup, ) @@ -329,8 +343,10 @@ def _retrieve_paths( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, ): - """Run A/B/C/D/E retrieval paths in parallel""" + """Run A/B/C/D/E/F retrieval paths in parallel""" tasks = [] id_filter = { "user_id": info.get("user_id", None), @@ -428,6 +444,22 @@ def _retrieve_paths( mode=mode, ) ) + if include_preference_memory: + tasks.append( + executor.submit( + self._retrieve_from_preference_memory, + query, + parsed_goal, + query_embedding, + pref_mem_top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + mode=mode, + ) + ) results = [] for t in tasks: results.extend(t.result()) @@ -863,6 +895,57 @@ def _retrieve_from_skill_memory( search_filter=search_filter, ) + @timed + def _retrieve_from_preference_memory( + self, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter: dict | None = None, + search_priority: dict | None = None, + user_name: str | None = None, + id_filter: dict | None = None, + mode: str = "fast", + ): + """Retrieve and rerank from PreferenceMemory""" + if memory_type not in ["All", "PreferenceMemory"]: + logger.info(f"[PATH-F] '{query}' Skipped (memory_type does not match)") + return [] + + # chain of thinking + cot_embeddings = [] + if self.vec_cot: + queries = self._cot_query(query, mode=mode, context=parsed_goal.context) + if len(queries) > 1: + cot_embeddings = self.embedder.embed(queries) + cot_embeddings.extend(query_embedding) + else: + cot_embeddings = query_embedding + + items = self.graph_retriever.retrieve( + query=query, + parsed_goal=parsed_goal, + query_embedding=cot_embeddings, + top_k=top_k * 2, + memory_scope="PreferenceMemory", + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + id_filter=id_filter, + use_fast_graph=self.use_fast_graph, + ) + + return self.reranker.rerank( + query=query, + query_embedding=query_embedding[0], + graph_results=items, + top_k=top_k, + parsed_goal=parsed_goal, + search_filter=search_filter, + ) + @timed def _retrieve_simple( self, @@ -933,6 +1016,8 @@ def _sort_and_trim( tool_mem_top_k=6, include_skill_memory=False, skill_mem_top_k=3, + include_preference_memory=False, + pref_mem_top_k=6, ): """Sort results by score and trim to top_k""" final_items = [] @@ -1000,6 +1085,28 @@ def _sort_and_trim( ) ) + if include_preference_memory: + pref_results = [ + (item, score) + for item, score in results + if item.metadata.memory_type == "PreferenceMemory" + ] + sorted_pref_results = sorted(pref_results, key=lambda pair: pair[1], reverse=True)[ + :pref_mem_top_k + ] + for item, score in sorted_pref_results: + if plugin and round(score, 2) == 0.00: + continue + meta_data = item.metadata.model_dump() + meta_data["relativity"] = score + final_items.append( + TextualMemoryItem( + id=item.id, + memory=item.memory, + metadata=SearchedTreeNodeTextualMemoryMetadata(**meta_data), + ) + ) + # separate textual results results = [ (item, score) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index d890c77bf..6df410c19 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import os import time import traceback @@ -11,10 +10,8 @@ from memos.api.handlers.formatters_handler import ( format_memory_item, - post_process_pref_mem, post_process_textual_mem, ) -from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_reader.utils import parse_keep_filter_response from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -22,7 +19,6 @@ ADD_TASK_LABEL, MEM_FEEDBACK_TASK_LABEL, MEM_READ_TASK_LABEL, - PREF_ADD_TASK_LABEL, ) from memos.memories.textual.item import TextualMemoryItem from memos.multi_mem_cube.views import MemCubeView @@ -78,38 +74,23 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: ) target_session_id = add_req.session_id or "default_session" - sync_mode = add_req.async_mode or self._get_sync_mode() - self.logger.info( f"[SingleCubeView] cube={self.cube_id} " f"Processing add with mode={sync_mode}, session={target_session_id}" ) - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._process_text_mem, add_req, user_context, sync_mode) - pref_future = executor.submit(self._process_pref_mem, add_req, user_context, sync_mode) - - text_results = text_future.result() - pref_results = pref_future.result() - - self.logger.info( - f"[SingleCubeView] cube={self.cube_id} text_results={len(text_results)}, " - f"pref_results={len(pref_results)}" - ) - - for item in text_results: - item["cube_id"] = self.cube_id - for item in pref_results: - item["cube_id"] = self.cube_id - - all_memories = text_results + pref_results + all_memories = self._process_text_mem(add_req, user_context, sync_mode) - # TODO: search existing memories and compare + self.logger.info(f"[SingleCubeView] cube={self.cube_id} total_results={len(all_memories)}") return all_memories @timed def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: + """ + Unified memory search handling (text + preference memories). + Preference memories are now searched through the same _search_text flow. + """ # Create UserContext object user_context = UserContext( user_id=search_req.user_id, @@ -131,28 +112,16 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: # Determine search mode search_mode = self._get_search_mode(search_req.mode) - # Execute search in parallel for text and preference memories - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._search_text, search_req, user_context, search_mode) - pref_future = executor.submit(self._search_pref, search_req, user_context) - - text_formatted_memories = text_future.result() - pref_formatted_memories = pref_future.result() + # Unified search through _search_text (includes all memory types) + all_formatted_memories = self._search_text(search_req, user_context, search_mode) - # Build result + # Build result with unified processing memories_result = post_process_textual_mem( memories_result, - text_formatted_memories, + all_formatted_memories, self.cube_id, ) - memories_result = post_process_pref_mem( - memories_result, - pref_formatted_memories, - self.cube_id, - search_req.include_preference, - ) - self.logger.info(f"Search memories result: {memories_result}") self.logger.info(f"Search {len(memories_result)} memories.") return memories_result @@ -407,71 +376,6 @@ def _dedup_by_content(memories: list) -> list: return formatted_memories - @timed - def _search_pref( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list[dict[str, Any]]: - """ - Search preference memories. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of formatted preference memory items - TODO: ADD CUBE ID IN PREFERENCE MEMORY - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - if not search_req.include_preference: - return [] - - logger.info(f"search_req.filter for preference memory: {search_req.filter}") - logger.info(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}") - try: - results = self.naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "mem_cube_id": user_context.mem_cube_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, - search_filter=search_req.filter, - ) - include_embedding = os.getenv("INCLUDE_EMBEDDING", "false") == "true" - formatted_results = self._postformat_memories( - results, user_context.mem_cube_id, include_embedding=include_embedding - ) - - # For each returned item, tackle with metadata.info project_id / - # operation / manager_user_id - for item in formatted_results: - if not isinstance(item, dict): - continue - metadata = item.get("metadata") - if not isinstance(metadata, dict): - continue - info = metadata.get("info") - if not isinstance(info, dict): - continue - - for key in ("project_id", "operation", "manager_user_id"): - if key not in info: - continue - value = info.pop(key) - if key not in metadata: - metadata[key] = value - - return formatted_results - except Exception as e: - self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) - return [] - def _fast_search( self, search_req: APISearchRequest, @@ -645,89 +549,6 @@ def _schedule_memory_tasks( ) self.mem_scheduler.submit_messages(messages=[message_item_add]) - @timed - def _process_pref_mem( - self, - add_req: APIADDRequest, - user_context: UserContext, - sync_mode: str, - ) -> list[dict[str, Any]]: - """ - Process and add preference memories. - - Extracts preferences from messages and adds them to the preference memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - - Returns: - List of formatted preference responses - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - if add_req.messages is None or isinstance(add_req.messages, str): - return [] - - for message in add_req.messages: - if isinstance(message, dict) and message.get("role", None) is None: - return [] - - target_session_id = add_req.session_id or "default_session" - - if sync_mode == "async": - try: - messages_list = [add_req.messages] - message_item_pref = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=user_context.mem_cube_id, - mem_cube=self.naive_mem_cube, - label=PREF_ADD_TASK_LABEL, - content=json.dumps(messages_list), - timestamp=datetime.utcnow(), - info=add_req.info, - user_name=self.cube_id, - task_id=add_req.task_id, - user_context=user_context, - ) - self.mem_scheduler.submit_messages(messages=[message_item_pref]) - self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async") - except Exception as e: - self.logger.error( - f"[SingleCubeView] cube={self.cube_id} Failed to submit PREF_ADD: {e}", - exc_info=True, - ) - return [] - else: - pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - **(add_req.info or {}), - "user_id": add_req.user_id, - "session_id": target_session_id, - "mem_cube_id": user_context.mem_cube_id, - }, - user_context=user_context, - ) - pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) - self.logger.info( - f"[SingleCubeView] cube={self.cube_id} " - f"added {len(pref_ids_local)} preferences for user {add_req.user_id}: {pref_ids_local}" - ) - - return [ - { - "memory": memory.metadata.preference, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] - def add_before_search( self, messages: list[dict], @@ -834,7 +655,7 @@ def _process_text_mem( sync_mode: str, ) -> list[dict[str, Any]]: """ - Process and add text memories. + Process and add text memories (including preference memories). Extracts memories from messages and adds them to the text memory system. Handles both sync and async modes. @@ -959,13 +780,15 @@ def _process_text_mem( "[SingleCubeView] merged_from provided but graph_db is unavailable; skip archiving." ) + # Format results uniformly text_memories = [ { "memory": memory.memory, "memory_id": memory_id, "memory_type": memory.metadata.memory_type, + "cube_id": self.cube_id, } - for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + for memory_id, memory in zip(mem_ids_local, mem_group, strict=False) ] return text_memories diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py index 6d57e3605..fa713a7d1 100644 --- a/src/memos/search/search_service.py +++ b/src/memos/search/search_service.py @@ -62,6 +62,8 @@ def search_text_memories( tool_mem_top_k=search_req.tool_mem_top_k, include_skill_memory=search_req.include_skill_memory, skill_mem_top_k=search_req.skill_mem_top_k, + include_preference_memory=search_req.include_preference, + pref_mem_top_k=search_req.pref_top_k, dedup=search_req.dedup, include_embedding=include_embedding, ) From 1062fec86277b6252e672a196afa8712a8d2b747 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Tue, 3 Mar 2026 10:22:07 +0800 Subject: [PATCH 08/20] feat: optimzie polardb ThreadedConnectionPool (#1152) * feat:optimzie polardb ThreadedConnectionPool * feat:optimzie polardb ThreadedConnectionPool * feat:optimzie polardb ThreadedConnectionPool * feat:optimzie polardb ThreadedConnectionPool * feat:optimzie polardb ThreadedConnectionPool * feat:add _warm_up_on_startup * feat:add _warm_up_on_startup * feat:add _warm_up_on_startup * feat:add _warm_up_on_startup * feat:config format --------- Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/config.py | 12 +- src/memos/configs/graph_db.py | 27 + src/memos/graph_dbs/polardb.py | 1133 ++++++++++++-------------------- 3 files changed, 467 insertions(+), 705 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 65049b0c2..fa12bcf55 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -675,7 +675,17 @@ def get_polardb_config(user_id: str | None = None) -> dict[str, Any]: "user_name": user_name, "use_multi_db": use_multi_db, "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", "1024")), + # .env: CONNECTION_WAIT_TIMEOUT, SKIP_CONNECTION_HEALTH_CHECK, WARM_UP_ON_STARTUP_BY_FULL, WARM_UP_ON_STARTUP_BY_ALL + "connection_wait_timeout": int(os.getenv("CONNECTION_WAIT_TIMEOUT", "60")), + "skip_connection_health_check": os.getenv( + "SKIP_CONNECTION_HEALTH_CHECK", "false" + ).lower() + == "true", + "warm_up_on_startup_by_full": os.getenv("WARM_UP_ON_STARTUP_BY_FULL", "false").lower() + == "true", + "warm_up_on_startup_by_all": os.getenv("WARM_UP_ON_STARTUP_BY_ALL", "false").lower() + == "true", } @staticmethod diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 9b1ce7f9d..5900d2357 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -202,6 +202,33 @@ class PolarDBGraphDBConfig(BaseConfig): default=100, description="Maximum number of connections in the connection pool", ) + connection_wait_timeout: int = Field( + default=30, + ge=1, + le=3600, + description="Max seconds to wait for a connection slot before raising (0 = wait forever, not recommended)", + ) + skip_connection_health_check: bool = Field( + default=False, + description=( + "If True, skip SELECT 1 health check when getting connections (~1-2ms saved per request). " + "Use only when pool/network is reliable." + ), + ) + warm_up_on_startup_by_full: bool = Field( + default=True, + description=( + "If True, run search_by_fulltext warm-up on pool connections at init to reduce " + "first-query latency (~200ms planning). Requires user_name in config." + ), + ) + warm_up_on_startup_by_all: bool = Field( + default=False, + description=( + "If True, run all connection warm-up on pool connections at init to reduce " + "first-query latency (~200ms planning). Requires user_name in config." + ), + ) @model_validator(mode="after") def validate_config(self): diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 592f45a7f..ac03cda2e 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1,9 +1,10 @@ import json import random import textwrap +import threading import time -from contextlib import suppress +from contextlib import contextmanager from datetime import datetime from typing import Any, Literal @@ -136,7 +137,11 @@ def __init__(self, config: PolarDBGraphDBConfig): port = config.get("port") user = config.get("user") password = config.get("password") - maxconn = config.get("maxconn", 100) # De + maxconn = config.get("maxconn", 100) + self._connection_wait_timeout = config.get("connection_wait_timeout", 60) + self._skip_connection_health_check = config.get("skip_connection_health_check", False) + self._warm_up_on_startup_by_full = config.get("warm_up_on_startup_by_full", False) + self._warm_up_on_startup_by_all = config.get("warm_up_on_startup_by_all", False) else: self.db_name = config.db_name self.user_name = config.user_name @@ -145,13 +150,19 @@ def __init__(self, config: PolarDBGraphDBConfig): user = config.user password = config.password maxconn = config.maxconn if hasattr(config, "maxconn") else 100 - """ - # Create connection - self.connection = psycopg2.connect( - host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 + self._connection_wait_timeout = getattr(config, "connection_wait_timeout", 60) + self._skip_connection_health_check = getattr( + config, "skip_connection_health_check", False + ) + self._warm_up_on_startup_by_full = getattr(config, "warm_up_on_startup_by_full", False) + self._warm_up_on_startup_by_all = getattr(config, "warm_up_on_startup_by_all", False) + logger.info( + f"polardb init config connection_wait_timeout:{self._connection_wait_timeout},_skip_connection_health_check:{self._skip_connection_health_check},warm_up_on_startup_by_full:{self._warm_up_on_startup_by_full},warm_up_on_startup_by_all:{self._warm_up_on_startup_by_all}" + ) + + logger.info( + f" db_name: {self.db_name} maxconn: {maxconn} connection_wait_timeout: {self._connection_wait_timeout}s" ) - """ - logger.info(f" db_name: {self.db_name} current maxconn is:'{maxconn}'") # Create connection pool self.connection_pool = psycopg2.pool.ThreadedConnectionPool( @@ -163,13 +174,16 @@ def __init__(self, config: PolarDBGraphDBConfig): password=password, dbname=self.db_name, connect_timeout=60, # Connection timeout in seconds - keepalives_idle=40, # Seconds of inactivity before sending keepalive (should be < server idle timeout) + keepalives_idle=120, # Seconds of inactivity before sending keepalive (should be < server idle timeout) keepalives_interval=15, # Seconds between keepalive retries keepalives_count=5, # Number of keepalive retries before considering connection dead ) - # Keep a reference to the pool for cleanup - self._pool_closed = False + self._semaphore = threading.BoundedSemaphore(maxconn) + if self._warm_up_on_startup_by_full: + self._warm_up_search_connections_by_full() + if self._warm_up_on_startup_by_all: + self._warm_up_connections_by_all() """ # Handle auto_create @@ -194,194 +208,76 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) - def _get_connection_old(self): - """Get a connection from the pool.""" - if self._pool_closed: - raise RuntimeError("Connection pool has been closed") - conn = self.connection_pool.getconn() - # Set autocommit for PolarDB compatibility - conn.autocommit = True - return conn - - def _get_connection(self): - logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") - if self._pool_closed: - raise RuntimeError("Connection pool has been closed") - - max_retries = 500 - import psycopg2.pool - - for attempt in range(max_retries): - conn = None + def _warm_up_search_connections_by_full(self, user_name: str | None = None) -> None: + logger.info("--warm_up_search_connections_by_full--start-up----") + user_name = user_name or self.user_name + if not user_name: + logger.debug("[warm_up] Skipped: no user_name for warm-up") + return + warm_count = min(5, self.connection_pool.minconn) + for _ in range(warm_count): try: - conn = self.connection_pool.getconn() - - if conn.closed != 0: - logger.warning( - f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" - ) - try: - self.connection_pool.putconn(conn, close=True) - except Exception as e: - logger.warning( - f"[_get_connection] Failed to return closed connection to pool: {e}" - ) - with suppress(Exception): - conn.close() - - conn = None - if attempt < max_retries - 1: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError("Pool returned a closed connection after all retries") - - # Set autocommit for PolarDB compatibility - conn.autocommit = True - - # Test connection health with SELECT 1 - try: - cursor = conn.cursor() - cursor.execute("SELECT 1") - cursor.fetchone() - cursor.close() - except Exception as health_check_error: - # Connection is not usable, return it to pool with close flag and try again - logger.warning( - f"[_get_connection] Connection health check failed (attempt {attempt + 1}/{max_retries}): {health_check_error}" - ) - try: - self.connection_pool.putconn(conn, close=True) - except Exception as putconn_error: - logger.warning( - f"[_get_connection] Failed to return unhealthy connection to pool: {putconn_error}" - ) - with suppress(Exception): - conn.close() - - conn = None - if attempt < max_retries - 1: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Failed to get a healthy connection from pool after {max_retries} attempts: {health_check_error}" - ) from health_check_error - - # Connection is healthy, return it - return conn - - except psycopg2.pool.PoolError as pool_error: - error_msg = str(pool_error).lower() - if "exhausted" in error_msg or "pool" in error_msg: - # Log pool status for debugging - try: - # Try to get pool stats if available - pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" - logger.info( - f" polardb get_connection Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" - ) - except Exception: - logger.warning( - f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" - ) - - # For pool exhaustion, wait longer before retry (connections may be returned) - if attempt < max_retries - 1: - # Longer backoff for pool exhaustion: 0.5s, 1.0s, 2.0s - wait_time = 0.5 * (2**attempt) - logger.info(f"[_get_connection] Waiting {wait_time}s before retry...") - """time.sleep(wait_time)""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Connection pool exhausted after {max_retries} attempts. " - f"This usually means connections are not being returned to the pool. " - ) from pool_error - else: - # Other pool errors - retry with normal backoff - if attempt < max_retries - 1: - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Failed to get connection from pool: {pool_error}" - ) from pool_error - + self.search_by_fulltext( + query_words=["warmup"], + top_k=1, + user_name=user_name, + ) except Exception as e: - if conn is not None: - try: - self.connection_pool.putconn(conn, close=True) - except Exception as putconn_error: - logger.warning( - f"[_get_connection] Failed to return connection after error: {putconn_error}" - ) - with suppress(Exception): - conn.close() - - if attempt >= max_retries - 1: - raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e - else: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) + logger.debug(f"[warm_up] Warm-up query failed (non-fatal): {e}") + break + logger.info(f"[warm_up] Pre-warmed {warm_count} connections for search_by_fulltext") + + def warm_up_search_connections_by_full(self, user_name: str | None = None) -> None: + self._warm_up_search_connections_by_full(user_name) + + def _warm_up_connections_by_all(self): + logger.info("--_warm_up_connections_by_all--start-up") + warm_count = self.connection_pool.minconn + preheated = 0 + logger.info(f"[warm_up] Pre-warming {warm_count} connections...") + for _ in range(warm_count): + try: + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute("SELECT 1") + preheated += 1 + except Exception as e: + logger.warning(f"[warm_up] Failed to pre-warm connection: {e}") continue + logger.info(f"[warm_up] Pre-warmed {preheated}/{warm_count} connections") - # Should never reach here, but just in case - raise RuntimeError("Failed to get connection after all retries") - - def _return_connection(self, connection): - if self._pool_closed: - if connection: - try: - connection.close() - logger.debug("[_return_connection] Closed connection (pool is closed)") - except Exception as e: - logger.warning( - f"[_return_connection] Failed to close connection after pool closed: {e}" - ) - return - - if not connection: - return + @contextmanager + def _get_connection(self): + timeout = getattr(self, "_connection_wait_timeout", 5) + if timeout <= 0: + self._semaphore.acquire() + else: + if not self._semaphore.acquire(timeout=timeout): + logger.warning(f"Timeout waiting for connection slot ({timeout}s)") + raise RuntimeError( + f"Connection pool busy: could not acquire a slot within {timeout}s (all connections in use)." + ) + conn = None + broken = False try: - if hasattr(connection, "closed") and connection.closed != 0: - logger.debug( - "[_return_connection] Connection is closed, closing it instead of returning to pool" - ) + conn = self.connection_pool.getconn() + logger.debug(f"Acquired connection {id(conn)} from pool") + conn.autocommit = True + with conn.cursor() as cur: + cur.execute("SELECT 1") + yield conn + except Exception as e: + broken = True + logger.exception(f"Connection failed or broken: {e}") + raise + finally: + if conn: try: - connection.close() + self.connection_pool.putconn(conn, close=broken) + logger.debug(f"Returned connection {id(conn)} to pool (broken={broken})") except Exception as e: - logger.warning(f"[_return_connection] Failed to close closed connection: {e}") - return - - self.connection_pool.putconn(connection) - logger.debug("[_return_connection] Successfully returned connection to pool") - except Exception as e: - logger.error( - f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True - ) - try: - connection.close() - logger.debug( - "[_return_connection] Closed connection as fallback after putconn failure" - ) - except Exception as close_error: - logger.warning( - f"[_return_connection] Failed to close connection after putconn error: {close_error}" - ) - - def _return_connection_old(self, connection): - """Return a connection to the pool.""" - if not self._pool_closed and connection: - self.connection_pool.putconn(connection) + logger.warning(f"Failed to return connection to pool: {e}") + self._semaphore.release() def _ensure_database_exists(self): """Create database if it doesn't exist.""" @@ -396,11 +292,8 @@ def _ensure_database_exists(self): @timed def _create_graph(self): """Create PostgreSQL schema and table for graph storage.""" - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Create schema if it doesn't exist cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') logger.info(f"Schema '{self.db_name}_graph' ensured.") @@ -448,8 +341,6 @@ def _create_graph(self): except Exception as e: logger.error(f"Failed to create graph schema: {e}") raise e - finally: - self._return_connection(conn) def create_index( self, @@ -462,11 +353,8 @@ def create_index( Create indexes for embedding and other fields. Note: This creates PostgreSQL indexes on the underlying tables. """ - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Create indexes on the underlying PostgreSQL tables # Apache AGE stores data in regular PostgreSQL tables cursor.execute(f""" @@ -486,8 +374,6 @@ def create_index( logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") - finally: - self._return_connection(conn) def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: """Get count of memory nodes by type.""" @@ -500,19 +386,14 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params = [self.format_param_value(memory_type), self.format_param_value(user_name)] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return result[0] if result else 0 except Exception as e: logger.error(f"[get_memory_count] Failed: {e}") return -1 - finally: - self._return_connection(conn) @timed def node_not_exist(self, scope: str, user_name: str | None = None) -> int: @@ -527,19 +408,14 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: query += "\nLIMIT 1" params = [self.format_param_value(scope), self.format_param_value(user_name)] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return 1 if result else 0 except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def remove_oldest_memory( @@ -569,10 +445,8 @@ def remove_oldest_memory( self.format_param_value(user_name), keep_latest, ] - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Execute query to get IDs to delete cursor.execute(select_query, select_params) ids_to_delete = [row[0] for row in cursor.fetchall()] @@ -584,9 +458,9 @@ def remove_oldest_memory( # Build delete query placeholders = ",".join(["%s"] * len(ids_to_delete)) delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE id IN ({placeholders}) - """ + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE id IN ({placeholders}) + """ delete_params = ids_to_delete # Execute deletion @@ -600,8 +474,6 @@ def remove_oldest_memory( except Exception as e: logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: @@ -663,17 +535,12 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(self.format_param_value(user_name)) - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def delete_node(self, id: str, user_name: str | None = None) -> None: @@ -694,26 +561,18 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(self.format_param_value(user_name)) - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def create_extension(self): extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Ensure in the correct database context cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] @@ -736,20 +595,15 @@ def create_extension(self): except Exception as e: logger.warning(f"Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def create_graph(self): - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(f""" - SELECT COUNT(*) FROM ag_catalog.ag_graph - WHERE name = '{self.db_name}_graph'; - """) + SELECT COUNT(*) FROM ag_catalog.ag_graph + WHERE name = '{self.db_name}_graph'; + """) graph_exists = cursor.fetchone()[0] > 0 if graph_exists: @@ -760,8 +614,6 @@ def create_graph(self): except Exception as e: logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def create_edge(self): @@ -770,11 +622,9 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: - conn = None logger.info(f"Creating elabel: {label_name}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") logger.info(f"Successfully created elabel: {label_name}") except Exception as e: @@ -783,8 +633,6 @@ def create_edge(self): else: logger.warning(f"Failed to create label {label_name}: {e}") logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def add_edge( @@ -825,10 +673,8 @@ def add_edge( ); """ logger.info(f"polardb [add_edge] query: {query}, properties: {json.dumps(properties)}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") @@ -837,8 +683,6 @@ def add_edge( except Exception as e: logger.error(f"Failed to insert edge: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def delete_edge(self, source_id: str, target_id: str, type: str) -> None: @@ -853,14 +697,9 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: DELETE FROM "{self.db_name}_graph"."Edges" WHERE source_id = %s AND target_id = %s AND edge_type = %s """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type)) - logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") @timed def edge_exists_old( @@ -915,15 +754,10 @@ def edge_exists_old( WHERE {where_clause} LIMIT 1 """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result is not None - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result is not None @timed def edge_exists( @@ -971,15 +805,10 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - result = cursor.fetchone() - return result is not None and result[0] is not None - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None @timed def get_node( @@ -1015,10 +844,8 @@ def get_node( params.append(self.format_param_value(user_name)) logger.info(f"polardb [get_node] query: {query},params: {params}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() @@ -1067,8 +894,6 @@ def get_node( except Exception as e: logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) return None - finally: - self._return_connection(conn) @timed def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, Any]]: @@ -1105,50 +930,45 @@ def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, logger.info(f"get_nodes query:{query},params:{params}") - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() - nodes = [] - for row in results: - node_id, properties_json, embedding_json = row - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {node_id}") - properties = {} - else: - properties = properties_json if properties_json else {} + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} - # Parse embedding from JSONB if it exists - if embedding_json is not None and kwargs.get("include_embedding"): - try: - # remove embedding - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - nodes.append( - self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } + # Parse embedding from JSONB if it exists + if embedding_json is not None and kwargs.get("include_embedding"): + try: + # remove embedding + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } ) - return nodes - finally: - self._return_connection(conn) + ) + return nodes @timed def get_edges_old( @@ -1366,10 +1186,8 @@ def get_children_with_embeddings( WHERE t.cid::graphid = m.id; """ - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1424,8 +1242,6 @@ def get_children_with_embeddings( except Exception as e: logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: """Get the path of nodes from source to target within a limited depth.""" @@ -1507,11 +1323,9 @@ def get_subgraph( RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1) $$ ) as (centers agtype, neighbors agtype, rels agtype); """ - conn = None logger.info(f"[get_subgraph] Query: {query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1636,8 +1450,6 @@ def get_subgraph( except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) return {"core_node": None, "neighbors": [], "edges": []} - finally: - self._return_connection(conn) def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" @@ -1751,29 +1563,24 @@ def search_by_keywords_like( logger.info( f"[search_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - item = {"id": id_val} - if return_fields: - properties = row[2] # properties column - item.update(self._extract_fields_from_properties(properties, return_fields)) - output.append(item) - logger.info( - f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" - ) - return output - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + logger.info( + f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output @timed def search_by_keywords_tfidf( @@ -1859,30 +1666,25 @@ def search_by_keywords_tfidf( logger.info( f"[search_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - item = {"id": id_val} - if return_fields: - properties = row[2] # properties column - item.update(self._extract_fields_from_properties(properties, return_fields)) - output.append(item) - - logger.info( - f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" - ) - return output - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + + logger.info( + f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output @timed def search_by_fulltext( @@ -2007,38 +1809,29 @@ def search_by_fulltext( """ params = [tsquery_string] logger.info(f"[search_by_fulltext] query: {query}, params: {params}") - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] # old_id - rank = row[1] # rank score (no memory_text column) - - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - score_val = float(rank) - - # Apply threshold filter if specified - if threshold is None or score_val >= threshold: - item = {"id": id_val, "score": score_val} - if return_fields: - properties = row[2] # properties column - item.update( - self._extract_fields_from_properties(properties, return_fields) - ) - output.append(item) - elapsed_time = time.time() - start_time - logger.info( - f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s" - ) - return output[:top_k] - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[1] # rank score (no memory_text column) + + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + elapsed_time = time.time() - start_time + logger.info(f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s") + return output[:top_k] @timed def search_by_embedding( @@ -2135,42 +1928,35 @@ def search_by_embedding( logger.info(f"[search_by_embedding] query: {query}, params: {params}") - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - if params: - cursor.execute(query, params) - else: - cursor.execute(query) - results = cursor.fetchall() - output = [] - for row in results: - if len(row) < 5: - logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") - continue - oldid = row[3] # old_id - score = row[4] # scope - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - score_val = float(score) - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - item = {"id": id_val, "score": score_val} - if return_fields: - properties = row[1] # properties column - item.update( - self._extract_fields_from_properties(properties, return_fields) - ) - output.append(item) - elapsed_time = time.time() - start_time - logger.info( - f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" - ) - return output[:top_k] - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + results = cursor.fetchall() + output = [] + for row in results: + if len(row) < 5: + logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") + continue + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[1] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + elapsed_time = time.time() - start_time + logger.info( + f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" + ) + return output[:top_k] @timed def get_by_metadata( @@ -2285,18 +2071,14 @@ def get_by_metadata( """ ids = [] - conn = None logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() ids = [str(item[0]).strip('"') for item in results] except Exception as e: logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}") - finally: - self._return_connection(conn) return ids @@ -2448,10 +2230,8 @@ def get_grouped_counts( {where_clause} GROUP BY {", ".join(group_by_fields)} """ - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Handle parameterized query if params and isinstance(params, list): cursor.execute(query, params) @@ -2476,8 +2256,6 @@ def get_grouped_counts( except Exception as e: logger.error(f"Failed to get grouped counts: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def deduplicate_nodes(self) -> None: """Deduplicate redundant or semantically similar nodes.""" @@ -2509,14 +2287,9 @@ def clear(self, user_name: str | None = None) -> None: DETACH DELETE n $$) AS (result agtype) """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - logger.info("Cleared all nodes from database.") - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") except Exception as e: logger.error(f"[ERROR] Failed to clear database: {e}") @@ -2585,132 +2358,129 @@ def export_graph( else: offset = None - conn = None try: - conn = self._get_connection() - # Build WHERE conditions - where_conditions = [] - if user_name: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" - ) - if user_id: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" - ) + with self._get_connection() as conn: + # Build WHERE conditions + where_conditions = [] + if user_name: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + if user_id: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" + ) - # Add memory_type filter condition - if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: - # Escape memory_type values and build IN clause - memory_type_values = [] - for mt in memory_type: - # Escape single quotes in memory_type value - escaped_memory_type = str(mt).replace("'", "''") - memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") - memory_type_in_clause = ", ".join(memory_type_values) - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" - ) + # Add memory_type filter condition + if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: + # Escape memory_type values and build IN clause + memory_type_values = [] + for mt in memory_type: + # Escape single quotes in memory_type value + escaped_memory_type = str(mt).replace("'", "''") + memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") + memory_type_in_clause = ", ".join(memory_type_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" + ) - # Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list - if status is None: - # Default behavior: exclude deleted entries - where_conditions.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype" - ) - elif isinstance(status, list) and len(status) > 0: - # status IN (list) - status_values = [] - for st in status: - escaped_status = str(st).replace("'", "''") - status_values.append(f"'\"{escaped_status}\"'::agtype") - status_in_clause = ", ".join(status_values) - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})" - ) + # Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list + if status is None: + # Default behavior: exclude deleted entries + where_conditions.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype" + ) + elif isinstance(status, list) and len(status) > 0: + # status IN (list) + status_values = [] + for st in status: + escaped_status = str(st).replace("'", "''") + status_values.append(f"'\"{escaped_status}\"'::agtype") + status_in_clause = ", ".join(status_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})" + ) - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[export_graph] filter_conditions: {filter_conditions}") - if filter_conditions: - where_conditions.extend(filter_conditions) - - where_clause = "" - if where_conditions: - where_clause = f"WHERE {' AND '.join(where_conditions)}" - - # Get total count of nodes before pagination - count_node_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - {where_clause} - """ - logger.info(f"[export_graph nodes count] Query: {count_node_query}") - with conn.cursor() as cursor: - cursor.execute(count_node_query) - total_nodes = cursor.fetchone()[0] - - # Export nodes - # Build pagination clause if needed - pagination_clause = "" - if use_pagination: - pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - - if include_embedding: - node_query = f""" - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} - """ - else: - node_query = f""" - SELECT id, properties + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[export_graph] filter_conditions: {filter_conditions}") + if filter_conditions: + where_conditions.extend(filter_conditions) + + where_clause = "" + if where_conditions: + where_clause = f"WHERE {' AND '.join(where_conditions)}" + + # Get total count of nodes before pagination + count_node_query = f""" + SELECT COUNT(*) FROM "{self.db_name}_graph"."Memory" {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} """ - logger.info(f"[export_graph nodes] Query: {node_query}") - with conn.cursor() as cursor: - cursor.execute(node_query) - node_results = cursor.fetchall() - nodes = [] - - for row in node_results: - if include_embedding: - """row is (id, properties, embedding)""" - _, properties_json, embedding_json = row - else: - """row is (id, properties)""" - _, properties_json = row - embedding_json = None + logger.info(f"[export_graph nodes count] Query: {count_node_query}") + with conn.cursor() as cursor: + cursor.execute(count_node_query) + total_nodes = cursor.fetchone()[0] + + # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + + if include_embedding: + node_query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + else: + node_query = f""" + SELECT id, properties + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + logger.info(f"[export_graph nodes] Query: {node_query}") + with conn.cursor() as cursor: + cursor.execute(node_query) + node_results = cursor.fetchall() + nodes = [] + + for row in node_results: + if include_embedding: + """row is (id, properties, embedding)""" + _, properties_json, embedding_json = row + else: + """row is (id, properties)""" + _, properties_json = row + embedding_json = None - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except json.JSONDecodeError: - properties = {} - else: - properties = properties_json if properties_json else {} + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except json.JSONDecodeError: + properties = {} + else: + properties = properties_json if properties_json else {} - # Remove embedding field if include_embedding is False - if not include_embedding: - properties.pop("embedding", None) - elif include_embedding and embedding_json is not None: - properties["embedding"] = embedding_json + # Remove embedding field if include_embedding is False + if not include_embedding: + properties.pop("embedding", None) + elif include_embedding and embedding_json is not None: + properties["embedding"] = embedding_json - nodes.append(self._parse_node(properties)) + nodes.append(self._parse_node(properties)) except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e - finally: - self._return_connection(conn) edges = [] return { @@ -2732,13 +2502,9 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: RETURN count(n) $$) AS (count agtype) """ - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: result = self.execute_query(query, conn) return int(result.one_or_none()["count"].value) - finally: - self._return_connection(conn) @timed def get_all_memory_items( @@ -2825,18 +2591,16 @@ def get_all_memory_items( """ nodes = [] node_ids = set() - conn = None logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() for row in results: """ - if isinstance(row, (list, tuple)) and len(row) >= 2: - """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ if isinstance(row, list | tuple) and len(row) >= 2: embedding_val, node_val = row[0], row[1] else: @@ -2851,8 +2615,6 @@ def get_all_memory_items( except Exception as e: logger.warning(f"Failed to get memories: {e}", exc_info=True) - finally: - self._return_connection(conn) return nodes else: @@ -2879,29 +2641,25 @@ def get_all_memory_items( """ nodes = [] - conn = None logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() for row in results: """ - if isinstance(row[0], str): - memory_data = json.loads(row[0]) - else: - memory_data = row[0] # 如果已经是字典,直接使用 - nodes.append(self._parse_node(memory_data)) - """ + if isinstance(row[0], str): + memory_data = json.loads(row[0]) + else: + memory_data = row[0] # 如果已经是字典,直接使用 + nodes.append(self._parse_node(memory_data)) + """ memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] nodes.append(self._parse_node(memory_data)) except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) - finally: - self._return_connection(conn) return nodes @@ -3104,10 +2862,8 @@ def get_structure_optimization_candidates( candidates = [] node_ids = set() - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() logger.info(f"Found {len(results)} structure optimization candidates") @@ -3115,8 +2871,8 @@ def get_structure_optimization_candidates( if include_embedding: # When include_embedding=True, return full node object """ - if isinstance(row, (list, tuple)) and len(row) >= 2: - """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ if isinstance(row, list | tuple) and len(row) >= 2: embedding_val, node_val = row[0], row[1] else: @@ -3184,8 +2940,6 @@ def get_structure_optimization_candidates( except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) - finally: - self._return_connection(conn) return candidates @@ -3355,60 +3109,59 @@ def add_node( elif len(embedding_vector) == 768: embedding_column = "embedding_768" - conn = None insert_query = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Delete existing record first (if any) - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(delete_query, (id,)) - # - get_graph_id_query = f""" - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(get_graph_id_query, (id,)) - graph_id = cursor.fetchone()[0] - properties["graph_id"] = str(graph_id) - - # Then insert new record - if embedding_vector: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s, - %s - ) + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) """ - cursor.execute( - insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) - ) - logger.info( - f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" - ) - else: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s + cursor.execute(delete_query, (id,)) + # + get_graph_id_query = f""" + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(get_graph_id_query, (id,)) + graph_id = cursor.fetchone()[0] + properties["graph_id"] = str(graph_id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) ) - """ - cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info( + f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info( + f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + if insert_query: logger.info( - f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + f"In add node polardb: id-{id} memory-{memory} query-{insert_query}" ) except Exception as e: logger.error(f"[add_node] Failed to add node: {e}", exc_info=True) raise - finally: - if insert_query: - logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") - self._return_connection(conn) @timed def add_nodes_batch( @@ -3529,10 +3282,8 @@ def add_nodes_batch( nodes_by_embedding_column[col] = [] nodes_by_embedding_column[col].append(node) - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Process each group separately for embedding_column, nodes_group in nodes_by_embedding_column.items(): # Batch delete existing records using IN clause @@ -3625,7 +3376,8 @@ def add_nodes_batch( properties_json = json.dumps(node["properties"]) cursor.execute( - f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) + f"EXECUTE {prepare_name}(%s, %s)", + (node["id"], properties_json), ) finally: # DEALLOCATE prepared statement (always execute, even on error) @@ -3650,8 +3402,6 @@ def add_nodes_batch( except Exception as e: logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) raise - finally: - self._return_connection(conn) def _build_node_from_agtype(self, node_agtype, embedding=None): """ @@ -3763,10 +3513,8 @@ def get_neighbors_by_tag( logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -3815,8 +3563,6 @@ def get_neighbors_by_tag( except Exception as e: logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def get_neighbors_by_tag_ccl( self, @@ -4075,10 +3821,8 @@ def get_edges( $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ logger.info(f"get_edges query:{query}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -4125,8 +3869,6 @@ def get_edges( except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def _convert_graph_edges(self, core_node: dict) -> dict: import copy @@ -5132,11 +4874,9 @@ def delete_node_by_prams( ) return 0 - conn = None total_deleted_count = 0 try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Build WHERE conditions list where_conditions = [] @@ -5197,9 +4937,6 @@ def delete_node_by_prams( except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) raise - finally: - self._return_connection(conn) - logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes") return total_deleted_count @@ -5263,11 +5000,9 @@ def escape_memory_id(mid: str) -> str: """ logger.info(f"[get_user_names_by_memory_ids] query: {query}") - conn = None result_dict = {} try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -5307,8 +5042,6 @@ def escape_memory_id(mid: str) -> str: f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) def exist_user_name(self, user_name: str) -> dict[str, bool]: """Check if user name exists in the graph. @@ -5342,10 +5075,8 @@ def escape_user_name(un: str) -> str: """ logger.info(f"[exist_user_name] query: {query}") result_dict = {} - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) count = cursor.fetchone()[0] result = count > 0 @@ -5356,8 +5087,6 @@ def escape_user_name(un: str) -> str: f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) @timed def delete_node_by_mem_cube_id( @@ -5381,10 +5110,8 @@ def delete_node_by_mem_cube_id( ) return 0 - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" user_name_param = self.format_param_value(mem_cube_id) @@ -5434,7 +5161,11 @@ def delete_node_by_mem_cube_id( logger.info( f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}" ) - update_params = [json.dumps(update_properties), current_time, user_name_param] + update_params = [ + json.dumps(update_properties), + current_time, + user_name_param, + ] cursor.execute(update_query, update_params) updated_count = cursor.rowcount @@ -5448,8 +5179,6 @@ def delete_node_by_mem_cube_id( f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) @timed def recover_memory_by_mem_cube_id( @@ -5476,10 +5205,8 @@ def recover_memory_by_mem_cube_id( f"delete_record_id={delete_record_id}" ) - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" where_clause = f"{user_name_condition} AND {delete_record_id_condition}" @@ -5523,5 +5250,3 @@ def recover_memory_by_mem_cube_id( f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) From 063d8705ac812e62d8338b07ffc60bded3c0a495 Mon Sep 17 00:00:00 2001 From: Jiang <33757498+hijzy@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:57:45 +0800 Subject: [PATCH 09/20] fix: Use relativity instead of score for preference memory (#1153) * test: add rerank model * test: test reranker model * test: test reranker model * test: delete useless log * test: reformat --------- Co-authored-by: jiang --- src/memos/api/handlers/search_handler.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 58121776e..ba1c50b07 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -120,10 +120,7 @@ def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> d if not isinstance(mem, dict): continue meta = mem.get("metadata", {}) - if key == "text_mem": - score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 - else: - score = meta.get("score", 1.0) if isinstance(meta, dict) else 1.0 + score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 try: score_val = float(score) if score is not None else 1.0 except (TypeError, ValueError): From dcdc3cff6cb2b18581d590e2fa33dc0bd3c1e65c Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Wed, 4 Mar 2026 14:55:18 +0800 Subject: [PATCH 10/20] fix: image bug; single item in multi-mudal-reader has no embedding; (#1154) * fix: image url bug * fix: single item has no embedding bug * fix: image lang bug * fix: image lang bug --- src/memos/mem_reader/multi_modal_struct.py | 25 ++++++++----------- .../read_multi_modal/image_parser.py | 6 ++--- .../mem_reader/read_multi_modal/utils.py | 21 +++++++++++++++- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 62e8f2d75..0b3e19208 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -190,8 +190,16 @@ def _concat_multi_modal_memories( else: processed_items.append(item) - # If only one item after processing, return as-is + # If only one item after processing, compute embedding and return if len(processed_items) == 1: + single_item = processed_items[0] + if single_item and single_item.memory: + try: + single_item.metadata.embedding = self.embedder.embed([single_item.memory])[0] + except Exception as e: + logger.error( + f"[MultiModalStruct] Error computing embedding for single item: {e}" + ) return processed_items windows = [] @@ -289,7 +297,6 @@ def _build_window_from_items( # Collect all memory texts and sources memory_texts = [] all_sources = [] - seen_content = set() # Track seen source content to avoid duplicates roles = set() aggregated_file_ids: list[str] = [] @@ -303,18 +310,8 @@ def _build_window_from_items( item_sources = [item_sources] for source in item_sources: - # Get content from source for deduplication - source_content = None - if isinstance(source, dict): - source_content = source.get("content", "") - else: - source_content = getattr(source, "content", "") or "" - - # Only add if content is different (empty content is considered unique) - content_key = source_content if source_content else None - if content_key and content_key not in seen_content: - seen_content.add(content_key) - all_sources.append(source) + # Add source to all_sources + all_sources.append(source) # Extract role from source if hasattr(source, "role") and source.role: diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py index 97400ca26..d66642edb 100644 --- a/src/memos/mem_reader/read_multi_modal/image_parser.py +++ b/src/memos/mem_reader/read_multi_modal/image_parser.py @@ -137,10 +137,10 @@ def parse_fine( # Get context items if available context_items = kwargs.get("context_items") - # Determine language: prioritize lang from source (passed via kwargs), - # fallback to detecting from context_items if lang not provided + # Determine language: prioritize lang from context_items, + # fallback to kwargs lang = kwargs.get("lang") - if lang is None and context_items: + if context_items: for item in context_items: if hasattr(item, "memory") and item.memory: lang = detect_lang(item.memory) diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index be82587bf..96918589b 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -341,13 +341,32 @@ def detect_lang(text): if not text or not isinstance(text, str): return "en" cleaned_text = text - # remove role and timestamp + # remove role and timestamp-like prefixes cleaned_text = re.sub( r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE ) + # timestamps like [11:32 AM on 04 March, 2026] + cleaned_text = re.sub( + r"\[\s*\d{1,2}:\d{2}\s*(?:AM|PM)\s+on\s+\d{2}\s+[A-Za-z]+\s*,\s*\d{4}\s*\]", + "", + cleaned_text, + flags=re.IGNORECASE, + ) + # purely numeric timestamps like [2025-01-01 10:00] cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) # remove URLs to prevent the dilution of Chinese characters cleaned_text = re.sub(r'https?://[^\s<>"{}|\\^`\[\]]+', "", cleaned_text) + # remove MessageType schema keywords (multimodal JSON noise) + cleaned_text = re.sub( + r"\b(text|type|image_url|imageurl|url)\b", "", cleaned_text, flags=re.IGNORECASE + ) + # remove schema keywords like text / type / image_url / url + cleaned_text = re.sub( + r"\b(text|type|image_url|imageurl|url|file|file_id)\b", + "", + cleaned_text, + flags=re.IGNORECASE, + ) # extract chinese characters chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" chinese_chars = re.findall(chinese_pattern, cleaned_text) From 150e77af331a16af2cd0a5bacd4995da1ffa2662 Mon Sep 17 00:00:00 2001 From: Qi Weng Date: Wed, 4 Mar 2026 15:42:03 +0800 Subject: [PATCH 11/20] =?UTF-8?q?fix:=20Stop=20throwing=20error=20when=20e?= =?UTF-8?q?mbedding=20is=20missing=20in=20add=5Fnodes=5Fbatch=E2=80=A6=20(?= =?UTF-8?q?#1157)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix: Stop throwing error when embedding is missing in add_nodes_batch function for Neo4j Community graphdb backend --- src/memos/graph_dbs/neo4j_community.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 09ad46c42..470d8cd8e 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -140,8 +140,6 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N metadata.setdefault("delete_record_id", "") embedding = metadata.pop("embedding", None) - if embedding is None: - raise ValueError(f"Missing 'embedding' in metadata for node {node_id}") vector_sync_status = "success" vec_items.append( From ec9eeada1dcee4f67e9685d2096974847067dd0b Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Thu, 5 Mar 2026 10:34:01 +0800 Subject: [PATCH 12/20] feat: align general/naive memory with tree-memory (#1159) --- .../core_memories/general_textual_memory.py | 16 +--- .../core_memories/naive_textual_memory.py | 93 +++++++++++-------- .../read_multi_modal/image_parser.py | 1 + 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/examples/core_memories/general_textual_memory.py b/examples/core_memories/general_textual_memory.py index d5c765b01..007736a6e 100644 --- a/examples/core_memories/general_textual_memory.py +++ b/examples/core_memories/general_textual_memory.py @@ -68,21 +68,9 @@ example_id = "a19b6caa-5d59-42ad-8c8a-e4f7118435b4" -print("===== Extract memories =====") -# Extract memories from a conversation -# The extractor LLM processes the conversation to identify relevant information. -memories = m.extract( - [ - {"role": "user", "content": "I love tomatoes."}, - {"role": "assistant", "content": "Great! Tomatoes are delicious."}, - ] -) -pprint.pprint(memories) -print() - print("==== Add memories ====") -# Add the extracted memories to the memory store -m.add(memories) +# Add example memories to the memory store +m.add(example_memories) # Add a manually created memory item m.add( [ diff --git a/examples/core_memories/naive_textual_memory.py b/examples/core_memories/naive_textual_memory.py index ab73060c7..1e7901e0f 100644 --- a/examples/core_memories/naive_textual_memory.py +++ b/examples/core_memories/naive_textual_memory.py @@ -1,20 +1,11 @@ -import json import os +import pprint import uuid from memos.configs.memory import MemoryConfigFactory from memos.memories.factory import MemoryFactory -def print_result(title, result): - """Helper function: Pretty print the result.""" - print(f"\n{'=' * 10} {title} {'=' * 10}") - if isinstance(result, list | dict): - print(json.dumps(result, indent=2, ensure_ascii=False, default=str)) - else: - print(result) - - # Configure memory backend with OpenAI extractor config = MemoryConfigFactory( backend="naive_text", @@ -38,39 +29,55 @@ def print_result(title, result): # Create memory instance m = MemoryFactory.from_config(config) +example_memories = [ + { + "memory": "I'm a RUCer, I'm happy.", + "metadata": { + "type": "event", + }, + }, + { + "memory": "MemOS is awesome!", + "metadata": { + "type": "opinion", + }, + }, +] + +example_id = str(uuid.uuid4()) -# Extract memories from a simulated conversation -memories = m.extract( +print("==== Add memories ====") +# Add example memories to the memory store +m.add(example_memories) +# Manually create a memory item and add it +m.add( [ - {"role": "user", "content": "I love tomatoes."}, - {"role": "assistant", "content": "Great! Tomatoes are delicious."}, + { + "id": example_id, + "memory": "User is Chinese.", + "metadata": {"type": "opinion"}, + } ] ) -print_result("Extract memories", memories) - - -# Add the extracted memories to storage -m.add(memories) - -# Manually create a memory item and add it -example_id = str(uuid.uuid4()) -manual_memory = [{"id": example_id, "memory": "User is Chinese.", "metadata": {"type": "opinion"}}] -m.add(manual_memory) - -# Print all current memories -print_result("Add memories (Check all after adding)", m.get_all()) - +print("All memories after addition:") +pprint.pprint(m.get_all()) +print() -# Search for relevant memories based on the query +print("==== Search memories ====") +# Search for memories related to a query search_results = m.search("Tell me more about the user", top_k=2) -print_result("Search memories", search_results) - +pprint.pprint(search_results) +print() +print("==== Get memories ====") # Get specific memory item by ID -memory_item = m.get(example_id) -print_result("Get memory", memory_item) - +print(f"Memory with ID {example_id}:") +pprint.pprint(m.get(example_id)) +print(f"Memories by IDs [{example_id}]:") +pprint.pprint(m.get_by_ids([example_id])) +print() +print("==== Update memories ====") # Update the memory content for the specified ID m.update( example_id, @@ -80,9 +87,9 @@ def print_result(title, result): "metadata": {"type": "opinion", "confidence": 85}, }, ) -updated_memory = m.get(example_id) -print_result("Update memory", updated_memory) - +print(f"Memory after update (ID {example_id}):") +pprint.pprint(m.get(example_id)) +print() print("==== Dump memory ====") # Dump the current state of memory to a file @@ -90,12 +97,16 @@ def print_result(title, result): print("Memory dumped to 'tmp/naive_mem'.") print() - +print("==== Delete memories ====") # Delete memory with the specified ID m.delete([example_id]) -print_result("Delete memory (Check all after deleting)", m.get_all()) - +print("All memories after deletion:") +pprint.pprint(m.get_all()) +print() +print("==== Delete all memories ====") # Delete all memories in storage m.delete_all() -print_result("Delete all", m.get_all()) +print("All memories after delete_all:") +pprint.pprint(m.get_all()) +print() diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py index d66642edb..0d5e8bcc2 100644 --- a/src/memos/mem_reader/read_multi_modal/image_parser.py +++ b/src/memos/mem_reader/read_multi_modal/image_parser.py @@ -144,6 +144,7 @@ def parse_fine( for item in context_items: if hasattr(item, "memory") and item.memory: lang = detect_lang(item.memory) + source.lang = lang break if not lang: lang = "en" From 60147154d55ea301b0f889c56b17373fd3938b7b Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:44:40 +0800 Subject: [PATCH 13/20] feat:update logs (#1161) --- src/memos/graph_dbs/polardb.py | 135 +++++++++------------------------ 1 file changed, 34 insertions(+), 101 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index ac03cda2e..6a0be0d32 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1703,32 +1703,19 @@ def search_by_fulltext( return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: - """ - Full-text search functionality using PostgreSQL's full-text search capabilities. - - Args: - query_text: query text - top_k: maximum number of results to return - scope: memory type filter (memory_type) - status: status filter, defaults to "activated" - threshold: similarity threshold filter - search_filter: additional property filter conditions - user_name: username filter - knowledgebase_ids: knowledgebase ids filter - filter: filter conditions with 'and' or 'or' logic for search results. - tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 - tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) - return_fields: additional node fields to include in results - **kwargs: other parameters (e.g. cube_name) - - Returns: - list[dict]: result list containing id and score. - If return_fields is specified, each dict also includes the requested fields. - """ + start_time = time.perf_counter() logger.info( - f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}" + " search_by_fulltext query_words=%s top_k=%s scope=%s status=%s threshold=%s search_filter=%s user_name=%s knowledgebase_ids=%s filter=%s", + query_words, + top_k, + scope, + status, + threshold, + search_filter, + user_name, + knowledgebase_ids, + filter, ) - start_time = time.time() where_clauses = [] if scope: @@ -1744,22 +1731,18 @@ def search_by_fulltext( "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" ) - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, default_user_name=self.config.user_name, ) - logger.info(f"[search_by_fulltext] user_name_conditions: {user_name_conditions}") - # Add OR condition if we have any user_name conditions if user_name_conditions: if len(user_name_conditions) == 1: where_clauses.append(user_name_conditions[0]) else: where_clauses.append(f"({' OR '.join(user_name_conditions)})") - # Add search_filter conditions if search_filter: for key, value in search_filter.items(): if isinstance(value, str): @@ -1772,17 +1755,12 @@ def search_by_fulltext( ) filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) tsquery_string = " | ".join(query_words) where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - logger.info(f"[search_by_fulltext] where_clause: {where_clause}") - select_cols = f"""ag_catalog.agtype_access_operator(m.properties, '"id"'::agtype) AS old_id, ts_rank(m.{tsvector_field}, q.fq) AS rank""" if return_fields: @@ -1808,7 +1786,8 @@ def search_by_fulltext( LIMIT {top_k}; """ params = [tsquery_string] - logger.info(f"[search_by_fulltext] query: {query}, params: {params}") + logger.info("search_by_fulltext query=%s params=%s", query, params) + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1829,8 +1808,8 @@ def search_by_fulltext( properties = row[2] # properties column item.update(self._extract_fields_from_properties(properties, return_fields)) output.append(item) - elapsed_time = time.time() - start_time - logger.info(f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s") + elapsed = (time.perf_counter() - start_time) * 1000 + logger.info("search_by_fulltext internal took %.1f ms", elapsed) return output[:top_k] @timed @@ -1849,9 +1828,18 @@ def search_by_embedding( **kwargs, ) -> list[dict]: logger.info( - f"search_by_embedding user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},scope:{scope},status:{status},search_filter:{search_filter},filter:{filter},knowledgebase_ids:{knowledgebase_ids},return_fields:{return_fields}" + "search_by_embedding user_name:%s,filter: %s, knowledgebase_ids: %s,scope:%s,status:%s,search_filter:%s,filter:%s,knowledgebase_ids:%s,return_fields:%s", + user_name, + filter, + knowledgebase_ids, + scope, + status, + search_filter, + filter, + knowledgebase_ids, + return_fields, ) - start_time = time.time() + start_time = time.perf_counter() where_clauses = [] if scope: where_clauses.append( @@ -1890,7 +1878,6 @@ def search_by_embedding( ) filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" @@ -1926,7 +1913,7 @@ def search_by_embedding( else: pass - logger.info(f"[search_by_embedding] query: {query}, params: {params}") + logger.info(" search_by_embedding query: %s", query) with self._get_connection() as conn, conn.cursor() as cursor: if params: @@ -1952,9 +1939,9 @@ def search_by_embedding( properties = row[1] # properties column item.update(self._extract_fields_from_properties(properties, return_fields)) output.append(item) - elapsed_time = time.time() - start_time + elapsed_time = time.perf_counter() - start_time logger.info( - f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" + "search_by_embedding query embedding completed time took %.1f ms", elapsed_time ) return output[:top_k] @@ -3169,27 +3156,15 @@ def add_nodes_batch( nodes: list[dict[str, Any]], user_name: str | None = None, ) -> None: - """ - Batch add multiple memory nodes to the graph. + logger.info(f" add_nodes_batch Processing only first node (total nodes: {len(nodes)})") - Args: - nodes: List of node dictionaries, each containing: - - id: str - Node ID - - memory: str - Memory content - - metadata: dict[str, Any] - Node metadata - user_name: Optional user name (will use config default if not provided) - """ - batch_start_time = time.time() + batch_start_time = time.perf_counter() if not nodes: logger.warning("[add_nodes_batch] Empty nodes list, skipping") return - logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") - - # user_name comes from parameter; fallback to config if missing effective_user_name = user_name if user_name else self.config.user_name - # Prepare all nodes prepared_nodes = [] for node_data in nodes: try: @@ -3199,16 +3174,13 @@ def add_nodes_batch( logger.debug(f"[add_nodes_batch] Processing node id: {id}") - # Set user_name in metadata metadata["user_name"] = effective_user_name metadata = _prepare_node_metadata(metadata) - # Merge node and set metadata created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) - # Prepare properties properties = { "id": id, "memory": memory, @@ -3219,32 +3191,26 @@ def add_nodes_batch( **metadata, } - # Generate embedding if not provided if "embedding" not in properties or not properties["embedding"]: properties["embedding"] = generate_vector( self._get_config_value("embedding_dimension", 1024) ) - # Serialization - JSON-serialize sources and usage fields for field_name in ["sources", "usage"]: if properties.get(field_name): if isinstance(properties[field_name], list): for idx in range(len(properties[field_name])): - # Serialize only when element is not a string if not isinstance(properties[field_name][idx], str): properties[field_name][idx] = json.dumps( properties[field_name][idx] ) elif isinstance(properties[field_name], str): - # If already a string, leave as-is pass - # Extract embedding for separate column embedding_vector = properties.pop("embedding", []) if not isinstance(embedding_vector, list): embedding_vector = [] - # Select column name based on embedding dimension embedding_column = "embedding" # default column if len(embedding_vector) == 3072: embedding_column = "embedding_3072" @@ -3267,14 +3233,12 @@ def add_nodes_batch( f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", exc_info=True, ) - # Continue with other nodes continue if not prepared_nodes: logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") return - # Group nodes by embedding column to optimize batch inserts nodes_by_embedding_column = {} for node in prepared_nodes: col = node["embedding_column"] @@ -3284,9 +3248,7 @@ def add_nodes_batch( try: with self._get_connection() as conn, conn.cursor() as cursor: - # Process each group separately for embedding_column, nodes_group in nodes_by_embedding_column.items(): - # Batch delete existing records using IN clause ids_to_delete = [node["id"] for node in nodes_group] if ids_to_delete: delete_query = f""" @@ -3297,7 +3259,6 @@ def add_nodes_batch( """ cursor.execute(delete_query, (ids_to_delete,)) - # Batch get graph_ids for all nodes get_graph_ids_query = f""" SELECT id_val, @@ -3307,21 +3268,16 @@ def add_nodes_batch( cursor.execute(get_graph_ids_query, (ids_to_delete,)) graph_id_map = {row[0]: row[1] for row in cursor.fetchall()} - # Add graph_id to properties for node in nodes_group: graph_id = graph_id_map.get(node["id"]) if graph_id: node["properties"]["graph_id"] = str(graph_id) - # Use PREPARE/EXECUTE for efficient batch insert - # Generate unique prepare statement name to avoid conflicts prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" - try: if embedding_column and any( node["embedding_vector"] for node in nodes_group ): - # PREPARE statement for insert with embedding prepare_query = f""" PREPARE {prepare_name} AS INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) @@ -3331,16 +3287,9 @@ def add_nodes_batch( $3::vector ) """ - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" - ) cursor.execute(prepare_query) - # Execute prepared statement for each node for node in nodes_group: properties_json = json.dumps(node["properties"]) embedding_json = ( @@ -3354,7 +3303,6 @@ def add_nodes_batch( (node["id"], properties_json, embedding_json), ) else: - # PREPARE statement for insert without embedding prepare_query = f""" PREPARE {prepare_name} AS INSERT INTO {self.db_name}_graph."Memory"(id, properties) @@ -3363,40 +3311,25 @@ def add_nodes_batch( $2::text::agtype ) """ - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" - ) cursor.execute(prepare_query) - # Execute prepared statement for each node for node in nodes_group: properties_json = json.dumps(node["properties"]) - cursor.execute( f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json), ) finally: - # DEALLOCATE prepared statement (always execute, even on error) try: cursor.execute(f"DEALLOCATE {prepare_name}") - logger.info( - f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" - ) except Exception as dealloc_error: logger.warning( f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" ) - - logger.info( - f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" - ) - elapsed_time = time.time() - batch_start_time + elapsed_time = time.perf_counter() - batch_start_time logger.info( - f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" + "add_nodes_batch batch insert completed successfully in took %.1f ms", + elapsed_time, ) except Exception as e: From 355455c7bf57f75640355d4d3225dd2f85c7cdfd Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:29:30 +0800 Subject: [PATCH 14/20] feat:optimize search_by_fulltext (#1164) --- src/memos/graph_dbs/polardb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 6a0be0d32..21db8a47d 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1777,6 +1777,7 @@ def search_by_fulltext( ) where_clause_cte = f"WHERE {' AND '.join(where_with_q)}" if where_with_q else "" query = f""" + /*+ Set(max_parallel_workers_per_gather 0) */ WITH q AS (SELECT to_tsquery('{tsquery_config}', %s) AS fq) SELECT {select_cols} FROM "{self.db_name}_graph"."Memory" m From 33d4d22a56edaca55268af965347951ca7f06792 Mon Sep 17 00:00:00 2001 From: Jiang <33757498+hijzy@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:31:20 +0800 Subject: [PATCH 15/20] fix: delete useless graph recall (#1167) * test: add rerank model * fix: delete useless graph recall --------- Co-authored-by: jiang --- .../textual/tree_text_memory/retrieve/task_goal_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index f4d6c4847..3b160a56e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -72,7 +72,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: else: return ParsedTaskGoal( memories=[task_description], - keys=[task_description], + keys=[], tags=[], goal_type="default", rephrased_query=task_description, From 8eb03620ce563bc656f87724726f61f9919ab238 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Thu, 5 Mar 2026 21:01:35 +0800 Subject: [PATCH 16/20] feat:optimize search_by_fulltext (#1170) --- src/memos/graph_dbs/polardb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 21db8a47d..6a0be0d32 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1777,7 +1777,6 @@ def search_by_fulltext( ) where_clause_cte = f"WHERE {' AND '.join(where_with_q)}" if where_with_q else "" query = f""" - /*+ Set(max_parallel_workers_per_gather 0) */ WITH q AS (SELECT to_tsquery('{tsquery_config}', %s) AS fq) SELECT {select_cols} FROM "{self.db_name}_graph"."Memory" m From df12d5e2e94a003d9a2b04ea540ad63b1ae9f055 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Fri, 6 Mar 2026 10:43:23 +0800 Subject: [PATCH 17/20] feat:optimize search_by_fulltext (#1171) --- src/memos/graph_dbs/polardb.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 6a0be0d32..8332efbc1 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -254,8 +254,13 @@ def _get_connection(self): if not self._semaphore.acquire(timeout=timeout): logger.warning(f"Timeout waiting for connection slot ({timeout}s)") raise RuntimeError( - f"Connection pool busy: could not acquire a slot within {timeout}s (all connections in use)." + f"Connection pool busy: acquire a slot within {timeout}s (all connections in use)." ) + logger.info( + "Connection pool usage: %s/%s", + self.connection_pool.maxconn - self._semaphore._value, + self.connection_pool.maxconn, + ) conn = None broken = False @@ -264,7 +269,7 @@ def _get_connection(self): logger.debug(f"Acquired connection {id(conn)} from pool") conn.autocommit = True with conn.cursor() as cur: - cur.execute("SELECT 1") + cur.execute(f'SET search_path = {self.db_name}_graph, ag_catalog, "$user", public;') yield conn except Exception as e: broken = True @@ -1777,6 +1782,7 @@ def search_by_fulltext( ) where_clause_cte = f"WHERE {' AND '.join(where_with_q)}" if where_with_q else "" query = f""" + /*+ Set(max_parallel_workers_per_gather 0) */ WITH q AS (SELECT to_tsquery('{tsquery_config}', %s) AS fq) SELECT {select_cols} FROM "{self.db_name}_graph"."Memory" m From 39623ab856f79030a2e8664e884543a32de8a6cd Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:29:40 +0800 Subject: [PATCH 18/20] fix: The feedback function fails when calling the search interface due to the failure in passing the 'user_name' parameter. (#1174) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add log * add log * add log * hot fix --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> --- README.md | 34 +++++++++---------- docker/Dockerfile.krolik | 2 +- .../data/mem_cube_tree/textual_memory.json | 2 +- src/memos/graph_dbs/polardb.py | 2 +- src/memos/mem_feedback/feedback.py | 9 ++--- src/memos/mem_os/utils/default_config.py | 2 ++ .../tree_text_memory/retrieve/searcher.py | 3 +- 7 files changed, 29 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 45a372ed1..834e53021 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Awesome AI Memory - +

@@ -55,7 +55,7 @@

- + Get Free API: [Try API](https://memos-dashboard.openmem.net/quickstart/?source=github) @@ -68,11 +68,11 @@ Get Free API: [Try API](https://memos-dashboard.openmem.net/quickstart/?source=g ![](https://cdn.memtensor.com.cn/img/1770612303123_mnaisk_compressed.png) - [**72% lower token usage**](https://x.com/MemOS_dev/status/2020854044583924111) – intelligent memory retrieval instead of loading full chat history -- [**Multi-agent memory sharing**](https://x.com/MemOS_dev/status/2020538135487062094) – multi-instance agents share memory via same user_id. Automatic context handoff. +- [**Multi-agent memory sharing**](https://x.com/MemOS_dev/status/2020538135487062094) – multi-instance agents share memory via same user_id. Automatic context handoff. 🦞 Your lobster now has a working memory system. -Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/) +Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/) Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin) ## 📌 MemOS: Memory Operating System for AI Agents @@ -92,7 +92,7 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe ### News -- **2025-12-24** · 🎉 **MemOS v2.0: Stardust (星尘) Release** +- **2025-12-24** · 🎉 **MemOS v2.0: Stardust (星尘) Release** Comprehensive KB (doc/URL parsing + cross-project sharing), memory feedback & precise deletion, multi-modal memory (images/charts), tool memory for agent planning, Redis Streams scheduling + DB optimizations, streaming/non-streaming chat, MCP upgrade, and lightweight quick/full deployment.
New Features @@ -139,7 +139,7 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe
-- **2025-08-07** · 🎉 **MemOS v1.0.0 (MemCube) Release** +- **2025-08-07** · 🎉 **MemOS v1.0.0 (MemCube) Release** First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, NebulaGraph support, improved search capabilities, and the official Playground launch.
@@ -177,11 +177,11 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe
-- **2025-07-07** · 🎉 **MemOS v1.0: Stellar (星河) Preview Release** +- **2025-07-07** · 🎉 **MemOS v1.0: Stellar (星河) Preview Release** A SOTA Memory OS for LLMs is now open-sourced. -- **2025-07-04** · 🎉 **MemOS Paper Release** +- **2025-07-04** · 🎉 **MemOS Paper Release** [MemOS: A Memory OS for AI System](https://arxiv.org/abs/2507.03724) is available on arXiv. -- **2024-07-04** · 🎉 **Memory3 Model Release at WAIC 2024** +- **2024-07-04** · 🎉 **Memory3 Model Release at WAIC 2024** The Memory3 model, featuring a memory-layered architecture, was unveiled at the 2024 World Artificial Intelligence Conference.
@@ -194,9 +194,9 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe - Go to **API Keys** and copy your key #### Next Steps -- [MemOS Cloud Getting Started](https://memos-docs.openmem.net/memos_cloud/quick_start/) +- [MemOS Cloud Getting Started](https://memos-docs.openmem.net/memos_cloud/quick_start/) Connect to MemOS Cloud and enable memory in minutes. -- [MemOS Cloud Platform](https://memos.openmem.net/?from=/quickstart/) +- [MemOS Cloud Platform](https://memos.openmem.net/?from=/quickstart/) Explore the Cloud dashboard, features, and workflows. ### 🖥️ 2、Self-Hosted (Local/Private) @@ -234,7 +234,7 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe ```python import requests import json - + data = { "user_id": "8736b16e-1d20-4163-980b-a5063c3facdc", "mem_cube_id": "b32d0977-435d-4828-a86f-4f47f8b55bca", @@ -250,7 +250,7 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe "Content-Type": "application/json" } url = "http://localhost:8000/product/add" - + res = requests.post(url=url, headers=headers, data=json.dumps(data)) print(f"result: {res.json()}") ``` @@ -258,7 +258,7 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe ```python import requests import json - + data = { "query": "What do I like", "user_id": "8736b16e-1d20-4163-980b-a5063c3facdc", @@ -268,7 +268,7 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe "Content-Type": "application/json" } url = "http://localhost:8000/product/search" - + res = requests.post(url=url, headers=headers, data=json.dumps(data)) print(f"result: {res.json()}") ``` @@ -277,8 +277,8 @@ Try it: Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTe ## 📚 Resources -- **Awesome-AI-Memory** - This is a curated repository dedicated to resources on memory and memory systems for large language models. It systematically collects relevant research papers, frameworks, tools, and practical insights. The repository aims to organize and present the rapidly evolving research landscape of LLM memory, bridging multiple research directions including natural language processing, information retrieval, agentic systems, and cognitive science. +- **Awesome-AI-Memory** + This is a curated repository dedicated to resources on memory and memory systems for large language models. It systematically collects relevant research papers, frameworks, tools, and practical insights. The repository aims to organize and present the rapidly evolving research landscape of LLM memory, bridging multiple research directions including natural language processing, information retrieval, agentic systems, and cognitive science. - **Get started** 👉 [IAAR-Shanghai/Awesome-AI-Memory](https://github.com/IAAR-Shanghai/Awesome-AI-Memory) - **MemOS Cloud OpenClaw Plugin** Official OpenClaw lifecycle plugin for MemOS Cloud. It automatically recalls context from MemOS before the agent starts and saves the conversation back to MemOS after the agent finishes. diff --git a/docker/Dockerfile.krolik b/docker/Dockerfile.krolik index c475a6d30..dcae7e0d9 100644 --- a/docker/Dockerfile.krolik +++ b/docker/Dockerfile.krolik @@ -1,5 +1,5 @@ # MemOS with Krolik Security Extensions -# +# # This Dockerfile builds MemOS with authentication, rate limiting, and admin API. # It uses the overlay pattern to keep customizations separate from base code. diff --git a/examples/data/mem_cube_tree/textual_memory.json b/examples/data/mem_cube_tree/textual_memory.json index 91f426ca2..97a2b1dd0 100644 --- a/examples/data/mem_cube_tree/textual_memory.json +++ b/examples/data/mem_cube_tree/textual_memory.json @@ -4216,4 +4216,4 @@ "edges": [], "total_nodes": 4, "total_edges": 0 -} \ No newline at end of file +} diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 8332efbc1..c8ed0a97f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1919,7 +1919,7 @@ def search_by_embedding( else: pass - logger.info(" search_by_embedding query: %s", query) + logger.info(" search_by_embedding query: %s user_name: %s", query, user_name) with self._get_connection() as conn, conn.cursor() as cursor: if params: diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 18045af2c..b8019004d 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -559,23 +559,24 @@ def check_has_edges(mem_item: TextualMemoryItem) -> tuple[TextualMemoryItem, boo edges = self.searcher.graph_store.get_edges(mem_item.id, user_name=user_name) return (mem_item, len(edges) == 0) + logger.info(f"[feedback _retrieve] query: {query}, user_name: {user_name}") text_mems = self.searcher.search( - query, + query=query, + top_k=top_k, info=info, memory_type="AllSummaryMemory", user_name=user_name, - top_k=top_k, full_recall=True, ) text_mems = [item[0] for item in text_mems if float(item[1]) > 0.01] if self.pref_feedback: pref_mems = self.searcher.search( - query, + query=query, + top_k=top_k, info=info, memory_type="PreferenceMemory", user_name=user_name, - top_k=top_k, include_preference_memory=True, full_recall=True, ) diff --git a/src/memos/mem_os/utils/default_config.py b/src/memos/mem_os/utils/default_config.py index 9898cbe8c..de79d535d 100644 --- a/src/memos/mem_os/utils/default_config.py +++ b/src/memos/mem_os/utils/default_config.py @@ -4,12 +4,14 @@ """ import logging + from typing import Literal from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig from memos.mem_cube.general import GeneralMemCube + logger = logging.getLogger(__name__) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index b4994671f..eb15b48ed 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -92,7 +92,7 @@ def retrieve( **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: logger.info( - f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" + f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}, user_name={user_name}" ) parsed_goal, query_embedding, _context, query = self._parse_task( query, @@ -859,6 +859,7 @@ def _retrieve_from_skill_memory( mode: str = "fast", ): """Retrieve and rerank from SkillMemory""" + if memory_type not in ["All", "SkillMemory"]: logger.info(f"[PATH-E] '{query}' Skipped (memory_type does not match)") return [] From 0a69ec21c033d5fb7907acbfee75880a13bf2c35 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:59:38 +0800 Subject: [PATCH 19/20] feat:optimize config (#1176) feat:optimize search_by_fulltext --- src/memos/graph_dbs/polardb.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index c8ed0a97f..ad75f4b65 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -173,7 +173,7 @@ def __init__(self, config: PolarDBGraphDBConfig): user=user, password=password, dbname=self.db_name, - connect_timeout=60, # Connection timeout in seconds + connect_timeout=10, # Connection timeout in seconds keepalives_idle=120, # Seconds of inactivity before sending keepalive (should be < server idle timeout) keepalives_interval=15, # Seconds between keepalive retries keepalives_count=5, # Number of keepalive retries before considering connection dead @@ -247,7 +247,7 @@ def _warm_up_connections_by_all(self): @contextmanager def _get_connection(self): - timeout = getattr(self, "_connection_wait_timeout", 5) + timeout = self._connection_wait_timeout if timeout <= 0: self._semaphore.acquire() else: @@ -1919,7 +1919,7 @@ def search_by_embedding( else: pass - logger.info(" search_by_embedding query: %s user_name: %s", query, user_name) + logger.info(" search_by_embedding query: %s", query) with self._get_connection() as conn, conn.cursor() as cursor: if params: From b94159e6e69ebb625e85aa5af4060c5d224a2d81 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:53:19 +0800 Subject: [PATCH 20/20] chore: change version number to v2.0.8 (#1186) --- pyproject.toml | 2 +- src/memos/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4a9ea8852..9f17c0000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ############################################################################## name = "MemoryOS" -version = "2.0.7" +version = "2.0.8" description = "Intelligence Begins with Memory" license = {text = "Apache-2.0"} readme = "README.md" diff --git a/src/memos/__init__.py b/src/memos/__init__.py index fefa3b2ab..36cc0b5b5 100644 --- a/src/memos/__init__.py +++ b/src/memos/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.7" +__version__ = "2.0.8" from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig