diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 5d50cf68f..f24f1072c 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -15,9 +15,6 @@ logger = get_logger(__name__) -# Graph database configuration -GRAPH_NAME = "test_memos_graph" - def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: node_id = item["id"] @@ -119,6 +116,7 @@ def __init__(self, config: PolarDBGraphDBConfig): but it will be removed automatically before returning to external consumers. """ import psycopg2 + import psycopg2.pool self.config = config @@ -137,12 +135,26 @@ def __init__(self, config: PolarDBGraphDBConfig): port = config.port user = config.user password = config.password - + """ # Create connection self.connection = psycopg2.connect( - host=host, port=port, user=user, password=password, dbname=self.db_name + host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 ) - self.connection.autocommit = True + """ + + # Create connection pool + self.connection_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=5, + maxconn=2000, + host=host, + port=port, + user=user, + password=password, + dbname=self.db_name, + ) + + # Keep a reference to the pool for cleanup + self._pool_closed = False """ # Handle auto_create @@ -167,6 +179,17 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) + def _get_connection(self): + """Get a connection from the pool.""" + if self._pool_closed: + raise RuntimeError("Connection pool has been closed") + return self.connection_pool.getconn() + + def _return_connection(self, connection): + """Return a connection to the pool.""" + if not self._pool_closed and connection: + self.connection_pool.putconn(connection) + def _ensure_database_exists(self): """Create database if it doesn't exist.""" try: @@ -180,8 +203,10 @@ def _ensure_database_exists(self): @timed def _create_graph(self): """Create PostgreSQL schema and table for graph storage.""" + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Create schema if it doesn't exist cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') logger.info(f"Schema '{self.db_name}_graph' ensured.") @@ -229,6 +254,8 @@ def _create_graph(self): except Exception as e: logger.error(f"Failed to create graph schema: {e}") raise e + finally: + self._return_connection(conn) def create_index( self, @@ -241,8 +268,10 @@ def create_index( Create indexes for embedding and other fields. Note: This creates PostgreSQL indexes on the underlying tables. """ + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Create indexes on the underlying PostgreSQL tables # Apache AGE stores data in regular PostgreSQL tables cursor.execute(f""" @@ -262,6 +291,8 @@ def create_index( logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") + finally: + self._return_connection(conn) def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: """Get count of memory nodes by type.""" @@ -274,14 +305,18 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params = [f'"{memory_type}"', f'"{user_name}"'] + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return result[0] if result else 0 except Exception as e: logger.error(f"[get_memory_count] Failed: {e}") return -1 + finally: + self._return_connection(conn) @timed def node_not_exist(self, scope: str, user_name: str | None = None) -> int: @@ -296,14 +331,18 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: query += "\nLIMIT 1" params = [f'"{scope}"', f'"{user_name}"'] + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return 1 if result else 0 except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def remove_oldest_memory( @@ -329,9 +368,9 @@ def remove_oldest_memory( OFFSET %s """ select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest] - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Execute query to get IDs to delete cursor.execute(select_query, select_params) ids_to_delete = [row[0] for row in cursor.fetchall()] @@ -357,6 +396,8 @@ def remove_oldest_memory( except Exception as e: logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: @@ -414,12 +455,16 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def delete_node(self, id: str, user_name: str | None = None) -> None: @@ -440,18 +485,24 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def create_extension(self): extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Ensure in the correct database context cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] @@ -474,11 +525,15 @@ def create_extension(self): except Exception as e: logger.warning(f"Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def create_graph(self): + # Get a connection from the pool + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(f""" SELECT COUNT(*) FROM ag_catalog.ag_graph WHERE name = '{self.db_name}_graph'; @@ -493,6 +548,8 @@ def create_graph(self): except Exception as e: logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def create_edge(self): @@ -501,9 +558,11 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: + print(f"🪶 Creating elabel: {label_name}") + conn = self._get_connection() logger.info(f"Creating elabel: {label_name}") try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") logger.info(f"Successfully created elabel: {label_name}") except Exception as e: @@ -512,6 +571,8 @@ def create_edge(self): else: logger.warning(f"Failed to create label {label_name}: {e}") logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) + finally: + self._return_connection(conn) @timed def add_edge( @@ -543,13 +604,16 @@ def add_edge( ); """ + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") except Exception as e: logger.error(f"Failed to insert edge: {e}", exc_info=True) raise + finally: + self._return_connection(conn) @timed def delete_edge(self, source_id: str, target_id: str, type: str) -> None: @@ -564,10 +628,13 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: DELETE FROM "{self.db_name}_graph"."Edges" WHERE source_id = %s AND target_id = %s AND edge_type = %s """ - - with self.connection.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type)) - logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") + finally: + self._return_connection(conn) @timed def edge_exists_old( @@ -622,11 +689,14 @@ def edge_exists_old( WHERE {where_clause} LIMIT 1 """ - - with self.connection.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result is not None + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result is not None + finally: + self._return_connection(conn) @timed def edge_exists( @@ -674,10 +744,14 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - with self.connection.cursor() as cursor: - cursor.execute(query) - result = cursor.fetchone() - return result is not None and result[0] is not None + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None + finally: + self._return_connection(conn) @timed def get_node( @@ -720,16 +794,17 @@ def format_param_value(value: str) -> str: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(format_param_value(user_name)) + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() if result: if include_embedding: - node_id, properties_json, embedding_json = result + properties_json, embedding_json = result else: - node_id, properties_json = result + properties_json = result embedding_json = None # Parse properties from JSONB if it's a string @@ -755,13 +830,19 @@ def format_param_value(value: str) -> str: logger.warning(f"Failed to parse embedding for node {id}") return self._parse_node( - {"id": id, "memory": properties.get("memory", ""), **properties} + { + "id": id, + "memory": json.loads(properties[1]).get("memory", ""), + **json.loads(properties[1]), + } ) return None except Exception as e: logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) return None + finally: + self._return_connection(conn) @timed def get_nodes( @@ -803,43 +884,47 @@ def get_nodes( query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') - with self.connection.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() - nodes = [] - for row in results: - node_id, properties_json, embedding_json = row - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {node_id}") - properties = {} - else: - properties = properties_json if properties_json else {} + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} - # Parse embedding from JSONB if it exists - if embedding_json is not None: - try: - # remove embedding - """ - embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json - # properties["embedding"] = embedding - """ - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - nodes.append( - self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } + # Parse embedding from JSONB if it exists + if embedding_json is not None: + try: + # remove embedding + """ + embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json + # properties["embedding"] = embedding + """ + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } + ) ) - ) - return nodes + return nodes + finally: + self._return_connection(conn) @timed def get_edges_old( @@ -1057,8 +1142,9 @@ def get_children_with_embeddings( WHERE t.cid::graphid = m.id; """ + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1113,6 +1199,8 @@ def get_children_with_embeddings( except Exception as e: logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: """Get the path of nodes from source to target within a limited depth.""" @@ -1174,9 +1262,9 @@ def get_subgraph( r) $$ ) as (centers agtype, neighbors agtype, rels agtype); """ - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() @@ -1250,6 +1338,8 @@ def get_subgraph( except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) return {"core_node": None, "neighbors": [], "edges": []} + finally: + self._return_connection(conn) def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" @@ -1333,24 +1423,28 @@ def search_by_embedding( """ params = [vector] - with self.connection.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - """ - polarId = row[0] # id - properties = row[1] # properties - # embedding = row[3] # embedding - """ - oldid = row[3] # old_id - score = row[4] # scope - id_val = str(oldid) - score_val = float(score) - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) - return output[:top_k] + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + """ + polarId = row[0] # id + properties = row[1] # properties + # embedding = row[3] # embedding + """ + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid) + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + return output[:top_k] + finally: + self._return_connection(conn) @timed def get_by_metadata( @@ -1439,13 +1533,16 @@ def get_by_metadata( """ ids = [] + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() ids = [str(item[0]).strip('"') for item in results] except Exception as e: logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + finally: + self._return_connection(conn) return ids @@ -1596,8 +1693,9 @@ def get_grouped_counts( GROUP BY {", ".join(group_by_fields)} """ + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: # Handle parameterized query if params and isinstance(params, list): cursor.execute(query, params) @@ -1622,6 +1720,8 @@ def get_grouped_counts( except Exception as e: logger.error(f"Failed to get grouped counts: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def deduplicate_nodes(self) -> None: """Deduplicate redundant or semantically similar nodes.""" @@ -1653,10 +1753,13 @@ def clear(self, user_name: str | None = None) -> None: DETACH DELETE n $$) AS (result agtype) """ - - with self.connection.cursor() as cursor: - cursor.execute(query) - logger.info("Cleared all nodes from database.") + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") + finally: + self._return_connection(conn) except Exception as e: logger.error(f"[ERROR] Failed to clear database: {e}") @@ -1678,7 +1781,7 @@ def export_graph( } """ user_name = user_name if user_name else self._get_config_value("user_name") - + conn = self._get_connection() try: # Export nodes if include_embedding: @@ -1694,16 +1797,16 @@ def export_graph( WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype """ - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(node_query) node_results = cursor.fetchall() nodes = [] for row in node_results: if include_embedding: - node_id, properties_json, embedding_json = row + properties_json, embedding_json = row else: - node_id, properties_json = row + properties_json = row embedding_json = None # Parse properties from JSONB if it's a string @@ -1733,7 +1836,10 @@ def export_graph( except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e + finally: + self._return_connection(conn) + conn = self._get_connection() try: # Export edges using cypher query edge_query = f""" @@ -1744,7 +1850,7 @@ def export_graph( $$) AS (source agtype, target agtype, edge agtype) """ - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(edge_query) edge_results = cursor.fetchall() edges = [] @@ -1806,6 +1912,9 @@ def export_graph( except Exception as e: logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e + finally: + self._return_connection(conn) + return {"nodes": nodes, "edges": edges} @timed @@ -1820,9 +1929,12 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: RETURN count(n) $$) AS (count agtype) """ - - result = self.execute_query(query) - return int(result.one_or_none()["count"].value) + conn = self._get_connection() + try: + result = self.execute_query(query, conn) + return int(result.one_or_none()["count"].value) + finally: + self._return_connection(conn) @timed def get_all_memory_items( @@ -1863,8 +1975,9 @@ def get_all_memory_items( """ nodes = [] node_ids = set() + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -1886,6 +1999,8 @@ def get_all_memory_items( except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) return nodes else: @@ -1899,8 +2014,9 @@ def get_all_memory_items( """ nodes = [] + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -1917,6 +2033,8 @@ def get_all_memory_items( except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) + finally: + self._return_connection(conn) return nodes @@ -2119,8 +2237,9 @@ def get_structure_optimization_candidates( candidates = [] node_ids = set() + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() logger.info(f"Found {len(results)} structure optimization candidates") @@ -2197,6 +2316,8 @@ def get_structure_optimization_candidates( except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) + finally: + self._return_connection(conn) return candidates @@ -2319,44 +2440,48 @@ def add_node( elif len(embedding_vector) == 768: embedding_column = "embedding_768" - with self.connection.cursor() as cursor: - # Delete existing record first (if any) - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(delete_query, (id,)) - # - get_graph_id_query = f""" - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(get_graph_id_query, (id,)) - graph_id = cursor.fetchone()[0] - properties["graph_id"] = str(graph_id) - - # Then insert new record - if embedding_vector: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s, - %s - ) + conn = self._get_connection() + try: + with conn.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) """ - cursor.execute( - insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) - ) - else: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s + cursor.execute(delete_query, (id,)) + # + get_graph_id_query = f""" + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(get_graph_id_query, (id,)) + graph_id = cursor.fetchone()[0] + properties["graph_id"] = str(graph_id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) ) - """ - cursor.execute(insert_query, (id, json.dumps(properties))) - logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info(f"Added node {id} to graph '{self.db_name}_graph'.") + finally: + self._return_connection(conn) def _build_node_from_agtype(self, node_agtype, embedding=None): """ @@ -2463,8 +2588,9 @@ def get_neighbors_by_tag( logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -2513,6 +2639,8 @@ def get_neighbors_by_tag( except Exception as e: logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def get_neighbors_by_tag_ccl( self, @@ -2758,9 +2886,9 @@ def get_edges( RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ - + conn = self._get_connection() try: - with self.connection.cursor() as cursor: + with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -2805,6 +2933,8 @@ def get_edges( except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] + finally: + self._return_connection(conn) def _convert_graph_edges(self, core_node: dict) -> dict: import copy @@ -2812,6 +2942,8 @@ def _convert_graph_edges(self, core_node: dict) -> dict: data = copy.deepcopy(core_node) id_map = {} core_node = data.get("core_node", {}) + if not core_node: + return core_node core_meta = core_node.get("metadata", {}) if "graph_id" in core_meta and "id" in core_node: id_map[core_meta["graph_id"]] = core_node["id"]