diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index ad75f4b65..a87c83510 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -124,7 +124,6 @@ def __init__(self, config: PolarDBGraphDBConfig): All node queries will enforce `user_name` in WHERE conditions and store it in metadata, but it will be removed automatically before returning to external consumers. """ - import psycopg2 import psycopg2.pool self.config = config @@ -177,6 +176,7 @@ def __init__(self, config: PolarDBGraphDBConfig): keepalives_idle=120, # Seconds of inactivity before sending keepalive (should be < server idle timeout) keepalives_interval=15, # Seconds between keepalive retries keepalives_count=5, # Number of keepalive retries before considering connection dead + options=f"-c search_path={self.db_name}_graph,ag_catalog,$user,public", ) self._semaphore = threading.BoundedSemaphore(maxconn) @@ -2300,48 +2300,57 @@ def export_graph( status: list[str] | None = None, **kwargs, ) -> dict[str, Any]: - """ - Export all graph nodes and edges in a structured form. - Args: - include_embedding (bool): Whether to include the large embedding field. - user_name (str, optional): User name for filtering in non-multi-db mode - user_id (str, optional): User ID for filtering - page (int, optional): Page number (starts from 1). If None, exports all data without pagination. - page_size (int, optional): Number of items per page. If None, exports all data without pagination. - filter (dict, optional): Filter dictionary for metadata filtering. Supports "and", "or" logic and operators: - - "=": equality - - "in": value in list - - "contains": array contains value - - "gt", "lt", "gte", "lte": comparison operators - - "like": fuzzy matching - Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]} - memory_type (list[str], optional): List of memory_type values to filter by. If provided, only nodes/edges with - memory_type in this list will be exported. Example: ["LongTermMemory", "WorkingMemory"] - status (list[str], optional): List of status values to filter by. If not provided, only nodes/edges with - status != 'deleted' are exported. If provided, only nodes/edges with status in this list are exported. - Example: ["activated"] or ["activated", "archived"] - - Returns: - { - "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ], - "total_nodes": int, # Total number of nodes matching the filter criteria - "total_edges": int, # Total number of edges matching the filter criteria - } - """ + start_time = time.perf_counter() logger.info( f" export_graph include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}" ) user_id = user_id if user_id else self._get_config_value("user_id") - # Initialize total counts + extracted_object_type: str | None = None + extracted_mem_cube_id: str | None = None + + def _extract_special_filter_values(filter_obj): + nonlocal extracted_object_type, extracted_mem_cube_id + + if isinstance(filter_obj, dict): + if "and" in filter_obj and isinstance(filter_obj["and"], list): + cleaned_items = [] + for item in filter_obj["and"]: + cleaned_item = _extract_special_filter_values(item) + if cleaned_item not in (None, {}, []): + cleaned_items.append(cleaned_item) + return {"and": cleaned_items} if cleaned_items else None + + if "or" in filter_obj and isinstance(filter_obj["or"], list): + cleaned_items = [] + for item in filter_obj["or"]: + cleaned_item = _extract_special_filter_values(item) + if cleaned_item not in (None, {}, []): + cleaned_items.append(cleaned_item) + return {"or": cleaned_items} if cleaned_items else None + + cleaned_dict = {} + for key, value in filter_obj.items(): + if key == "object_type" and isinstance(value, str): + if extracted_object_type is None: + extracted_object_type = value + continue + if key == "mem_cube_id" and isinstance(value, str): + if extracted_mem_cube_id is None: + extracted_mem_cube_id = value + continue + cleaned_dict[key] = value + return cleaned_dict if cleaned_dict else None + + return filter_obj + + filter_for_sql = _extract_special_filter_values(filter) + total_nodes = 0 total_edges = 0 - # Determine if pagination is needed use_pagination = page is not None and page_size is not None - # Validate pagination parameters if pagination is enabled if use_pagination: if page < 1: page = 1 @@ -2351,129 +2360,148 @@ def export_graph( else: offset = None - try: - with self._get_connection() as conn: - # Build WHERE conditions - where_conditions = [] - if user_name: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" - ) - if user_id: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" - ) + where_conditions = [] + has_object_type_filter = ( + isinstance(extracted_object_type, str) + and isinstance(extracted_mem_cube_id, str) + and extracted_mem_cube_id.strip() != "" + ) - # Add memory_type filter condition - if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: - # Escape memory_type values and build IN clause - memory_type_values = [] - for mt in memory_type: - # Escape single quotes in memory_type value - escaped_memory_type = str(mt).replace("'", "''") - memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") - memory_type_in_clause = ", ".join(memory_type_values) - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" - ) + if user_name and not has_object_type_filter: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) - # Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list - if status is None: - # Default behavior: exclude deleted entries - where_conditions.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype" - ) - elif isinstance(status, list) and len(status) > 0: - # status IN (list) - status_values = [] - for st in status: - escaped_status = str(st).replace("'", "''") - status_values.append(f"'\"{escaped_status}\"'::agtype") - status_in_clause = ", ".join(status_values) - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})" - ) + if has_object_type_filter: + object_type_value = extracted_object_type.strip().lower() + escaped_mem_cube_id = extracted_mem_cube_id.replace("'", "''") + if object_type_value == "user": + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) <> '\"{escaped_mem_cube_id}\"'::agtype" + ) + elif object_type_value == "public": + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{escaped_mem_cube_id}\"'::agtype" + ) - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[export_graph] filter_conditions: {filter_conditions}") - if filter_conditions: - where_conditions.extend(filter_conditions) + if user_id: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" + ) + + if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: + memory_type_values = [] + for mt in memory_type: + escaped_memory_type = str(mt).replace("'", "''") + memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") + memory_type_in_clause = ", ".join(memory_type_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" + ) + + if status is None: + where_conditions.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype" + ) + elif isinstance(status, list) and len(status) > 0: + status_values = [] + for st in status: + escaped_status = str(st).replace("'", "''") + status_values.append(f"'\"{escaped_status}\"'::agtype") + status_in_clause = ", ".join(status_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})" + ) + + filter_conditions = self._build_filter_conditions_sql(filter_for_sql) + logger.info(f"[export_graph] filter_conditions: {filter_conditions}") + if filter_conditions: + where_conditions.extend(filter_conditions) - where_clause = "" - if where_conditions: - where_clause = f"WHERE {' AND '.join(where_conditions)}" + where_clause = "" + if where_conditions: + where_clause = f"WHERE {' AND '.join(where_conditions)}" - # Get total count of nodes before pagination - count_node_query = f""" - SELECT COUNT(*) + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + + order_clause = """ + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,id DESC + """ + if include_embedding: + node_query = f""" + WITH filtered AS ( + SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" {where_clause} - """ - logger.info(f"[export_graph nodes count] Query: {count_node_query}") - with conn.cursor() as cursor: - cursor.execute(count_node_query) - total_nodes = cursor.fetchone()[0] + ) + SELECT p.id, p.properties, p.embedding, c.total_count + FROM (SELECT COUNT(*) AS total_count FROM filtered) c + LEFT JOIN LATERAL ( + SELECT id, properties, embedding + FROM filtered + {order_clause} + {pagination_clause} + ) p ON TRUE + """ + else: + node_query = f""" + WITH filtered AS ( + SELECT id, properties + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ) + SELECT p.id, p.properties, c.total_count + FROM (SELECT COUNT(*) AS total_count FROM filtered) c + LEFT JOIN LATERAL ( + SELECT id, properties + FROM filtered + {order_clause} + {pagination_clause} + ) p ON TRUE + """ + logger.info(f"[export_graph nodes] Query: {node_query}") - # Export nodes - # Build pagination clause if needed - pagination_clause = "" - if use_pagination: - pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + try: + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(node_query) + node_results = cursor.fetchall() + nodes = [] + for row in node_results: if include_embedding: - node_query = f""" - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} - """ + row_id, properties_json, embedding_json, row_total_count = row else: - node_query = f""" - SELECT id, properties - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} - """ - logger.info(f"[export_graph nodes] Query: {node_query}") - with conn.cursor() as cursor: - cursor.execute(node_query) - node_results = cursor.fetchall() - nodes = [] - - for row in node_results: - if include_embedding: - """row is (id, properties, embedding)""" - _, properties_json, embedding_json = row - else: - """row is (id, properties)""" - _, properties_json = row - embedding_json = None + row_id, properties_json, row_total_count = row + embedding_json = None - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except json.JSONDecodeError: - properties = {} - else: - properties = properties_json if properties_json else {} + if row_total_count is not None: + total_nodes = int(row_total_count) + + if row_id is None: + continue + + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except json.JSONDecodeError: + properties = {} + else: + properties = properties_json if properties_json else {} - # Remove embedding field if include_embedding is False - if not include_embedding: - properties.pop("embedding", None) - elif include_embedding and embedding_json is not None: - properties["embedding"] = embedding_json + if not include_embedding: + properties.pop("embedding", None) + elif include_embedding and embedding_json is not None: + properties["embedding"] = embedding_json - nodes.append(self._parse_node(properties)) + nodes.append(self._parse_node(properties)) except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e + elapsed = (time.perf_counter() - start_time) * 1000 + logger.info("export internal took %.1f ms", elapsed) edges = [] return {