Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/postgres_mcp/database_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ruff: noqa: B008
import asyncio
import logging
from typing import Any
from typing import List
Expand Down Expand Up @@ -38,12 +39,16 @@ class DatabaseService:
def __init__(self, database_url: str, current_access_mode: models.AccessMode):
self.database_url = database_url
self.current_access_mode = current_access_mode
self._connect_lock = asyncio.Lock()

db_connection: Optional[DbConnPool] = None

async def get_sql_driver(self) -> Union[SqlDriver, SafeSqlDriver]:
if not self.db_connection:
self.db_connection = await self.create_db_connection()
if not self.db_connection or not self.db_connection.is_valid:
async with self._connect_lock:
# Re-check after acquiring lock
if not self.db_connection or not self.db_connection.is_valid:
self.db_connection = await self.create_db_connection()

base_driver = SqlDriver(conn=self.db_connection)

Expand All @@ -63,6 +68,7 @@ async def create_db_connection(self) -> DbConnPool:
logger.info("Successfully connected to database and initialized connection pool")
return self.db_connection
except Exception as e:
self.db_connection = None
logger.warning(
f"Could not connect to database: {obfuscate_password(str(e))}",
)
Expand Down
91 changes: 50 additions & 41 deletions src/postgres_mcp/sql/sql_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SQL driver adapter for PostgreSQL connections."""

import asyncio
import logging
import re
from dataclasses import dataclass
Expand Down Expand Up @@ -67,51 +68,57 @@ def __init__(self, connection_url: Optional[str] = None):
self.pool: AsyncConnectionPool | None = None
self._is_valid = False
self._last_error = None
self._connect_lock = asyncio.Lock()

async def pool_connect(self, connection_url: Optional[str] = None) -> AsyncConnectionPool:
"""Initialize connection pool with retry logic."""
# If we already have a valid pool, return it
# Fast path: if pool is already valid, return without acquiring lock
if self.pool and self._is_valid:
return self.pool

url = connection_url or self.connection_url
self.connection_url = url
if not url:
self._is_valid = False
self._last_error = "Database connection URL not provided"
raise ValueError(self._last_error)
async with self._connect_lock:
# Re-check after acquiring lock (another coroutine may have connected)
if self.pool and self._is_valid:
return self.pool

# Close any existing pool before creating a new one
await self.close()

try:
# Configure connection pool with appropriate settings
self.pool = AsyncConnectionPool(
conninfo=url,
min_size=1,
max_size=5,
open=False, # Don't connect immediately, let's do it explicitly
)

# Open the pool explicitly
await self.pool.open()

# Test the connection pool by executing a simple query
async with self.pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")

self._is_valid = True
self._last_error = None
return self.pool
except Exception as e:
self._is_valid = False
self._last_error = str(e)
url = connection_url or self.connection_url
self.connection_url = url
if not url:
self._is_valid = False
self._last_error = "Database connection URL not provided"
raise ValueError(self._last_error)

# Clean up failed pool
# Close any existing pool before creating a new one
await self.close()

raise ValueError(f"Connection attempt failed: {obfuscate_password(str(e))}") from e
try:
# Configure connection pool with appropriate settings
self.pool = AsyncConnectionPool(
conninfo=url,
min_size=1,
max_size=5,
open=False, # Don't connect immediately, let's do it explicitly
)

# Open the pool explicitly
await self.pool.open()

# Test the connection pool by executing a simple query
async with self.pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")

self._is_valid = True
self._last_error = None
return self.pool
except Exception as e:
self._is_valid = False
self._last_error = str(e)

# Clean up failed pool
await self.close()

raise ValueError(f"Connection attempt failed: {obfuscate_password(str(e))}") from e

async def close(self) -> None:
"""Close the connection pool."""
Expand Down Expand Up @@ -212,11 +219,12 @@ async def execute_query(
# Direct connection approach
return await self._execute_with_connection(self.conn, query, params, force_readonly=force_readonly)
except Exception as e:
# Mark pool as invalid if there was a connection issue
if self.conn and self.is_pool:
self.conn._is_valid = False # type: ignore
self.conn._last_error = str(e) # type: ignore
elif self.conn and not self.is_pool:
# For direct (non-pool) connections, clear the reference so it
# will be re-established on the next call. Pool connections are
# managed by psycopg_pool which handles broken connections
# internally — invalidating the whole pool here would race with
# concurrent in-flight queries that still hold a reference to it.
if self.conn and not self.is_pool:
self.conn = None

raise e
Expand All @@ -240,7 +248,8 @@ async def _execute_with_connection(self, connection, query, params, force_readon
while cursor.nextset():
pass

if cursor.description is None: # No results (like DDL statements)
# No results (like DDL statements)
if cursor.description is None:
if not force_readonly:
await cursor.execute("COMMIT")
elif transaction_started:
Expand Down
21 changes: 12 additions & 9 deletions tests/unit/sql/test_sql_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,27 +327,30 @@ async def mock_pool_execute(*args, **kwargs):


@pytest.mark.asyncio
async def test_connection_error_marks_pool_invalid(mock_db_pool):
"""Test that connection errors mark the pool as invalid."""
async def test_connection_error_does_not_invalidate_pool(mock_db_pool):
"""Test that query errors do not mark the pool as invalid.

The pool manages its own connection health internally via psycopg_pool.
Blindly invalidating the pool on every error would race with concurrent
in-flight queries that still reference the same pool.
"""
db_pool, connection, cursor = mock_db_pool

# Configure pool_connect to raise an exception
db_pool.pool_connect.side_effect = Exception("Connection failed")

# Mark pool as valid before the query
db_pool._is_valid = True

# Create SqlDriver with the mocked pool
driver = SqlDriver(conn=db_pool)

# Execute a query that will fail due to connection error
with pytest.raises(Exception):
await driver.execute_query("SELECT * FROM test")

# Make pool invalid manually (since we're bypassing the actual method)
db_pool._is_valid = False
db_pool._last_error = "Connection failed"

# Verify pool was marked as invalid
assert db_pool._is_valid is False
assert isinstance(db_pool._last_error, str)
# Pool should remain valid — the pool itself handles connection health
assert db_pool._is_valid is True


@pytest.mark.asyncio
Expand Down