diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 70d40f13c..ed3dc7010 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -502,7 +502,7 @@ def edge_exists( return result.single() is not None # Graph Query & Reasoning - def get_node(self, id: str, **kwargs) -> dict[str, Any] | None: + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: """ Retrieve the metadata and memory of a node. Args: @@ -510,18 +510,28 @@ def get_node(self, id: str, **kwargs) -> dict[str, Any] | None: Returns: Dictionary of node fields, or None if not found. """ - user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name + logger.info(f"[get_node] id: {id}") + user_name = kwargs.get("user_name") where_user = "" params = {"id": id} - if not self.config.use_multi_db and (self.config.user_name or user_name): + if user_name is not None: where_user = " AND n.user_name = $user_name" params["user_name"] = user_name query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n" + logger.info(f"[get_node] query: {query}") with self.driver.session(database=self.db_name) as session: record = session.run(query, params).single() - return self._parse_node(dict(record["n"])) if record else None + if not record: + return None + + node_dict = dict(record["n"]) + if include_embedding is False: + for key in ("embedding", "embedding_1024", "embedding_3072", "embedding_768"): + node_dict.pop(key, None) + + return self._parse_node(node_dict) def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]: """ diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index f2182f6cd..411dbffe5 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -1056,3 +1056,58 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}") new_node["metadata"]["embedding"] = None return new_node + + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: + """Get user names by memory ids. + + Args: + memory_ids: List of memory node IDs to query. + + Returns: + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} + """ + if not memory_ids: + return {} + + logger.info( + f"[ neo4j_community get_user_names_by_memory_ids] Querying memory_ids {memory_ids}" + ) + + try: + with self.driver.session(database=self.db_name) as session: + # Query to get memory_id and user_name pairs + query = """ + MATCH (n:Memory) + WHERE n.id IN $memory_ids + RETURN n.id AS memory_id, n.user_name AS user_name + """ + logger.info(f"[get_user_names_by_memory_ids] query: {query}") + + result = session.run(query, memory_ids=memory_ids) + result_dict = {} + + # Build result dictionary from query results + for record in result: + memory_id = record["memory_id"] + user_name = record["user_name"] + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in memory_ids: + if mid not in result_dict: + result_dict[mid] = None + + logger.info( + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" + ) + + return result_dict + except Exception as e: + logger.error( + f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index b9c8ca84b..cf81e5e28 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2534,6 +2534,7 @@ def export_graph( page: int | None = None, page_size: int | None = None, filter: dict | None = None, + memory_type: list[str] | None = None, **kwargs, ) -> dict[str, Any]: """ @@ -2551,6 +2552,8 @@ def export_graph( - "gt", "lt", "gte", "lte": comparison operators - "like": fuzzy matching Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]} + memory_type (list[str], optional): List of memory_type values to filter by. If provided, only nodes/edges with + memory_type in this list will be exported. Example: ["LongTermMemory", "WorkingMemory"] Returns: { @@ -2561,7 +2564,7 @@ def export_graph( } """ logger.info( - f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}" + f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}" ) user_id = user_id if user_id else self._get_config_value("user_id") @@ -2596,6 +2599,19 @@ def export_graph( f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" ) + # Add memory_type filter condition + if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: + # Escape memory_type values and build IN clause + memory_type_values = [] + for mt in memory_type: + # Escape single quotes in memory_type value + escaped_memory_type = str(mt).replace("'", "''") + memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") + memory_type_in_clause = ", ".join(memory_type_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" + ) + # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) logger.info(f"[export_graph] filter_conditions: {filter_conditions}") @@ -2691,6 +2707,15 @@ def export_graph( cypher_where_conditions.append(f"a.user_id = '{user_id}'") cypher_where_conditions.append(f"b.user_id = '{user_id}'") + # Add memory_type filter condition for edges (apply to both source and target nodes) + if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: + # Escape single quotes in memory_type values for Cypher + escaped_memory_types = [mt.replace("'", "\\'") for mt in memory_type] + memory_type_list_str = ", ".join([f"'{mt}'" for mt in escaped_memory_types]) + # Cypher IN syntax: a.memory_type IN ['LongTermMemory', 'WorkingMemory'] + cypher_where_conditions.append(f"a.memory_type IN [{memory_type_list_str}]") + cypher_where_conditions.append(f"b.memory_type IN [{memory_type_list_str}]") + # Build filter conditions for edges (apply to both source and target nodes) filter_where_clause = self._build_filter_conditions_cypher(filter) logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}") @@ -4310,7 +4335,7 @@ def _build_user_name_and_kb_ids_conditions_sql( user_name_conditions = [] effective_user_name = user_name if user_name else default_user_name - if effective_user_name: + if user_name: user_name_conditions.append( f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{effective_user_name}\"'::agtype" ) @@ -5441,3 +5466,205 @@ def escape_user_name(un: str) -> str: raise finally: self._return_connection(conn) + + @timed + def delete_node_by_mem_cube_id( + self, + mem_kube_id: dict | None = None, + delete_record_id: dict | None = None, + deleted_type: bool = False, + ) -> int: + # Handle dict type parameters (extract value if dict) + if isinstance(mem_kube_id, dict): + # Try to get a value from dict, use first value if multiple + mem_kube_id = next(iter(mem_kube_id.values())) if mem_kube_id else None + + if isinstance(delete_record_id, dict): + delete_record_id = next(iter(delete_record_id.values())) if delete_record_id else None + + # Validate required parameters + if not mem_kube_id: + logger.warning("[delete_node_by_mem_cube_id] mem_kube_id is required but not provided") + return 0 + + if not delete_record_id: + logger.warning( + "[delete_node_by_mem_cube_id] delete_record_id is required but not provided" + ) + return 0 + + # Convert to string if needed + mem_kube_id = str(mem_kube_id) if mem_kube_id else None + delete_record_id = str(delete_record_id) if delete_record_id else None + + logger.info( + f"[delete_node_by_mem_cube_id] mem_kube_id={mem_kube_id}, " + f"delete_record_id={delete_record_id}, deleted_type={deleted_type}" + ) + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Build WHERE clause for user_name using parameter binding + # user_name must match mem_kube_id + user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + + # Prepare parameter for user_name + user_name_param = self.format_param_value(mem_kube_id) + + if deleted_type: + # Hard delete: WHERE user_name = mem_kube_id AND delete_record_id = $delete_record_id + delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" + where_clause = f"{user_name_condition} AND {delete_record_id_condition}" + + # Prepare parameters for WHERE clause (user_name and delete_record_id) + where_params = [user_name_param, self.format_param_value(delete_record_id)] + + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_mem_cube_id] Hard delete query: {delete_query}") + + cursor.execute(delete_query, where_params) + deleted_count = cursor.rowcount + + logger.info(f"[delete_node_by_mem_cube_id] Hard deleted {deleted_count} nodes") + return deleted_count + else: + # Soft delete: WHERE user_name = mem_kube_id (only user_name condition) + where_clause = user_name_condition + + current_time = datetime.utcnow().isoformat() + # Build update properties JSON with status, delete_time, and delete_record_id + # Use PostgreSQL JSONB merge operator (||) to update properties + # Convert agtype to jsonb, merge with new values, then convert back to agtype + update_query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = ( + properties::jsonb || %s::jsonb + )::text::agtype + WHERE {where_clause} + """ + # Create update JSON with the three fields to update + update_properties = { + "status": "deleted", + "delete_time": current_time, + "delete_record_id": delete_record_id, + } + logger.info( + f"[delete_node_by_mem_cube_id] Soft delete update_query: {update_query}" + ) + logger.info( + f"[delete_node_by_mem_cube_id] update_properties: {update_properties}" + ) + + # Combine update_properties JSON with user_name parameter (only user_name, no delete_record_id) + update_params = [json.dumps(update_properties), user_name_param] + cursor.execute(update_query, update_params) + updated_count = cursor.rowcount + + logger.info( + f"[delete_node_by_mem_cube_id] Soft deleted (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn) + + @timed + def recover_memory_by_mem_kube_id( + self, + mem_kube_id: str | None = None, + delete_record_id: str | None = None, + ) -> int: + """ + Recover memory nodes by mem_kube_id (user_name) and delete_record_id. + + This function updates the status to 'activated', and clears delete_record_id and delete_time. + + Args: + mem_kube_id: The mem_kube_id which corresponds to user_name in the table. + delete_record_id: The delete_record_id to match. + + Returns: + int: Number of nodes recovered (updated). + """ + logger.info(f"recover_memory_by_mem_kube_id mem_kube_id:{mem_kube_id},delete_record_id:{delete_record_id}") + # Validate required parameters + if not mem_kube_id: + logger.warning( + "[recover_memory_by_mem_kube_id] mem_kube_id is required but not provided" + ) + return 0 + + if not delete_record_id: + logger.warning( + "[recover_memory_by_mem_kube_id] delete_record_id is required but not provided" + ) + return 0 + + logger.info( + f"[recover_memory_by_mem_kube_id] mem_kube_id={mem_kube_id}, " + f"delete_record_id={delete_record_id}" + ) + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Build WHERE clause for user_name and delete_record_id using parameter binding + user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" + where_clause = f"{user_name_condition} AND {delete_record_id_condition}" + + # Prepare parameters for WHERE clause + where_params = [ + self.format_param_value(mem_kube_id), + self.format_param_value(delete_record_id), + ] + + # Build update properties: status='activated', delete_record_id='', delete_time='' + # Use PostgreSQL JSONB merge operator (||) to update properties + update_properties = { + "status": "activated", + "delete_record_id": "", + "delete_time": "", + } + + update_query = f""" + UPDATE "{self.db_name}_graph"."Memory" + SET properties = ( + properties::jsonb || %s::jsonb + )::text::agtype + WHERE {where_clause} + """ + + logger.info(f"[recover_memory_by_mem_kube_id] Update query: {update_query}") + logger.info( + f"[recover_memory_by_mem_kube_id] update_properties: {update_properties}" + ) + + # Combine update_properties JSON with where_params + update_params = [json.dumps(update_properties), *where_params] + cursor.execute(update_query, update_params) + updated_count = cursor.rowcount + + logger.info( + f"[recover_memory_by_mem_kube_id] Recovered (updated) {updated_count} nodes" + ) + return updated_count + + except Exception as e: + logger.error( + f"[recover_memory_by_mem_kube_id] Failed to recover nodes: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn)