diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index ac03cda2e..6a0be0d32 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1703,32 +1703,19 @@ def search_by_fulltext( return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: - """ - Full-text search functionality using PostgreSQL's full-text search capabilities. - - Args: - query_text: query text - top_k: maximum number of results to return - scope: memory type filter (memory_type) - status: status filter, defaults to "activated" - threshold: similarity threshold filter - search_filter: additional property filter conditions - user_name: username filter - knowledgebase_ids: knowledgebase ids filter - filter: filter conditions with 'and' or 'or' logic for search results. - tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 - tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) - return_fields: additional node fields to include in results - **kwargs: other parameters (e.g. cube_name) - - Returns: - list[dict]: result list containing id and score. - If return_fields is specified, each dict also includes the requested fields. - """ + start_time = time.perf_counter() logger.info( - f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}" + " search_by_fulltext query_words=%s top_k=%s scope=%s status=%s threshold=%s search_filter=%s user_name=%s knowledgebase_ids=%s filter=%s", + query_words, + top_k, + scope, + status, + threshold, + search_filter, + user_name, + knowledgebase_ids, + filter, ) - start_time = time.time() where_clauses = [] if scope: @@ -1744,22 +1731,18 @@ def search_by_fulltext( "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" ) - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, default_user_name=self.config.user_name, ) - logger.info(f"[search_by_fulltext] user_name_conditions: {user_name_conditions}") - # Add OR condition if we have any user_name conditions if user_name_conditions: if len(user_name_conditions) == 1: where_clauses.append(user_name_conditions[0]) else: where_clauses.append(f"({' OR '.join(user_name_conditions)})") - # Add search_filter conditions if search_filter: for key, value in search_filter.items(): if isinstance(value, str): @@ -1772,17 +1755,12 @@ def search_by_fulltext( ) filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) tsquery_string = " | ".join(query_words) where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - logger.info(f"[search_by_fulltext] where_clause: {where_clause}") - select_cols = f"""ag_catalog.agtype_access_operator(m.properties, '"id"'::agtype) AS old_id, ts_rank(m.{tsvector_field}, q.fq) AS rank""" if return_fields: @@ -1808,7 +1786,8 @@ def search_by_fulltext( LIMIT {top_k}; """ params = [tsquery_string] - logger.info(f"[search_by_fulltext] query: {query}, params: {params}") + logger.info("search_by_fulltext query=%s params=%s", query, params) + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1829,8 +1808,8 @@ def search_by_fulltext( properties = row[2] # properties column item.update(self._extract_fields_from_properties(properties, return_fields)) output.append(item) - elapsed_time = time.time() - start_time - logger.info(f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s") + elapsed = (time.perf_counter() - start_time) * 1000 + logger.info("search_by_fulltext internal took %.1f ms", elapsed) return output[:top_k] @timed @@ -1849,9 +1828,18 @@ def search_by_embedding( **kwargs, ) -> list[dict]: logger.info( - f"search_by_embedding user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},scope:{scope},status:{status},search_filter:{search_filter},filter:{filter},knowledgebase_ids:{knowledgebase_ids},return_fields:{return_fields}" + "search_by_embedding user_name:%s,filter: %s, knowledgebase_ids: %s,scope:%s,status:%s,search_filter:%s,filter:%s,knowledgebase_ids:%s,return_fields:%s", + user_name, + filter, + knowledgebase_ids, + scope, + status, + search_filter, + filter, + knowledgebase_ids, + return_fields, ) - start_time = time.time() + start_time = time.perf_counter() where_clauses = [] if scope: where_clauses.append( @@ -1890,7 +1878,6 @@ def search_by_embedding( ) filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" @@ -1926,7 +1913,7 @@ def search_by_embedding( else: pass - logger.info(f"[search_by_embedding] query: {query}, params: {params}") + logger.info(" search_by_embedding query: %s", query) with self._get_connection() as conn, conn.cursor() as cursor: if params: @@ -1952,9 +1939,9 @@ def search_by_embedding( properties = row[1] # properties column item.update(self._extract_fields_from_properties(properties, return_fields)) output.append(item) - elapsed_time = time.time() - start_time + elapsed_time = time.perf_counter() - start_time logger.info( - f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" + "search_by_embedding query embedding completed time took %.1f ms", elapsed_time ) return output[:top_k] @@ -3169,27 +3156,15 @@ def add_nodes_batch( nodes: list[dict[str, Any]], user_name: str | None = None, ) -> None: - """ - Batch add multiple memory nodes to the graph. + logger.info(f" add_nodes_batch Processing only first node (total nodes: {len(nodes)})") - Args: - nodes: List of node dictionaries, each containing: - - id: str - Node ID - - memory: str - Memory content - - metadata: dict[str, Any] - Node metadata - user_name: Optional user name (will use config default if not provided) - """ - batch_start_time = time.time() + batch_start_time = time.perf_counter() if not nodes: logger.warning("[add_nodes_batch] Empty nodes list, skipping") return - logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") - - # user_name comes from parameter; fallback to config if missing effective_user_name = user_name if user_name else self.config.user_name - # Prepare all nodes prepared_nodes = [] for node_data in nodes: try: @@ -3199,16 +3174,13 @@ def add_nodes_batch( logger.debug(f"[add_nodes_batch] Processing node id: {id}") - # Set user_name in metadata metadata["user_name"] = effective_user_name metadata = _prepare_node_metadata(metadata) - # Merge node and set metadata created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) - # Prepare properties properties = { "id": id, "memory": memory, @@ -3219,32 +3191,26 @@ def add_nodes_batch( **metadata, } - # Generate embedding if not provided if "embedding" not in properties or not properties["embedding"]: properties["embedding"] = generate_vector( self._get_config_value("embedding_dimension", 1024) ) - # Serialization - JSON-serialize sources and usage fields for field_name in ["sources", "usage"]: if properties.get(field_name): if isinstance(properties[field_name], list): for idx in range(len(properties[field_name])): - # Serialize only when element is not a string if not isinstance(properties[field_name][idx], str): properties[field_name][idx] = json.dumps( properties[field_name][idx] ) elif isinstance(properties[field_name], str): - # If already a string, leave as-is pass - # Extract embedding for separate column embedding_vector = properties.pop("embedding", []) if not isinstance(embedding_vector, list): embedding_vector = [] - # Select column name based on embedding dimension embedding_column = "embedding" # default column if len(embedding_vector) == 3072: embedding_column = "embedding_3072" @@ -3267,14 +3233,12 @@ def add_nodes_batch( f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", exc_info=True, ) - # Continue with other nodes continue if not prepared_nodes: logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") return - # Group nodes by embedding column to optimize batch inserts nodes_by_embedding_column = {} for node in prepared_nodes: col = node["embedding_column"] @@ -3284,9 +3248,7 @@ def add_nodes_batch( try: with self._get_connection() as conn, conn.cursor() as cursor: - # Process each group separately for embedding_column, nodes_group in nodes_by_embedding_column.items(): - # Batch delete existing records using IN clause ids_to_delete = [node["id"] for node in nodes_group] if ids_to_delete: delete_query = f""" @@ -3297,7 +3259,6 @@ def add_nodes_batch( """ cursor.execute(delete_query, (ids_to_delete,)) - # Batch get graph_ids for all nodes get_graph_ids_query = f""" SELECT id_val, @@ -3307,21 +3268,16 @@ def add_nodes_batch( cursor.execute(get_graph_ids_query, (ids_to_delete,)) graph_id_map = {row[0]: row[1] for row in cursor.fetchall()} - # Add graph_id to properties for node in nodes_group: graph_id = graph_id_map.get(node["id"]) if graph_id: node["properties"]["graph_id"] = str(graph_id) - # Use PREPARE/EXECUTE for efficient batch insert - # Generate unique prepare statement name to avoid conflicts prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" - try: if embedding_column and any( node["embedding_vector"] for node in nodes_group ): - # PREPARE statement for insert with embedding prepare_query = f""" PREPARE {prepare_name} AS INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) @@ -3331,16 +3287,9 @@ def add_nodes_batch( $3::vector ) """ - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" - ) cursor.execute(prepare_query) - # Execute prepared statement for each node for node in nodes_group: properties_json = json.dumps(node["properties"]) embedding_json = ( @@ -3354,7 +3303,6 @@ def add_nodes_batch( (node["id"], properties_json, embedding_json), ) else: - # PREPARE statement for insert without embedding prepare_query = f""" PREPARE {prepare_name} AS INSERT INTO {self.db_name}_graph."Memory"(id, properties) @@ -3363,40 +3311,25 @@ def add_nodes_batch( $2::text::agtype ) """ - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" - ) cursor.execute(prepare_query) - # Execute prepared statement for each node for node in nodes_group: properties_json = json.dumps(node["properties"]) - cursor.execute( f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json), ) finally: - # DEALLOCATE prepared statement (always execute, even on error) try: cursor.execute(f"DEALLOCATE {prepare_name}") - logger.info( - f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" - ) except Exception as dealloc_error: logger.warning( f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" ) - - logger.info( - f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" - ) - elapsed_time = time.time() - batch_start_time + elapsed_time = time.perf_counter() - batch_start_time logger.info( - f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" + "add_nodes_batch batch insert completed successfully in took %.1f ms", + elapsed_time, ) except Exception as e: