diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index f0a23e39b..6f4597982 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2578,6 +2578,7 @@ def export_graph( logger.info( f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}" ) + start_time = time.time() user_id = user_id if user_id else self._get_config_value("user_id") # Initialize total counts @@ -2651,26 +2652,14 @@ def export_graph( if where_conditions: where_clause = f"WHERE {' AND '.join(where_conditions)}" - # Get total count of nodes before pagination - count_node_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - {where_clause} - """ - logger.info(f"[export_graph nodes count] Query: {count_node_query}") - with conn.cursor() as cursor: - cursor.execute(count_node_query) - total_nodes = cursor.fetchone()[0] - - # Export nodes - # Build pagination clause if needed pagination_clause = "" if use_pagination: pagination_clause = f"LIMIT {page_size} OFFSET {offset}" if include_embedding: node_query = f""" - SELECT id, properties, embedding + SELECT id, properties, embedding, + COUNT(*) OVER () AS total_nodes FROM "{self.db_name}_graph"."Memory" {where_clause} ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, @@ -2679,7 +2668,8 @@ def export_graph( """ else: node_query = f""" - SELECT id, properties + SELECT id, properties, + COUNT(*) OVER () AS total_nodes FROM "{self.db_name}_graph"."Memory" {where_clause} ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, @@ -2690,15 +2680,16 @@ def export_graph( with conn.cursor() as cursor: cursor.execute(node_query) node_results = cursor.fetchall() + total_nodes = int(node_results[0][-1]) if node_results else 0 nodes = [] for row in node_results: if include_embedding: - """row is (id, properties, embedding)""" - _, properties_json, embedding_json = row + """row is (id, properties, embedding, total_nodes)""" + _, properties_json, embedding_json, _ = row else: - """row is (id, properties)""" - _, properties_json = row + """row is (id, properties, total_nodes)""" + _, properties_json, _ = row embedding_json = None # Parse properties from JSONB if it's a string @@ -2723,160 +2714,9 @@ def export_graph( raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e finally: self._return_connection(conn) - - conn = None - try: - conn = self._get_connection() - # Build Cypher WHERE conditions for edges - cypher_where_conditions = [] - if user_name: - cypher_where_conditions.append(f"a.user_name = '{user_name}'") - cypher_where_conditions.append(f"b.user_name = '{user_name}'") - if user_id: - cypher_where_conditions.append(f"a.user_id = '{user_id}'") - cypher_where_conditions.append(f"b.user_id = '{user_id}'") - - # Add memory_type filter condition for edges (apply to both source and target nodes) - if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: - # Escape single quotes in memory_type values for Cypher - escaped_memory_types = [mt.replace("'", "\\'") for mt in memory_type] - memory_type_list_str = ", ".join([f"'{mt}'" for mt in escaped_memory_types]) - # Cypher IN syntax: a.memory_type IN ['LongTermMemory', 'WorkingMemory'] - cypher_where_conditions.append(f"a.memory_type IN [{memory_type_list_str}]") - cypher_where_conditions.append(f"b.memory_type IN [{memory_type_list_str}]") - - # Add status filter for edges: if not passed, exclude deleted; otherwise filter by IN list - if status is None: - # Default behavior: exclude deleted entries - cypher_where_conditions.append("a.status <> 'deleted' AND b.status <> 'deleted'") - elif isinstance(status, list) and len(status) > 0: - escaped_statuses = [st.replace("'", "\\'") for st in status] - status_list_str = ", ".join([f"'{st}'" for st in escaped_statuses]) - cypher_where_conditions.append(f"a.status IN [{status_list_str}]") - cypher_where_conditions.append(f"b.status IN [{status_list_str}]") - - # Build filter conditions for edges (apply to both source and target nodes) - filter_where_clause = self._build_filter_conditions_cypher(filter) - logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}") - if filter_where_clause: - # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists - # Remove the leading " AND " and replace n. with a. for source node and b. for target node - filter_clause = filter_where_clause.strip() - if filter_clause.startswith("AND "): - filter_clause = filter_clause[4:].strip() - # Replace n. with a. for source node and create a copy for target node - source_filter = filter_clause.replace("n.", "a.") - target_filter = filter_clause.replace("n.", "b.") - # Combine source and target filters with AND - combined_filter = f"({source_filter}) AND ({target_filter})" - cypher_where_conditions.append(combined_filter) - - cypher_where_clause = "" - if cypher_where_conditions: - cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" - - # Get total count of edges before pagination - count_edge_query = f""" - SELECT COUNT(*) - FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - {cypher_where_clause} - RETURN a.id AS source, b.id AS target, type(r) as edge - $$) AS (source agtype, target agtype, edge agtype) - ) AS edges - """ - logger.info(f"[export_graph edges count] Query: {count_edge_query}") - with conn.cursor() as cursor: - cursor.execute(count_edge_query) - total_edges = cursor.fetchone()[0] - - # Export edges using cypher query - # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery - # Build pagination clause if needed - edge_pagination_clause = "" - if use_pagination: - edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - - edge_query = f""" - SELECT source, target, edge FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - {cypher_where_clause} - RETURN a.id AS source, b.id AS target, type(r) as edge - ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC, - COALESCE(b.created_at, '1970-01-01T00:00:00') DESC, - a.id DESC, b.id DESC - $$) AS (source agtype, target agtype, edge agtype) - ) AS edges - {edge_pagination_clause} - """ - logger.info(f"[export_graph edges] Query: {edge_query}") - with conn.cursor() as cursor: - cursor.execute(edge_query) - edge_results = cursor.fetchall() - edges = [] - - for row in edge_results: - source_agtype, target_agtype, edge_agtype = row - - # Extract and clean source - source_raw = ( - source_agtype.value - if hasattr(source_agtype, "value") - else str(source_agtype) - ) - if ( - isinstance(source_raw, str) - and source_raw.startswith('"') - and source_raw.endswith('"') - ): - source = source_raw[1:-1] - else: - source = str(source_raw) - - # Extract and clean target - target_raw = ( - target_agtype.value - if hasattr(target_agtype, "value") - else str(target_agtype) - ) - if ( - isinstance(target_raw, str) - and target_raw.startswith('"') - and target_raw.endswith('"') - ): - target = target_raw[1:-1] - else: - target = str(target_raw) - - # Extract and clean edge type - type_raw = ( - edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype) - ) - if ( - isinstance(type_raw, str) - and type_raw.startswith('"') - and type_raw.endswith('"') - ): - edge_type = type_raw[1:-1] - else: - edge_type = str(type_raw) - - edges.append( - { - "source": source, - "target": target, - "type": edge_type, - } - ) - - except Exception as e: - logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) - raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e - finally: - self._return_connection(conn) - + edges = [] + elapsed_time = time.time() - start_time + logger.info(f"[export_graph] query completed successfully in {elapsed_time:.2f}s") return { "nodes": nodes, "edges": edges,