diff --git a/src/postgres_mcp/database_service.py b/src/postgres_mcp/database_service.py index 9e0bb31..6919e89 100644 --- a/src/postgres_mcp/database_service.py +++ b/src/postgres_mcp/database_service.py @@ -1,4 +1,5 @@ # ruff: noqa: B008 +import asyncio import logging from typing import Any from typing import List @@ -38,12 +39,16 @@ class DatabaseService: def __init__(self, database_url: str, current_access_mode: models.AccessMode): self.database_url = database_url self.current_access_mode = current_access_mode + self._connect_lock = asyncio.Lock() db_connection: Optional[DbConnPool] = None async def get_sql_driver(self) -> Union[SqlDriver, SafeSqlDriver]: - if not self.db_connection: - self.db_connection = await self.create_db_connection() + if not self.db_connection or not self.db_connection.is_valid: + async with self._connect_lock: + # Re-check after acquiring lock + if not self.db_connection or not self.db_connection.is_valid: + self.db_connection = await self.create_db_connection() base_driver = SqlDriver(conn=self.db_connection) @@ -63,6 +68,7 @@ async def create_db_connection(self) -> DbConnPool: logger.info("Successfully connected to database and initialized connection pool") return self.db_connection except Exception as e: + self.db_connection = None logger.warning( f"Could not connect to database: {obfuscate_password(str(e))}", ) diff --git a/src/postgres_mcp/sql/sql_driver.py b/src/postgres_mcp/sql/sql_driver.py index 5beacb0..f4017e8 100644 --- a/src/postgres_mcp/sql/sql_driver.py +++ b/src/postgres_mcp/sql/sql_driver.py @@ -1,5 +1,6 @@ """SQL driver adapter for PostgreSQL connections.""" +import asyncio import logging import re from dataclasses import dataclass @@ -67,51 +68,57 @@ def __init__(self, connection_url: Optional[str] = None): self.pool: AsyncConnectionPool | None = None self._is_valid = False self._last_error = None + self._connect_lock = asyncio.Lock() async def pool_connect(self, connection_url: Optional[str] = None) -> AsyncConnectionPool: """Initialize connection pool with retry logic.""" - # If we already have a valid pool, return it + # Fast path: if pool is already valid, return without acquiring lock if self.pool and self._is_valid: return self.pool - url = connection_url or self.connection_url - self.connection_url = url - if not url: - self._is_valid = False - self._last_error = "Database connection URL not provided" - raise ValueError(self._last_error) + async with self._connect_lock: + # Re-check after acquiring lock (another coroutine may have connected) + if self.pool and self._is_valid: + return self.pool - # Close any existing pool before creating a new one - await self.close() - - try: - # Configure connection pool with appropriate settings - self.pool = AsyncConnectionPool( - conninfo=url, - min_size=1, - max_size=5, - open=False, # Don't connect immediately, let's do it explicitly - ) - - # Open the pool explicitly - await self.pool.open() - - # Test the connection pool by executing a simple query - async with self.pool.connection() as conn: - async with conn.cursor() as cursor: - await cursor.execute("SELECT 1") - - self._is_valid = True - self._last_error = None - return self.pool - except Exception as e: - self._is_valid = False - self._last_error = str(e) + url = connection_url or self.connection_url + self.connection_url = url + if not url: + self._is_valid = False + self._last_error = "Database connection URL not provided" + raise ValueError(self._last_error) - # Clean up failed pool + # Close any existing pool before creating a new one await self.close() - raise ValueError(f"Connection attempt failed: {obfuscate_password(str(e))}") from e + try: + # Configure connection pool with appropriate settings + self.pool = AsyncConnectionPool( + conninfo=url, + min_size=1, + max_size=5, + open=False, # Don't connect immediately, let's do it explicitly + ) + + # Open the pool explicitly + await self.pool.open() + + # Test the connection pool by executing a simple query + async with self.pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute("SELECT 1") + + self._is_valid = True + self._last_error = None + return self.pool + except Exception as e: + self._is_valid = False + self._last_error = str(e) + + # Clean up failed pool + await self.close() + + raise ValueError(f"Connection attempt failed: {obfuscate_password(str(e))}") from e async def close(self) -> None: """Close the connection pool.""" @@ -212,11 +219,12 @@ async def execute_query( # Direct connection approach return await self._execute_with_connection(self.conn, query, params, force_readonly=force_readonly) except Exception as e: - # Mark pool as invalid if there was a connection issue - if self.conn and self.is_pool: - self.conn._is_valid = False # type: ignore - self.conn._last_error = str(e) # type: ignore - elif self.conn and not self.is_pool: + # For direct (non-pool) connections, clear the reference so it + # will be re-established on the next call. Pool connections are + # managed by psycopg_pool which handles broken connections + # internally — invalidating the whole pool here would race with + # concurrent in-flight queries that still hold a reference to it. + if self.conn and not self.is_pool: self.conn = None raise e @@ -240,7 +248,8 @@ async def _execute_with_connection(self, connection, query, params, force_readon while cursor.nextset(): pass - if cursor.description is None: # No results (like DDL statements) + # No results (like DDL statements) + if cursor.description is None: if not force_readonly: await cursor.execute("COMMIT") elif transaction_started: diff --git a/tests/unit/sql/test_sql_driver.py b/tests/unit/sql/test_sql_driver.py index 4033537..c89d485 100644 --- a/tests/unit/sql/test_sql_driver.py +++ b/tests/unit/sql/test_sql_driver.py @@ -327,13 +327,21 @@ async def mock_pool_execute(*args, **kwargs): @pytest.mark.asyncio -async def test_connection_error_marks_pool_invalid(mock_db_pool): - """Test that connection errors mark the pool as invalid.""" +async def test_connection_error_does_not_invalidate_pool(mock_db_pool): + """Test that query errors do not mark the pool as invalid. + + The pool manages its own connection health internally via psycopg_pool. + Blindly invalidating the pool on every error would race with concurrent + in-flight queries that still reference the same pool. + """ db_pool, connection, cursor = mock_db_pool # Configure pool_connect to raise an exception db_pool.pool_connect.side_effect = Exception("Connection failed") + # Mark pool as valid before the query + db_pool._is_valid = True + # Create SqlDriver with the mocked pool driver = SqlDriver(conn=db_pool) @@ -341,13 +349,8 @@ async def test_connection_error_marks_pool_invalid(mock_db_pool): with pytest.raises(Exception): await driver.execute_query("SELECT * FROM test") - # Make pool invalid manually (since we're bypassing the actual method) - db_pool._is_valid = False - db_pool._last_error = "Connection failed" - - # Verify pool was marked as invalid - assert db_pool._is_valid is False - assert isinstance(db_pool._last_error, str) + # Pool should remain valid — the pool itself handles connection health + assert db_pool._is_valid is True @pytest.mark.asyncio