Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/postgres_mcp/database_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@


class DatabaseService:
def __init__(self, database_url: str, current_access_mode: models.AccessMode):
def __init__(self, database_url: str, current_access_mode: models.AccessMode, query_timeout: float | None = None):
self.database_url = database_url
self.current_access_mode = current_access_mode
self.query_timeout = query_timeout
self._connect_lock = asyncio.Lock()

db_connection: Optional[DbConnPool] = None
Expand All @@ -54,15 +55,19 @@ async def get_sql_driver(self) -> Union[SqlDriver, SafeSqlDriver]:

if self.current_access_mode == models.AccessMode.RESTRICTED:
logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)")
# 30 second timeout
return SafeSqlDriver(sql_driver=base_driver, timeout=30)
return SafeSqlDriver(sql_driver=base_driver)
else:
logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)")
return base_driver

async def create_db_connection(self) -> DbConnPool:
logger.info(f"Creating new database connection pool for URL: {obfuscate_password(self.database_url)}")
self.db_connection = DbConnPool(connection_url=self.database_url)
logger.info(
f"Creating new database connection pool for URL: {obfuscate_password(self.database_url)}, statement_timeout={self.query_timeout}s"
)
self.db_connection = DbConnPool(
connection_url=self.database_url,
statement_timeout_seconds=self.query_timeout,
)
try:
await self.db_connection.pool_connect(self.database_url)
logger.info("Successfully connected to database and initialized connection pool")
Expand Down
7 changes: 6 additions & 1 deletion src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@
database_port: Optional[str] = None
database_username: Optional[str] = None
database_password: Optional[str] = None
query_timeout: Optional[float] = None
shutdown_in_progress = False


async def get_service(database_name: str) -> DatabaseService:
if database_name not in db_services:
database_url = create_database_url(database_name)
db_services[database_name] = DatabaseService(database_url, current_access_mode)
db_services[database_name] = DatabaseService(database_url, current_access_mode, query_timeout)
return db_services[database_name]


Expand Down Expand Up @@ -245,11 +246,15 @@ async def main():
global database_port
global database_username
global database_password
global query_timeout

current_access_mode = AccessMode(args.access_mode)
database_host = os.environ.get("DATABASE_HOST", args.database_host)
database_port = os.environ.get("DATABASE_PORT", args.database_port)

raw_query_timeout = os.environ.get("QUERY_TIMEOUT")
query_timeout = float(raw_query_timeout) if raw_query_timeout is not None else None

if args.database_creds_file:
database_username, database_password = read_database_creds(args.database_creds_file)
else:
Expand Down
10 changes: 9 additions & 1 deletion src/postgres_mcp/sql/sql_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def obfuscate_password(text: str | None) -> str | None:
class DbConnPool:
"""Database connection manager using psycopg's connection pool."""

def __init__(self, connection_url: Optional[str] = None):
def __init__(self, connection_url: Optional[str] = None, statement_timeout_seconds: float | None = None):
self.connection_url = connection_url
self.statement_timeout_seconds = statement_timeout_seconds
self.pool: AsyncConnectionPool | None = None
self._is_valid = False
self._last_error = None
Expand Down Expand Up @@ -92,12 +93,19 @@ async def pool_connect(self, connection_url: Optional[str] = None) -> AsyncConne
await self.close()

try:
# Build connection kwargs with optional statement_timeout
connect_kwargs: dict[str, Any] = {}
if self.statement_timeout_seconds is not None:
timeout_ms = int(self.statement_timeout_seconds * 1000)
connect_kwargs["options"] = f"-c statement_timeout={timeout_ms}"

# Configure connection pool with appropriate settings
self.pool = AsyncConnectionPool(
conninfo=url,
min_size=1,
max_size=5,
open=False, # Don't connect immediately, let's do it explicitly
kwargs=connect_kwargs if connect_kwargs else None,
)

# Open the pool explicitly
Expand Down
66 changes: 55 additions & 11 deletions tests/unit/test_access_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,75 @@ async def test_get_sql_driver_returns_correct_driver(access_mode, expected_drive
driver = await service.get_sql_driver()
assert isinstance(driver, expected_driver_type)

# When in RESTRICTED mode, verify timeout is set
if access_mode == AccessMode.RESTRICTED:
assert isinstance(driver, SafeSqlDriver)
assert driver.timeout == 30


@pytest.mark.asyncio
async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection):
"""Test that get_sql_driver sets the timeout in restricted mode."""
async def test_get_sql_driver_sets_restricted_mode(mock_db_connection):
"""Test that get_sql_driver wraps with SafeSqlDriver in restricted mode."""
service = DatabaseService(database_url="postgresql://user:pass@localhost/test", current_access_mode=AccessMode.RESTRICTED)
with patch.object(service, "db_connection", mock_db_connection):
driver = await service.get_sql_driver()
assert isinstance(driver, SafeSqlDriver)
assert driver.timeout == 30
assert hasattr(driver, "sql_driver")


@pytest.mark.asyncio
async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection):
"""Test that get_sql_driver in unrestricted mode is a regular SqlDriver."""
async def test_get_sql_driver_in_unrestricted_mode(mock_db_connection):
"""Test that get_sql_driver in unrestricted mode returns a plain SqlDriver."""
service = DatabaseService(database_url="postgresql://user:pass@localhost/test", current_access_mode=AccessMode.UNRESTRICTED)
with patch.object(service, "db_connection", mock_db_connection):
driver = await service.get_sql_driver()
assert isinstance(driver, SqlDriver)
assert not hasattr(driver, "timeout")
assert not isinstance(driver, SafeSqlDriver)


@pytest.mark.asyncio
async def test_create_db_connection_no_timeout_restricted_mode():
"""Test that restricted mode has no statement_timeout by default."""
service = DatabaseService(
database_url="postgresql://user:pass@localhost/test",
current_access_mode=AccessMode.RESTRICTED,
)
with patch.object(DbConnPool, "pool_connect", new_callable=AsyncMock):
pool = await service.create_db_connection()
assert pool.statement_timeout_seconds is None


@pytest.mark.asyncio
async def test_create_db_connection_no_timeout_unrestricted_mode():
"""Test that unrestricted mode has no statement_timeout by default."""
service = DatabaseService(
database_url="postgresql://user:pass@localhost/test",
current_access_mode=AccessMode.UNRESTRICTED,
)
with patch.object(DbConnPool, "pool_connect", new_callable=AsyncMock):
pool = await service.create_db_connection()
assert pool.statement_timeout_seconds is None


@pytest.mark.asyncio
async def test_create_db_connection_custom_timeout_restricted_mode():
"""Test that query_timeout overrides the default 30s in restricted mode."""
service = DatabaseService(
database_url="postgresql://user:pass@localhost/test",
current_access_mode=AccessMode.RESTRICTED,
query_timeout=120,
)
with patch.object(DbConnPool, "pool_connect", new_callable=AsyncMock):
pool = await service.create_db_connection()
assert pool.statement_timeout_seconds == 120


@pytest.mark.asyncio
async def test_create_db_connection_custom_timeout_unrestricted_mode():
"""Test that query_timeout is applied in unrestricted mode without restricting SQL."""
service = DatabaseService(
database_url="postgresql://user:pass@localhost/test",
current_access_mode=AccessMode.UNRESTRICTED,
query_timeout=60,
)
with patch.object(DbConnPool, "pool_connect", new_callable=AsyncMock):
pool = await service.create_db_connection()
assert pool.statement_timeout_seconds == 60


@pytest.mark.asyncio
Expand Down