Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,33 +247,39 @@ def _warm_up_connections_by_all(self):

@contextmanager
def _get_connection(self):
import psycopg2

timeout = self._connection_wait_timeout
if timeout <= 0:
self._semaphore.acquire()
else:
if not self._semaphore.acquire(timeout=timeout):
logger.warning(f"Timeout waiting for connection slot ({timeout}s)")
raise RuntimeError(
f"Connection pool busy: acquire a slot within {timeout}s (all connections in use)."
)
if not self._semaphore.acquire(timeout=max(timeout, 0)):
logger.warning(f"Timeout waiting for connection slot ({timeout}s)")
raise RuntimeError("Connection pool busy")
logger.info(
"Connection pool usage: %s/%s",
self.connection_pool.maxconn - self._semaphore._value,
self.connection_pool.maxconn,
)
conn = None
broken = False

try:
conn = self.connection_pool.getconn()
logger.debug(f"Acquired connection {id(conn)} from pool")
conn.autocommit = True
for attempt in range(2):
try:
with conn.cursor() as cur:
cur.execute("SELECT 1")
break
except psycopg2.Error:
logger.warning("Dead connection detected, recreating (attempt %d)", attempt + 1)
self.connection_pool.putconn(conn, close=True)
conn = self.connection_pool.getconn()
conn.autocommit = True
else:
raise RuntimeError("Cannot obtain valid DB connection after 2 attempts")
with conn.cursor() as cur:
cur.execute(f'SET search_path = {self.db_name}_graph, ag_catalog, "$user", public;')
yield conn
except Exception as e:
except Exception:
broken = True
logger.exception(f"Connection failed or broken: {e}")
raise
finally:
if conn:
Expand Down Expand Up @@ -1814,7 +1820,7 @@ def search_by_fulltext(
properties = row[2] # properties column
item.update(self._extract_fields_from_properties(properties, return_fields))
output.append(item)
elapsed = (time.perf_counter() - start_time) * 1000
elapsed = (time.perf_counter() - start_time) * 1000.0
logger.info("search_by_fulltext internal took %.1f ms", elapsed)
return output[:top_k]

Expand Down Expand Up @@ -1945,7 +1951,7 @@ def search_by_embedding(
properties = row[1] # properties column
item.update(self._extract_fields_from_properties(properties, return_fields))
output.append(item)
elapsed_time = time.perf_counter() - start_time
elapsed_time = (time.perf_counter() - start_time) * 1000.0
logger.info(
"search_by_embedding query embedding completed time took %.1f ms", elapsed_time
)
Expand Down Expand Up @@ -2500,7 +2506,7 @@ def _extract_special_filter_values(filter_obj):
except Exception as e:
logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True)
raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
elapsed = (time.perf_counter() - start_time) * 1000
elapsed = (time.perf_counter() - start_time) * 1000.0
logger.info("export internal took %.1f ms", elapsed)

edges = []
Expand Down Expand Up @@ -3360,7 +3366,7 @@ def add_nodes_batch(
logger.warning(
f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}"
)
elapsed_time = time.perf_counter() - batch_start_time
elapsed_time = (time.perf_counter() - batch_start_time) * 1000.0
logger.info(
"add_nodes_batch batch insert completed successfully in took %.1f ms",
elapsed_time,
Expand Down Expand Up @@ -4815,7 +4821,7 @@ def delete_node_by_prams(
"""
batch_start_time = time.time()
logger.info(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
f" delete_node_by_prams memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)

# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
Expand Down Expand Up @@ -4889,17 +4895,17 @@ def delete_node_by_prams(
DELETE FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
logger.info(f" delete_node_by_prams delete_query: {delete_query}")

cursor.execute(delete_query)
deleted_count = cursor.rowcount
total_deleted_count = deleted_count

logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes")

elapsed_time = time.time() - batch_start_time
elapsed_time = (time.time() - batch_start_time) * 1000.0
logger.info(
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes"
f"delete_node_by_prams Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes"
)
except Exception as e:
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
Expand Down
Loading