Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 166 additions & 138 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(self, config: PolarDBGraphDBConfig):
All node queries will enforce `user_name` in WHERE conditions and store it in metadata,
but it will be removed automatically before returning to external consumers.
"""
import psycopg2
import psycopg2.pool

self.config = config
Expand Down Expand Up @@ -177,6 +176,7 @@ def __init__(self, config: PolarDBGraphDBConfig):
keepalives_idle=120, # Seconds of inactivity before sending keepalive (should be < server idle timeout)
keepalives_interval=15, # Seconds between keepalive retries
keepalives_count=5, # Number of keepalive retries before considering connection dead
options=f"-c search_path={self.db_name}_graph,ag_catalog,$user,public",
)

self._semaphore = threading.BoundedSemaphore(maxconn)
Expand Down Expand Up @@ -2300,48 +2300,57 @@ def export_graph(
status: list[str] | None = None,
**kwargs,
) -> dict[str, Any]:
"""
Export all graph nodes and edges in a structured form.
Args:
include_embedding (bool): Whether to include the large embedding field.
user_name (str, optional): User name for filtering in non-multi-db mode
user_id (str, optional): User ID for filtering
page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
page_size (int, optional): Number of items per page. If None, exports all data without pagination.
filter (dict, optional): Filter dictionary for metadata filtering. Supports "and", "or" logic and operators:
- "=": equality
- "in": value in list
- "contains": array contains value
- "gt", "lt", "gte", "lte": comparison operators
- "like": fuzzy matching
Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]}
memory_type (list[str], optional): List of memory_type values to filter by. If provided, only nodes/edges with
memory_type in this list will be exported. Example: ["LongTermMemory", "WorkingMemory"]
status (list[str], optional): List of status values to filter by. If not provided, only nodes/edges with
status != 'deleted' are exported. If provided, only nodes/edges with status in this list are exported.
Example: ["activated"] or ["activated", "archived"]

Returns:
{
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ],
"total_nodes": int, # Total number of nodes matching the filter criteria
"total_edges": int, # Total number of edges matching the filter criteria
}
"""
start_time = time.perf_counter()
logger.info(
f" export_graph include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}"
)
user_id = user_id if user_id else self._get_config_value("user_id")

# Initialize total counts
extracted_object_type: str | None = None
extracted_mem_cube_id: str | None = None

def _extract_special_filter_values(filter_obj):
nonlocal extracted_object_type, extracted_mem_cube_id

if isinstance(filter_obj, dict):
if "and" in filter_obj and isinstance(filter_obj["and"], list):
cleaned_items = []
for item in filter_obj["and"]:
cleaned_item = _extract_special_filter_values(item)
if cleaned_item not in (None, {}, []):
cleaned_items.append(cleaned_item)
return {"and": cleaned_items} if cleaned_items else None

if "or" in filter_obj and isinstance(filter_obj["or"], list):
cleaned_items = []
for item in filter_obj["or"]:
cleaned_item = _extract_special_filter_values(item)
if cleaned_item not in (None, {}, []):
cleaned_items.append(cleaned_item)
return {"or": cleaned_items} if cleaned_items else None

cleaned_dict = {}
for key, value in filter_obj.items():
if key == "object_type" and isinstance(value, str):
if extracted_object_type is None:
extracted_object_type = value
continue
if key == "mem_cube_id" and isinstance(value, str):
if extracted_mem_cube_id is None:
extracted_mem_cube_id = value
continue
cleaned_dict[key] = value
return cleaned_dict if cleaned_dict else None

return filter_obj

filter_for_sql = _extract_special_filter_values(filter)

total_nodes = 0
total_edges = 0

# Determine if pagination is needed
use_pagination = page is not None and page_size is not None

# Validate pagination parameters if pagination is enabled
if use_pagination:
if page < 1:
page = 1
Expand All @@ -2351,129 +2360,148 @@ def export_graph(
else:
offset = None

try:
with self._get_connection() as conn:
# Build WHERE conditions
where_conditions = []
if user_name:
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
)
if user_id:
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype"
)
where_conditions = []
has_object_type_filter = (
isinstance(extracted_object_type, str)
and isinstance(extracted_mem_cube_id, str)
and extracted_mem_cube_id.strip() != ""
)

# Add memory_type filter condition
if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
# Escape memory_type values and build IN clause
memory_type_values = []
for mt in memory_type:
# Escape single quotes in memory_type value
escaped_memory_type = str(mt).replace("'", "''")
memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype")
memory_type_in_clause = ", ".join(memory_type_values)
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})"
)
if user_name and not has_object_type_filter:
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
)

# Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list
if status is None:
# Default behavior: exclude deleted entries
where_conditions.append(
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype"
)
elif isinstance(status, list) and len(status) > 0:
# status IN (list)
status_values = []
for st in status:
escaped_status = str(st).replace("'", "''")
status_values.append(f"'\"{escaped_status}\"'::agtype")
status_in_clause = ", ".join(status_values)
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})"
)
if has_object_type_filter:
object_type_value = extracted_object_type.strip().lower()
escaped_mem_cube_id = extracted_mem_cube_id.replace("'", "''")
if object_type_value == "user":
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) <> '\"{escaped_mem_cube_id}\"'::agtype"
)
elif object_type_value == "public":
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{escaped_mem_cube_id}\"'::agtype"
)

# Build filter conditions using common method
filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[export_graph] filter_conditions: {filter_conditions}")
if filter_conditions:
where_conditions.extend(filter_conditions)
if user_id:
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype"
)

if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
memory_type_values = []
for mt in memory_type:
escaped_memory_type = str(mt).replace("'", "''")
memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype")
memory_type_in_clause = ", ".join(memory_type_values)
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})"
)

if status is None:
where_conditions.append(
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype"
)
elif isinstance(status, list) and len(status) > 0:
status_values = []
for st in status:
escaped_status = str(st).replace("'", "''")
status_values.append(f"'\"{escaped_status}\"'::agtype")
status_in_clause = ", ".join(status_values)
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})"
)

filter_conditions = self._build_filter_conditions_sql(filter_for_sql)
logger.info(f"[export_graph] filter_conditions: {filter_conditions}")
if filter_conditions:
where_conditions.extend(filter_conditions)

where_clause = ""
if where_conditions:
where_clause = f"WHERE {' AND '.join(where_conditions)}"
where_clause = ""
if where_conditions:
where_clause = f"WHERE {' AND '.join(where_conditions)}"

# Get total count of nodes before pagination
count_node_query = f"""
SELECT COUNT(*)
pagination_clause = ""
if use_pagination:
pagination_clause = f"LIMIT {page_size} OFFSET {offset}"

