From a2fa15c02c32fbf68ad5c8153238badfa6658a37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Wed, 25 Feb 2026 11:38:21 +0800 Subject: [PATCH 1/2] feat:optimize user_name && key_words --- src/memos/graph_dbs/polardb.py | 190 +++++++++------------------------ 1 file changed, 50 insertions(+), 140 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 5044564c3..824af8f54 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -204,21 +204,6 @@ def _get_connection_old(self): return conn def _get_connection(self): - """ - Get a connection from the pool. - - This function: - 1. Gets a connection from ThreadedConnectionPool - 2. Checks if connection is closed or unhealthy - 3. Returns healthy connection or retries (max 3 times) - 4. Handles connection pool exhaustion gracefully - - Returns: - psycopg2 connection object - - Raises: - RuntimeError: If connection pool is closed or exhausted after retries - """ logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") if self._pool_closed: raise RuntimeError("Connection pool has been closed") @@ -229,13 +214,9 @@ def _get_connection(self): for attempt in range(max_retries): conn = None try: - # Try to get connection from pool - # This may raise PoolError if pool is exhausted conn = self.connection_pool.getconn() - # Check if connection is closed if conn.closed != 0: - # Connection is closed, return it to pool with close flag and try again logger.warning( f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" ) @@ -295,19 +276,17 @@ def _get_connection(self): return conn except psycopg2.pool.PoolError as pool_error: - # Pool exhausted or other pool-related error - # Don't retry immediately for pool exhaustion - it's unlikely to resolve quickly error_msg = str(pool_error).lower() if "exhausted" in error_msg or "pool" in error_msg: # Log pool status for debugging try: # Try to get pool stats if available pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" - logger.error( + logger.warning( f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" ) except Exception: - logger.error( + logger.warning( f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" ) @@ -323,7 +302,6 @@ def _get_connection(self): raise RuntimeError( f"Connection pool exhausted after {max_retries} attempts. " f"This usually means connections are not being returned to the pool. " - f"Check for connection leaks in your code." ) from pool_error else: # Other pool errors - retry with normal backoff @@ -337,12 +315,8 @@ def _get_connection(self): ) from pool_error except Exception as e: - # Other exceptions (not pool-related) - # Only try to return connection if we actually got one - # If getconn() failed (e.g., pool exhausted), conn will be None if conn is not None: try: - # Return connection to pool if it's valid self.connection_pool.putconn(conn, close=True) except Exception as putconn_error: logger.warning( @@ -363,20 +337,7 @@ def _get_connection(self): raise RuntimeError("Failed to get connection after all retries") def _return_connection(self, connection): - """ - Return a connection to the pool. - - This function safely returns a connection to the pool, handling: - - Closed connections (close them instead of returning) - - Pool closed state (close connection directly) - - None connections (no-op) - - putconn() failures (close connection as fallback) - - Args: - connection: psycopg2 connection object or None - """ if self._pool_closed: - # Pool is closed, just close the connection if it exists if connection: try: connection.close() @@ -388,13 +349,10 @@ def _return_connection(self, connection): return if not connection: - # No connection to return - this is normal if _get_connection() failed return try: - # Check if connection is closed if hasattr(connection, "closed") and connection.closed != 0: - # Connection is closed, just close it explicitly and don't return to pool logger.debug( "[_return_connection] Connection is closed, closing it instead of returning to pool" ) @@ -404,12 +362,9 @@ def _return_connection(self, connection): logger.warning(f"[_return_connection] Failed to close closed connection: {e}") return - # Connection is valid, return to pool self.connection_pool.putconn(connection) logger.debug("[_return_connection] Successfully returned connection to pool") except Exception as e: - # If putconn fails, try to close the connection - # This prevents connection leaks if putconn() fails logger.error( f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True ) @@ -1116,9 +1071,7 @@ def get_node( self._return_connection(conn) @timed - def get_nodes( - self, ids: list[str], user_name: str | None = None, **kwargs - ) -> list[dict[str, Any]]: + def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, Any]]: """ Retrieve the metadata and memory of a list of nodes. Args: @@ -1973,7 +1926,6 @@ def search_by_fulltext( logger.info( f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}" ) - # Build WHERE clause dynamically, same as search_by_embedding start_time = time.time() where_clauses = [] @@ -2017,13 +1969,10 @@ def search_by_fulltext( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) - # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) - # Add fulltext search condition - # Convert query_text to OR query format: "word1 | word2 | word3" tsquery_string = " | ".join(query_words) where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") @@ -2032,23 +1981,31 @@ def search_by_fulltext( logger.info(f"[search_by_fulltext] where_clause: {where_clause}") - # Build fulltext search query - select_clause = f"""SELECT - ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text, - ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank""" + select_cols = f"""ag_catalog.agtype_access_operator(m.properties, '"id"'::agtype) AS old_id, + ts_rank(m.{tsvector_field}, q.fq) AS rank""" if return_fields: - select_clause += ", properties" - + select_cols += ", m.properties" + where_with_q = [] + for w in where_clauses: + if f"{tsvector_field} @@ to_tsquery(" in w: + where_with_q.append(f"m.{tsvector_field} @@ q.fq") + else: + where_with_q.append( + w.replace("(properties,", "(m.properties,") + .replace("(properties)", "(m.properties)") + .replace("ARRAY[properties,", "ARRAY[m.properties,") + ) + where_clause_cte = f"WHERE {' AND '.join(where_with_q)}" if where_with_q else "" query = f""" - {select_clause} - FROM "{self.db_name}_graph"."Memory" - {where_clause} + WITH q AS (SELECT to_tsquery('{tsquery_config}', %s) AS fq) + SELECT {select_cols} + FROM "{self.db_name}_graph"."Memory" m + CROSS JOIN q + {where_clause_cte} ORDER BY rank DESC LIMIT {top_k}; """ - - params = [tsquery_string, tsquery_string] + params = [tsquery_string] logger.info(f"[search_by_fulltext] query: {query}, params: {params}") conn = None try: @@ -2059,7 +2016,7 @@ def search_by_fulltext( output = [] for row in results: oldid = row[0] # old_id - rank = row[2] # rank score + rank = row[1] # rank score (no memory_text column) id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): @@ -2070,16 +2027,14 @@ def search_by_fulltext( if threshold is None or score_val >= threshold: item = {"id": id_val, "score": score_val} if return_fields: - properties = row[ - 3 - ] # properties column (after old_id, memory_text, rank) + properties = row[2] # properties column item.update( self._extract_fields_from_properties(properties, return_fields) ) output.append(item) elapsed_time = time.time() - start_time logger.info( - f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s" + f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s" ) return output[:top_k] finally: @@ -2089,30 +2044,23 @@ def search_by_fulltext( def search_by_embedding( self, vector: list[float], + user_name: str, top_k: int = 5, scope: str | None = None, status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, - user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, - return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ Retrieve node IDs based on vector similarity using PostgreSQL vector operations. - - Args: - return_fields (list[str], optional): Additional node fields to include in results - (e.g., ["memory", "status", "tags"]). When provided, each result dict will - contain these fields in addition to 'id' and 'score'. - Defaults to None (only 'id' and 'score' are returned). """ - # Build WHERE clause dynamically like nebular.py logger.info( - f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + f"search_by_embedding user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},scope:{scope},status:{status},search_filter:{search_filter},filter:{filter},knowledgebase_ids:{knowledgebase_ids}" ) + start_time = time.time() where_clauses = [] if scope: where_clauses.append( @@ -2127,31 +2075,18 @@ def search_by_embedding( "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" ) where_clauses.append("embedding is not null") - # Add user_name filter like nebular.py - - """ - # user_name = self._get_config_value("user_name") - # if not self.config.use_multi_db and user_name: - # if kwargs.get("cube_name"): - # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{kwargs['cube_name']}\"'::agtype") - # else: - # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype") - """ - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, default_user_name=self.config.user_name, ) - # Add OR condition if we have any user_name conditions if user_name_conditions: if len(user_name_conditions) == 1: where_clauses.append(user_name_conditions[0]) else: where_clauses.append(f"({' OR '.join(user_name_conditions)})") - # Add search_filter conditions like nebular.py if search_filter: for key, value in search_filter.items(): if isinstance(value, str): @@ -2163,14 +2098,12 @@ def search_by_embedding( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) - # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - # Keep original simple query structure but add dynamic WHERE clause query = f""" WITH t AS ( SELECT id, @@ -2187,19 +2120,12 @@ def search_by_embedding( FROM t WHERE scope > 0.1; """ - # Convert vector to string format for PostgreSQL vector type - # PostgreSQL vector type expects a string format like '[1,2,3]' vector_str = convert_to_vector(vector) - # Use string format directly in query instead of parameterized query - # Replace %s with the vector string, but need to quote it properly - # PostgreSQL vector type needs the string to be quoted query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)") params = [] - # Split query by lines and wrap long lines to prevent terminal truncation query_lines = query.strip().split("\n") for line in query_lines: - # Wrap lines longer than 200 characters to prevent terminal truncation if len(line) > 200: wrapped_lines = textwrap.wrap( line, width=200, break_long_words=False, break_on_hyphens=False @@ -2209,34 +2135,19 @@ def search_by_embedding( else: pass - logger.info(f"[search_by_embedding] query: {query}, params: {params}") + logger.info(f" search_by_embedding query: {query}, params: {params}") conn = None try: conn = self._get_connection() with conn.cursor() as cursor: - try: - # If params is empty, execute query directly without parameters - if params: - cursor.execute(query, params) - else: - cursor.execute(query) - except Exception as e: - logger.error(f"[search_by_embedding] Error executing query: {e}") - logger.error(f"[search_by_embedding] Query length: {len(query)}") - logger.error( - f"[search_by_embedding] Params type: {type(params)}, length: {len(params)}" - ) - logger.error(f"[search_by_embedding] Query contains %s: {'%s' in query}") - raise + if params: + cursor.execute(query, params) + else: + cursor.execute(query) results = cursor.fetchall() output = [] for row in results: - """ - polarId = row[0] # id - properties = row[1] # properties - # embedding = row[3] # embedding - """ if len(row) < 5: logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") continue @@ -2248,14 +2159,15 @@ def search_by_embedding( score_val = float(score) score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score if threshold is None or score_val >= threshold: - item = {"id": id_val, "score": score_val} - if return_fields: - properties = row[1] # properties column - item.update( - self._extract_fields_from_properties(properties, return_fields) - ) - output.append(item) + output.append({"id": id_val, "score": score_val}) + elapsed_time = time.time() - start_time + logger.info( + f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" + ) return output[:top_k] + except Exception as e: + logger.error(f"[search_by_embedding] Error executing query: {e}") + raise finally: self._return_connection(conn) @@ -2263,7 +2175,7 @@ def search_by_embedding( def get_by_metadata( self, filters: list[dict[str, Any]], - user_name: str | None = None, + user_name: str, filter: dict | None = None, knowledgebase_ids: list | None = None, user_name_flag: bool = True, @@ -2285,7 +2197,9 @@ def get_by_metadata( Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). """ - logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") + logger.info( + f" get_by_metadata user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},filters:{filters}" + ) user_name = user_name if user_name else self._get_config_value("user_name") @@ -2340,9 +2254,6 @@ def get_by_metadata( else: raise ValueError(f"Unsupported operator: {op}") - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( user_name=user_name, knowledgebase_ids=knowledgebase_ids, @@ -2612,8 +2523,8 @@ def clear(self, user_name: str | None = None) -> None: @timed def export_graph( self, + user_name: str, include_embedding: bool = False, - user_name: str | None = None, user_id: str | None = None, page: int | None = None, page_size: int | None = None, @@ -2652,7 +2563,7 @@ def export_graph( } """ logger.info( - f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}" + f" export_graph include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}" ) user_id = user_id if user_id else self._get_config_value("user_id") @@ -2984,8 +2895,8 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: def get_all_memory_items( self, scope: str, + user_name: str, include_embedding: bool = False, - user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list | None = None, status: str | None = None, @@ -3006,14 +2917,13 @@ def get_all_memory_items( list[dict]: Full list of memory items under this scope. """ logger.info( - f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}" + f"[get_all_memory_items] user_name: {user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status},scope:{scope}" ) user_name = user_name if user_name else self._get_config_value("user_name") if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( user_name=user_name, knowledgebase_ids=knowledgebase_ids, From fd2002304a6257fcdc30100c325ac42008c8c17f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Wed, 25 Feb 2026 11:45:55 +0800 Subject: [PATCH 2/2] feat:optimize user_name && key_words --- src/memos/graph_dbs/polardb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 824af8f54..1524cc5ba 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2293,7 +2293,7 @@ def get_by_metadata( results = cursor.fetchall() ids = [str(item[0]).strip('"') for item in results] except Exception as e: - logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}") finally: self._return_connection(conn)