From 5cacd74d592a312d6c63ddbd14e0917aae457e7a Mon Sep 17 00:00:00 2001 From: David Gilady Date: Tue, 24 Feb 2026 11:11:57 +0000 Subject: [PATCH] feat: add QUERY_TIMEOUT env var for server-side statement_timeout - Read QUERY_TIMEOUT environment variable (seconds) in server.py - Pass query_timeout through DatabaseService to DbConnPool - DbConnPool sets PostgreSQL statement_timeout via connection options - Timeout is enforced server-side, independent of access mode - Add tests for timeout configuration in both access modes --- src/postgres_mcp/database_service.py | 15 ++++--- src/postgres_mcp/server.py | 7 ++- src/postgres_mcp/sql/sql_driver.py | 10 ++++- tests/unit/test_access_mode.py | 66 +++++++++++++++++++++++----- 4 files changed, 80 insertions(+), 18 deletions(-) diff --git a/src/postgres_mcp/database_service.py b/src/postgres_mcp/database_service.py index 6919e89..524d19c 100644 --- a/src/postgres_mcp/database_service.py +++ b/src/postgres_mcp/database_service.py @@ -36,9 +36,10 @@ class DatabaseService: - def __init__(self, database_url: str, current_access_mode: models.AccessMode): + def __init__(self, database_url: str, current_access_mode: models.AccessMode, query_timeout: float | None = None): self.database_url = database_url self.current_access_mode = current_access_mode + self.query_timeout = query_timeout self._connect_lock = asyncio.Lock() db_connection: Optional[DbConnPool] = None @@ -54,15 +55,19 @@ async def get_sql_driver(self) -> Union[SqlDriver, SafeSqlDriver]: if self.current_access_mode == models.AccessMode.RESTRICTED: logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") - # 30 second timeout - return SafeSqlDriver(sql_driver=base_driver, timeout=30) + return SafeSqlDriver(sql_driver=base_driver) else: logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") return base_driver async def create_db_connection(self) -> DbConnPool: - logger.info(f"Creating new database connection pool for URL: {obfuscate_password(self.database_url)}") - self.db_connection = DbConnPool(connection_url=self.database_url) + logger.info( + f"Creating new database connection pool for URL: {obfuscate_password(self.database_url)}, statement_timeout={self.query_timeout}s" + ) + self.db_connection = DbConnPool( + connection_url=self.database_url, + statement_timeout_seconds=self.query_timeout, + ) try: await self.db_connection.pool_connect(self.database_url) logger.info("Successfully connected to database and initialized connection pool") diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 14612d4..086ae0b 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -41,13 +41,14 @@ database_port: Optional[str] = None database_username: Optional[str] = None database_password: Optional[str] = None +query_timeout: Optional[float] = None shutdown_in_progress = False async def get_service(database_name: str) -> DatabaseService: if database_name not in db_services: database_url = create_database_url(database_name) - db_services[database_name] = DatabaseService(database_url, current_access_mode) + db_services[database_name] = DatabaseService(database_url, current_access_mode, query_timeout) return db_services[database_name] @@ -245,11 +246,15 @@ async def main(): global database_port global database_username global database_password + global query_timeout current_access_mode = AccessMode(args.access_mode) database_host = os.environ.get("DATABASE_HOST", args.database_host) database_port = os.environ.get("DATABASE_PORT", args.database_port) + raw_query_timeout = os.environ.get("QUERY_TIMEOUT") + query_timeout = float(raw_query_timeout) if raw_query_timeout is not None else None + if args.database_creds_file: database_username, database_password = read_database_creds(args.database_creds_file) else: diff --git a/src/postgres_mcp/sql/sql_driver.py b/src/postgres_mcp/sql/sql_driver.py index f4017e8..c014c92 100644 --- a/src/postgres_mcp/sql/sql_driver.py +++ b/src/postgres_mcp/sql/sql_driver.py @@ -63,8 +63,9 @@ def obfuscate_password(text: str | None) -> str | None: class DbConnPool: """Database connection manager using psycopg's connection pool.""" - def __init__(self, connection_url: Optional[str] = None): + def __init__(self, connection_url: Optional[str] = None, statement_timeout_seconds: float | None = None): self.connection_url = connection_url + self.statement_timeout_seconds = statement_timeout_seconds self.pool: AsyncConnectionPool | None = None self._is_valid = False self._last_error = None @@ -92,12 +93,19 @@ async def pool_connect(self, connection_url: Optional[str] = None) -> AsyncConne await self.close() try: + # Build connection kwargs with optional statement_timeout + connect_kwargs: dict[str, Any] = {} + if self.statement_timeout_seconds is not None: + timeout_ms = int(self.statement_timeout_seconds * 1000) + connect_kwargs["options"] = f"-c statement_timeout={timeout_ms}" + # Configure connection pool with appropriate settings self.pool = AsyncConnectionPool( conninfo=url, min_size=1, max_size=5, open=False, # Don't connect immediately, let's do it explicitly + kwargs=connect_kwargs if connect_kwargs else None, ) # Open the pool explicitly diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index 4955247..e835380 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -35,31 +35,75 @@ async def test_get_sql_driver_returns_correct_driver(access_mode, expected_drive driver = await service.get_sql_driver() assert isinstance(driver, expected_driver_type) - # When in RESTRICTED mode, verify timeout is set - if access_mode == AccessMode.RESTRICTED: - assert isinstance(driver, SafeSqlDriver) - assert driver.timeout == 30 - @pytest.mark.asyncio -async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): - """Test that get_sql_driver sets the timeout in restricted mode.""" +async def test_get_sql_driver_sets_restricted_mode(mock_db_connection): + """Test that get_sql_driver wraps with SafeSqlDriver in restricted mode.""" service = DatabaseService(database_url="postgresql://user:pass@localhost/test", current_access_mode=AccessMode.RESTRICTED) with patch.object(service, "db_connection", mock_db_connection): driver = await service.get_sql_driver() assert isinstance(driver, SafeSqlDriver) - assert driver.timeout == 30 assert hasattr(driver, "sql_driver") @pytest.mark.asyncio -async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection): - """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" +async def test_get_sql_driver_in_unrestricted_mode(mock_db_connection): + """Test that get_sql_driver in unrestricted mode returns a plain SqlDriver.""" service = DatabaseService(database_url="postgresql://user:pass@localhost/test", current_access_mode=AccessMode.UNRESTRICTED) with patch.object(service, "db_connection", mock_db_connection): driver = await service.get_sql_driver() assert isinstance(driver, SqlDriver) - assert not hasattr(driver, "timeout") + assert not isinstance(driver, SafeSqlDriver) + + +@pytest.mark.asyncio +async def test_create_db_connection_no_timeout_restricted_mode(): + """Test that restricted mode has no statement_timeout by default.""" + service = DatabaseService( + database_url="postgresql://user:pass@localhost/test", + current_access_mode=AccessMode.RESTRICTED, + ) + with patch.object(DbConnPool, "pool_connect", new_callable=AsyncMock): + pool = await service.create_db_connection() + assert pool.statement_timeout_seconds is None + + +@pytest.mark.asyncio +async def test_create_db_connection_no_timeout_unrestricted_mode(): + """Test that unrestricted mode has no statement_timeout by default.""" + service = DatabaseService( + database_url="postgresql://user:pass@localhost/test", + current_access_mode=AccessMode.UNRESTRICTED, + ) + with patch.object(DbConnPool, "pool_connect", new_callable=AsyncMock): + pool = await service.create_db_connection() + assert pool.statement_timeout_seconds is None + + +@pytest.mark.asyncio +async def test_create_db_connection_custom_timeout_restricted_mode(): + """Test that query_timeout overrides the default 30s in restricted mode.""" + service = DatabaseService( + database_url="postgresql://user:pass@localhost/test", + current_access_mode=AccessMode.RESTRICTED, + query_timeout=120, + ) + with patch.object(DbConnPool, "pool_connect", new_callable=AsyncMock): + pool = await service.create_db_connection() + assert pool.statement_timeout_seconds == 120 + + +@pytest.mark.asyncio +async def test_create_db_connection_custom_timeout_unrestricted_mode(): + """Test that query_timeout is applied in unrestricted mode without restricting SQL.""" + service = DatabaseService( + database_url="postgresql://user:pass@localhost/test", + current_access_mode=AccessMode.UNRESTRICTED, + query_timeout=60, + ) + with patch.object(DbConnPool, "pool_connect", new_callable=AsyncMock): + pool = await service.create_db_connection() + assert pool.statement_timeout_seconds == 60 @pytest.mark.asyncio