Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 185 additions & 4 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,26 +502,36 @@ def edge_exists(
return result.single() is not None

# Graph Query & Reasoning
def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None:
"""
Retrieve the metadata and memory of a node.
Args:
id: Node identifier.
Returns:
Dictionary of node fields, or None if not found.
"""
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
logger.info(f"[get_node] id: {id}")
user_name = kwargs.get("user_name")
where_user = ""
params = {"id": id}
if not self.config.use_multi_db and (self.config.user_name or user_name):
if user_name is not None:
where_user = " AND n.user_name = $user_name"
params["user_name"] = user_name

query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n"
logger.info(f"[get_node] query: {query}")

with self.driver.session(database=self.db_name) as session:
record = session.run(query, params).single()
return self._parse_node(dict(record["n"])) if record else None
if not record:
return None

node_dict = dict(record["n"])
if include_embedding is False:
for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"):
node_dict.pop(key, None)

return self._parse_node(node_dict)

def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
"""
Expand Down Expand Up @@ -1940,3 +1950,174 @@ def exist_user_name(self, user_name: str) -> dict[str, bool]:
f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True
)
raise

def delete_node_by_mem_cube_id(
self,
mem_kube_id: dict | None = None,
delete_record_id: dict | None = None,
deleted_type: bool = False,
) -> int:
"""
Delete nodes by mem_kube_id (user_name) and delete_record_id.

Args:
mem_kube_id: The mem_kube_id which corresponds to user_name in the table.
Can be dict or str. If dict, will extract the value.
delete_record_id: The delete_record_id to match.
Can be dict or str. If dict, will extract the value.
deleted_type: If True, performs hard delete (directly deletes records).
If False, performs soft delete (updates status to 'deleted' and sets delete_record_id and delete_time).

Returns:
int: Number of nodes deleted or updated.
"""
# Handle dict type parameters (extract value if dict)
if isinstance(mem_kube_id, dict):
# Try to get a value from dict, use first value if multiple
mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None

if isinstance(delete_record_id, dict):
delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None

# Validate required parameters
if not mem_kube_id:
logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided")
return 0

if not delete_record_id:
logger.warning(
"[delete_node_by_mem_cube_id] delete_record_id is required but not provided"
)
return 0

# Convert to string if needed
mem_kube_id = str(mem_kube_id) if mem_kube_id else None
delete_record_id = str(delete_record_id) if delete_record_id else None

logger.info(
f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, "
f"delete_record_id={delete_record_id}, deleted_type={deleted_type}"
)

try:
with self.driver.session(database=self.db_name) as session:
if deleted_type:
# Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id
query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id
DETACH DELETE n
"""
logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {query}")

result = session.run(
query, mem_kube_id=mem_kube_id, delete_record_id=delete_record_id
)
summary = result.consume()
deleted_count = summary.counters.nodes_deleted if summary.counters else 0

logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes")
return deleted_count
else:
# Soft delete: WHERE user_name = mem_kube_id (only user_name condition)
current_time = datetime.utcnow().isoformat()

query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_kube_id
SET n.status = $status,
n.delete_record_id = $delete_record_id,
n.delete_time = $delete_time
RETURN count(n) AS updated_count
"""
logger.info(f"[delete_node_by_mem_cube_id] Soft delete query: {query}")

result = session.run(
query,
mem_kube_id=mem_kube_id,
status="deleted",
delete_record_id=delete_record_id,
delete_time=current_time,
)
record = result.single()
updated_count = record["updated_count"] if record else 0

logger.info(
f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True
)
raise

def recover_memory_by_mem_kube_id(
self,
mem_kube_id: str | None = None,
delete_record_id: str | None = None,
) -> int:
"""
Recover memory nodes by mem_kube_id (user_name) and delete_record_id.

This function updates the status to 'activated', and clears delete_record_id and delete_time.

Args:
mem_kube_id: The mem_kube_id which corresponds to user_name in the table.
delete_record_id: The delete_record_id to match.

Returns:
int: Number of nodes recovered (updated).
"""
# Validate required parameters
if not mem_kube_id:
logger.warning(
"[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided"
)
return 0

if not delete_record_id:
logger.warning(
"[recover_memory_by_mem_kube_id] delete_record_id is required but not provided"
)
return 0

logger.info(
f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, "
f"delete_record_id={delete_record_id}"
)

try:
with self.driver.session(database=self.db_name) as session:
query = """
MATCH (n:Memory)
WHERE n.user_name = $mem_kube_id AND n.delete_record_id = $delete_record_id
SET n.status = $status,
n.delete_record_id = $delete_record_id_empty,
n.delete_time = $delete_time_empty
RETURN count(n) AS updated_count
"""
logger.info(f"[recover_memory_by_mem_kube_id] Update query: {query}")

result = session.run(
query,
mem_kube_id=mem_kube_id,
delete_record_id=delete_record_id,
status="activated",
delete_record_id_empty="",
delete_time_empty="",
)
record = result.single()
updated_count = record["updated_count"] if record else 0

logger.info(
f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes"
)
return updated_count

except Exception as e:
logger.error(
f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True
)
raise
55 changes: 55 additions & 0 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,3 +1056,58 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}")
new_node["metadata"]["embedding"] = None
return new_node

def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
"""Get user names by memory ids.

Args:
memory_ids: List of memory node IDs to query.

Returns:
dict[str, str | None]: Dictionary mapping memory_id to user_name.
- Key: memory_id
- Value: user_name if exists, None if memory_id does not exist
Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None}
"""
if not memory_ids:
return {}

logger.info(
f"[ neo4j_community get_user_names_by_memory_ids] Querying memory_ids {memory_ids}"
)

try:
with self.driver.session(database=self.db_name) as session:
# Query to get memory_id and user_name pairs
query = """
MATCH (n:Memory)
WHERE n.id IN $memory_ids
RETURN n.id AS memory_id, n.user_name AS user_name
"""
logger.info(f"[get_user_names_by_memory_ids] query: {query}")

result = session.run(query, memory_ids=memory_ids)
result_dict = {}

# Build result dictionary from query results
for record in result:
memory_id = record["memory_id"]
user_name = record["user_name"]
result_dict[memory_id] = user_name if user_name else None

# Set None for memory_ids that were not found
for mid in memory_ids:
if mid not in result_dict:
result_dict[mid] = None

logger.info(
f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, "
f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names"
)

return result_dict
except Exception as e:
logger.error(
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
)
raise
Loading