Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,8 @@ def add_edge(

start_time = time.time()
if not source_id or not target_id:
logger.info(f"Edge '{source_id}' and '{target_id}' are both None")
raise ValueError("[add_edge] source_id and target_id must be provided")
logger.error(f"Edge '{source_id}' and '{target_id}' are both None")
return

source_exists = self.get_node(source_id) is not None
target_exists = self.get_node(target_id) is not None
Expand All @@ -806,11 +806,6 @@ def add_edge(
logger.warning(
"[add_edge] Source %s or target %s does not exist.", source_exists, target_exists
)
logger.info(
"[add_edge_error] Source %s or target %s does not exist.",
source_exists,
target_exists,
)
return

properties = {}
Expand Down Expand Up @@ -4039,34 +4034,47 @@ def get_edges(
...
]
"""
start_time = time.time()
logger.info(f" get_edges id:{id},type:{type},direction:{direction},user_name:{user_name}")
user_name = user_name if user_name else self._get_config_value("user_name")

if direction == "OUTGOING":
pattern = "(a:Memory)-[r]->(b:Memory)"
where_clause = f"a.id = '{id}'"
elif direction == "INCOMING":
pattern = "(a:Memory)<-[r]-(b:Memory)"
where_clause = f"a.id = '{id}'"
elif direction == "ANY":
pattern = "(a:Memory)-[r]-(b:Memory)"
where_clause = f"a.id = '{id}' OR b.id = '{id}'"
else:
if direction not in ("OUTGOING", "INCOMING", "ANY"):
raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")

# Add type filter
if type != "ANY":
where_clause += f" AND type(r) = '{type}'"

# Add user filter
where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
# Escape single quotes for safe embedding in Cypher string
id_esc = (id or "").replace("'", "''")
user_esc = (user_name or "").replace("'", "''")
type_esc = (type or "").replace("'", "''")
type_filter = f" AND type(r) = '{type_esc}'" if type != "ANY" else ""
logger.info(f"type_filter:{type_filter}")

if direction == "OUTGOING":
cypher_body = f"""
MATCH (a:Memory)-[r:{type}]->(b:Memory)
WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
"""
elif direction == "INCOMING":
cypher_body = f"""
MATCH (b:Memory)<-[r:{type}]-(a:Memory)
WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
"""
else: # ANY: union of OUTGOING and INCOMING
cypher_body = f"""
MATCH (a:Memory)-[r]->(b:Memory)
WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter}
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
UNION ALL
MATCH (b:Memory)<-[r]-(a:Memory)
WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter}
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
"""
query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH {pattern}
WHERE {where_clause}
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
{cypher_body.strip()}
$$) AS (from_id agtype, to_id agtype, edge_type agtype)
"""
logger.info(f"get_edges query:{query}")
conn = None
try:
conn = self._get_connection()
Expand Down Expand Up @@ -4110,6 +4118,8 @@ def get_edges(
edge_type = str(edge_type_raw)

edges.append({"from": from_id, "to": to_id, "type": edge_type})
elapsed_time = time.time() - start_time
logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s")
return edges

except Exception as e:
Expand Down