Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,26 +502,36 @@ def edge_exists(
return result.single() is not None

# Graph Query & Reasoning
def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None:
"""
Retrieve the metadata and memory of a node.
Args:
id: Node identifier.
Returns:
Dictionary of node fields, or None if not found.
"""
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
logger.info(f"[get_node] id: {id}")
user_name = kwargs.get("user_name")
where_user = ""
params = {"id": id}
if not self.config.use_multi_db and (self.config.user_name or user_name):
if user_name is not None:
where_user = " AND n.user_name = $user_name"
params["user_name"] = user_name

query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n"
logger.info(f"[get_node] query: {query}")

with self.driver.session(database=self.db_name) as session:
record = session.run(query, params).single()
return self._parse_node(dict(record["n"])) if record else None
if not record:
return None

node_dict = dict(record["n"])
if include_embedding is False:
for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"):
node_dict.pop(key, None)

return self._parse_node(node_dict)

def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
"""
Expand Down
55 changes: 55 additions & 0 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,3 +1056,58 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}")
new_node["metadata"]["embedding"] = None
return new_node

def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
"""Get user names by memory ids.

Args:
memory_ids: List of memory node IDs to query.

Returns:
dict[str, str | None]: Dictionary mapping memory_id to user_name.
- Key: memory_id
- Value: user_name if exists, None if memory_id does not exist
Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None}
"""
if not memory_ids:
return {}

logger.info(
f"[ neo4j_community get_user_names_by_memory_ids] Querying memory_ids {memory_ids}"
)

try:
with self.driver.session(database=self.db_name) as session:
# Query to get memory_id and user_name pairs
query = """
MATCH (n:Memory)
WHERE n.id IN $memory_ids
RETURN n.id AS memory_id, n.user_name AS user_name
"""
logger.info(f"[get_user_names_by_memory_ids] query: {query}")

result = session.run(query, memory_ids=memory_ids)
result_dict = {}

# Build result dictionary from query results
for record in result:
memory_id = record["memory_id"]
user_name = record["user_name"]
result_dict[memory_id] = user_name if user_name else None

# Set None for memory_ids that were not found
for mid in memory_ids:
if mid not in result_dict:
result_dict[mid] = None

logger.info(
f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, "
f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names"
)

return result_dict
except Exception as e:
logger.error(
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
)
raise
231 changes: 229 additions & 2 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,6 +2534,7 @@ def export_graph(
page: int | None = None,
page_size: int | None = None,
filter: dict | None = None,
memory_type: list[str] | None = None,
**kwargs,
) -> dict[str, Any]:
"""
Expand All @@ -2551,6 +2552,8 @@ def export_graph(
- "gt", "lt", "gte", "lte": comparison operators
- "like": fuzzy matching
Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]}
memory_type (list[str], optional): List of memory_type values to filter by. If provided, only nodes/edges with
memory_type in this list will be exported. Example: ["LongTermMemory", "WorkingMemory"]

