Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 1 addition & 154 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2723,160 +2723,7 @@ def export_graph(
raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
finally:
self._return_connection(conn)

conn = None
try:
conn = self._get_connection()
# Build Cypher WHERE conditions for edges
cypher_where_conditions = []
if user_name:
cypher_where_conditions.append(f"a.user_name = '{user_name}'")
cypher_where_conditions.append(f"b.user_name = '{user_name}'")
if user_id:
cypher_where_conditions.append(f"a.user_id = '{user_id}'")
cypher_where_conditions.append(f"b.user_id = '{user_id}'")

# Add memory_type filter condition for edges (apply to both source and target nodes)
if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
# Escape single quotes in memory_type values for Cypher
escaped_memory_types = [mt.replace("'", "\\'") for mt in memory_type]
memory_type_list_str = ", ".join([f"'{mt}'" for mt in escaped_memory_types])
# Cypher IN syntax: a.memory_type IN ['LongTermMemory', 'WorkingMemory']
cypher_where_conditions.append(f"a.memory_type IN [{memory_type_list_str}]")
cypher_where_conditions.append(f"b.memory_type IN [{memory_type_list_str}]")

# Add status filter for edges: if not passed, exclude deleted; otherwise filter by IN list
if status is None:
# Default behavior: exclude deleted entries
cypher_where_conditions.append("a.status <> 'deleted' AND b.status <> 'deleted'")
elif isinstance(status, list) and len(status) > 0:
escaped_statuses = [st.replace("'", "\\'") for st in status]
status_list_str = ", ".join([f"'{st}'" for st in escaped_statuses])
cypher_where_conditions.append(f"a.status IN [{status_list_str}]")
cypher_where_conditions.append(f"b.status IN [{status_list_str}]")

# Build filter conditions for edges (apply to both source and target nodes)
filter_where_clause = self._build_filter_conditions_cypher(filter)
logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}")
if filter_where_clause:
# _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists
# Remove the leading " AND " and replace n. with a. for source node and b. for target node
filter_clause = filter_where_clause.strip()
if filter_clause.startswith("AND "):
filter_clause = filter_clause[4:].strip()
# Replace n. with a. for source node and create a copy for target node
source_filter = filter_clause.replace("n.", "a.")
target_filter = filter_clause.replace("n.", "b.")
# Combine source and target filters with AND
combined_filter = f"({source_filter}) AND ({target_filter})"
cypher_where_conditions.append(combined_filter)

cypher_where_clause = ""
if cypher_where_conditions:
cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}"

# Get total count of edges before pagination
count_edge_query = f"""
SELECT COUNT(*)
FROM (
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (a:Memory)-[r]->(b:Memory)
{cypher_where_clause}
RETURN a.id AS source, b.id AS target, type(r) as edge
$$) AS (source agtype, target agtype, edge agtype)
) AS edges
"""
logger.info(f"[export_graph edges count] Query: {count_edge_query}")
with conn.cursor() as cursor:
cursor.execute(count_edge_query)
total_edges = cursor.fetchone()[0]

# Export edges using cypher query
# Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery
# Build pagination clause if needed
edge_pagination_clause = ""
if use_pagination:
edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}"

edge_query = f"""
SELECT source, target, edge FROM (
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (a:Memory)-[r]->(b:Memory)
{cypher_where_clause}
RETURN a.id AS source, b.id AS target, type(r) as edge
ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC,
COALESCE(b.created_at, '1970-01-01T00:00:00') DESC,
a.id DESC, b.id DESC
$$) AS (source agtype, target agtype, edge agtype)
) AS edges
{edge_pagination_clause}
"""
logger.info(f"[export_graph edges] Query: {edge_query}")
with conn.cursor() as cursor:
cursor.execute(edge_query)
edge_results = cursor.fetchall()
edges = []

for row in edge_results:
source_agtype, target_agtype, edge_agtype = row

# Extract and clean source
source_raw = (
source_agtype.value
if hasattr(source_agtype, "value")
else str(source_agtype)
)
if (
isinstance(source_raw, str)
and source_raw.startswith('"')
and source_raw.endswith('"')
):
source = source_raw[1:-1]
else:
source = str(source_raw)

# Extract and clean target
target_raw = (
target_agtype.value
if hasattr(target_agtype, "value")
else str(target_agtype)
)
if (
isinstance(target_raw, str)
and target_raw.startswith('"')
and target_raw.endswith('"')
):
target = target_raw[1:-1]
else:
target = str(target_raw)

# Extract and clean edge type
type_raw = (
edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype)
)
if (
isinstance(type_raw, str)
and type_raw.startswith('"')
and type_raw.endswith('"')
):
edge_type = type_raw[1:-1]
else:
edge_type = str(type_raw)

edges.append(
{
"source": source,
"target": target,
"type": edge_type,
}
)

except Exception as e:
logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True)
raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
finally:
self._return_connection(conn)

edges = []
return {
"nodes": nodes,
"edges": edges,
Expand Down