From d534bf01b25a7d53e3f358f41200da19f4e5252b Mon Sep 17 00:00:00 2001 From: ccl <13282138256@163.com> Date: Tue, 28 Oct 2025 17:05:19 +0800 Subject: [PATCH 01/11] fix --- src/memos/api/config.py | 200 +++++++++++++++++++++++++++++++++ src/memos/graph_dbs/polardb.py | 1 - 2 files changed, 200 insertions(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d1bc6efff..1bb17c842 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -1,8 +1,16 @@ +import base64 +import hashlib +import hmac import json +import logging import os +import re +import time from typing import Any +import requests + from dotenv import load_dotenv from memos.configs.mem_cube import GeneralMemCubeConfig @@ -13,6 +21,198 @@ # Load environment variables load_dotenv() +logger = logging.getLogger(__name__) + + +def _update_env_from_dict(data: dict[str, Any]) -> None: + """Apply a dict to environment variables, with change logging.""" + + def _is_sensitive(name: str) -> bool: + n = name.upper() + return any(s in n for s in ["PASSWORD", "SECRET", "AK", "SK", "TOKEN", "KEY"]) + + for k, v in data.items(): + if isinstance(v, dict): + new_val = json.dumps(v, ensure_ascii=False) + elif isinstance(v, bool): + new_val = "true" if v else "false" + elif v is None: + new_val = "" + else: + new_val = str(v) + + old_val = os.environ.get(k) + os.environ[k] = new_val + + try: + log_old = "***" if _is_sensitive(k) else (old_val if old_val is not None else "") + log_new = "***" if _is_sensitive(k) else new_val + if old_val != new_val: + logger.info(f"Nacos config update: {k}={log_new} (was {log_old})") + except Exception as e: + # Avoid logging failures blocking config updates + logger.debug(f"Skip logging change for {k}: {e}") + + +def get_config_json(name: str, default: Any | None = None) -> Any: + """Read JSON object/array from env and parse. Returns default on missing/invalid.""" + raw = os.getenv(name) + if not raw: + return default + try: + return json.loads(raw) + except Exception: + logger.warning(f"Invalid JSON in env '{name}', returning default.") + return default + + +def get_config_value(path: str, default: Any | None = None) -> Any: + """Read value from env with optional dot-path for structured configs. + + Examples: + - get_config_value("MONGODB_CONFIG.base_uri") + - get_config_value("MONGODB_BASE_URI") + """ + if "." not in path: + val = os.getenv(path) + return val if val is not None else default + root, *subkeys = path.split(".") + data = get_config_json(root, default=None) + if not isinstance(data, dict): + return default + cur: Any = data + for key in subkeys: + if isinstance(cur, dict) and key in cur: + cur = cur[key] + else: + return default + return cur + + +class NacosConfigManager: + _client = None + _data_id = None + _group = None + _enabled = False + + @classmethod + def _sign(cls, secret_key: str, data: str) -> str: + """HMAC-SHA1 sgin""" + signature = hmac.new(secret_key.encode("utf-8"), data.encode("utf-8"), hashlib.sha1) + return base64.b64encode(signature.digest()).decode() + + @staticmethod + def parse_properties(content: str) -> dict[str, Any]: + """parse properties to dict""" + data: dict[str, Any] = {} + for line in content.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + match = re.match(r"^([^=]+)=(.*)$", line) + if match: + key = match.group(1).strip() + value = match.group(2).strip() + val_lower = value.lower() + if val_lower in ("true", "false"): + value_parsed: Any = val_lower == "true" + elif re.match(r"^[+-]?\d+$", value): + try: + value_parsed = int(value) + except Exception: + value_parsed = value + else: + value_parsed = value + data[key] = value_parsed + return data + + @classmethod + def start_config_watch(cls): + while True: + cls.init() + time.sleep(60) + + @classmethod + def start_watch_if_enabled(cls) -> None: + enable = os.getenv("NACOS_ENABLE_WATCH", "false").lower() == "true" + print("enable:", enable) + if not enable: + return + interval = int(os.getenv("NACOS_WATCH_INTERVAL", "60")) + import threading + + def _loop() -> None: + while True: + try: + cls.init() + except Exception as e: + logger.error(f"❌ Nacos watch loop error: {e}") + time.sleep(interval) + + threading.Thread(target=_loop, daemon=True).start() + logger.info(f"Nacos watch thread started (interval={interval}s).") + + @classmethod + def init(cls) -> None: + server_addr = os.getenv("NACOS_SERVER_ADDR") + data_id = os.getenv("NACOS_DATA_ID") + group = os.getenv("NACOS_GROUP", "DEFAULT_GROUP") + namespace = os.getenv("NACOS_NAMESPACE", "") + ak = os.getenv("AK") + sk = os.getenv("SK") + + if not (server_addr and data_id and ak and sk): + logger.warning("❌ missing NACOS_SERVER_ADDR / AK / SK / DATA_ID") + return + + base_url = f"http://{server_addr}/nacos/v1/cs/configs" + + def _auth_headers(): + ts = str(int(time.time() * 1000)) + + sign_data = namespace + "+" + group + "+" + ts if namespace else group + "+" + ts + signature = cls._sign(sk, sign_data) + return { + "Spas-AccessKey": ak, + "Spas-Signature": signature, + "timeStamp": ts, + } + + try: + params = { + "dataId": data_id, + "group": group, + "tenant": namespace, + } + + headers = _auth_headers() + resp = requests.get(base_url, headers=headers, params=params, timeout=10) + + if resp.status_code != 200: + logger.error(f"Nacos AK/SK fail: {resp.status_code} {resp.text}") + return + + content = resp.text.strip() + if not content: + logger.warning("⚠️ Nacos is empty") + return + try: + data_props = cls.parse_properties(content) + logger.info("nacos config:", data_props) + _update_env_from_dict(data_props) + logger.info("✅ parse Nacos setting is Properties ") + except Exception as e: + logger.error(f"⚠️ Nacos parse fail(not JSON/YAML/Properties): {e}") + return + + except Exception as e: + logger.error(f"❌ Nacos AK/SK init fail: {e}") + + +# init Nacos +NacosConfigManager.init() +NacosConfigManager.start_watch_if_enabled() + class APIConfig: """Centralized configuration management for MemOS APIs.""" diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 38e71298f..2a83ec1cc 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -835,7 +835,6 @@ def get_nodes( # Parse embedding from JSONB if it exists if embedding_json is not None: try: - print("embedding_json:", embedding_json) # remove embedding """ embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json From 2336b6d3f3ba5630fedded2a6d2a321c1ba0c45c Mon Sep 17 00:00:00 2001 From: ccl <13282138256@163.com> Date: Wed, 29 Oct 2025 17:47:40 +0800 Subject: [PATCH 02/11] =?UTF-8?q?fix=EF=BC=9Anacos?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/api/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index c6970d0f7..292e7ffb5 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -203,10 +203,12 @@ def _auth_headers(): logger.info("✅ parse Nacos setting is Properties ") except Exception as e: logger.error(f"⚠️ Nacos parse fail(not JSON/YAML/Properties): {e}") - return + raise Exception(f"Nacos configuration parsing failed: {e}") + except Exception as e: logger.error(f"❌ Nacos AK/SK init fail: {e}") + raise Exception(f"❌ Nacos AK/SK init fail: {e}") # init Nacos From 464abed3d239d6281fb5dc5a4a76ff7d9d8b54a8 Mon Sep 17 00:00:00 2001 From: liji <532311301@qq.com> Date: Wed, 29 Oct 2025 18:56:07 +0800 Subject: [PATCH 03/11] feat: fix config Exception --- src/memos/api/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 292e7ffb5..ca7525c41 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -203,12 +203,12 @@ def _auth_headers(): logger.info("✅ parse Nacos setting is Properties ") except Exception as e: logger.error(f"⚠️ Nacos parse fail(not JSON/YAML/Properties): {e}") - raise Exception(f"Nacos configuration parsing failed: {e}") + raise Exception(f"Nacos configuration parsing failed: {e}") from e except Exception as e: logger.error(f"❌ Nacos AK/SK init fail: {e}") - raise Exception(f"❌ Nacos AK/SK init fail: {e}") + raise Exception(f"❌ Nacos AK/SK init fail: {e}") from e # init Nacos From a25434d2b52eaa7b9899c0ae59f7b7d074330a0d Mon Sep 17 00:00:00 2001 From: liji <532311301@qq.com> Date: Wed, 29 Oct 2025 19:24:40 +0800 Subject: [PATCH 04/11] feat: format config --- src/memos/api/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index ca7525c41..bb02b99be 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -205,7 +205,6 @@ def _auth_headers(): logger.error(f"⚠️ Nacos parse fail(not JSON/YAML/Properties): {e}") raise Exception(f"Nacos configuration parsing failed: {e}") from e - except Exception as e: logger.error(f"❌ Nacos AK/SK init fail: {e}") raise Exception(f"❌ Nacos AK/SK init fail: {e}") from e From 364148419d1d77f8af4de8cb6f99257c9e19aa42 Mon Sep 17 00:00:00 2001 From: liji <532311301@qq.com> Date: Wed, 29 Oct 2025 21:00:23 +0800 Subject: [PATCH 05/11] feat: format config --- src/memos/api/config.py | 65 ++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 92ce8eeb9..7ac882d6c 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -95,35 +95,74 @@ class NacosConfigManager: _group = None _enabled = False + # Pre-compile regex patterns for better performance + _KEY_VALUE_PATTERN = re.compile(r"^([^=]+)=(.*)$") + _INTEGER_PATTERN = re.compile(r"^[+-]?\d+$") + _FLOAT_PATTERN = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$") + @classmethod def _sign(cls, secret_key: str, data: str) -> str: """HMAC-SHA1 sgin""" signature = hmac.new(secret_key.encode("utf-8"), data.encode("utf-8"), hashlib.sha1) return base64.b64encode(signature.digest()).decode() + @staticmethod + def _parse_value(value: str) -> Any: + """Parse string value to appropriate Python type. + + Supports: bool, int, float, and string. + """ + if not value: + return value + + val_lower = value.lower() + + # Boolean + if val_lower in ("true", "false"): + return val_lower == "true" + + # Integer + if NacosConfigManager._INTEGER_PATTERN.match(value): + try: + return int(value) + except (ValueError, OverflowError): + return value + + # Float + if NacosConfigManager._FLOAT_PATTERN.match(value): + try: + return float(value) + except (ValueError, OverflowError): + return value + + # Default to string + return value + @staticmethod def parse_properties(content: str) -> dict[str, Any]: - """parse properties to dict""" + """Parse properties file content to dictionary with type inference. + + Supports: + - Comments (lines starting with #) + - Key-value pairs (KEY=VALUE) + - Type inference (bool, int, float, string) + """ data: dict[str, Any] = {} + for line in content.splitlines(): line = line.strip() + + # Skip empty lines and comments if not line or line.startswith("#"): continue - match = re.match(r"^([^=]+)=(.*)$", line) + + # Parse key-value pair + match = NacosConfigManager._KEY_VALUE_PATTERN.match(line) if match: key = match.group(1).strip() value = match.group(2).strip() - val_lower = value.lower() - if val_lower in ("true", "false"): - value_parsed: Any = val_lower == "true" - elif re.match(r"^[+-]?\d+$", value): - try: - value_parsed = int(value) - except Exception: - value_parsed = value - else: - value_parsed = value - data[key] = value_parsed + data[key] = NacosConfigManager._parse_value(value) + return data @classmethod From 60a1ed16c26680af9440512afc4c0246450aa852 Mon Sep 17 00:00:00 2001 From: ccl <13282138256@163.com> Date: Thu, 30 Oct 2025 11:45:05 +0800 Subject: [PATCH 06/11] =?UTF-8?q?fix=EF=BC=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/graph_dbs/polardb.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 971a56e04..183ee67b5 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -91,6 +91,7 @@ def clean_properties(props): return {} return {k: v for k, v in props.items() if k not in vector_keys} +global_connection_pool = None class PolarDBGraphDB(BaseGraphDB): """PolarDB-based implementation using Apache AGE graph database extension.""" @@ -137,11 +138,20 @@ def __init__(self, config: PolarDBGraphDBConfig): port = config.port user = config.user password = config.password - + """ # Create connection self.connection = psycopg2.connect( + host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 + ) + """ + + import psycopg2.pool + global global_connection_pool + global_connection_pool = psycopg2.pool.SimpleConnectionPool( + 10, 2000, host=host, port=port, user=user, password=password, dbname=self.db_name ) + self.connection = global_connection_pool.getconn() self.connection.autocommit = True """ @@ -736,9 +746,9 @@ def format_param_value(value: str) -> str: if result: if include_embedding: - node_id, properties_json, embedding_json = result + properties_json, embedding_json = result else: - node_id, properties_json = result + properties_json = result embedding_json = None # Parse properties from JSONB if it's a string @@ -1727,9 +1737,9 @@ def export_graph( for row in node_results: if include_embedding: - node_id, properties_json, embedding_json = row + properties_json, embedding_json = row else: - node_id, properties_json = row + properties_json = row embedding_json = None # Parse properties from JSONB if it's a string From 31fcae798ce93f20247bf9a8243c51a8f975ef75 Mon Sep 17 00:00:00 2001 From: ccl <13282138256@163.com> Date: Thu, 30 Oct 2025 11:58:16 +0800 Subject: [PATCH 07/11] =?UTF-8?q?fix=EF=BC=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/graph_dbs/polardb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 183ee67b5..42be55646 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2859,6 +2859,8 @@ def _convert_graph_edges(self, core_node: dict) -> dict: data = copy.deepcopy(core_node) id_map = {} core_node = data.get("core_node", {}) + if not core_node: + return core_node core_meta = core_node.get("metadata", {}) if "graph_id" in core_meta and "id" in core_node: id_map[core_meta["graph_id"]] = core_node["id"] From 59c081378725fec3f9b808dd50e546f6c743e889 Mon Sep 17 00:00:00 2001 From: ccl <13282138256@163.com> Date: Thu, 30 Oct 2025 15:07:33 +0800 Subject: [PATCH 08/11] =?UTF-8?q?fix=EF=BC=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/graph_dbs/polardb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 42be55646..8ab2966f8 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -774,7 +774,7 @@ def format_param_value(value: str) -> str: logger.warning(f"Failed to parse embedding for node {id}") return self._parse_node( - {"id": id, "memory": properties.get("memory", ""), **properties} + {"id": id, "memory": json.loads(properties[1]).get("memory", ""), **json.loads(properties[1])} ) return None From 5515f6cf6a056028dcdd8c7cee68b8d1c5d2cf7c Mon Sep 17 00:00:00 2001 From: ccl <13282138256@163.com> Date: Thu, 30 Oct 2025 15:08:48 +0800 Subject: [PATCH 09/11] =?UTF-8?q?fix=EF=BC=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/graph_dbs/polardb.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 8ab2966f8..4f6d6f184 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -15,10 +15,6 @@ logger = get_logger(__name__) -# Graph database configuration -GRAPH_NAME = "test_memos_graph" - - def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: node_id = item["id"] memory = item["memory"] From b46ac072195b76760ce265fa0048be52485fa527 Mon Sep 17 00:00:00 2001 From: ccl <13282138256@163.com> Date: Thu, 30 Oct 2025 16:58:14 +0800 Subject: [PATCH 10/11] add polardb pool --- src/memos/graph_dbs/polardb.py | 410 +++++++++++++++++++++------------ 1 file changed, 264 insertions(+), 146 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 4f6d6f184..a4f9709ba 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -116,6 +116,7 @@ def __init__(self, config: PolarDBGraphDBConfig): but it will be removed automatically before returning to external consumers. """ import psycopg2 + import psycopg2.pool self.config = config @@ -141,14 +142,19 @@ def __init__(self, config: PolarDBGraphDBConfig): ) """ - import psycopg2.pool - global global_connection_pool - global_connection_pool = psycopg2.pool.SimpleConnectionPool( - 10, 2000, - host=host, port=port, user=user, password=password, dbname=self.db_name + # Create connection pool + self.connection_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=5, + maxconn=2000, + host=host, + port=port, + user=user, + password=password, + dbname=self.db_name ) - self.connection = global_connection_pool.getconn() - self.connection.autocommit = True + + # Keep a reference to the pool for cleanup + self._pool_closed = False """ # Handle auto_create @@ -173,6 +179,17 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) + def _get_connection(self): + """Get a connection from the pool.""" + if self._pool_closed: + raise RuntimeError("Connection pool has been closed") + return self.connection_pool.getconn() + + def _return_connection(self, connection): + """Return a connection to the pool.""" + if not self._pool_closed and connection: + self.connection_pool.putconn(connection) + def _ensure_database_exists(self): """Create database if it doesn't exist.""" try: @@ -186,8 +203,10 @@ def _ensure_database_exists(self): @timed def _create_graph(self): """Create PostgreSQL schema and table for graph storage.""" + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Create schema if it doesn't exist cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') logger.info(f"Schema '{self.db_name}_graph' ensured.") @@ -235,6 +254,8 @@ def _create_graph(self): except Exception as e: logger.error(f"Failed to create graph schema: {e}") raise e + finally: + self._return_connection(conn) def create_index( self, @@ -247,8 +268,10 @@ def create_index( Create indexes for embedding and other fields. Note: This creates PostgreSQL indexes on the underlying tables. """ + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Create indexes on the underlying PostgreSQL tables # Apache AGE stores data in regular PostgreSQL tables cursor.execute(f""" @@ -268,6 +291,8 @@ def create_index( logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") + finally: + self._return_connection(conn) def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: """Get count of memory nodes by type.""" @@ -282,14 +307,18 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in print(f"[get_memory_count] Query: {query}, Params: {params}") + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return result[0] if result else 0 except Exception as e: logger.error(f"[get_memory_count] Failed: {e}") return -1 + finally: + self._return_connection(conn) @timed def node_not_exist(self, scope: str, user_name: str | None = None) -> int: @@ -305,9 +334,10 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: params = [f'"{scope}"', f'"{user_name}"'] print(f"[node_not_exist] Query: {query}, Params: {params}") - + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() print(f"[node_not_exist] Query result: {result}") @@ -315,6 +345,8 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def remove_oldest_memory( @@ -340,9 +372,9 @@ def remove_oldest_memory( OFFSET %s """ select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest] - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Execute query to get IDs to delete cursor.execute(select_query, select_params) ids_to_delete = [row[0] for row in cursor.fetchall()] @@ -368,6 +400,8 @@ def remove_oldest_memory( except Exception as e: logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: @@ -426,12 +460,16 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N params.append(f'"{user_name}"') print(f"[update_node] query: {query}, params: {params}") + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def delete_node(self, id: str, user_name: str | None = None) -> None: @@ -453,18 +491,24 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: params.append(f'"{user_name}"') print(f"[delete_node] query: {query}, params: {params}") + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def create_extension(self): extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Ensure in the correct database context cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] @@ -485,11 +529,15 @@ def create_extension(self): except Exception as e: print(f"⚠️ Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def create_graph(self): + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(f""" SELECT COUNT(*) FROM ag_catalog.ag_graph WHERE name = '{self.db_name}_graph'; @@ -504,6 +552,8 @@ def create_graph(self): except Exception as e: print(f"⚠️ Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def create_edge(self): @@ -513,8 +563,9 @@ def create_edge(self): for label_name in valid_rel_types: print(f"🪶 Creating elabel: {label_name}") + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") print(f"✅ Successfully created elabel: {label_name}") except Exception as e: @@ -523,6 +574,8 @@ def create_edge(self): else: print(f"⚠️ Failed to create label {label_name}: {e}") logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def add_edge( @@ -554,14 +607,16 @@ def add_edge( ); """ print(f"Executing add_edge: {query}") - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") except Exception as e: logger.error(f"Failed to insert edge: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def delete_edge(self, source_id: str, target_id: str, type: str) -> None: @@ -576,10 +631,13 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: DELETE FROM "{self.db_name}_graph"."Edges" WHERE source_id = %s AND target_id = %s AND edge_type = %s """ - - with self.connection.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type)) - logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + finally: + self._return_connection(conn) @timed def edge_exists_old( @@ -634,11 +692,14 @@ def edge_exists_old( WHERE {where_clause} LIMIT 1 """ - - with self.connection.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result is not None + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result is not None + finally: + self._return_connection(conn) @timed def edge_exists( @@ -688,10 +749,14 @@ def edge_exists( query += "\n$$) AS (r agtype)" print(f"edge_exists query: {query}") - with self.connection.cursor() as cursor: - cursor.execute(query) - result = cursor.fetchone() - return result is not None and result[0] is not None + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None + finally: + self._return_connection(conn) @timed def get_node( @@ -735,8 +800,9 @@ def format_param_value(value: str) -> str: params.append(format_param_value(user_name)) print(f"[get_node] query: {query}, params: {params}") + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() @@ -777,6 +843,8 @@ def format_param_value(value: str) -> str: except Exception as e: logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) return None + finally: + self._return_connection(conn) @timed def get_nodes( @@ -819,43 +887,47 @@ def get_nodes( params.append(f'"{user_name}"') print(f"[get_nodes] query: {query}, params: {params}") - with self.connection.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() - nodes = [] - for row in results: - node_id, properties_json, embedding_json = row - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {node_id}") - properties = {} - else: - properties = properties_json if properties_json else {} + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} - # Parse embedding from JSONB if it exists - if embedding_json is not None: - try: - # remove embedding - """ - embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json - # properties["embedding"] = embedding - """ - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - nodes.append( - self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } + # Parse embedding from JSONB if it exists + if embedding_json is not None: + try: + # remove embedding + """ + embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json + # properties["embedding"] = embedding + """ + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) ) - ) - return nodes + return nodes + finally: + self._return_connection(conn) @timed def get_edges_old( @@ -1074,9 +1146,9 @@ def get_children_with_embeddings( """ print("[get_children_with_embeddings] query:", query) - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1131,6 +1203,8 @@ def get_children_with_embeddings( except Exception as e: logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: """Get the path of nodes from source to target within a limited depth.""" @@ -1192,9 +1266,9 @@ def get_subgraph( r) $$ ) as (centers agtype, neighbors agtype, rels agtype); """ - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() print("[get_subgraph] result:", result) @@ -1269,6 +1343,8 @@ def get_subgraph( except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) return {"core_node": None, "neighbors": [], "edges": []} + finally: + self._return_connection(conn) def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" @@ -1355,24 +1431,29 @@ def search_by_embedding( print( f"[search_by_embedding] query: {query}, params: {params}, where_clause: {where_clause}" ) - with self.connection.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - """ - polarId = row[0] # id - properties = row[1] # properties - # embedding = row[3] # embedding - """ - oldid = row[3] # old_id - score = row[4] # scope - id_val = str(oldid) - score_val = float(score) - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) - return output[:top_k] + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + """ + polarId = row[0] # id + properties = row[1] # properties + # embedding = row[3] # embedding + """ + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid) + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + return output[:top_k] + finally: + self._return_connection(conn) + @timed def get_by_metadata( @@ -1463,8 +1544,9 @@ def get_by_metadata( print(f"[get_by_metadata] query: {cypher_query}, where_str: {where_str}") ids = [] + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() print("[get_by_metadata] result:", results) @@ -1472,6 +1554,8 @@ def get_by_metadata( except Exception as e: print("Failed to get metadata:", {e}) logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + finally: + self._return_connection(conn) return ids @@ -1627,9 +1711,9 @@ def get_grouped_counts( """ print("[get_grouped_counts] query:", query) - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Handle parameterized query if params and isinstance(params, list): cursor.execute(query, params) @@ -1654,6 +1738,10 @@ def get_grouped_counts( except Exception as e: logger.error(f"Failed to get grouped counts: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) + + def deduplicate_nodes(self) -> None: """Deduplicate redundant or semantically similar nodes.""" @@ -1685,10 +1773,13 @@ def clear(self, user_name: str | None = None) -> None: DETACH DELETE n $$) AS (result agtype) """ - - with self.connection.cursor() as cursor: - cursor.execute(query) - logger.info("Cleared all nodes from database.") + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") + finally: + self._return_connection(conn) except Exception as e: logger.error(f"[ERROR] Failed to clear database: {e}") @@ -1710,7 +1801,7 @@ def export_graph( } """ user_name = user_name if user_name else self._get_config_value("user_name") - + conn = self._get_connection() try: # Export nodes if include_embedding: @@ -1726,7 +1817,7 @@ def export_graph( WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype """ - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(node_query) node_results = cursor.fetchall() nodes = [] @@ -1765,7 +1856,10 @@ def export_graph( except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e + finally: + self._return_connection(conn) + conn = self._get_connection() try: # Export edges using cypher query edge_query = f""" @@ -1776,7 +1870,7 @@ def export_graph( $$) AS (source agtype, target agtype, edge agtype) """ - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(edge_query) edge_results = cursor.fetchall() edges = [] @@ -1838,6 +1932,9 @@ def export_graph( except Exception as e: logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e + finally: + self._return_connection(conn) + return {"nodes": nodes, "edges": edges} @timed @@ -1852,9 +1949,12 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: RETURN count(n) $$) AS (count agtype) """ - - result = self.execute_query(query) - return int(result.one_or_none()["count"].value) + conn = self._get_connection() + try: + result = self.execute_query(query, conn) + return int(result.one_or_none()["count"].value) + finally: + self._return_connection(conn) @timed def get_all_memory_items( @@ -1896,8 +1996,9 @@ def get_all_memory_items( nodes = [] node_ids = set() print("[get_all_memory_items embedding true ] cypher_query:", cypher_query) + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -1919,6 +2020,8 @@ def get_all_memory_items( except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) return nodes else: @@ -1933,8 +2036,9 @@ def get_all_memory_items( print("[get_all_memory_items embedding false ] cypher_query:", cypher_query) nodes = [] + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -1951,6 +2055,8 @@ def get_all_memory_items( except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) return nodes @@ -2162,8 +2268,9 @@ def get_structure_optimization_candidates( candidates = [] node_ids = set() + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() print("result------", len(results)) @@ -2240,6 +2347,8 @@ def get_structure_optimization_candidates( except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) + finally: + self._return_connection(conn) return candidates @@ -2362,44 +2471,48 @@ def add_node( elif len(embedding_vector) == 768: embedding_column = "embedding_768" - with self.connection.cursor() as cursor: - # Delete existing record first (if any) - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(delete_query, (id,)) - # - get_graph_id_query = f""" - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(get_graph_id_query, (id,)) - graph_id = cursor.fetchone()[0] - properties["graph_id"] = str(graph_id) - - # Then insert new record - if embedding_vector: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s, - %s - ) + conn = self._get_connection() + try: + with conn.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) """ - cursor.execute( - insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) - ) - else: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s + cursor.execute(delete_query, (id,)) + # + get_graph_id_query = f""" + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(get_graph_id_query, (id,)) + graph_id = cursor.fetchone()[0] + properties["graph_id"] = str(graph_id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) ) - """ - cursor.execute(insert_query, (id, json.dumps(properties))) - logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") + finally: + self._return_connection(conn) def _build_node_from_agtype(self, node_agtype, embedding=None): """ @@ -2506,8 +2619,9 @@ def get_neighbors_by_tag( print(f"[get_neighbors_by_tag] query: {query}, params: {params}") + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -2556,6 +2670,8 @@ def get_neighbors_by_tag( except Exception as e: logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def get_neighbors_by_tag_ccl( self, @@ -2801,9 +2917,9 @@ def get_edges( RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -2848,6 +2964,8 @@ def get_edges( except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def _convert_graph_edges(self, core_node: dict) -> dict: import copy From 298c0298945e0f81a54b30b99d2aeb49376dc3c3 Mon Sep 17 00:00:00 2001 From: ccl <13282138256@163.com> Date: Thu, 30 Oct 2025 17:36:37 +0800 Subject: [PATCH 11/11] =?UTF-8?q?fix=EF=BC=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/graph_dbs/polardb.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index ab06517d4..f24f1072c 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -15,6 +15,7 @@ logger = get_logger(__name__) + def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: node_id = item["id"] memory = item["memory"] @@ -149,7 +150,7 @@ def __init__(self, config: PolarDBGraphDBConfig): port=port, user=user, password=password, - dbname=self.db_name + dbname=self.db_name, ) # Keep a reference to the pool for cleanup @@ -829,7 +830,11 @@ def format_param_value(value: str) -> str: logger.warning(f"Failed to parse embedding for node {id}") return self._parse_node( - {"id": id, "memory": json.loads(properties[1]).get("memory", ""), **json.loads(properties[1])} + { + "id": id, + "memory": json.loads(properties[1]).get("memory", ""), + **json.loads(properties[1]), + } ) return None @@ -1418,7 +1423,6 @@ def search_by_embedding( """ params = [vector] - conn = self._get_connection() try: with conn.cursor() as cursor: @@ -1442,7 +1446,6 @@ def search_by_embedding( finally: self._return_connection(conn) - @timed def get_by_metadata( self, filters: list[dict[str, Any]], user_name: str | None = None @@ -1720,8 +1723,6 @@ def get_grouped_counts( finally: self._return_connection(conn) - - def deduplicate_nodes(self) -> None: """Deduplicate redundant or semantically similar nodes.""" raise NotImplementedError