Returns:
{
Expand All @@ -2561,7 +2564,7 @@ def export_graph(
}
"""
logger.info(
f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}"
f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}"
)
user_id = user_id if user_id else self._get_config_value("user_id")

Expand Down Expand Up @@ -2596,6 +2599,19 @@ def export_graph(
f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype"
)

# Add memory_type filter condition
if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
# Escape memory_type values and build IN clause
memory_type_values = []
for mt in memory_type:
# Escape single quotes in memory_type value
escaped_memory_type = str(mt).replace("'", "''")
memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype")
memory_type_in_clause = ", ".join(memory_type_values)
where_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})"
)

# Build filter conditions using common method
filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[export_graph] filter_conditions: {filter_conditions}")
Expand Down Expand Up @@ -2691,6 +2707,15 @@ def export_graph(
cypher_where_conditions.append(f"a.user_id = '{user_id}'")
cypher_where_conditions.append(f"b.user_id = '{user_id}'")

# Add memory_type filter condition for edges (apply to both source and target nodes)
if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
# Escape single quotes in memory_type values for Cypher
escaped_memory_types = [mt.replace("'", "\\'") for mt in memory_type]
memory_type_list_str = ", ".join([f"'{mt}'" for mt in escaped_memory_types])
# Cypher IN syntax: a.memory_type IN ['LongTermMemory', 'WorkingMemory']
cypher_where_conditions.append(f"a.memory_type IN [{memory_type_list_str}]")
cypher_where_conditions.append(f"b.memory_type IN [{memory_type_list_str}]")

# Build filter conditions for edges (apply to both source and target nodes)
filter_where_clause = self._build_filter_conditions_cypher(filter)
logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}")
Expand Down Expand Up @@ -4310,7 +4335,7 @@ def _build_user_name_and_kb_ids_conditions_sql(
user_name_conditions = []
effective_user_name = user_name if user_name else default_user_name

if effective_user_name:
if user_name:
user_name_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{effective_user_name}\"'::agtype"
)
Expand Down Expand Up @@ -5441,3 +5466,205 @@ def escape_user_name(un: str) -> str:
raise
finally:
self._return_connection(conn)

@timed
def delete_node_by_mem_cube_id(
self,
mem_kube_id: dict | None = None,
delete_record_id: dict | None = None,
deleted_type: bool = False,
) -> int:
# Handle dict type parameters (extract value if dict)
if isinstance(mem_kube_id, dict):
# Try to get a value from dict, use first value if multiple
mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None

if isinstance(delete_record_id, dict):
delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None

# Validate required parameters
if not mem_kube_id:
logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided")
return 0

if not delete_record_id:
logger.warning(
"[delete_node_by_mem_cube_id] delete_record_id is required but not provided"
)
return 0

# Convert to string if needed
mem_kube_id = str(mem_kube_id) if mem_kube_id else None
delete_record_id = str(delete_record_id) if delete_record_id else None

logger.info(
f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, "
f"delete_record_id={delete_record_id}, deleted_type={deleted_type}"
)

conn = None
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Build WHERE clause for user_name using parameter binding
# user_name must match mem_kube_id
user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"

# Prepare parameter for user_name
user_name_param = self.format_param_value(mem_kube_id)

if deleted_type:
# Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id
delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype"
where_clause = f"{user_name_condition} AND {delete_record_id_condition}"

# Prepare parameters for WHERE clause (user_name and delete_record_id)
where_params = [user_name_param, self.format_param_value(delete_record_id)]

delete_query = f"""
DELETE FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {delete_query}")

cursor.execute(delete_query, where_params)
deleted_count = cursor.rowcount

logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
return deleted_count
else:
# Soft delete: WHERE user_name = mem_kube_id (only user_name condition)
where_clause = user_name_condition

current_time = datetime.utcnow().isoformat()
# Build update properties JSON with status, delete_time, and delete_record_id
# Use PostgreSQL JSONB merge operator (||) to update properties
# Convert agtype to jsonb, merge with new values, then convert back to agtype
update_query = f"""
UPDATE "{self.db_name}_graph"."Memory"
SET properties = (
properties::jsonb || %s::jsonb
)::text::agtype
WHERE {where_clause}
"""
# Create update JSON with the three fields to update
update_properties = {
"status": "deleted",
"delete_time": current_time,
"delete_record_id": delete_record_id,
}
logger.info(
f"[delete_node_by_mem_cube_id] Soft delete update_query: {update_query}"
)
logger.info(
f"[delete_node_by_mem_cube_id] update_properties: {update_properties}"
)

# Combine update_properties JSON with user_name parameter (only user_name, no delete_record_id)
update_params = [json.dumps(update_properties), user_name_param]
cursor.execute(update_query, update_params)
updated_count = cursor.rowcount

logger.info(
f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True
)
raise
finally:
self._return_connection(conn)

@timed
def recover_memory_by_mem_kube_id(
self,
mem_kube_id: str | None = None,
delete_record_id: str | None = None,
) -> int:
"""
Recover memory nodes by mem_kube_id (user_name) and delete_record_id.

This function updates the status to 'activated', and clears delete_record_id and delete_time.

Args:
mem_kube_id: The mem_kube_id which corresponds to user_name in the table.
delete_record_id: The delete_record_id to match.

Returns:
int: Number of nodes recovered (updated).
"""
logger.info(f"recover_memory_by_mem_kube_id mem_kube_id:{mem_kube_id},delete_record_id:{delete_record_id}")
# Validate required parameters
if not mem_kube_id:
logger.warning(
"[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided"
)
return 0

if not delete_record_id:
logger.warning(
"[recover_memory_by_mem_kube_id] delete_record_id is required but not provided"
)
return 0

logger.info(
f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, "
f"delete_record_id={delete_record_id}"
)

conn = None
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Build WHERE clause for user_name and delete_record_id using parameter binding
user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype"
where_clause = f"{user_name_condition} AND {delete_record_id_condition}"

# Prepare parameters for WHERE clause
where_params = [
self.format_param_value(mem_kube_id),
self.format_param_value(delete_record_id),
]

# Build update properties: status='activated', delete_record_id='', delete_time=''
# Use PostgreSQL JSONB merge operator (||) to update properties
update_properties = {
"status": "activated",
"delete_record_id": "",
"delete_time": "",
}

update_query = f"""
UPDATE "{self.db_name}_graph"."Memory"
SET properties = (
properties::jsonb || %s::jsonb
)::text::agtype
WHERE {where_clause}
"""

logger.info(f"[recover_memory_by_mem_kube_id] Update query: {update_query}")
logger.info(
f"[recover_memory_by_mem_kube_id] update_properties: {update_properties}"
)

# Combine update_properties JSON with where_params
update_params = [json.dumps(update_properties), *where_params]
cursor.execute(update_query, update_params)
updated_count = cursor.rowcount

logger.info(
f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True
)
raise
finally:
self._return_connection(conn)
Loading