From fb753be3beb68af70e2371dbb6d40d46c85f4f29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 2 Mar 2026 19:33:46 +0800 Subject: [PATCH 01/10] feat:optimzie polardb ThreadedConnectionPool --- src/memos/graph_dbs/polardb.py | 2332 ++++++++++++++------------------ 1 file changed, 1013 insertions(+), 1319 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 592f45a7f..841839baf 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3,7 +3,7 @@ import textwrap import time -from contextlib import suppress +from contextlib import contextmanager, suppress from datetime import datetime from typing import Any, Literal @@ -14,6 +14,7 @@ from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger from memos.utils import timed +import threading logger = get_logger(__name__) @@ -145,12 +146,7 @@ def __init__(self, config: PolarDBGraphDBConfig): user = config.user password = config.password maxconn = config.maxconn if hasattr(config, "maxconn") else 100 - """ - # Create connection - self.connection = psycopg2.connect( - host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 - ) - """ + logger.info(f" db_name: {self.db_name} current maxconn is:'{maxconn}'") # Create connection pool @@ -168,8 +164,7 @@ def __init__(self, config: PolarDBGraphDBConfig): keepalives_count=5, # Number of keepalive retries before considering connection dead ) - # Keep a reference to the pool for cleanup - self._pool_closed = False + self._semaphore = threading.BoundedSemaphore(maxconn) """ # Handle auto_create @@ -194,194 +189,33 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) - def _get_connection_old(self): - """Get a connection from the pool.""" - if self._pool_closed: - raise RuntimeError("Connection pool has been closed") - conn = self.connection_pool.getconn() - # Set autocommit for PolarDB compatibility - conn.autocommit = True - return conn - + @contextmanager def _get_connection(self): - logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") - if self._pool_closed: - raise RuntimeError("Connection pool has been closed") - - max_retries = 500 - import psycopg2.pool - - for attempt in range(max_retries): - conn = None - try: - conn = self.connection_pool.getconn() - - if conn.closed != 0: - logger.warning( - f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" - ) - try: - self.connection_pool.putconn(conn, close=True) - except Exception as e: - logger.warning( - f"[_get_connection] Failed to return closed connection to pool: {e}" - ) - with suppress(Exception): - conn.close() - - conn = None - if attempt < max_retries - 1: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError("Pool returned a closed connection after all retries") - - # Set autocommit for PolarDB compatibility - conn.autocommit = True - - # Test connection health with SELECT 1 - try: - cursor = conn.cursor() - cursor.execute("SELECT 1") - cursor.fetchone() - cursor.close() - except Exception as health_check_error: - # Connection is not usable, return it to pool with close flag and try again - logger.warning( - f"[_get_connection] Connection health check failed (attempt {attempt + 1}/{max_retries}): {health_check_error}" - ) - try: - self.connection_pool.putconn(conn, close=True) - except Exception as putconn_error: - logger.warning( - f"[_get_connection] Failed to return unhealthy connection to pool: {putconn_error}" - ) - with suppress(Exception): - conn.close() - - conn = None - if attempt < max_retries - 1: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Failed to get a healthy connection from pool after {max_retries} attempts: {health_check_error}" - ) from health_check_error - - # Connection is healthy, return it - return conn - - except psycopg2.pool.PoolError as pool_error: - error_msg = str(pool_error).lower() - if "exhausted" in error_msg or "pool" in error_msg: - # Log pool status for debugging - try: - # Try to get pool stats if available - pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" - logger.info( - f" polardb get_connection Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" - ) - except Exception: - logger.warning( - f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" - ) - - # For pool exhaustion, wait longer before retry (connections may be returned) - if attempt < max_retries - 1: - # Longer backoff for pool exhaustion: 0.5s, 1.0s, 2.0s - wait_time = 0.5 * (2**attempt) - logger.info(f"[_get_connection] Waiting {wait_time}s before retry...") - """time.sleep(wait_time)""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Connection pool exhausted after {max_retries} attempts. " - f"This usually means connections are not being returned to the pool. " - ) from pool_error - else: - # Other pool errors - retry with normal backoff - if attempt < max_retries - 1: - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Failed to get connection from pool: {pool_error}" - ) from pool_error - - except Exception as e: - if conn is not None: - try: - self.connection_pool.putconn(conn, close=True) - except Exception as putconn_error: - logger.warning( - f"[_get_connection] Failed to return connection after error: {putconn_error}" - ) - with suppress(Exception): - conn.close() - - if attempt >= max_retries - 1: - raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e - else: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - - # Should never reach here, but just in case - raise RuntimeError("Failed to get connection after all retries") - - def _return_connection(self, connection): - if self._pool_closed: - if connection: - try: - connection.close() - logger.debug("[_return_connection] Closed connection (pool is closed)") - except Exception as e: - logger.warning( - f"[_return_connection] Failed to close connection after pool closed: {e}" - ) - return - - if not connection: - return - + """ + 安全获取连接(阻塞等待,不会抛 pool exhausted) + """ + self._semaphore.acquire() + conn = None try: - if hasattr(connection, "closed") and connection.closed != 0: - logger.debug( - "[_return_connection] Connection is closed, closing it instead of returning to pool" - ) + conn = self.connection_pool.getconn() + conn.autocommit = True + with conn.cursor() as cur: + cur.execute("SELECT 1") + yield conn + except Exception: + if conn: try: - connection.close() - except Exception as e: - logger.warning(f"[_return_connection] Failed to close closed connection: {e}") - return - - self.connection_pool.putconn(connection) - logger.debug("[_return_connection] Successfully returned connection to pool") - except Exception as e: - logger.error( - f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True - ) - try: - connection.close() - logger.debug( - "[_return_connection] Closed connection as fallback after putconn failure" - ) - except Exception as close_error: - logger.warning( - f"[_return_connection] Failed to close connection after putconn error: {close_error}" - ) - - def _return_connection_old(self, connection): - """Return a connection to the pool.""" - if not self._pool_closed and connection: - self.connection_pool.putconn(connection) + self.connection_pool.putconn(conn, close=True) + except Exception: + pass + raise + finally: + if conn: + try: + self.connection_pool.putconn(conn) + except Exception: + pass + self._semaphore.release() def _ensure_database_exists(self): """Create database if it doesn't exist.""" @@ -396,60 +230,56 @@ def _ensure_database_exists(self): @timed def _create_graph(self): """Create PostgreSQL schema and table for graph storage.""" - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Create schema if it doesn't exist - cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') - logger.info(f"Schema '{self.db_name}_graph' ensured.") - - # Create Memory table if it doesn't exist - cursor.execute(f""" - CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Memory" ( - id TEXT PRIMARY KEY, - properties JSONB NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - """) - logger.info(f"Memory table created in schema '{self.db_name}_graph'.") + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Create schema if it doesn't exist + cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') + logger.info(f"Schema '{self.db_name}_graph' ensured.") - # Add embedding column if it doesn't exist (using JSONB for compatibility) - try: + # Create Memory table if it doesn't exist cursor.execute(f""" - ALTER TABLE "{self.db_name}_graph"."Memory" - ADD COLUMN IF NOT EXISTS embedding JSONB; + CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Memory" ( + id TEXT PRIMARY KEY, + properties JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); """) - logger.info("Embedding column added to Memory table.") - except Exception as e: - logger.warning(f"Failed to add embedding column: {e}") + logger.info(f"Memory table created in schema '{self.db_name}_graph'.") - # Create indexes - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties - ON "{self.db_name}_graph"."Memory" USING GIN (properties); - """) + # Add embedding column if it doesn't exist (using JSONB for compatibility) + try: + cursor.execute(f""" + ALTER TABLE "{self.db_name}_graph"."Memory" + ADD COLUMN IF NOT EXISTS embedding JSONB; + """) + logger.info("Embedding column added to Memory table.") + except Exception as e: + logger.warning(f"Failed to add embedding column: {e}") - # Create vector index for embedding field - try: + # Create indexes cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding - ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) - WITH (lists = 100); + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); """) - logger.info("Vector index created for Memory table.") - except Exception as e: - logger.warning(f"Vector index creation failed (might not be supported): {e}") - logger.info("Indexes created for Memory table.") + # Create vector index for embedding field + try: + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + """) + logger.info("Vector index created for Memory table.") + except Exception as e: + logger.warning(f"Vector index creation failed (might not be supported): {e}") + + logger.info("Indexes created for Memory table.") except Exception as e: logger.error(f"Failed to create graph schema: {e}") raise e - finally: - self._return_connection(conn) def create_index( self, @@ -462,32 +292,28 @@ def create_index( Create indexes for embedding and other fields. Note: This creates PostgreSQL indexes on the underlying tables. """ - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Create indexes on the underlying PostgreSQL tables - # Apache AGE stores data in regular PostgreSQL tables - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties - ON "{self.db_name}_graph"."Memory" USING GIN (properties); - """) - - # Try to create vector index, but don't fail if it doesn't work - try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Create indexes on the underlying PostgreSQL tables + # Apache AGE stores data in regular PostgreSQL tables cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding - ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); """) - except Exception as ve: - logger.warning(f"Vector index creation failed (might not be supported): {ve}") - logger.debug("Indexes created successfully.") + # Try to create vector index, but don't fail if it doesn't work + try: + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); + """) + except Exception as ve: + logger.warning(f"Vector index creation failed (might not be supported): {ve}") + + logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") - finally: - self._return_connection(conn) def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: """Get count of memory nodes by type.""" @@ -500,19 +326,15 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params = [self.format_param_value(memory_type), self.format_param_value(user_name)] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result[0] if result else 0 + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result[0] if result else 0 except Exception as e: logger.error(f"[get_memory_count] Failed: {e}") return -1 - finally: - self._return_connection(conn) @timed def node_not_exist(self, scope: str, user_name: str | None = None) -> int: @@ -527,19 +349,15 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: query += "\nLIMIT 1" params = [self.format_param_value(scope), self.format_param_value(user_name)] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return 1 if result else 0 + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return 1 if result else 0 except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def remove_oldest_memory( @@ -569,39 +387,36 @@ def remove_oldest_memory( self.format_param_value(user_name), keep_latest, ] - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Execute query to get IDs to delete - cursor.execute(select_query, select_params) - ids_to_delete = [row[0] for row in cursor.fetchall()] - - if not ids_to_delete: - logger.info(f"No {memory_type} memories to remove for user {user_name}") - return - - # Build delete query - placeholders = ",".join(["%s"] * len(ids_to_delete)) - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE id IN ({placeholders}) - """ - delete_params = ids_to_delete + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Execute query to get IDs to delete + cursor.execute(select_query, select_params) + ids_to_delete = [row[0] for row in cursor.fetchall()] - # Execute deletion - cursor.execute(delete_query, delete_params) - deleted_count = cursor.rowcount - logger.info( - f"Removed {deleted_count} oldest {memory_type} memories, " - f"keeping {keep_latest} latest for user {user_name}, " - f"removed ids: {ids_to_delete}" - ) + if not ids_to_delete: + logger.info(f"No {memory_type} memories to remove for user {user_name}") + return + + # Build delete query + placeholders = ",".join(["%s"] * len(ids_to_delete)) + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE id IN ({placeholders}) + """ + delete_params = ids_to_delete + + # Execute deletion + cursor.execute(delete_query, delete_params) + deleted_count = cursor.rowcount + logger.info( + f"Removed {deleted_count} oldest {memory_type} memories, " + f"keeping {keep_latest} latest for user {user_name}, " + f"removed ids: {ids_to_delete}" + ) except Exception as e: logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: @@ -663,17 +478,13 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(self.format_param_value(user_name)) - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query, params) except Exception as e: logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def delete_node(self, id: str, user_name: str | None = None) -> None: @@ -694,74 +505,62 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(self.format_param_value(user_name)) - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query, params) except Exception as e: logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def create_extension(self): extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Ensure in the correct database context - cursor.execute("SELECT current_database();") - current_db = cursor.fetchone()[0] - logger.info(f"Current database context: {current_db}") + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Ensure in the correct database context + cursor.execute("SELECT current_database();") + current_db = cursor.fetchone()[0] + logger.info(f"Current database context: {current_db}") - for ext_name, ext_desc in extensions: - try: - cursor.execute(f"create extension if not exists {ext_name};") - logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") - except Exception as e: - if "already exists" in str(e): - logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") - else: - logger.warning( - f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" - ) - logger.error( - f"Failed to create extension '{ext_name}': {e}", exc_info=True - ) + for ext_name, ext_desc in extensions: + try: + cursor.execute(f"create extension if not exists {ext_name};") + logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") + except Exception as e: + if "already exists" in str(e): + logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") + else: + logger.warning( + f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" + ) + logger.error( + f"Failed to create extension '{ext_name}': {e}", exc_info=True + ) except Exception as e: logger.warning(f"Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def create_graph(self): - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(f""" - SELECT COUNT(*) FROM ag_catalog.ag_graph - WHERE name = '{self.db_name}_graph'; - """) - graph_exists = cursor.fetchone()[0] > 0 + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(f""" + SELECT COUNT(*) FROM ag_catalog.ag_graph + WHERE name = '{self.db_name}_graph'; + """) + graph_exists = cursor.fetchone()[0] > 0 - if graph_exists: - logger.info(f"Graph '{self.db_name}_graph' already exists.") - else: - cursor.execute(f"select create_graph('{self.db_name}_graph');") - logger.info(f"Graph database '{self.db_name}_graph' created.") + if graph_exists: + logger.info(f"Graph '{self.db_name}_graph' already exists.") + else: + cursor.execute(f"select create_graph('{self.db_name}_graph');") + logger.info(f"Graph database '{self.db_name}_graph' created.") except Exception as e: logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def create_edge(self): @@ -770,21 +569,18 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: - conn = None logger.info(f"Creating elabel: {label_name}") try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") - logger.info(f"Successfully created elabel: {label_name}") + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") + logger.info(f"Successfully created elabel: {label_name}") except Exception as e: if "already exists" in str(e): logger.info(f"Label '{label_name}' already exists, skipping.") else: logger.warning(f"Failed to create label {label_name}: {e}") logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def add_edge( @@ -825,20 +621,17 @@ def add_edge( ); """ logger.info(f"polardb [add_edge] query: {query}, properties: {json.dumps(properties)}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) - logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) + logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") - elapsed_time = time.time() - start_time - logger.info(f" polardb [add_edge] insert completed time in {elapsed_time:.2f}s") + elapsed_time = time.time() - start_time + logger.info(f" polardb [add_edge] insert completed time in {elapsed_time:.2f}s") except Exception as e: logger.error(f"Failed to insert edge: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def delete_edge(self, source_id: str, target_id: str, type: str) -> None: @@ -853,14 +646,10 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: DELETE FROM "{self.db_name}_graph"."Edges" WHERE source_id = %s AND target_id = %s AND edge_type = %s """ - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, (source_id, target_id, type)) logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") - finally: - self._return_connection(conn) @timed def edge_exists_old( @@ -915,15 +704,11 @@ def edge_exists_old( WHERE {where_clause} LIMIT 1 """ - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return result is not None - finally: - self._return_connection(conn) @timed def edge_exists( @@ -971,15 +756,11 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() return result is not None and result[0] is not None - finally: - self._return_connection(conn) @timed def get_node( @@ -1015,60 +796,57 @@ def get_node( params.append(self.format_param_value(user_name)) logger.info(f"polardb [get_node] query: {query},params: {params}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() - if result: - if include_embedding: - _, properties_json, embedding_json = result - else: - _, properties_json = result - embedding_json = None + if result: + if include_embedding: + _, properties_json, embedding_json = result + else: + _, properties_json = result + embedding_json = None - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {id}") - properties = {} - else: - properties = properties_json if properties_json else {} + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {id}") + properties = {} + else: + properties = properties_json if properties_json else {} - # Parse embedding from JSONB if it exists and include_embedding is True - if include_embedding and embedding_json is not None: - try: - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {id}") + # Parse embedding from JSONB if it exists and include_embedding is True + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {id}") - elapsed_time = time.time() - start_time - logger.info( - f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s" - ) - return self._parse_node( - { - "id": id, - "memory": properties.get("memory", ""), - **properties, - } - ) - return None + elapsed_time = time.time() - start_time + logger.info( + f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s" + ) + return self._parse_node( + { + "id": id, + "memory": properties.get("memory", ""), + **properties, + } + ) + return None except Exception as e: logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) return None - finally: - self._return_connection(conn) @timed def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, Any]]: @@ -1105,9 +883,7 @@ def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, logger.info(f"get_nodes query:{query},params:{params}") - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1147,8 +923,6 @@ def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, ) ) return nodes - finally: - self._return_connection(conn) @timed def get_edges_old( @@ -1366,66 +1140,63 @@ def get_children_with_embeddings( WHERE t.cid::graphid = m.id; """ - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() - children = [] - for row in results: - # Handle child_id - remove possible quotes - child_id_raw = row[0].value if hasattr(row[0], "value") else str(row[0]) - if isinstance(child_id_raw, str): - # If string starts and ends with quotes, remove quotes - if child_id_raw.startswith('"') and child_id_raw.endswith('"'): - child_id = child_id_raw[1:-1] + children = [] + for row in results: + # Handle child_id - remove possible quotes + child_id_raw = row[0].value if hasattr(row[0], "value") else str(row[0]) + if isinstance(child_id_raw, str): + # If string starts and ends with quotes, remove quotes + if child_id_raw.startswith('"') and child_id_raw.endswith('"'): + child_id = child_id_raw[1:-1] + else: + child_id = child_id_raw else: - child_id = child_id_raw - else: - child_id = str(child_id_raw) + child_id = str(child_id_raw) - # Handle embedding - get from database embedding column - embedding_raw = row[1] - embedding = [] - if embedding_raw is not None: - try: - if isinstance(embedding_raw, str): - # If it is a JSON string, parse it - embedding = json.loads(embedding_raw) - elif isinstance(embedding_raw, list): - # If already a list, use directly - embedding = embedding_raw + # Handle embedding - get from database embedding column + embedding_raw = row[1] + embedding = [] + if embedding_raw is not None: + try: + if isinstance(embedding_raw, str): + # If it is a JSON string, parse it + embedding = json.loads(embedding_raw) + elif isinstance(embedding_raw, list): + # If already a list, use directly + embedding = embedding_raw + else: + # Try converting to list + embedding = list(embedding_raw) + except (json.JSONDecodeError, TypeError, ValueError) as e: + logger.warning( + f"Failed to parse embedding for child node {child_id}: {e}" + ) + embedding = [] + + # Handle memory - remove possible quotes + memory_raw = row[2].value if hasattr(row[2], "value") else str(row[2]) + if isinstance(memory_raw, str): + # If string starts and ends with quotes, remove quotes + if memory_raw.startswith('"') and memory_raw.endswith('"'): + memory = memory_raw[1:-1] else: - # Try converting to list - embedding = list(embedding_raw) - except (json.JSONDecodeError, TypeError, ValueError) as e: - logger.warning( - f"Failed to parse embedding for child node {child_id}: {e}" - ) - embedding = [] - - # Handle memory - remove possible quotes - memory_raw = row[2].value if hasattr(row[2], "value") else str(row[2]) - if isinstance(memory_raw, str): - # If string starts and ends with quotes, remove quotes - if memory_raw.startswith('"') and memory_raw.endswith('"'): - memory = memory_raw[1:-1] + memory = memory_raw else: - memory = memory_raw - else: - memory = str(memory_raw) + memory = str(memory_raw) - children.append({"id": child_id, "embedding": embedding, "memory": memory}) + children.append({"id": child_id, "embedding": embedding, "memory": memory}) - return children + return children except Exception as e: logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: """Get the path of nodes from source to target within a limited depth.""" @@ -1507,137 +1278,134 @@ def get_subgraph( RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1) $$ ) as (centers agtype, neighbors agtype, rels agtype); """ - conn = None logger.info(f"[get_subgraph] Query: {query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() - - if not results: - return {"core_node": None, "neighbors": [], "edges": []} + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() - # Merge results from all UNION ALL rows - all_centers_list = [] - all_neighbors_list = [] - all_edges_list = [] + if not results: + return {"core_node": None, "neighbors": [], "edges": []} - for result in results: - if not result or not result[0]: - continue + # Merge results from all UNION ALL rows + all_centers_list = [] + all_neighbors_list = [] + all_edges_list = [] - centers_data = result[0] if result[0] else "[]" - neighbors_data = result[1] if result[1] else "[]" - edges_data = result[2] if result[2] else "[]" + for result in results: + if not result or not result[0]: + continue - # Parse JSON data - try: - # Clean ::vertex and ::edge suffixes in data - if isinstance(centers_data, str): - centers_data = centers_data.replace("::vertex", "") - if isinstance(neighbors_data, str): - neighbors_data = neighbors_data.replace("::vertex", "") - if isinstance(edges_data, str): - edges_data = edges_data.replace("::edge", "") - - centers_list = ( - json.loads(centers_data) - if isinstance(centers_data, str) - else centers_data - ) - neighbors_list = ( - json.loads(neighbors_data) - if isinstance(neighbors_data, str) - else neighbors_data - ) - edges_list = ( - json.loads(edges_data) if isinstance(edges_data, str) else edges_data - ) + centers_data = result[0] if result[0] else "[]" + neighbors_data = result[1] if result[1] else "[]" + edges_data = result[2] if result[2] else "[]" - # Collect data from this row - if isinstance(centers_list, list): - all_centers_list.extend(centers_list) - if isinstance(neighbors_list, list): - all_neighbors_list.extend(neighbors_list) - if isinstance(edges_list, list): - all_edges_list.extend(edges_list) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON data: {e}") - continue + # Parse JSON data + try: + # Clean ::vertex and ::edge suffixes in data + if isinstance(centers_data, str): + centers_data = centers_data.replace("::vertex", "") + if isinstance(neighbors_data, str): + neighbors_data = neighbors_data.replace("::vertex", "") + if isinstance(edges_data, str): + edges_data = edges_data.replace("::edge", "") + + centers_list = ( + json.loads(centers_data) + if isinstance(centers_data, str) + else centers_data + ) + neighbors_list = ( + json.loads(neighbors_data) + if isinstance(neighbors_data, str) + else neighbors_data + ) + edges_list = ( + json.loads(edges_data) if isinstance(edges_data, str) else edges_data + ) - # Deduplicate centers by ID - centers_dict = {} - for center_data in all_centers_list: - if isinstance(center_data, dict) and "properties" in center_data: - center_id_key = center_data["properties"].get("id") - if center_id_key and center_id_key not in centers_dict: - centers_dict[center_id_key] = center_data - - # Parse center node (use first center) - core_node = None - if centers_dict: - center_data = next(iter(centers_dict.values())) - if isinstance(center_data, dict) and "properties" in center_data: - core_node = self._parse_node(center_data["properties"]) - - # Deduplicate neighbors by ID - neighbors_dict = {} - for neighbor_data in all_neighbors_list: - if isinstance(neighbor_data, dict) and "properties" in neighbor_data: - neighbor_id = neighbor_data["properties"].get("id") - if neighbor_id and neighbor_id not in neighbors_dict: - neighbors_dict[neighbor_id] = neighbor_data - - # Parse neighbor nodes - neighbors = [] - for neighbor_data in neighbors_dict.values(): - if isinstance(neighbor_data, dict) and "properties" in neighbor_data: - neighbor_parsed = self._parse_node(neighbor_data["properties"]) - neighbors.append(neighbor_parsed) - - # Deduplicate edges by (source, target, type) - edges_dict = {} - for edge_group in all_edges_list: - if isinstance(edge_group, list): - for edge_data in edge_group: - if isinstance(edge_data, dict): - edge_key = ( - edge_data.get("start_id", ""), - edge_data.get("end_id", ""), - edge_data.get("label", ""), - ) - if edge_key not in edges_dict: - edges_dict[edge_key] = { - "type": edge_data.get("label", ""), - "source": edge_data.get("start_id", ""), - "target": edge_data.get("end_id", ""), - } - elif isinstance(edge_group, dict): - # Handle single edge (not in a list) - edge_key = ( - edge_group.get("start_id", ""), - edge_group.get("end_id", ""), - edge_group.get("label", ""), - ) - if edge_key not in edges_dict: - edges_dict[edge_key] = { - "type": edge_group.get("label", ""), - "source": edge_group.get("start_id", ""), - "target": edge_group.get("end_id", ""), - } + # Collect data from this row + if isinstance(centers_list, list): + all_centers_list.extend(centers_list) + if isinstance(neighbors_list, list): + all_neighbors_list.extend(neighbors_list) + if isinstance(edges_list, list): + all_edges_list.extend(edges_list) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON data: {e}") + continue + + # Deduplicate centers by ID + centers_dict = {} + for center_data in all_centers_list: + if isinstance(center_data, dict) and "properties" in center_data: + center_id_key = center_data["properties"].get("id") + if center_id_key and center_id_key not in centers_dict: + centers_dict[center_id_key] = center_data + + # Parse center node (use first center) + core_node = None + if centers_dict: + center_data = next(iter(centers_dict.values())) + if isinstance(center_data, dict) and "properties" in center_data: + core_node = self._parse_node(center_data["properties"]) + + # Deduplicate neighbors by ID + neighbors_dict = {} + for neighbor_data in all_neighbors_list: + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_id = neighbor_data["properties"].get("id") + if neighbor_id and neighbor_id not in neighbors_dict: + neighbors_dict[neighbor_id] = neighbor_data + + # Parse neighbor nodes + neighbors = [] + for neighbor_data in neighbors_dict.values(): + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_parsed = self._parse_node(neighbor_data["properties"]) + neighbors.append(neighbor_parsed) + + # Deduplicate edges by (source, target, type) + edges_dict = {} + for edge_group in all_edges_list: + if isinstance(edge_group, list): + for edge_data in edge_group: + if isinstance(edge_data, dict): + edge_key = ( + edge_data.get("start_id", ""), + edge_data.get("end_id", ""), + edge_data.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_data.get("label", ""), + "source": edge_data.get("start_id", ""), + "target": edge_data.get("end_id", ""), + } + elif isinstance(edge_group, dict): + # Handle single edge (not in a list) + edge_key = ( + edge_group.get("start_id", ""), + edge_group.get("end_id", ""), + edge_group.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_group.get("label", ""), + "source": edge_group.get("start_id", ""), + "target": edge_group.get("end_id", ""), + } - edges = list(edges_dict.values()) + edges = list(edges_dict.values()) - return self._convert_graph_edges( - {"core_node": core_node, "neighbors": neighbors, "edges": edges} - ) + return self._convert_graph_edges( + {"core_node": core_node, "neighbors": neighbors, "edges": edges} + ) except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) return {"core_node": None, "neighbors": [], "edges": []} - finally: - self._return_connection(conn) def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" @@ -1751,9 +1519,7 @@ def search_by_keywords_like( logger.info( f"[search_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1772,8 +1538,6 @@ def search_by_keywords_like( f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" ) return output - finally: - self._return_connection(conn) @timed def search_by_keywords_tfidf( @@ -1859,9 +1623,7 @@ def search_by_keywords_tfidf( logger.info( f"[search_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1881,8 +1643,6 @@ def search_by_keywords_tfidf( f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" ) return output - finally: - self._return_connection(conn) @timed def search_by_fulltext( @@ -2007,9 +1767,7 @@ def search_by_fulltext( """ params = [tsquery_string] logger.info(f"[search_by_fulltext] query: {query}, params: {params}") - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -2037,8 +1795,6 @@ def search_by_fulltext( f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s" ) return output[:top_k] - finally: - self._return_connection(conn) @timed def search_by_embedding( @@ -2135,9 +1891,7 @@ def search_by_embedding( logger.info(f"[search_by_embedding] query: {query}, params: {params}") - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: if params: cursor.execute(query, params) @@ -2169,8 +1923,6 @@ def search_by_embedding( f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" ) return output[:top_k] - finally: - self._return_connection(conn) @timed def get_by_metadata( @@ -2285,18 +2037,15 @@ def get_by_metadata( """ ids = [] - conn = None logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - ids = [str(item[0]).strip('"') for item in results] + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + ids = [str(item[0]).strip('"') for item in results] except Exception as e: logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}") - finally: - self._return_connection(conn) return ids @@ -2448,36 +2197,33 @@ def get_grouped_counts( {where_clause} GROUP BY {", ".join(group_by_fields)} """ - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Handle parameterized query - if params and isinstance(params, list): - cursor.execute(query, params) - else: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Handle parameterized query + if params and isinstance(params, list): + cursor.execute(query, params) + else: + cursor.execute(query) + results = cursor.fetchall() - output = [] - for row in results: - group_values = {} - for i, field in enumerate(group_fields): - value = row[i] - if hasattr(value, "value"): - group_values[field] = value.value - else: - group_values[field] = str(value) - count_value = row[-1] # Last column is count - output.append({**group_values, "count": int(count_value)}) + output = [] + for row in results: + group_values = {} + for i, field in enumerate(group_fields): + value = row[i] + if hasattr(value, "value"): + group_values[field] = value.value + else: + group_values[field] = str(value) + count_value = row[-1] # Last column is count + output.append({**group_values, "count": int(count_value)}) - return output + return output except Exception as e: logger.error(f"Failed to get grouped counts: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def deduplicate_nodes(self) -> None: """Deduplicate redundant or semantically similar nodes.""" @@ -2509,14 +2255,10 @@ def clear(self, user_name: str | None = None) -> None: DETACH DELETE n $$) AS (result agtype) """ - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query) logger.info("Cleared all nodes from database.") - finally: - self._return_connection(conn) except Exception as e: logger.error(f"[ERROR] Failed to clear database: {e}") @@ -2585,132 +2327,129 @@ def export_graph( else: offset = None - conn = None try: - conn = self._get_connection() - # Build WHERE conditions - where_conditions = [] - if user_name: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" - ) - if user_id: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" - ) + with self._get_connection() as conn: + # Build WHERE conditions + where_conditions = [] + if user_name: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + if user_id: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" + ) - # Add memory_type filter condition - if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: - # Escape memory_type values and build IN clause - memory_type_values = [] - for mt in memory_type: - # Escape single quotes in memory_type value - escaped_memory_type = str(mt).replace("'", "''") - memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") - memory_type_in_clause = ", ".join(memory_type_values) - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" - ) + # Add memory_type filter condition + if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: + # Escape memory_type values and build IN clause + memory_type_values = [] + for mt in memory_type: + # Escape single quotes in memory_type value + escaped_memory_type = str(mt).replace("'", "''") + memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") + memory_type_in_clause = ", ".join(memory_type_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" + ) - # Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list - if status is None: - # Default behavior: exclude deleted entries - where_conditions.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype" - ) - elif isinstance(status, list) and len(status) > 0: - # status IN (list) - status_values = [] - for st in status: - escaped_status = str(st).replace("'", "''") - status_values.append(f"'\"{escaped_status}\"'::agtype") - status_in_clause = ", ".join(status_values) - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})" - ) + # Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list + if status is None: + # Default behavior: exclude deleted entries + where_conditions.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype" + ) + elif isinstance(status, list) and len(status) > 0: + # status IN (list) + status_values = [] + for st in status: + escaped_status = str(st).replace("'", "''") + status_values.append(f"'\"{escaped_status}\"'::agtype") + status_in_clause = ", ".join(status_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})" + ) - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[export_graph] filter_conditions: {filter_conditions}") - if filter_conditions: - where_conditions.extend(filter_conditions) - - where_clause = "" - if where_conditions: - where_clause = f"WHERE {' AND '.join(where_conditions)}" - - # Get total count of nodes before pagination - count_node_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - {where_clause} - """ - logger.info(f"[export_graph nodes count] Query: {count_node_query}") - with conn.cursor() as cursor: - cursor.execute(count_node_query) - total_nodes = cursor.fetchone()[0] - - # Export nodes - # Build pagination clause if needed - pagination_clause = "" - if use_pagination: - pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - - if include_embedding: - node_query = f""" - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} - """ - else: - node_query = f""" - SELECT id, properties + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[export_graph] filter_conditions: {filter_conditions}") + if filter_conditions: + where_conditions.extend(filter_conditions) + + where_clause = "" + if where_conditions: + where_clause = f"WHERE {' AND '.join(where_conditions)}" + + # Get total count of nodes before pagination + count_node_query = f""" + SELECT COUNT(*) FROM "{self.db_name}_graph"."Memory" {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} """ - logger.info(f"[export_graph nodes] Query: {node_query}") - with conn.cursor() as cursor: - cursor.execute(node_query) - node_results = cursor.fetchall() - nodes = [] - - for row in node_results: - if include_embedding: - """row is (id, properties, embedding)""" - _, properties_json, embedding_json = row - else: - """row is (id, properties)""" - _, properties_json = row - embedding_json = None + logger.info(f"[export_graph nodes count] Query: {count_node_query}") + with conn.cursor() as cursor: + cursor.execute(count_node_query) + total_nodes = cursor.fetchone()[0] + + # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + + if include_embedding: + node_query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + else: + node_query = f""" + SELECT id, properties + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + logger.info(f"[export_graph nodes] Query: {node_query}") + with conn.cursor() as cursor: + cursor.execute(node_query) + node_results = cursor.fetchall() + nodes = [] + + for row in node_results: + if include_embedding: + """row is (id, properties, embedding)""" + _, properties_json, embedding_json = row + else: + """row is (id, properties)""" + _, properties_json = row + embedding_json = None - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except json.JSONDecodeError: - properties = {} - else: - properties = properties_json if properties_json else {} + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except json.JSONDecodeError: + properties = {} + else: + properties = properties_json if properties_json else {} - # Remove embedding field if include_embedding is False - if not include_embedding: - properties.pop("embedding", None) - elif include_embedding and embedding_json is not None: - properties["embedding"] = embedding_json + # Remove embedding field if include_embedding is False + if not include_embedding: + properties.pop("embedding", None) + elif include_embedding and embedding_json is not None: + properties["embedding"] = embedding_json - nodes.append(self._parse_node(properties)) + nodes.append(self._parse_node(properties)) except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e - finally: - self._return_connection(conn) edges = [] return { @@ -2732,13 +2471,9 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: RETURN count(n) $$) AS (count agtype) """ - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: result = self.execute_query(query, conn) return int(result.one_or_none()["count"].value) - finally: - self._return_connection(conn) @timed def get_all_memory_items( @@ -2825,34 +2560,31 @@ def get_all_memory_items( """ nodes = [] node_ids = set() - conn = None logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() - for row in results: - """ - if isinstance(row, (list, tuple)) and len(row) >= 2: - """ - if isinstance(row, list | tuple) and len(row) >= 2: - embedding_val, node_val = row[0], row[1] - else: - embedding_val, node_val = None, row[0] + for row in results: + """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] - node = self._build_node_from_agtype(node_val, embedding_val) - if node: - node_id = node["id"] - if node_id not in node_ids: - nodes.append(node) - node_ids.add(node_id) + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + nodes.append(node) + node_ids.add(node_id) except Exception as e: logger.warning(f"Failed to get memories: {e}", exc_info=True) - finally: - self._return_connection(conn) return nodes else: @@ -2879,29 +2611,26 @@ def get_all_memory_items( """ nodes = [] - conn = None logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() - for row in results: - """ - if isinstance(row[0], str): - memory_data = json.loads(row[0]) - else: - memory_data = row[0] # 如果已经是字典,直接使用 - nodes.append(self._parse_node(memory_data)) - """ - memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] - nodes.append(self._parse_node(memory_data)) + for row in results: + """ + if isinstance(row[0], str): + memory_data = json.loads(row[0]) + else: + memory_data = row[0] # 如果已经是字典,直接使用 + nodes.append(self._parse_node(memory_data)) + """ + memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] + nodes.append(self._parse_node(memory_data)) except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) - finally: - self._return_connection(conn) return nodes @@ -3104,88 +2833,85 @@ def get_structure_optimization_candidates( candidates = [] node_ids = set() - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - logger.info(f"Found {len(results)} structure optimization candidates") - for row in results: - if include_embedding: - # When include_embedding=True, return full node object - """ - if isinstance(row, (list, tuple)) and len(row) >= 2: - """ - if isinstance(row, list | tuple) and len(row) >= 2: - embedding_val, node_val = row[0], row[1] + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + logger.info(f"Found {len(results)} structure optimization candidates") + for row in results: + if include_embedding: + # When include_embedding=True, return full node object + """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] + + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) else: - embedding_val, node_val = None, row[0] - - node = self._build_node_from_agtype(node_val, embedding_val) - if node: - node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) - else: - # When include_embedding=False, return field dictionary - # Define field names matching the RETURN clause - field_names = [ - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - "graph_id", - ] - - # Convert row to dictionary - node_data = {} - for i, field_name in enumerate(field_names): - if i < len(row): - value = row[i] - # Handle special fields - if field_name in ["tags", "sources", "usage"] and isinstance( - value, str - ): - try: - # Try parsing JSON string - node_data[field_name] = json.loads(value) - except (json.JSONDecodeError, TypeError): + # When include_embedding=False, return field dictionary + # Define field names matching the RETURN clause + field_names = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "usage", + "background", + "graph_id", + ] + + # Convert row to dictionary + node_data = {} + for i, field_name in enumerate(field_names): + if i < len(row): + value = row[i] + # Handle special fields + if field_name in ["tags", "sources", "usage"] and isinstance( + value, str + ): + try: + # Try parsing JSON string + node_data[field_name] = json.loads(value) + except (json.JSONDecodeError, TypeError): + node_data[field_name] = value + else: node_data[field_name] = value - else: - node_data[field_name] = value - # Parse node using _parse_node_new - try: - node = self._parse_node_new(node_data) - node_id = node["id"] + # Parse node using _parse_node_new + try: + node = self._parse_node_new(node_data) + node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) - logger.debug(f"Parsed node successfully: {node_id}") - except Exception as e: - logger.error(f"Failed to parse node: {e}") + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + logger.debug(f"Parsed node successfully: {node_id}") + except Exception as e: + logger.error(f"Failed to parse node: {e}") except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) - finally: - self._return_connection(conn) return candidates @@ -3355,60 +3081,57 @@ def add_node( elif len(embedding_vector) == 768: embedding_column = "embedding_768" - conn = None insert_query = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Delete existing record first (if any) - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(delete_query, (id,)) - # - get_graph_id_query = f""" - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(get_graph_id_query, (id,)) - graph_id = cursor.fetchone()[0] - properties["graph_id"] = str(graph_id) - - # Then insert new record - if embedding_vector: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s, - %s - ) + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) """ - cursor.execute( - insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) - ) - logger.info( - f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" - ) - else: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s + cursor.execute(delete_query, (id,)) + # + get_graph_id_query = f""" + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(get_graph_id_query, (id,)) + graph_id = cursor.fetchone()[0] + properties["graph_id"] = str(graph_id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) ) - """ - cursor.execute(insert_query, (id, json.dumps(properties))) - logger.info( - f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" - ) + logger.info( + f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info( + f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + if insert_query: + logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") except Exception as e: logger.error(f"[add_node] Failed to add node: {e}", exc_info=True) raise - finally: - if insert_query: - logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") - self._return_connection(conn) @timed def add_nodes_batch( @@ -3529,129 +3252,126 @@ def add_nodes_batch( nodes_by_embedding_column[col] = [] nodes_by_embedding_column[col].append(node) - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Process each group separately - for embedding_column, nodes_group in nodes_by_embedding_column.items(): - # Batch delete existing records using IN clause - ids_to_delete = [node["id"] for node in nodes_group] - if ids_to_delete: - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id IN ( - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, unnest(%s::text[])::cstring) - ) + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Process each group separately + for embedding_column, nodes_group in nodes_by_embedding_column.items(): + # Batch delete existing records using IN clause + ids_to_delete = [node["id"] for node in nodes_group] + if ids_to_delete: + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id IN ( + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, unnest(%s::text[])::cstring) + ) + """ + cursor.execute(delete_query, (ids_to_delete,)) + + # Batch get graph_ids for all nodes + get_graph_ids_query = f""" + SELECT + id_val, + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, id_val::text::cstring) as graph_id + FROM unnest(%s::text[]) as id_val """ - cursor.execute(delete_query, (ids_to_delete,)) - - # Batch get graph_ids for all nodes - get_graph_ids_query = f""" - SELECT - id_val, - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, id_val::text::cstring) as graph_id - FROM unnest(%s::text[]) as id_val - """ - cursor.execute(get_graph_ids_query, (ids_to_delete,)) - graph_id_map = {row[0]: row[1] for row in cursor.fetchall()} + cursor.execute(get_graph_ids_query, (ids_to_delete,)) + graph_id_map = {row[0]: row[1] for row in cursor.fetchall()} - # Add graph_id to properties - for node in nodes_group: - graph_id = graph_id_map.get(node["id"]) - if graph_id: - node["properties"]["graph_id"] = str(graph_id) + # Add graph_id to properties + for node in nodes_group: + graph_id = graph_id_map.get(node["id"]) + if graph_id: + node["properties"]["graph_id"] = str(graph_id) - # Use PREPARE/EXECUTE for efficient batch insert - # Generate unique prepare statement name to avoid conflicts - prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" + # Use PREPARE/EXECUTE for efficient batch insert + # Generate unique prepare statement name to avoid conflicts + prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" - try: - if embedding_column and any( - node["embedding_vector"] for node in nodes_group - ): - # PREPARE statement for insert with embedding - prepare_query = f""" - PREPARE {prepare_name} AS - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), - $2::text::agtype, - $3::vector + try: + if embedding_column and any( + node["embedding_vector"] for node in nodes_group + ): + # PREPARE statement for insert with embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), + $2::text::agtype, + $3::vector + ) + """ + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" ) - """ - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" - ) - cursor.execute(prepare_query) + cursor.execute(prepare_query) - # Execute prepared statement for each node - for node in nodes_group: - properties_json = json.dumps(node["properties"]) - embedding_json = ( - json.dumps(node["embedding_vector"]) - if node["embedding_vector"] - else None - ) + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) + embedding_json = ( + json.dumps(node["embedding_vector"]) + if node["embedding_vector"] + else None + ) - cursor.execute( - f"EXECUTE {prepare_name}(%s, %s, %s)", - (node["id"], properties_json, embedding_json), + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s, %s)", + (node["id"], properties_json, embedding_json), + ) + else: + # PREPARE statement for insert without embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), + $2::text::agtype + ) + """ + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" ) - else: - # PREPARE statement for insert without embedding - prepare_query = f""" - PREPARE {prepare_name} AS - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), - $2::text::agtype + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" ) - """ - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" - ) - cursor.execute(prepare_query) + cursor.execute(prepare_query) - # Execute prepared statement for each node - for node in nodes_group: - properties_json = json.dumps(node["properties"]) + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) - cursor.execute( - f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) + ) + finally: + # DEALLOCATE prepared statement (always execute, even on error) + try: + cursor.execute(f"DEALLOCATE {prepare_name}") + logger.info( + f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" + ) + except Exception as dealloc_error: + logger.warning( + f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" ) - finally: - # DEALLOCATE prepared statement (always execute, even on error) - try: - cursor.execute(f"DEALLOCATE {prepare_name}") - logger.info( - f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" - ) - except Exception as dealloc_error: - logger.warning( - f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" - ) - logger.info( - f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" - ) - elapsed_time = time.time() - batch_start_time - logger.info( - f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" - ) + logger.info( + f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" + ) + elapsed_time = time.time() - batch_start_time + logger.info( + f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" + ) except Exception as e: logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) raise - finally: - self._return_connection(conn) def _build_node_from_agtype(self, node_agtype, embedding=None): """ @@ -3763,60 +3483,57 @@ def get_neighbors_by_tag( logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - - nodes_with_overlap = [] - for row in results: - node_id, properties_json, embedding_json = row - properties = properties_json if properties_json else {} - - # Parse embedding - if include_embedding and embedding_json is not None: - try: - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() - # Compute tag overlap - node_tags = properties.get("tags", []) - if isinstance(node_tags, str): - try: - node_tags = json.loads(node_tags) - except (json.JSONDecodeError, TypeError): - node_tags = [] + nodes_with_overlap = [] + for row in results: + node_id, properties_json, embedding_json = row + properties = properties_json if properties_json else {} - overlap_tags = [tag for tag in tags if tag in node_tags] - overlap_count = len(overlap_tags) + # Parse embedding + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") - if overlap_count >= min_overlap: - node_data = self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } - ) - nodes_with_overlap.append((node_data, overlap_count)) + # Compute tag overlap + node_tags = properties.get("tags", []) + if isinstance(node_tags, str): + try: + node_tags = json.loads(node_tags) + except (json.JSONDecodeError, TypeError): + node_tags = [] + + overlap_tags = [tag for tag in tags if tag in node_tags] + overlap_count = len(overlap_tags) + + if overlap_count >= min_overlap: + node_data = self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + nodes_with_overlap.append((node_data, overlap_count)) - # Sort by overlap count and return top_k items - nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) - return [node for node, _ in nodes_with_overlap[:top_k]] + # Sort by overlap count and return top_k items + nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) + return [node for node, _ in nodes_with_overlap[:top_k]] except Exception as e: logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def get_neighbors_by_tag_ccl( self, @@ -4075,59 +3792,55 @@ def get_edges( $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ logger.info(f"get_edges query:{query}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() - edges = [] - for row in results: - # Extract and clean from_id - from_id_raw = row[0].value if hasattr(row[0], "value") else row[0] - if ( - isinstance(from_id_raw, str) - and from_id_raw.startswith('"') - and from_id_raw.endswith('"') - ): - from_id = from_id_raw[1:-1] - else: - from_id = str(from_id_raw) - - # Extract and clean to_id - to_id_raw = row[1].value if hasattr(row[1], "value") else row[1] - if ( - isinstance(to_id_raw, str) - and to_id_raw.startswith('"') - and to_id_raw.endswith('"') - ): - to_id = to_id_raw[1:-1] - else: - to_id = str(to_id_raw) - - # Extract and clean edge_type - edge_type_raw = row[2].value if hasattr(row[2], "value") else row[2] - if ( - isinstance(edge_type_raw, str) - and edge_type_raw.startswith('"') - and edge_type_raw.endswith('"') - ): - edge_type = edge_type_raw[1:-1] - else: - edge_type = str(edge_type_raw) + edges = [] + for row in results: + # Extract and clean from_id + from_id_raw = row[0].value if hasattr(row[0], "value") else row[0] + if ( + isinstance(from_id_raw, str) + and from_id_raw.startswith('"') + and from_id_raw.endswith('"') + ): + from_id = from_id_raw[1:-1] + else: + from_id = str(from_id_raw) + + # Extract and clean to_id + to_id_raw = row[1].value if hasattr(row[1], "value") else row[1] + if ( + isinstance(to_id_raw, str) + and to_id_raw.startswith('"') + and to_id_raw.endswith('"') + ): + to_id = to_id_raw[1:-1] + else: + to_id = str(to_id_raw) + + # Extract and clean edge_type + edge_type_raw = row[2].value if hasattr(row[2], "value") else row[2] + if ( + isinstance(edge_type_raw, str) + and edge_type_raw.startswith('"') + and edge_type_raw.endswith('"') + ): + edge_type = edge_type_raw[1:-1] + else: + edge_type = str(edge_type_raw) - edges.append({"from": from_id, "to": to_id, "type": edge_type}) - elapsed_time = time.time() - start_time - logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s") - return edges + edges.append({"from": from_id, "to": to_id, "type": edge_type}) + elapsed_time = time.time() - start_time + logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s") + return edges except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) - def _convert_graph_edges(self, core_node: dict) -> dict: import copy @@ -5132,74 +4845,70 @@ def delete_node_by_prams( ) return 0 - conn = None total_deleted_count = 0 try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Build WHERE conditions list - where_conditions = [] - - # Add memory_ids conditions - if memory_ids: - logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") - id_conditions = [] - for node_id in memory_ids: - id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - where_conditions.append(f"({' OR '.join(id_conditions)})") - - # Add file_ids conditions - if file_ids: - logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") - file_id_conditions = [] - for file_id in file_ids: - file_id_conditions.append( - f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" - ) - where_conditions.append(f"({' OR '.join(file_id_conditions)})") + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Build WHERE conditions list + where_conditions = [] + + # Add memory_ids conditions + if memory_ids: + logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") + id_conditions = [] + for node_id in memory_ids: + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" + ) + where_conditions.append(f"({' OR '.join(id_conditions)})") + + # Add file_ids conditions + if file_ids: + logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") + file_id_conditions = [] + for file_id in file_ids: + file_id_conditions.append( + f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" + ) + where_conditions.append(f"({' OR '.join(file_id_conditions)})") - # Add filter conditions - if filter_conditions: - logger.info("[delete_node_by_prams] Processing filter conditions") - where_conditions.extend(filter_conditions) + # Add filter conditions + if filter_conditions: + logger.info("[delete_node_by_prams] Processing filter conditions") + where_conditions.extend(filter_conditions) - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_conditions.append(f"({user_name_where})") + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_conditions.append(f"({user_name_where})") - # Build final WHERE clause - if not where_conditions: - logger.warning("[delete_node_by_prams] No WHERE conditions to delete") - return 0 + # Build final WHERE clause + if not where_conditions: + logger.warning("[delete_node_by_prams] No WHERE conditions to delete") + return 0 - where_clause = " AND ".join(where_conditions) + where_clause = " AND ".join(where_conditions) - # Delete directly without counting - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") - cursor.execute(delete_query) - deleted_count = cursor.rowcount - total_deleted_count = deleted_count + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count = deleted_count - logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") + logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") - elapsed_time = time.time() - batch_start_time - logger.info( - f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" - ) + elapsed_time = time.time() - batch_start_time + logger.info( + f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" + ) except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) raise - finally: - self._return_connection(conn) - logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes") return total_deleted_count @@ -5263,53 +4972,49 @@ def escape_memory_id(mid: str) -> str: """ logger.info(f"[get_user_names_by_memory_ids] query: {query}") - conn = None result_dict = {} try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() - # Build result dictionary from query results - for row in results: - memory_id_raw = row[0] - user_name_raw = row[1] + # Build result dictionary from query results + for row in results: + memory_id_raw = row[0] + user_name_raw = row[1] - # Remove quotes if present - if isinstance(memory_id_raw, str): - memory_id = memory_id_raw.strip('"').strip("'") - else: - memory_id = str(memory_id_raw).strip('"').strip("'") + # Remove quotes if present + if isinstance(memory_id_raw, str): + memory_id = memory_id_raw.strip('"').strip("'") + else: + memory_id = str(memory_id_raw).strip('"').strip("'") - if isinstance(user_name_raw, str): - user_name = user_name_raw.strip('"').strip("'") - else: - user_name = ( - str(user_name_raw).strip('"').strip("'") if user_name_raw else None - ) + if isinstance(user_name_raw, str): + user_name = user_name_raw.strip('"').strip("'") + else: + user_name = ( + str(user_name_raw).strip('"').strip("'") if user_name_raw else None + ) - result_dict[memory_id] = user_name if user_name else None + result_dict[memory_id] = user_name if user_name else None - # Set None for memory_ids that were not found - for mid in normalized_memory_ids: - if mid not in result_dict: - result_dict[mid] = None + # Set None for memory_ids that were not found + for mid in normalized_memory_ids: + if mid not in result_dict: + result_dict[mid] = None - logger.info( - f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " - f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" - ) + logger.info( + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" + ) - return result_dict + return result_dict except Exception as e: logger.error( f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) - def exist_user_name(self, user_name: str) -> dict[str, bool]: """Check if user name exists in the graph. @@ -5342,23 +5047,19 @@ def escape_user_name(un: str) -> str: """ logger.info(f"[exist_user_name] query: {query}") result_dict = {} - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - count = cursor.fetchone()[0] - result = count > 0 - result_dict[user_name] = result - return result_dict + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query) + count = cursor.fetchone()[0] + result = count > 0 + result_dict[user_name] = result + return result_dict except Exception as e: logger.error( f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) - @timed def delete_node_by_mem_cube_id( self, @@ -5381,76 +5082,72 @@ def delete_node_by_mem_cube_id( ) return 0 - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - - user_name_param = self.format_param_value(mem_cube_id) + with self._get_connection() as conn: + with conn.cursor() as cursor: + user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - if hard_delete: - delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" - where_clause = f"{user_name_condition} AND {delete_record_id_condition}" + user_name_param = self.format_param_value(mem_cube_id) - where_params = [user_name_param, self.format_param_value(delete_record_id)] + if hard_delete: + delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" + where_clause = f"{user_name_condition} AND {delete_record_id_condition}" - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {delete_query}") + where_params = [user_name_param, self.format_param_value(delete_record_id)] - cursor.execute(delete_query, where_params) - deleted_count = cursor.rowcount + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {delete_query}") - logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") - return deleted_count - else: - delete_time_empty_condition = ( - "(ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) IS NULL " - "OR ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) = '\"\"'::agtype)" - ) - delete_record_id_empty_condition = ( - "(ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) IS NULL " - "OR ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = '\"\"'::agtype)" - ) - where_clause = f"{user_name_condition} AND {delete_time_empty_condition} AND {delete_record_id_empty_condition}" + cursor.execute(delete_query, where_params) + deleted_count = cursor.rowcount - current_time = datetime.utcnow().isoformat() - update_query = f""" - UPDATE "{self.db_name}_graph"."Memory" - SET properties = ( - properties::jsonb || %s::jsonb - )::text::agtype, - deletetime = %s - WHERE {where_clause} - """ - update_properties = { - "status": "deleted", - "delete_time": current_time, - "delete_record_id": delete_record_id, - } - logger.info( - f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}" - ) - update_params = [json.dumps(update_properties), current_time, user_name_param] - cursor.execute(update_query, update_params) - updated_count = cursor.rowcount + logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") + return deleted_count + else: + delete_time_empty_condition = ( + "(ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) IS NULL " + "OR ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) = '\"\"'::agtype)" + ) + delete_record_id_empty_condition = ( + "(ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) IS NULL " + "OR ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = '\"\"'::agtype)" + ) + where_clause = f"{user_name_condition} AND {delete_time_empty_condition} AND {delete_record_id_empty_condition}" + + current_time = datetime.utcnow().isoformat() + update_query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = ( + properties::jsonb || %s::jsonb + )::text::agtype, + deletetime = %s + WHERE {where_clause} + """ + update_properties = { + "status": "deleted", + "delete_time": current_time, + "delete_record_id": delete_record_id, + } + logger.info( + f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}" + ) + update_params = [json.dumps(update_properties), current_time, user_name_param] + cursor.execute(update_query, update_params) + updated_count = cursor.rowcount - logger.info( - f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" - ) - return updated_count + logger.info( + f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" + ) + return updated_count except Exception as e: logger.error( f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) - @timed def recover_memory_by_mem_cube_id( self, @@ -5476,52 +5173,49 @@ def recover_memory_by_mem_cube_id( f"delete_record_id={delete_record_id}" ) - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" - where_clause = f"{user_name_condition} AND {delete_record_id_condition}" + with self._get_connection() as conn: + with conn.cursor() as cursor: + user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" + where_clause = f"{user_name_condition} AND {delete_record_id_condition}" - where_params = [ - self.format_param_value(mem_cube_id), - self.format_param_value(delete_record_id), - ] + where_params = [ + self.format_param_value(mem_cube_id), + self.format_param_value(delete_record_id), + ] - update_properties = { - "status": "activated", - "delete_record_id": "", - "delete_time": "", - } + update_properties = { + "status": "activated", + "delete_record_id": "", + "delete_time": "", + } - update_query = f""" - UPDATE "{self.db_name}_graph"."Memory" - SET properties = ( - properties::jsonb || %s::jsonb - )::text::agtype, - deletetime = NULL - WHERE {where_clause} - """ + update_query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = ( + properties::jsonb || %s::jsonb + )::text::agtype, + deletetime = NULL + WHERE {where_clause} + """ - logger.info(f"[recover_memory_by_mem_cube_id] Update query: {update_query}") - logger.info( - f"[recover_memory_by_mem_cube_id] update_properties: {update_properties}" - ) + logger.info(f"[recover_memory_by_mem_cube_id] Update query: {update_query}") + logger.info( + f"[recover_memory_by_mem_cube_id] update_properties: {update_properties}" + ) - update_params = [json.dumps(update_properties), *where_params] - cursor.execute(update_query, update_params) - updated_count = cursor.rowcount + update_params = [json.dumps(update_properties), *where_params] + cursor.execute(update_query, update_params) + updated_count = cursor.rowcount - logger.info( - f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes" - ) - return updated_count + logger.info( + f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes" + ) + return updated_count except Exception as e: logger.error( f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) From 4164e6b82745d108b11b21975eda51bb11e60835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 2 Mar 2026 19:44:18 +0800 Subject: [PATCH 02/10] feat:optimzie polardb ThreadedConnectionPool --- src/memos/graph_dbs/polardb.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 841839baf..7ed40d31e 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -196,23 +196,19 @@ def _get_connection(self): """ self._semaphore.acquire() conn = None + broken = False + try: conn = self.connection_pool.getconn() conn.autocommit = True - with conn.cursor() as cur: - cur.execute("SELECT 1") yield conn except Exception: - if conn: - try: - self.connection_pool.putconn(conn, close=True) - except Exception: - pass + broken = True raise finally: if conn: try: - self.connection_pool.putconn(conn) + self.connection_pool.putconn(conn, close=broken) except Exception: pass self._semaphore.release() From 0ebbd5a540e8403b00c8bee42fed00db6150c1a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 2 Mar 2026 19:54:56 +0800 Subject: [PATCH 03/10] feat:optimzie polardb ThreadedConnectionPool --- src/memos/graph_dbs/polardb.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 7ed40d31e..b4a4fee39 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -137,7 +137,8 @@ def __init__(self, config: PolarDBGraphDBConfig): port = config.get("port") user = config.get("user") password = config.get("password") - maxconn = config.get("maxconn", 100) # De + maxconn = config.get("maxconn", 100) + self._connection_wait_timeout = config.get("connection_wait_timeout", 60) else: self.db_name = config.db_name self.user_name = config.user_name @@ -146,8 +147,11 @@ def __init__(self, config: PolarDBGraphDBConfig): user = config.user password = config.password maxconn = config.maxconn if hasattr(config, "maxconn") else 100 + self._connection_wait_timeout = getattr(config, "connection_wait_timeout", 60) - logger.info(f" db_name: {self.db_name} current maxconn is:'{maxconn}'") + logger.info( + f" db_name: {self.db_name} maxconn: {maxconn} connection_wait_timeout: {self._connection_wait_timeout}s" + ) # Create connection pool self.connection_pool = psycopg2.pool.ThreadedConnectionPool( @@ -191,10 +195,15 @@ def _get_config_value(self, key: str, default=None): @contextmanager def _get_connection(self): - """ - 安全获取连接(阻塞等待,不会抛 pool exhausted) - """ - self._semaphore.acquire() + timeout = getattr(self, "_connection_wait_timeout", 60) + if timeout <= 0: + self._semaphore.acquire() + else: + if not self._semaphore.acquire(timeout=timeout): + raise RuntimeError( + f"Connection pool busy: could not acquire a slot within {timeout}s " + "(all connections in use). Consider increasing maxconn or reducing load." + ) conn = None broken = False From eca1c046e5db2833a285919375ea5996557dfa56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 2 Mar 2026 20:22:27 +0800 Subject: [PATCH 04/10] feat:optimzie polardb ThreadedConnectionPool --- src/memos/graph_dbs/polardb.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index b4a4fee39..0aef67385 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -195,31 +195,36 @@ def _get_config_value(self, key: str, default=None): @contextmanager def _get_connection(self): - timeout = getattr(self, "_connection_wait_timeout", 60) + timeout = getattr(self, "_connection_wait_timeout", 5) if timeout <= 0: self._semaphore.acquire() else: if not self._semaphore.acquire(timeout=timeout): + logger.warning(f"Timeout waiting for connection slot ({timeout}s)") raise RuntimeError( - f"Connection pool busy: could not acquire a slot within {timeout}s " - "(all connections in use). Consider increasing maxconn or reducing load." + f"Connection pool busy: could not acquire a slot within {timeout}s (all connections in use)." ) conn = None broken = False try: conn = self.connection_pool.getconn() + logger.debug(f"Acquired connection {id(conn)} from pool") conn.autocommit = True + with conn.cursor() as cur: + cur.execute("SELECT 1") yield conn except Exception: broken = True + logger.error(f"Connection failed or broken: {e}") raise finally: if conn: try: self.connection_pool.putconn(conn, close=broken) - except Exception: - pass + logger.debug(f"Returned connection {id(conn)} to pool (broken={broken})") + except Exception as e: + logger.warning(f"Failed to return connection to pool: {e}") self._semaphore.release() def _ensure_database_exists(self): From 89a6ff72c5e81f36842b3391b4ba0aab9bc2697a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 2 Mar 2026 20:51:17 +0800 Subject: [PATCH 05/10] feat:optimzie polardb ThreadedConnectionPool --- src/memos/graph_dbs/polardb.py | 1897 ++++++++++++++++---------------- 1 file changed, 933 insertions(+), 964 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 0aef67385..a48eaf0f7 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1,9 +1,10 @@ import json import random import textwrap +import threading import time -from contextlib import contextmanager, suppress +from contextlib import contextmanager from datetime import datetime from typing import Any, Literal @@ -14,7 +15,6 @@ from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger from memos.utils import timed -import threading logger = get_logger(__name__) @@ -214,7 +214,7 @@ def _get_connection(self): with conn.cursor() as cur: cur.execute("SELECT 1") yield conn - except Exception: + except Exception as e: broken = True logger.error(f"Connection failed or broken: {e}") raise @@ -241,51 +241,50 @@ def _ensure_database_exists(self): def _create_graph(self): """Create PostgreSQL schema and table for graph storage.""" try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - # Create schema if it doesn't exist - cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') - logger.info(f"Schema '{self.db_name}_graph' ensured.") + with self._get_connection() as conn, conn.cursor() as cursor: + # Create schema if it doesn't exist + cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') + logger.info(f"Schema '{self.db_name}_graph' ensured.") + + # Create Memory table if it doesn't exist + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Memory" ( + id TEXT PRIMARY KEY, + properties JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """) + logger.info(f"Memory table created in schema '{self.db_name}_graph'.") - # Create Memory table if it doesn't exist + # Add embedding column if it doesn't exist (using JSONB for compatibility) + try: cursor.execute(f""" - CREATE TABLE IF NOT EXISTS "{self.db_name}_graph"."Memory" ( - id TEXT PRIMARY KEY, - properties JSONB NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); + ALTER TABLE "{self.db_name}_graph"."Memory" + ADD COLUMN IF NOT EXISTS embedding JSONB; """) - logger.info(f"Memory table created in schema '{self.db_name}_graph'.") + logger.info("Embedding column added to Memory table.") + except Exception as e: + logger.warning(f"Failed to add embedding column: {e}") - # Add embedding column if it doesn't exist (using JSONB for compatibility) - try: - cursor.execute(f""" - ALTER TABLE "{self.db_name}_graph"."Memory" - ADD COLUMN IF NOT EXISTS embedding JSONB; - """) - logger.info("Embedding column added to Memory table.") - except Exception as e: - logger.warning(f"Failed to add embedding column: {e}") + # Create indexes + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); + """) - # Create indexes + # Create vector index for embedding field + try: cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties - ON "{self.db_name}_graph"."Memory" USING GIN (properties); + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); """) + logger.info("Vector index created for Memory table.") + except Exception as e: + logger.warning(f"Vector index creation failed (might not be supported): {e}") - # Create vector index for embedding field - try: - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding - ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) - WITH (lists = 100); - """) - logger.info("Vector index created for Memory table.") - except Exception as e: - logger.warning(f"Vector index creation failed (might not be supported): {e}") - - logger.info("Indexes created for Memory table.") + logger.info("Indexes created for Memory table.") except Exception as e: logger.error(f"Failed to create graph schema: {e}") @@ -303,25 +302,24 @@ def create_index( Note: This creates PostgreSQL indexes on the underlying tables. """ try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - # Create indexes on the underlying PostgreSQL tables - # Apache AGE stores data in regular PostgreSQL tables + with self._get_connection() as conn, conn.cursor() as cursor: + # Create indexes on the underlying PostgreSQL tables + # Apache AGE stores data in regular PostgreSQL tables + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memory_properties + ON "{self.db_name}_graph"."Memory" USING GIN (properties); + """) + + # Try to create vector index, but don't fail if it doesn't work + try: cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties - ON "{self.db_name}_graph"."Memory" USING GIN (properties); + CREATE INDEX IF NOT EXISTS idx_memory_embedding + ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); """) + except Exception as ve: + logger.warning(f"Vector index creation failed (might not be supported): {ve}") - # Try to create vector index, but don't fail if it doesn't work - try: - cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding - ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); - """) - except Exception as ve: - logger.warning(f"Vector index creation failed (might not be supported): {ve}") - - logger.debug("Indexes created successfully.") + logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") @@ -337,11 +335,10 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in params = [self.format_param_value(memory_type), self.format_param_value(user_name)] try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result[0] if result else 0 + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result[0] if result else 0 except Exception as e: logger.error(f"[get_memory_count] Failed: {e}") return -1 @@ -360,11 +357,10 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: params = [self.format_param_value(scope), self.format_param_value(user_name)] try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return 1 if result else 0 + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return 1 if result else 0 except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) raise @@ -398,32 +394,31 @@ def remove_oldest_memory( keep_latest, ] try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - # Execute query to get IDs to delete - cursor.execute(select_query, select_params) - ids_to_delete = [row[0] for row in cursor.fetchall()] - - if not ids_to_delete: - logger.info(f"No {memory_type} memories to remove for user {user_name}") - return - - # Build delete query - placeholders = ",".join(["%s"] * len(ids_to_delete)) - delete_query = f""" + with self._get_connection() as conn, conn.cursor() as cursor: + # Execute query to get IDs to delete + cursor.execute(select_query, select_params) + ids_to_delete = [row[0] for row in cursor.fetchall()] + + if not ids_to_delete: + logger.info(f"No {memory_type} memories to remove for user {user_name}") + return + + # Build delete query + placeholders = ",".join(["%s"] * len(ids_to_delete)) + delete_query = f""" DELETE FROM "{self.db_name}_graph"."Memory" WHERE id IN ({placeholders}) """ - delete_params = ids_to_delete + delete_params = ids_to_delete - # Execute deletion - cursor.execute(delete_query, delete_params) - deleted_count = cursor.rowcount - logger.info( - f"Removed {deleted_count} oldest {memory_type} memories, " - f"keeping {keep_latest} latest for user {user_name}, " - f"removed ids: {ids_to_delete}" - ) + # Execute deletion + cursor.execute(delete_query, delete_params) + deleted_count = cursor.rowcount + logger.info( + f"Removed {deleted_count} oldest {memory_type} memories, " + f"keeping {keep_latest} latest for user {user_name}, " + f"removed ids: {ids_to_delete}" + ) except Exception as e: logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) raise @@ -489,9 +484,8 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N params.append(self.format_param_value(user_name)) try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) except Exception as e: logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) raise @@ -516,9 +510,8 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: params.append(self.format_param_value(user_name)) try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) except Exception as e: logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) raise @@ -527,27 +520,26 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: def create_extension(self): extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - # Ensure in the correct database context - cursor.execute("SELECT current_database();") - current_db = cursor.fetchone()[0] - logger.info(f"Current database context: {current_db}") + with self._get_connection() as conn, conn.cursor() as cursor: + # Ensure in the correct database context + cursor.execute("SELECT current_database();") + current_db = cursor.fetchone()[0] + logger.info(f"Current database context: {current_db}") - for ext_name, ext_desc in extensions: - try: - cursor.execute(f"create extension if not exists {ext_name};") - logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") - except Exception as e: - if "already exists" in str(e): - logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") - else: - logger.warning( - f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" - ) - logger.error( - f"Failed to create extension '{ext_name}': {e}", exc_info=True - ) + for ext_name, ext_desc in extensions: + try: + cursor.execute(f"create extension if not exists {ext_name};") + logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") + except Exception as e: + if "already exists" in str(e): + logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") + else: + logger.warning( + f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" + ) + logger.error( + f"Failed to create extension '{ext_name}': {e}", exc_info=True + ) except Exception as e: logger.warning(f"Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) @@ -555,19 +547,18 @@ def create_extension(self): @timed def create_graph(self): try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(f""" + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(f""" SELECT COUNT(*) FROM ag_catalog.ag_graph WHERE name = '{self.db_name}_graph'; """) - graph_exists = cursor.fetchone()[0] > 0 + graph_exists = cursor.fetchone()[0] > 0 - if graph_exists: - logger.info(f"Graph '{self.db_name}_graph' already exists.") - else: - cursor.execute(f"select create_graph('{self.db_name}_graph');") - logger.info(f"Graph database '{self.db_name}_graph' created.") + if graph_exists: + logger.info(f"Graph '{self.db_name}_graph' already exists.") + else: + cursor.execute(f"select create_graph('{self.db_name}_graph');") + logger.info(f"Graph database '{self.db_name}_graph' created.") except Exception as e: logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) @@ -581,10 +572,9 @@ def create_edge(self): for label_name in valid_rel_types: logger.info(f"Creating elabel: {label_name}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") - logger.info(f"Successfully created elabel: {label_name}") + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") + logger.info(f"Successfully created elabel: {label_name}") except Exception as e: if "already exists" in str(e): logger.info(f"Label '{label_name}' already exists, skipping.") @@ -632,13 +622,12 @@ def add_edge( """ logger.info(f"polardb [add_edge] query: {query}, properties: {json.dumps(properties)}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) - logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) + logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") - elapsed_time = time.time() - start_time - logger.info(f" polardb [add_edge] insert completed time in {elapsed_time:.2f}s") + elapsed_time = time.time() - start_time + logger.info(f" polardb [add_edge] insert completed time in {elapsed_time:.2f}s") except Exception as e: logger.error(f"Failed to insert edge: {e}", exc_info=True) raise @@ -656,10 +645,9 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: DELETE FROM "{self.db_name}_graph"."Edges" WHERE source_id = %s AND target_id = %s AND edge_type = %s """ - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type)) - logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") @timed def edge_exists_old( @@ -714,11 +702,10 @@ def edge_exists_old( WHERE {where_clause} LIMIT 1 """ - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result is not None + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result is not None @timed def edge_exists( @@ -766,11 +753,10 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query) - result = cursor.fetchone() - return result is not None and result[0] is not None + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None @timed def get_node( @@ -807,52 +793,51 @@ def get_node( logger.info(f"polardb [get_node] query: {query},params: {params}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() - if result: - if include_embedding: - _, properties_json, embedding_json = result - else: - _, properties_json = result - embedding_json = None + if result: + if include_embedding: + _, properties_json, embedding_json = result + else: + _, properties_json = result + embedding_json = None - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {id}") - properties = {} - else: - properties = properties_json if properties_json else {} + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {id}") + properties = {} + else: + properties = properties_json if properties_json else {} - # Parse embedding from JSONB if it exists and include_embedding is True - if include_embedding and embedding_json is not None: - try: - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {id}") + # Parse embedding from JSONB if it exists and include_embedding is True + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {id}") - elapsed_time = time.time() - start_time - logger.info( - f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s" - ) - return self._parse_node( - { - "id": id, - "memory": properties.get("memory", ""), - **properties, - } - ) - return None + elapsed_time = time.time() - start_time + logger.info( + f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s" + ) + return self._parse_node( + { + "id": id, + "memory": properties.get("memory", ""), + **properties, + } + ) + return None except Exception as e: logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) @@ -893,46 +878,45 @@ def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, logger.info(f"get_nodes query:{query},params:{params}") - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() - nodes = [] - for row in results: - node_id, properties_json, embedding_json = row - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {node_id}") - properties = {} - else: - properties = properties_json if properties_json else {} + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} - # Parse embedding from JSONB if it exists - if embedding_json is not None and kwargs.get("include_embedding"): - try: - # remove embedding - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - nodes.append( - self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } + # Parse embedding from JSONB if it exists + if embedding_json is not None and kwargs.get("include_embedding"): + try: + # remove embedding + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } ) - return nodes + ) + return nodes @timed def get_edges_old( @@ -1151,58 +1135,57 @@ def get_children_with_embeddings( """ try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() - children = [] - for row in results: - # Handle child_id - remove possible quotes - child_id_raw = row[0].value if hasattr(row[0], "value") else str(row[0]) - if isinstance(child_id_raw, str): - # If string starts and ends with quotes, remove quotes - if child_id_raw.startswith('"') and child_id_raw.endswith('"'): - child_id = child_id_raw[1:-1] - else: - child_id = child_id_raw + children = [] + for row in results: + # Handle child_id - remove possible quotes + child_id_raw = row[0].value if hasattr(row[0], "value") else str(row[0]) + if isinstance(child_id_raw, str): + # If string starts and ends with quotes, remove quotes + if child_id_raw.startswith('"') and child_id_raw.endswith('"'): + child_id = child_id_raw[1:-1] else: - child_id = str(child_id_raw) + child_id = child_id_raw + else: + child_id = str(child_id_raw) - # Handle embedding - get from database embedding column - embedding_raw = row[1] - embedding = [] - if embedding_raw is not None: - try: - if isinstance(embedding_raw, str): - # If it is a JSON string, parse it - embedding = json.loads(embedding_raw) - elif isinstance(embedding_raw, list): - # If already a list, use directly - embedding = embedding_raw - else: - # Try converting to list - embedding = list(embedding_raw) - except (json.JSONDecodeError, TypeError, ValueError) as e: - logger.warning( - f"Failed to parse embedding for child node {child_id}: {e}" - ) - embedding = [] - - # Handle memory - remove possible quotes - memory_raw = row[2].value if hasattr(row[2], "value") else str(row[2]) - if isinstance(memory_raw, str): - # If string starts and ends with quotes, remove quotes - if memory_raw.startswith('"') and memory_raw.endswith('"'): - memory = memory_raw[1:-1] + # Handle embedding - get from database embedding column + embedding_raw = row[1] + embedding = [] + if embedding_raw is not None: + try: + if isinstance(embedding_raw, str): + # If it is a JSON string, parse it + embedding = json.loads(embedding_raw) + elif isinstance(embedding_raw, list): + # If already a list, use directly + embedding = embedding_raw else: - memory = memory_raw + # Try converting to list + embedding = list(embedding_raw) + except (json.JSONDecodeError, TypeError, ValueError) as e: + logger.warning( + f"Failed to parse embedding for child node {child_id}: {e}" + ) + embedding = [] + + # Handle memory - remove possible quotes + memory_raw = row[2].value if hasattr(row[2], "value") else str(row[2]) + if isinstance(memory_raw, str): + # If string starts and ends with quotes, remove quotes + if memory_raw.startswith('"') and memory_raw.endswith('"'): + memory = memory_raw[1:-1] else: - memory = str(memory_raw) + memory = memory_raw + else: + memory = str(memory_raw) - children.append({"id": child_id, "embedding": embedding, "memory": memory}) + children.append({"id": child_id, "embedding": embedding, "memory": memory}) - return children + return children except Exception as e: logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) @@ -1290,128 +1273,127 @@ def get_subgraph( """ logger.info(f"[get_subgraph] Query: {query}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() - if not results: - return {"core_node": None, "neighbors": [], "edges": []} + if not results: + return {"core_node": None, "neighbors": [], "edges": []} - # Merge results from all UNION ALL rows - all_centers_list = [] - all_neighbors_list = [] - all_edges_list = [] + # Merge results from all UNION ALL rows + all_centers_list = [] + all_neighbors_list = [] + all_edges_list = [] - for result in results: - if not result or not result[0]: - continue + for result in results: + if not result or not result[0]: + continue - centers_data = result[0] if result[0] else "[]" - neighbors_data = result[1] if result[1] else "[]" - edges_data = result[2] if result[2] else "[]" + centers_data = result[0] if result[0] else "[]" + neighbors_data = result[1] if result[1] else "[]" + edges_data = result[2] if result[2] else "[]" - # Parse JSON data - try: - # Clean ::vertex and ::edge suffixes in data - if isinstance(centers_data, str): - centers_data = centers_data.replace("::vertex", "") - if isinstance(neighbors_data, str): - neighbors_data = neighbors_data.replace("::vertex", "") - if isinstance(edges_data, str): - edges_data = edges_data.replace("::edge", "") - - centers_list = ( - json.loads(centers_data) - if isinstance(centers_data, str) - else centers_data - ) - neighbors_list = ( - json.loads(neighbors_data) - if isinstance(neighbors_data, str) - else neighbors_data - ) - edges_list = ( - json.loads(edges_data) if isinstance(edges_data, str) else edges_data - ) + # Parse JSON data + try: + # Clean ::vertex and ::edge suffixes in data + if isinstance(centers_data, str): + centers_data = centers_data.replace("::vertex", "") + if isinstance(neighbors_data, str): + neighbors_data = neighbors_data.replace("::vertex", "") + if isinstance(edges_data, str): + edges_data = edges_data.replace("::edge", "") + + centers_list = ( + json.loads(centers_data) + if isinstance(centers_data, str) + else centers_data + ) + neighbors_list = ( + json.loads(neighbors_data) + if isinstance(neighbors_data, str) + else neighbors_data + ) + edges_list = ( + json.loads(edges_data) if isinstance(edges_data, str) else edges_data + ) - # Collect data from this row - if isinstance(centers_list, list): - all_centers_list.extend(centers_list) - if isinstance(neighbors_list, list): - all_neighbors_list.extend(neighbors_list) - if isinstance(edges_list, list): - all_edges_list.extend(edges_list) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON data: {e}") - continue - - # Deduplicate centers by ID - centers_dict = {} - for center_data in all_centers_list: - if isinstance(center_data, dict) and "properties" in center_data: - center_id_key = center_data["properties"].get("id") - if center_id_key and center_id_key not in centers_dict: - centers_dict[center_id_key] = center_data - - # Parse center node (use first center) - core_node = None - if centers_dict: - center_data = next(iter(centers_dict.values())) - if isinstance(center_data, dict) and "properties" in center_data: - core_node = self._parse_node(center_data["properties"]) - - # Deduplicate neighbors by ID - neighbors_dict = {} - for neighbor_data in all_neighbors_list: - if isinstance(neighbor_data, dict) and "properties" in neighbor_data: - neighbor_id = neighbor_data["properties"].get("id") - if neighbor_id and neighbor_id not in neighbors_dict: - neighbors_dict[neighbor_id] = neighbor_data - - # Parse neighbor nodes - neighbors = [] - for neighbor_data in neighbors_dict.values(): - if isinstance(neighbor_data, dict) and "properties" in neighbor_data: - neighbor_parsed = self._parse_node(neighbor_data["properties"]) - neighbors.append(neighbor_parsed) - - # Deduplicate edges by (source, target, type) - edges_dict = {} - for edge_group in all_edges_list: - if isinstance(edge_group, list): - for edge_data in edge_group: - if isinstance(edge_data, dict): - edge_key = ( - edge_data.get("start_id", ""), - edge_data.get("end_id", ""), - edge_data.get("label", ""), - ) - if edge_key not in edges_dict: - edges_dict[edge_key] = { - "type": edge_data.get("label", ""), - "source": edge_data.get("start_id", ""), - "target": edge_data.get("end_id", ""), - } - elif isinstance(edge_group, dict): - # Handle single edge (not in a list) - edge_key = ( - edge_group.get("start_id", ""), - edge_group.get("end_id", ""), - edge_group.get("label", ""), - ) - if edge_key not in edges_dict: - edges_dict[edge_key] = { - "type": edge_group.get("label", ""), - "source": edge_group.get("start_id", ""), - "target": edge_group.get("end_id", ""), - } + # Collect data from this row + if isinstance(centers_list, list): + all_centers_list.extend(centers_list) + if isinstance(neighbors_list, list): + all_neighbors_list.extend(neighbors_list) + if isinstance(edges_list, list): + all_edges_list.extend(edges_list) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON data: {e}") + continue - edges = list(edges_dict.values()) + # Deduplicate centers by ID + centers_dict = {} + for center_data in all_centers_list: + if isinstance(center_data, dict) and "properties" in center_data: + center_id_key = center_data["properties"].get("id") + if center_id_key and center_id_key not in centers_dict: + centers_dict[center_id_key] = center_data + + # Parse center node (use first center) + core_node = None + if centers_dict: + center_data = next(iter(centers_dict.values())) + if isinstance(center_data, dict) and "properties" in center_data: + core_node = self._parse_node(center_data["properties"]) + + # Deduplicate neighbors by ID + neighbors_dict = {} + for neighbor_data in all_neighbors_list: + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_id = neighbor_data["properties"].get("id") + if neighbor_id and neighbor_id not in neighbors_dict: + neighbors_dict[neighbor_id] = neighbor_data + + # Parse neighbor nodes + neighbors = [] + for neighbor_data in neighbors_dict.values(): + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_parsed = self._parse_node(neighbor_data["properties"]) + neighbors.append(neighbor_parsed) + + # Deduplicate edges by (source, target, type) + edges_dict = {} + for edge_group in all_edges_list: + if isinstance(edge_group, list): + for edge_data in edge_group: + if isinstance(edge_data, dict): + edge_key = ( + edge_data.get("start_id", ""), + edge_data.get("end_id", ""), + edge_data.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_data.get("label", ""), + "source": edge_data.get("start_id", ""), + "target": edge_data.get("end_id", ""), + } + elif isinstance(edge_group, dict): + # Handle single edge (not in a list) + edge_key = ( + edge_group.get("start_id", ""), + edge_group.get("end_id", ""), + edge_group.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_group.get("label", ""), + "source": edge_group.get("start_id", ""), + "target": edge_group.get("end_id", ""), + } - return self._convert_graph_edges( - {"core_node": core_node, "neighbors": neighbors, "edges": edges} - ) + edges = list(edges_dict.values()) + + return self._convert_graph_edges( + {"core_node": core_node, "neighbors": neighbors, "edges": edges} + ) except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) @@ -1529,25 +1511,24 @@ def search_by_keywords_like( logger.info( f"[search_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" ) - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - item = {"id": id_val} - if return_fields: - properties = row[2] # properties column - item.update(self._extract_fields_from_properties(properties, return_fields)) - output.append(item) - logger.info( - f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" - ) - return output + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + logger.info( + f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output @timed def search_by_keywords_tfidf( @@ -1633,26 +1614,25 @@ def search_by_keywords_tfidf( logger.info( f"[search_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" ) - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - item = {"id": id_val} - if return_fields: - properties = row[2] # properties column - item.update(self._extract_fields_from_properties(properties, return_fields)) - output.append(item) - - logger.info( - f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" - ) - return output + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + + logger.info( + f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output @timed def search_by_fulltext( @@ -1777,34 +1757,29 @@ def search_by_fulltext( """ params = [tsquery_string] logger.info(f"[search_by_fulltext] query: {query}, params: {params}") - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] # old_id - rank = row[1] # rank score (no memory_text column) - - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - score_val = float(rank) - - # Apply threshold filter if specified - if threshold is None or score_val >= threshold: - item = {"id": id_val, "score": score_val} - if return_fields: - properties = row[2] # properties column - item.update( - self._extract_fields_from_properties(properties, return_fields) - ) - output.append(item) - elapsed_time = time.time() - start_time - logger.info( - f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s" - ) - return output[:top_k] + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[1] # rank score (no memory_text column) + + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + elapsed_time = time.time() - start_time + logger.info(f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s") + return output[:top_k] @timed def search_by_embedding( @@ -1901,38 +1876,35 @@ def search_by_embedding( logger.info(f"[search_by_embedding] query: {query}, params: {params}") - with self._get_connection() as conn: - with conn.cursor() as cursor: - if params: - cursor.execute(query, params) - else: - cursor.execute(query) - results = cursor.fetchall() - output = [] - for row in results: - if len(row) < 5: - logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") - continue - oldid = row[3] # old_id - score = row[4] # scope - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - score_val = float(score) - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - item = {"id": id_val, "score": score_val} - if return_fields: - properties = row[1] # properties column - item.update( - self._extract_fields_from_properties(properties, return_fields) - ) - output.append(item) - elapsed_time = time.time() - start_time - logger.info( - f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" - ) - return output[:top_k] + with self._get_connection() as conn, conn.cursor() as cursor: + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + results = cursor.fetchall() + output = [] + for row in results: + if len(row) < 5: + logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") + continue + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[1] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + elapsed_time = time.time() - start_time + logger.info( + f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" + ) + return output[:top_k] @timed def get_by_metadata( @@ -2049,11 +2021,10 @@ def get_by_metadata( ids = [] logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - ids = [str(item[0]).strip('"') for item in results] + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + ids = [str(item[0]).strip('"') for item in results] except Exception as e: logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}") @@ -2208,28 +2179,27 @@ def get_grouped_counts( GROUP BY {", ".join(group_by_fields)} """ try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - # Handle parameterized query - if params and isinstance(params, list): - cursor.execute(query, params) - else: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + # Handle parameterized query + if params and isinstance(params, list): + cursor.execute(query, params) + else: + cursor.execute(query) + results = cursor.fetchall() - output = [] - for row in results: - group_values = {} - for i, field in enumerate(group_fields): - value = row[i] - if hasattr(value, "value"): - group_values[field] = value.value - else: - group_values[field] = str(value) - count_value = row[-1] # Last column is count - output.append({**group_values, "count": int(count_value)}) + output = [] + for row in results: + group_values = {} + for i, field in enumerate(group_fields): + value = row[i] + if hasattr(value, "value"): + group_values[field] = value.value + else: + group_values[field] = str(value) + count_value = row[-1] # Last column is count + output.append({**group_values, "count": int(count_value)}) - return output + return output except Exception as e: logger.error(f"Failed to get grouped counts: {e}", exc_info=True) @@ -2265,10 +2235,9 @@ def clear(self, user_name: str | None = None) -> None: DETACH DELETE n $$) AS (result agtype) """ - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query) - logger.info("Cleared all nodes from database.") + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") except Exception as e: logger.error(f"[ERROR] Failed to clear database: {e}") @@ -2572,26 +2541,25 @@ def get_all_memory_items( node_ids = set() logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() - for row in results: - """ + for row in results: + """ if isinstance(row, (list, tuple)) and len(row) >= 2: """ - if isinstance(row, list | tuple) and len(row) >= 2: - embedding_val, node_val = row[0], row[1] - else: - embedding_val, node_val = None, row[0] + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] + else: + embedding_val, node_val = None, row[0] - node = self._build_node_from_agtype(node_val, embedding_val) - if node: - node_id = node["id"] - if node_id not in node_ids: - nodes.append(node) - node_ids.add(node_id) + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + nodes.append(node) + node_ids.add(node_id) except Exception as e: logger.warning(f"Failed to get memories: {e}", exc_info=True) @@ -2623,21 +2591,20 @@ def get_all_memory_items( nodes = [] logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() - for row in results: - """ + for row in results: + """ if isinstance(row[0], str): memory_data = json.loads(row[0]) else: memory_data = row[0] # 如果已经是字典,直接使用 nodes.append(self._parse_node(memory_data)) """ - memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] - nodes.append(self._parse_node(memory_data)) + memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] + nodes.append(self._parse_node(memory_data)) except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) @@ -2844,81 +2811,80 @@ def get_structure_optimization_candidates( candidates = [] node_ids = set() try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(cypher_query) - results = cursor.fetchall() - logger.info(f"Found {len(results)} structure optimization candidates") - for row in results: - if include_embedding: - # When include_embedding=True, return full node object - """ + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(cypher_query) + results = cursor.fetchall() + logger.info(f"Found {len(results)} structure optimization candidates") + for row in results: + if include_embedding: + # When include_embedding=True, return full node object + """ if isinstance(row, (list, tuple)) and len(row) >= 2: """ - if isinstance(row, list | tuple) and len(row) >= 2: - embedding_val, node_val = row[0], row[1] - else: - embedding_val, node_val = None, row[0] - - node = self._build_node_from_agtype(node_val, embedding_val) - if node: - node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) + if isinstance(row, list | tuple) and len(row) >= 2: + embedding_val, node_val = row[0], row[1] else: - # When include_embedding=False, return field dictionary - # Define field names matching the RETURN clause - field_names = [ - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - "graph_id", - ] - - # Convert row to dictionary - node_data = {} - for i, field_name in enumerate(field_names): - if i < len(row): - value = row[i] - # Handle special fields - if field_name in ["tags", "sources", "usage"] and isinstance( - value, str - ): - try: - # Try parsing JSON string - node_data[field_name] = json.loads(value) - except (json.JSONDecodeError, TypeError): - node_data[field_name] = value - else: + embedding_val, node_val = None, row[0] + + node = self._build_node_from_agtype(node_val, embedding_val) + if node: + node_id = node["id"] + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + else: + # When include_embedding=False, return field dictionary + # Define field names matching the RETURN clause + field_names = [ + "id", + "memory", + "user_name", + "user_id", + "session_id", + "status", + "key", + "confidence", + "tags", + "created_at", + "updated_at", + "memory_type", + "sources", + "source", + "node_type", + "visibility", + "usage", + "background", + "graph_id", + ] + + # Convert row to dictionary + node_data = {} + for i, field_name in enumerate(field_names): + if i < len(row): + value = row[i] + # Handle special fields + if field_name in ["tags", "sources", "usage"] and isinstance( + value, str + ): + try: + # Try parsing JSON string + node_data[field_name] = json.loads(value) + except (json.JSONDecodeError, TypeError): node_data[field_name] = value + else: + node_data[field_name] = value - # Parse node using _parse_node_new - try: - node = self._parse_node_new(node_data) - node_id = node["id"] + # Parse node using _parse_node_new + try: + node = self._parse_node_new(node_data) + node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) - logger.debug(f"Parsed node successfully: {node_id}") - except Exception as e: - logger.error(f"Failed to parse node: {e}") + if node_id not in node_ids: + candidates.append(node) + node_ids.add(node_id) + logger.debug(f"Parsed node successfully: {node_id}") + except Exception as e: + logger.error(f"Failed to parse node: {e}") except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) @@ -3138,7 +3104,9 @@ def add_node( f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" ) if insert_query: - logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") + logger.info( + f"In add node polardb: id-{id} memory-{memory} query-{insert_query}" + ) except Exception as e: logger.error(f"[add_node] Failed to add node: {e}", exc_info=True) raise @@ -3263,121 +3231,121 @@ def add_nodes_batch( nodes_by_embedding_column[col].append(node) try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - # Process each group separately - for embedding_column, nodes_group in nodes_by_embedding_column.items(): - # Batch delete existing records using IN clause - ids_to_delete = [node["id"] for node in nodes_group] - if ids_to_delete: - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id IN ( - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, unnest(%s::text[])::cstring) - ) - """ - cursor.execute(delete_query, (ids_to_delete,)) - - # Batch get graph_ids for all nodes - get_graph_ids_query = f""" - SELECT - id_val, - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, id_val::text::cstring) as graph_id - FROM unnest(%s::text[]) as id_val + with self._get_connection() as conn, conn.cursor() as cursor: + # Process each group separately + for embedding_column, nodes_group in nodes_by_embedding_column.items(): + # Batch delete existing records using IN clause + ids_to_delete = [node["id"] for node in nodes_group] + if ids_to_delete: + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id IN ( + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, unnest(%s::text[])::cstring) + ) """ - cursor.execute(get_graph_ids_query, (ids_to_delete,)) - graph_id_map = {row[0]: row[1] for row in cursor.fetchall()} + cursor.execute(delete_query, (ids_to_delete,)) - # Add graph_id to properties - for node in nodes_group: - graph_id = graph_id_map.get(node["id"]) - if graph_id: - node["properties"]["graph_id"] = str(graph_id) + # Batch get graph_ids for all nodes + get_graph_ids_query = f""" + SELECT + id_val, + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, id_val::text::cstring) as graph_id + FROM unnest(%s::text[]) as id_val + """ + cursor.execute(get_graph_ids_query, (ids_to_delete,)) + graph_id_map = {row[0]: row[1] for row in cursor.fetchall()} - # Use PREPARE/EXECUTE for efficient batch insert - # Generate unique prepare statement name to avoid conflicts - prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" + # Add graph_id to properties + for node in nodes_group: + graph_id = graph_id_map.get(node["id"]) + if graph_id: + node["properties"]["graph_id"] = str(graph_id) - try: - if embedding_column and any( - node["embedding_vector"] for node in nodes_group - ): - # PREPARE statement for insert with embedding - prepare_query = f""" - PREPARE {prepare_name} AS - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), - $2::text::agtype, - $3::vector - ) - """ - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" + # Use PREPARE/EXECUTE for efficient batch insert + # Generate unique prepare statement name to avoid conflicts + prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" + + try: + if embedding_column and any( + node["embedding_vector"] for node in nodes_group + ): + # PREPARE statement for insert with embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), + $2::text::agtype, + $3::vector ) + """ + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" + ) - cursor.execute(prepare_query) + cursor.execute(prepare_query) - # Execute prepared statement for each node - for node in nodes_group: - properties_json = json.dumps(node["properties"]) - embedding_json = ( - json.dumps(node["embedding_vector"]) - if node["embedding_vector"] - else None - ) + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) + embedding_json = ( + json.dumps(node["embedding_vector"]) + if node["embedding_vector"] + else None + ) - cursor.execute( - f"EXECUTE {prepare_name}(%s, %s, %s)", - (node["id"], properties_json, embedding_json), - ) - else: - # PREPARE statement for insert without embedding - prepare_query = f""" - PREPARE {prepare_name} AS - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), - $2::text::agtype - ) - """ - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s, %s)", + (node["id"], properties_json, embedding_json), ) - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" + else: + # PREPARE statement for insert without embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), + $2::text::agtype ) - cursor.execute(prepare_query) + """ + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" + ) + cursor.execute(prepare_query) - # Execute prepared statement for each node - for node in nodes_group: - properties_json = json.dumps(node["properties"]) + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) - cursor.execute( - f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) - ) - finally: - # DEALLOCATE prepared statement (always execute, even on error) - try: - cursor.execute(f"DEALLOCATE {prepare_name}") - logger.info( - f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" - ) - except Exception as dealloc_error: - logger.warning( - f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s)", + (node["id"], properties_json), ) + finally: + # DEALLOCATE prepared statement (always execute, even on error) + try: + cursor.execute(f"DEALLOCATE {prepare_name}") + logger.info( + f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" + ) + except Exception as dealloc_error: + logger.warning( + f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" + ) - logger.info( - f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" - ) - elapsed_time = time.time() - batch_start_time - logger.info( - f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" - ) + logger.info( + f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" + ) + elapsed_time = time.time() - batch_start_time + logger.info( + f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" + ) except Exception as e: logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) @@ -3494,52 +3462,51 @@ def get_neighbors_by_tag( logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - - nodes_with_overlap = [] - for row in results: - node_id, properties_json, embedding_json = row - properties = properties_json if properties_json else {} + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() - # Parse embedding - if include_embedding and embedding_json is not None: - try: - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") + nodes_with_overlap = [] + for row in results: + node_id, properties_json, embedding_json = row + properties = properties_json if properties_json else {} - # Compute tag overlap - node_tags = properties.get("tags", []) - if isinstance(node_tags, str): - try: - node_tags = json.loads(node_tags) - except (json.JSONDecodeError, TypeError): - node_tags = [] - - overlap_tags = [tag for tag in tags if tag in node_tags] - overlap_count = len(overlap_tags) - - if overlap_count >= min_overlap: - node_data = self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } + # Parse embedding + if include_embedding and embedding_json is not None: + try: + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json ) - nodes_with_overlap.append((node_data, overlap_count)) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + + # Compute tag overlap + node_tags = properties.get("tags", []) + if isinstance(node_tags, str): + try: + node_tags = json.loads(node_tags) + except (json.JSONDecodeError, TypeError): + node_tags = [] + + overlap_tags = [tag for tag in tags if tag in node_tags] + overlap_count = len(overlap_tags) + + if overlap_count >= min_overlap: + node_data = self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) + nodes_with_overlap.append((node_data, overlap_count)) - # Sort by overlap count and return top_k items - nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) - return [node for node, _ in nodes_with_overlap[:top_k]] + # Sort by overlap count and return top_k items + nodes_with_overlap.sort(key=lambda x: x[1], reverse=True) + return [node for node, _ in nodes_with_overlap[:top_k]] except Exception as e: logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) @@ -3803,54 +3770,54 @@ def get_edges( """ logger.info(f"get_edges query:{query}") try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() - edges = [] - for row in results: - # Extract and clean from_id - from_id_raw = row[0].value if hasattr(row[0], "value") else row[0] - if ( - isinstance(from_id_raw, str) - and from_id_raw.startswith('"') - and from_id_raw.endswith('"') - ): - from_id = from_id_raw[1:-1] - else: - from_id = str(from_id_raw) - - # Extract and clean to_id - to_id_raw = row[1].value if hasattr(row[1], "value") else row[1] - if ( - isinstance(to_id_raw, str) - and to_id_raw.startswith('"') - and to_id_raw.endswith('"') - ): - to_id = to_id_raw[1:-1] - else: - to_id = str(to_id_raw) - - # Extract and clean edge_type - edge_type_raw = row[2].value if hasattr(row[2], "value") else row[2] - if ( - isinstance(edge_type_raw, str) - and edge_type_raw.startswith('"') - and edge_type_raw.endswith('"') - ): - edge_type = edge_type_raw[1:-1] - else: - edge_type = str(edge_type_raw) + edges = [] + for row in results: + # Extract and clean from_id + from_id_raw = row[0].value if hasattr(row[0], "value") else row[0] + if ( + isinstance(from_id_raw, str) + and from_id_raw.startswith('"') + and from_id_raw.endswith('"') + ): + from_id = from_id_raw[1:-1] + else: + from_id = str(from_id_raw) + + # Extract and clean to_id + to_id_raw = row[1].value if hasattr(row[1], "value") else row[1] + if ( + isinstance(to_id_raw, str) + and to_id_raw.startswith('"') + and to_id_raw.endswith('"') + ): + to_id = to_id_raw[1:-1] + else: + to_id = str(to_id_raw) + + # Extract and clean edge_type + edge_type_raw = row[2].value if hasattr(row[2], "value") else row[2] + if ( + isinstance(edge_type_raw, str) + and edge_type_raw.startswith('"') + and edge_type_raw.endswith('"') + ): + edge_type = edge_type_raw[1:-1] + else: + edge_type = str(edge_type_raw) - edges.append({"from": from_id, "to": to_id, "type": edge_type}) - elapsed_time = time.time() - start_time - logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s") - return edges + edges.append({"from": from_id, "to": to_id, "type": edge_type}) + elapsed_time = time.time() - start_time + logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s") + return edges except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] + def _convert_graph_edges(self, core_node: dict) -> dict: import copy @@ -4857,65 +4824,64 @@ def delete_node_by_prams( total_deleted_count = 0 try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - # Build WHERE conditions list - where_conditions = [] - - # Add memory_ids conditions - if memory_ids: - logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") - id_conditions = [] - for node_id in memory_ids: - id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - where_conditions.append(f"({' OR '.join(id_conditions)})") - - # Add file_ids conditions - if file_ids: - logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") - file_id_conditions = [] - for file_id in file_ids: - file_id_conditions.append( - f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" - ) - where_conditions.append(f"({' OR '.join(file_id_conditions)})") + with self._get_connection() as conn, conn.cursor() as cursor: + # Build WHERE conditions list + where_conditions = [] + + # Add memory_ids conditions + if memory_ids: + logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") + id_conditions = [] + for node_id in memory_ids: + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" + ) + where_conditions.append(f"({' OR '.join(id_conditions)})") + + # Add file_ids conditions + if file_ids: + logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") + file_id_conditions = [] + for file_id in file_ids: + file_id_conditions.append( + f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" + ) + where_conditions.append(f"({' OR '.join(file_id_conditions)})") - # Add filter conditions - if filter_conditions: - logger.info("[delete_node_by_prams] Processing filter conditions") - where_conditions.extend(filter_conditions) + # Add filter conditions + if filter_conditions: + logger.info("[delete_node_by_prams] Processing filter conditions") + where_conditions.extend(filter_conditions) - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_conditions.append(f"({user_name_where})") + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_conditions.append(f"({user_name_where})") - # Build final WHERE clause - if not where_conditions: - logger.warning("[delete_node_by_prams] No WHERE conditions to delete") - return 0 + # Build final WHERE clause + if not where_conditions: + logger.warning("[delete_node_by_prams] No WHERE conditions to delete") + return 0 - where_clause = " AND ".join(where_conditions) + where_clause = " AND ".join(where_conditions) - # Delete directly without counting - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") - cursor.execute(delete_query) - deleted_count = cursor.rowcount - total_deleted_count = deleted_count + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count = deleted_count - logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") + logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") - elapsed_time = time.time() - batch_start_time - logger.info( - f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" - ) + elapsed_time = time.time() - batch_start_time + logger.info( + f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" + ) except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) raise @@ -4984,47 +4950,47 @@ def escape_memory_id(mid: str) -> str: logger.info(f"[get_user_names_by_memory_ids] query: {query}") result_dict = {} try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + results = cursor.fetchall() - # Build result dictionary from query results - for row in results: - memory_id_raw = row[0] - user_name_raw = row[1] + # Build result dictionary from query results + for row in results: + memory_id_raw = row[0] + user_name_raw = row[1] - # Remove quotes if present - if isinstance(memory_id_raw, str): - memory_id = memory_id_raw.strip('"').strip("'") - else: - memory_id = str(memory_id_raw).strip('"').strip("'") + # Remove quotes if present + if isinstance(memory_id_raw, str): + memory_id = memory_id_raw.strip('"').strip("'") + else: + memory_id = str(memory_id_raw).strip('"').strip("'") - if isinstance(user_name_raw, str): - user_name = user_name_raw.strip('"').strip("'") - else: - user_name = ( - str(user_name_raw).strip('"').strip("'") if user_name_raw else None - ) + if isinstance(user_name_raw, str): + user_name = user_name_raw.strip('"').strip("'") + else: + user_name = ( + str(user_name_raw).strip('"').strip("'") if user_name_raw else None + ) - result_dict[memory_id] = user_name if user_name else None + result_dict[memory_id] = user_name if user_name else None - # Set None for memory_ids that were not found - for mid in normalized_memory_ids: - if mid not in result_dict: - result_dict[mid] = None + # Set None for memory_ids that were not found + for mid in normalized_memory_ids: + if mid not in result_dict: + result_dict[mid] = None - logger.info( - f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " - f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" - ) + logger.info( + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" + ) - return result_dict + return result_dict except Exception as e: logger.error( f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True ) raise + def exist_user_name(self, user_name: str) -> dict[str, bool]: """Check if user name exists in the graph. @@ -5058,18 +5024,18 @@ def escape_user_name(un: str) -> str: logger.info(f"[exist_user_name] query: {query}") result_dict = {} try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - cursor.execute(query) - count = cursor.fetchone()[0] - result = count > 0 - result_dict[user_name] = result - return result_dict + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + count = cursor.fetchone()[0] + result = count > 0 + result_dict[user_name] = result + return result_dict except Exception as e: logger.error( f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True ) raise + @timed def delete_node_by_mem_cube_id( self, @@ -5093,71 +5059,75 @@ def delete_node_by_mem_cube_id( return 0 try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + with self._get_connection() as conn, conn.cursor() as cursor: + user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - user_name_param = self.format_param_value(mem_cube_id) + user_name_param = self.format_param_value(mem_cube_id) - if hard_delete: - delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" - where_clause = f"{user_name_condition} AND {delete_record_id_condition}" + if hard_delete: + delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" + where_clause = f"{user_name_condition} AND {delete_record_id_condition}" - where_params = [user_name_param, self.format_param_value(delete_record_id)] + where_params = [user_name_param, self.format_param_value(delete_record_id)] - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {delete_query}") + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {delete_query}") - cursor.execute(delete_query, where_params) - deleted_count = cursor.rowcount + cursor.execute(delete_query, where_params) + deleted_count = cursor.rowcount - logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") - return deleted_count - else: - delete_time_empty_condition = ( - "(ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) IS NULL " - "OR ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) = '\"\"'::agtype)" - ) - delete_record_id_empty_condition = ( - "(ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) IS NULL " - "OR ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = '\"\"'::agtype)" - ) - where_clause = f"{user_name_condition} AND {delete_time_empty_condition} AND {delete_record_id_empty_condition}" - - current_time = datetime.utcnow().isoformat() - update_query = f""" - UPDATE "{self.db_name}_graph"."Memory" - SET properties = ( - properties::jsonb || %s::jsonb - )::text::agtype, - deletetime = %s - WHERE {where_clause} - """ - update_properties = { - "status": "deleted", - "delete_time": current_time, - "delete_record_id": delete_record_id, - } - logger.info( - f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}" - ) - update_params = [json.dumps(update_properties), current_time, user_name_param] - cursor.execute(update_query, update_params) - updated_count = cursor.rowcount + logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") + return deleted_count + else: + delete_time_empty_condition = ( + "(ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) IS NULL " + "OR ag_catalog.agtype_access_operator(properties, '\"delete_time\"'::agtype) = '\"\"'::agtype)" + ) + delete_record_id_empty_condition = ( + "(ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) IS NULL " + "OR ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = '\"\"'::agtype)" + ) + where_clause = f"{user_name_condition} AND {delete_time_empty_condition} AND {delete_record_id_empty_condition}" - logger.info( - f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" - ) - return updated_count + current_time = datetime.utcnow().isoformat() + update_query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = ( + properties::jsonb || %s::jsonb + )::text::agtype, + deletetime = %s + WHERE {where_clause} + """ + update_properties = { + "status": "deleted", + "delete_time": current_time, + "delete_record_id": delete_record_id, + } + logger.info( + f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}" + ) + update_params = [ + json.dumps(update_properties), + current_time, + user_name_param, + ] + cursor.execute(update_query, update_params) + updated_count = cursor.rowcount + + logger.info( + f"delete_node_by_mem_cube_id Soft deleted (updated) {updated_count} nodes" + ) + return updated_count except Exception as e: logger.error( f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True ) raise + @timed def recover_memory_by_mem_cube_id( self, @@ -5184,45 +5154,44 @@ def recover_memory_by_mem_cube_id( ) try: - with self._get_connection() as conn: - with conn.cursor() as cursor: - user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" - where_clause = f"{user_name_condition} AND {delete_record_id_condition}" - - where_params = [ - self.format_param_value(mem_cube_id), - self.format_param_value(delete_record_id), - ] + with self._get_connection() as conn, conn.cursor() as cursor: + user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" + where_clause = f"{user_name_condition} AND {delete_record_id_condition}" + + where_params = [ + self.format_param_value(mem_cube_id), + self.format_param_value(delete_record_id), + ] - update_properties = { - "status": "activated", - "delete_record_id": "", - "delete_time": "", - } + update_properties = { + "status": "activated", + "delete_record_id": "", + "delete_time": "", + } - update_query = f""" - UPDATE "{self.db_name}_graph"."Memory" - SET properties = ( - properties::jsonb || %s::jsonb - )::text::agtype, - deletetime = NULL - WHERE {where_clause} - """ + update_query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = ( + properties::jsonb || %s::jsonb + )::text::agtype, + deletetime = NULL + WHERE {where_clause} + """ - logger.info(f"[recover_memory_by_mem_cube_id] Update query: {update_query}") - logger.info( - f"[recover_memory_by_mem_cube_id] update_properties: {update_properties}" - ) + logger.info(f"[recover_memory_by_mem_cube_id] Update query: {update_query}") + logger.info( + f"[recover_memory_by_mem_cube_id] update_properties: {update_properties}" + ) - update_params = [json.dumps(update_properties), *where_params] - cursor.execute(update_query, update_params) - updated_count = cursor.rowcount + update_params = [json.dumps(update_properties), *where_params] + cursor.execute(update_query, update_params) + updated_count = cursor.rowcount - logger.info( - f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes" - ) - return updated_count + logger.info( + f"[recover_memory_by_mem_cube_id] Recovered (updated) {updated_count} nodes" + ) + return updated_count except Exception as e: logger.error( From 2394fba803113a33511c2a3275962a4aff2a15aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 2 Mar 2026 21:48:50 +0800 Subject: [PATCH 06/10] feat:add _warm_up_on_startup --- src/memos/configs/graph_db.py | 21 +++++++++++++++++++++ src/memos/graph_dbs/polardb.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 9b1ce7f9d..070a83ec9 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -202,6 +202,27 @@ class PolarDBGraphDBConfig(BaseConfig): default=100, description="Maximum number of connections in the connection pool", ) + connection_wait_timeout: int = Field( + default=30, + ge=1, + le=3600, + description="Max seconds to wait for a connection slot before raising (0 = wait forever, not recommended)", + ) + skip_connection_health_check: bool = Field( + default=False, + description=( + "If True, skip SELECT 1 health check when getting connections (~1-2ms saved per request). " + "Use only when pool/network is reliable." + ), + ) + warm_up_on_startup: bool = Field( + default=True, + description=( + "If True, run search_by_fulltext warm-up on pool connections at init to reduce " + "first-query latency (~200ms planning). Requires user_name in config." + ), + ) + @model_validator(mode="after") def validate_config(self): diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index a48eaf0f7..a183af3c7 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -139,6 +139,8 @@ def __init__(self, config: PolarDBGraphDBConfig): password = config.get("password") maxconn = config.get("maxconn", 100) self._connection_wait_timeout = config.get("connection_wait_timeout", 60) + self._skip_connection_health_check = config.get("skip_connection_health_check", False) + self._warm_up_on_startup = config.get("warm_up_on_startup", False) else: self.db_name = config.db_name self.user_name = config.user_name @@ -148,6 +150,9 @@ def __init__(self, config: PolarDBGraphDBConfig): password = config.password maxconn = config.maxconn if hasattr(config, "maxconn") else 100 self._connection_wait_timeout = getattr(config, "connection_wait_timeout", 60) + self._skip_connection_health_check = getattr(config, "skip_connection_health_check", False) + self._warm_up_on_startup = getattr(config, "warm_up_on_startup", False) + logger.info(f"connection_wait_timeout:{self._connection_wait_timeout},_skip_connection_health_check:{self._skip_connection_health_check},warm_up_on_startup:{self._warm_up_on_startup}") logger.info( f" db_name: {self.db_name} maxconn: {maxconn} connection_wait_timeout: {self._connection_wait_timeout}s" @@ -163,12 +168,14 @@ def __init__(self, config: PolarDBGraphDBConfig): password=password, dbname=self.db_name, connect_timeout=60, # Connection timeout in seconds - keepalives_idle=40, # Seconds of inactivity before sending keepalive (should be < server idle timeout) + keepalives_idle=120, # Seconds of inactivity before sending keepalive (should be < server idle timeout) keepalives_interval=15, # Seconds between keepalive retries keepalives_count=5, # Number of keepalive retries before considering connection dead ) self._semaphore = threading.BoundedSemaphore(maxconn) + if self._warm_up_on_startup: + self._warm_up_search_connections() """ # Handle auto_create @@ -193,6 +200,27 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) + def _warm_up_search_connections(self, user_name: str | None = None) -> None: + user_name = user_name or self.user_name + if not user_name: + logger.debug("[warm_up] Skipped: no user_name for warm-up") + return + warm_count = min(5, self.connection_pool.minconn) + for _ in range(warm_count): + try: + self.search_by_fulltext( + query_words=["warmup"], + top_k=1, + user_name=user_name, + ) + except Exception as e: + logger.debug(f"[warm_up] Warm-up query failed (non-fatal): {e}") + break + logger.info(f"[warm_up] Pre-warmed {warm_count} connections for search_by_fulltext") + + def warm_up_search_connections(self, user_name: str | None = None) -> None: + self._warm_up_search_connections(user_name) + @contextmanager def _get_connection(self): timeout = getattr(self, "_connection_wait_timeout", 5) From 009215644ab63186345bcd42e58f5eb7a81865e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 2 Mar 2026 22:44:54 +0800 Subject: [PATCH 07/10] feat:add _warm_up_on_startup --- src/memos/graph_dbs/polardb.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index a183af3c7..2267f3b3f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -176,6 +176,7 @@ def __init__(self, config: PolarDBGraphDBConfig): self._semaphore = threading.BoundedSemaphore(maxconn) if self._warm_up_on_startup: self._warm_up_search_connections() + # self._warm_up_connections() """ # Handle auto_create @@ -221,6 +222,21 @@ def _warm_up_search_connections(self, user_name: str | None = None) -> None: def warm_up_search_connections(self, user_name: str | None = None) -> None: self._warm_up_search_connections(user_name) + def _warm_up_connections(self): + warm_count = self.connection_pool.minconn + preheated = 0 + logger.info(f"[warm_up] Pre-warming {warm_count} connections...") + for _ in range(warm_count): + try: + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") + preheated += 1 + except Exception as e: + logger.warning(f"[warm_up] Failed to pre-warm connection: {e}") + continue + logger.info(f"[warm_up] Pre-warmed {preheated}/{warm_count} connections") + @contextmanager def _get_connection(self): timeout = getattr(self, "_connection_wait_timeout", 5) @@ -244,7 +260,7 @@ def _get_connection(self): yield conn except Exception as e: broken = True - logger.error(f"Connection failed or broken: {e}") + logger.exception(f"Connection failed or broken: {e}") raise finally: if conn: From 972c4de8b3cb071326aa3b8ede5b2af21921da61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Tue, 3 Mar 2026 09:41:36 +0800 Subject: [PATCH 08/10] feat:add _warm_up_on_startup --- src/memos/api/config.py | 9 ++++++++- src/memos/configs/graph_db.py | 9 ++++++++- src/memos/graph_dbs/polardb.py | 25 +++++++++++++++---------- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 65049b0c2..45233660f 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -675,7 +675,14 @@ def get_polardb_config(user_id: str | None = None) -> dict[str, Any]: "user_name": user_name, "use_multi_db": use_multi_db, "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", "1024")), + # .env: CONNECTION_WAIT_TIMEOUT, SKIP_CONNECTION_HEALTH_CHECK, WARM_UP_ON_STARTUP_BY_FULL, WARM_UP_ON_STARTUP_BY_ALL + "connection_wait_timeout": int(os.getenv("CONNECTION_WAIT_TIMEOUT", "60")), + "skip_connection_health_check": os.getenv("SKIP_CONNECTION_HEALTH_CHECK", "false").lower() == "true", + "warm_up_on_startup_by_full": ( + os.getenv("WARM_UP_ON_STARTUP_BY_FULL") or os.getenv("WARM_UP_ON_STARTUP", "true") + ).lower() == "true", + "warm_up_on_startup_by_all": os.getenv("WARM_UP_ON_STARTUP_BY_ALL", "false").lower() == "true", } @staticmethod diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 070a83ec9..37cdb7555 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -215,13 +215,20 @@ class PolarDBGraphDBConfig(BaseConfig): "Use only when pool/network is reliable." ), ) - warm_up_on_startup: bool = Field( + warm_up_on_startup_by_full: bool = Field( default=True, description=( "If True, run search_by_fulltext warm-up on pool connections at init to reduce " "first-query latency (~200ms planning). Requires user_name in config." ), ) + warm_up_on_startup_by_all: bool = Field( + default=False, + description=( + "If True, run all connection warm-up on pool connections at init to reduce " + "first-query latency (~200ms planning). Requires user_name in config." + ), + ) @model_validator(mode="after") diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 2267f3b3f..bb25ac50e 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -140,7 +140,8 @@ def __init__(self, config: PolarDBGraphDBConfig): maxconn = config.get("maxconn", 100) self._connection_wait_timeout = config.get("connection_wait_timeout", 60) self._skip_connection_health_check = config.get("skip_connection_health_check", False) - self._warm_up_on_startup = config.get("warm_up_on_startup", False) + self._warm_up_on_startup_by_full = config.get("warm_up_on_startup_by_full", False) + self._warm_up_on_startup_by_all = config.get("warm_up_on_startup_by_all", False) else: self.db_name = config.db_name self.user_name = config.user_name @@ -151,8 +152,9 @@ def __init__(self, config: PolarDBGraphDBConfig): maxconn = config.maxconn if hasattr(config, "maxconn") else 100 self._connection_wait_timeout = getattr(config, "connection_wait_timeout", 60) self._skip_connection_health_check = getattr(config, "skip_connection_health_check", False) - self._warm_up_on_startup = getattr(config, "warm_up_on_startup", False) - logger.info(f"connection_wait_timeout:{self._connection_wait_timeout},_skip_connection_health_check:{self._skip_connection_health_check},warm_up_on_startup:{self._warm_up_on_startup}") + self._warm_up_on_startup_by_full = getattr(config, "warm_up_on_startup_by_full", False) + self._warm_up_on_startup_by_all = getattr(config, "warm_up_on_startup_by_all", False) + logger.info(f"polardb init config connection_wait_timeout:{self._connection_wait_timeout},_skip_connection_health_check:{self._skip_connection_health_check},warm_up_on_startup_by_full:{self._warm_up_on_startup_by_full},warm_up_on_startup_by_all:{self._warm_up_on_startup_by_all}") logger.info( f" db_name: {self.db_name} maxconn: {maxconn} connection_wait_timeout: {self._connection_wait_timeout}s" @@ -174,9 +176,10 @@ def __init__(self, config: PolarDBGraphDBConfig): ) self._semaphore = threading.BoundedSemaphore(maxconn) - if self._warm_up_on_startup: - self._warm_up_search_connections() - # self._warm_up_connections() + if self._warm_up_on_startup_by_full: + self._warm_up_search_connections_by_full() + if self._warm_up_on_startup_by_all: + self._warm_up_connections_by_all() """ # Handle auto_create @@ -201,7 +204,8 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) - def _warm_up_search_connections(self, user_name: str | None = None) -> None: + def _warm_up_search_connections_by_full(self, user_name: str | None = None) -> None: + logger.info(f"--warm_up_search_connections_by_full--start-up----") user_name = user_name or self.user_name if not user_name: logger.debug("[warm_up] Skipped: no user_name for warm-up") @@ -219,10 +223,11 @@ def _warm_up_search_connections(self, user_name: str | None = None) -> None: break logger.info(f"[warm_up] Pre-warmed {warm_count} connections for search_by_fulltext") - def warm_up_search_connections(self, user_name: str | None = None) -> None: - self._warm_up_search_connections(user_name) + def warm_up_search_connections_by_full(self, user_name: str | None = None) -> None: + self._warm_up_search_connections_by_full(user_name) - def _warm_up_connections(self): + def _warm_up_connections_by_all(self): + logger.info(f"--_warm_up_connections_by_all--start-up") warm_count = self.connection_pool.minconn preheated = 0 logger.info(f"[warm_up] Pre-warming {warm_count} connections...") From 36cd3eafd3018f63c700bd3f40076bab7ec452ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Tue, 3 Mar 2026 09:49:16 +0800 Subject: [PATCH 09/10] feat:add _warm_up_on_startup --- src/memos/api/config.py | 4 +--- src/memos/configs/graph_db.py | 1 - src/memos/graph_dbs/polardb.py | 17 ++++++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 45233660f..be8a396e7 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -679,9 +679,7 @@ def get_polardb_config(user_id: str | None = None) -> dict[str, Any]: # .env: CONNECTION_WAIT_TIMEOUT, SKIP_CONNECTION_HEALTH_CHECK, WARM_UP_ON_STARTUP_BY_FULL, WARM_UP_ON_STARTUP_BY_ALL "connection_wait_timeout": int(os.getenv("CONNECTION_WAIT_TIMEOUT", "60")), "skip_connection_health_check": os.getenv("SKIP_CONNECTION_HEALTH_CHECK", "false").lower() == "true", - "warm_up_on_startup_by_full": ( - os.getenv("WARM_UP_ON_STARTUP_BY_FULL") or os.getenv("WARM_UP_ON_STARTUP", "true") - ).lower() == "true", + "warm_up_on_startup_by_full": os.getenv("WARM_UP_ON_STARTUP_BY_FULL","false").lower() == "true", "warm_up_on_startup_by_all": os.getenv("WARM_UP_ON_STARTUP_BY_ALL", "false").lower() == "true", } diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 37cdb7555..5900d2357 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -230,7 +230,6 @@ class PolarDBGraphDBConfig(BaseConfig): ), ) - @model_validator(mode="after") def validate_config(self): """Validate config.""" diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index bb25ac50e..ac03cda2e 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -151,10 +151,14 @@ def __init__(self, config: PolarDBGraphDBConfig): password = config.password maxconn = config.maxconn if hasattr(config, "maxconn") else 100 self._connection_wait_timeout = getattr(config, "connection_wait_timeout", 60) - self._skip_connection_health_check = getattr(config, "skip_connection_health_check", False) + self._skip_connection_health_check = getattr( + config, "skip_connection_health_check", False + ) self._warm_up_on_startup_by_full = getattr(config, "warm_up_on_startup_by_full", False) self._warm_up_on_startup_by_all = getattr(config, "warm_up_on_startup_by_all", False) - logger.info(f"polardb init config connection_wait_timeout:{self._connection_wait_timeout},_skip_connection_health_check:{self._skip_connection_health_check},warm_up_on_startup_by_full:{self._warm_up_on_startup_by_full},warm_up_on_startup_by_all:{self._warm_up_on_startup_by_all}") + logger.info( + f"polardb init config connection_wait_timeout:{self._connection_wait_timeout},_skip_connection_health_check:{self._skip_connection_health_check},warm_up_on_startup_by_full:{self._warm_up_on_startup_by_full},warm_up_on_startup_by_all:{self._warm_up_on_startup_by_all}" + ) logger.info( f" db_name: {self.db_name} maxconn: {maxconn} connection_wait_timeout: {self._connection_wait_timeout}s" @@ -205,7 +209,7 @@ def _get_config_value(self, key: str, default=None): return getattr(self.config, key, default) def _warm_up_search_connections_by_full(self, user_name: str | None = None) -> None: - logger.info(f"--warm_up_search_connections_by_full--start-up----") + logger.info("--warm_up_search_connections_by_full--start-up----") user_name = user_name or self.user_name if not user_name: logger.debug("[warm_up] Skipped: no user_name for warm-up") @@ -227,15 +231,14 @@ def warm_up_search_connections_by_full(self, user_name: str | None = None) -> No self._warm_up_search_connections_by_full(user_name) def _warm_up_connections_by_all(self): - logger.info(f"--_warm_up_connections_by_all--start-up") + logger.info("--_warm_up_connections_by_all--start-up") warm_count = self.connection_pool.minconn preheated = 0 logger.info(f"[warm_up] Pre-warming {warm_count} connections...") for _ in range(warm_count): try: - with self._get_connection() as conn: - with conn.cursor() as cur: - cur.execute("SELECT 1") + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute("SELECT 1") preheated += 1 except Exception as e: logger.warning(f"[warm_up] Failed to pre-warm connection: {e}") From 6a2b8c292c9f59564ba064e2799707bb0304a32a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Tue, 3 Mar 2026 09:58:28 +0800 Subject: [PATCH 10/10] feat:config format --- src/memos/api/config.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index be8a396e7..fa12bcf55 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -678,9 +678,14 @@ def get_polardb_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", "1024")), # .env: CONNECTION_WAIT_TIMEOUT, SKIP_CONNECTION_HEALTH_CHECK, WARM_UP_ON_STARTUP_BY_FULL, WARM_UP_ON_STARTUP_BY_ALL "connection_wait_timeout": int(os.getenv("CONNECTION_WAIT_TIMEOUT", "60")), - "skip_connection_health_check": os.getenv("SKIP_CONNECTION_HEALTH_CHECK", "false").lower() == "true", - "warm_up_on_startup_by_full": os.getenv("WARM_UP_ON_STARTUP_BY_FULL","false").lower() == "true", - "warm_up_on_startup_by_all": os.getenv("WARM_UP_ON_STARTUP_BY_ALL", "false").lower() == "true", + "skip_connection_health_check": os.getenv( + "SKIP_CONNECTION_HEALTH_CHECK", "false" + ).lower() + == "true", + "warm_up_on_startup_by_full": os.getenv("WARM_UP_ON_STARTUP_BY_FULL", "false").lower() + == "true", + "warm_up_on_startup_by_all": os.getenv("WARM_UP_ON_STARTUP_BY_ALL", "false").lower() + == "true", } @staticmethod