diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index a87c83510..d740ad1d2 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -247,15 +247,12 @@ def _warm_up_connections_by_all(self): @contextmanager def _get_connection(self): + import psycopg2 + timeout = self._connection_wait_timeout - if timeout <= 0: - self._semaphore.acquire() - else: - if not self._semaphore.acquire(timeout=timeout): - logger.warning(f"Timeout waiting for connection slot ({timeout}s)") - raise RuntimeError( - f"Connection pool busy: acquire a slot within {timeout}s (all connections in use)." - ) + if not self._semaphore.acquire(timeout=max(timeout, 0)): + logger.warning(f"Timeout waiting for connection slot ({timeout}s)") + raise RuntimeError("Connection pool busy") logger.info( "Connection pool usage: %s/%s", self.connection_pool.maxconn - self._semaphore._value, @@ -263,17 +260,26 @@ def _get_connection(self): ) conn = None broken = False - try: conn = self.connection_pool.getconn() - logger.debug(f"Acquired connection {id(conn)} from pool") conn.autocommit = True + for attempt in range(2): + try: + with conn.cursor() as cur: + cur.execute("SELECT 1") + break + except psycopg2.Error: + logger.warning("Dead connection detected, recreating (attempt %d)", attempt + 1) + self.connection_pool.putconn(conn, close=True) + conn = self.connection_pool.getconn() + conn.autocommit = True + else: + raise RuntimeError("Cannot obtain valid DB connection after 2 attempts") with conn.cursor() as cur: cur.execute(f'SET search_path = {self.db_name}_graph, ag_catalog, "$user", public;') yield conn - except Exception as e: + except Exception: broken = True - logger.exception(f"Connection failed or broken: {e}") raise finally: if conn: @@ -1814,7 +1820,7 @@ def search_by_fulltext( properties = row[2] # properties column item.update(self._extract_fields_from_properties(properties, return_fields)) output.append(item) - elapsed = (time.perf_counter() - start_time) * 1000 + elapsed = (time.perf_counter() - start_time) * 1000.0 logger.info("search_by_fulltext internal took %.1f ms", elapsed) return output[:top_k] @@ -1945,7 +1951,7 @@ def search_by_embedding( properties = row[1] # properties column item.update(self._extract_fields_from_properties(properties, return_fields)) output.append(item) - elapsed_time = time.perf_counter() - start_time + elapsed_time = (time.perf_counter() - start_time) * 1000.0 logger.info( "search_by_embedding query embedding completed time took %.1f ms", elapsed_time ) @@ -2500,7 +2506,7 @@ def _extract_special_filter_values(filter_obj): except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e - elapsed = (time.perf_counter() - start_time) * 1000 + elapsed = (time.perf_counter() - start_time) * 1000.0 logger.info("export internal took %.1f ms", elapsed) edges = [] @@ -3360,7 +3366,7 @@ def add_nodes_batch( logger.warning( f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" ) - elapsed_time = time.perf_counter() - batch_start_time + elapsed_time = (time.perf_counter() - batch_start_time) * 1000.0 logger.info( "add_nodes_batch batch insert completed successfully in took %.1f ms", elapsed_time, @@ -4815,7 +4821,7 @@ def delete_node_by_prams( """ batch_start_time = time.time() logger.info( - f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + f" delete_node_by_prams memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" ) # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) @@ -4889,7 +4895,7 @@ def delete_node_by_prams( DELETE FROM "{self.db_name}_graph"."Memory" WHERE {where_clause} """ - logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + logger.info(f" delete_node_by_prams delete_query: {delete_query}") cursor.execute(delete_query) deleted_count = cursor.rowcount @@ -4897,9 +4903,9 @@ def delete_node_by_prams( logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") - elapsed_time = time.time() - batch_start_time + elapsed_time = (time.time() - batch_start_time) * 1000.0 logger.info( - f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" + f"delete_node_by_prams Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" ) except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)