order_clause = """
ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,id DESC
"""
if include_embedding:
node_query = f"""
WITH filtered AS (
SELECT id, properties, embedding
FROM "{self.db_name}_graph"."Memory"
{where_clause}
"""
logger.info(f"[export_graph nodes count] Query: {count_node_query}")
with conn.cursor() as cursor:
cursor.execute(count_node_query)
total_nodes = cursor.fetchone()[0]
)
SELECT p.id, p.properties, p.embedding, c.total_count
FROM (SELECT COUNT(*) AS total_count FROM filtered) c
LEFT JOIN LATERAL (
SELECT id, properties, embedding
FROM filtered
{order_clause}
{pagination_clause}
) p ON TRUE
"""
else:
node_query = f"""
WITH filtered AS (
SELECT id, properties
FROM "{self.db_name}_graph"."Memory"
{where_clause}
)
SELECT p.id, p.properties, c.total_count
FROM (SELECT COUNT(*) AS total_count FROM filtered) c
LEFT JOIN LATERAL (
SELECT id, properties
FROM filtered
{order_clause}
{pagination_clause}
) p ON TRUE
"""
logger.info(f"[export_graph nodes] Query: {node_query}")

# Export nodes
# Build pagination clause if needed
pagination_clause = ""
if use_pagination:
pagination_clause = f"LIMIT {page_size} OFFSET {offset}"
try:
with self._get_connection() as conn, conn.cursor() as cursor:
cursor.execute(node_query)
node_results = cursor.fetchall()
nodes = []

for row in node_results:
if include_embedding:
node_query = f"""
SELECT id, properties, embedding
FROM "{self.db_name}_graph"."Memory"
{where_clause}
ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,
id DESC
{pagination_clause}
"""
row_id, properties_json, embedding_json, row_total_count = row
else:
node_query = f"""
SELECT id, properties
FROM "{self.db_name}_graph"."Memory"
{where_clause}
ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,
id DESC
{pagination_clause}
"""
logger.info(f"[export_graph nodes] Query: {node_query}")
with conn.cursor() as cursor:
cursor.execute(node_query)
node_results = cursor.fetchall()
nodes = []

for row in node_results:
if include_embedding:
"""row is (id, properties, embedding)"""
_, properties_json, embedding_json = row
else:
"""row is (id, properties)"""
_, properties_json = row
embedding_json = None
row_id, properties_json, row_total_count = row
embedding_json = None

# Parse properties from JSONB if it's a string
if isinstance(properties_json, str):
try:
properties = json.loads(properties_json)
except json.JSONDecodeError:
properties = {}
else:
properties = properties_json if properties_json else {}
if row_total_count is not None:
total_nodes = int(row_total_count)

if row_id is None:
continue

if isinstance(properties_json, str):
try:
properties = json.loads(properties_json)
except json.JSONDecodeError:
properties = {}
else:
properties = properties_json if properties_json else {}

# Remove embedding field if include_embedding is False
if not include_embedding:
properties.pop("embedding", None)
elif include_embedding and embedding_json is not None:
properties["embedding"] = embedding_json
if not include_embedding:
properties.pop("embedding", None)
elif include_embedding and embedding_json is not None:
properties["embedding"] = embedding_json

nodes.append(self._parse_node(properties))
nodes.append(self._parse_node(properties))

except Exception as e:
logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True)
raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
elapsed = (time.perf_counter() - start_time) * 1000
logger.info("export internal took %.1f ms", elapsed)

edges = []
return {
Expand Down