From 07ef82f85cf6e464a52ab47f19c4b5e6ec7ab534 Mon Sep 17 00:00:00 2001 From: David Gilady Date: Wed, 25 Feb 2026 00:15:08 +0200 Subject: [PATCH 1/2] feat: support multiple database hosts via DATABASES__N__* env vars Add HostConfig dataclass and multi-host configuration support. Hosts are configured with DATABASES__N__HOST, DATABASES__N__PORT, DATABASES__N__USERNAME, DATABASES__N__PASSWORD environment variables. All MCP tools now accept an optional 'host' parameter to target a specific host. When omitted, the single configured host is used; when multiple hosts exist, the server returns a clear error. Legacy single-host configuration (DATABASE_HOST, CLI args) remains fully backward compatible. --- src/postgres_mcp/models.py | 11 ++ src/postgres_mcp/server.py | 208 +++++++++++++++++++-------- tests/unit/test_multi_host_test.py | 216 +++++++++++++++++++++++++++++ 3 files changed, 375 insertions(+), 60 deletions(-) create mode 100644 tests/unit/test_multi_host_test.py diff --git a/src/postgres_mcp/models.py b/src/postgres_mcp/models.py index ed8e110..2195bb5 100644 --- a/src/postgres_mcp/models.py +++ b/src/postgres_mcp/models.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from enum import Enum @@ -6,3 +7,13 @@ class AccessMode(str, Enum): UNRESTRICTED = "unrestricted" # Unrestricted access RESTRICTED = "restricted" # Read-only with safety features + + +@dataclass(frozen=True) +class HostConfig: + """Connection details for a single database host.""" + + host: str + port: int + username: str + password: str diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 086ae0b..3aa10ca 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -18,6 +18,7 @@ from starlette.responses import Response from postgres_mcp.models import AccessMode +from postgres_mcp.models import HostConfig from .database_health import HealthType from .database_service import DatabaseService @@ -29,7 +30,8 @@ PG_STAT_STATEMENTS = "pg_stat_statements" HYPOPG_EXTENSION = "hypopg" -ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] +ResponseType = List[types.TextContent | + types.ImageContent | types.EmbeddedResource] logger = logging.getLogger(__name__) @@ -37,19 +39,73 @@ # Global variables db_services: dict[str, DatabaseService] = {} current_access_mode = AccessMode.UNRESTRICTED -database_host: Optional[str] = None -database_port: Optional[str] = None -database_username: Optional[str] = None -database_password: Optional[str] = None +host_configs: dict[str, HostConfig] = {} query_timeout: Optional[float] = None shutdown_in_progress = False -async def get_service(database_name: str) -> DatabaseService: - if database_name not in db_services: - database_url = create_database_url(database_name) - db_services[database_name] = DatabaseService(database_url, current_access_mode, query_timeout) - return db_services[database_name] +def parse_host_configs_from_env() -> dict[str, HostConfig]: + """Parse DATABASES__N__HOST/PORT/USERNAME/PASSWORD environment variables into HostConfigs.""" + configs: dict[str, HostConfig] = {} + indices: set[str] = set() + for key in os.environ: + if key.startswith("DATABASES__") and key.endswith("__HOST"): + parts = key.split("__") + if len(parts) == 3: + indices.add(parts[1]) + + for idx in sorted(indices): + host = os.environ.get(f"DATABASES__{idx}__HOST") + if not host: + continue + port = int(os.environ.get(f"DATABASES__{idx}__PORT", "5432")) + username = os.environ.get(f"DATABASES__{idx}__USERNAME") + password = os.environ.get(f"DATABASES__{idx}__PASSWORD") + if not username or not password: + raise ValueError( + f"DATABASES__{idx}__USERNAME and DATABASES__{idx}__PASSWORD must both be set when DATABASES__{idx}__HOST is configured") + configs[host] = HostConfig( + host=host, port=port, username=username, password=password) + + return configs + + +def resolve_host_config(host: Optional[str] = None) -> HostConfig: + """Resolve the HostConfig for the given host name. + + When host is None, returns the single configured host or raises if multiple are configured. + """ + if host is not None: + if host not in host_configs: + available = ", ".join(sorted(host_configs.keys())) + raise ValueError( + f"No configuration found for host '{host}'. Available hosts: {available}") + return host_configs[host] + + if len(host_configs) == 1: + return next(iter(host_configs.values())) + + if len(host_configs) == 0: + raise ValueError( + "No database host configured. Set DATABASE_HOST or DATABASES__N__HOST environment variables.") + + available = ", ".join(sorted(host_configs.keys())) + raise ValueError( + f"Multiple database hosts are configured ({available}). The 'host' parameter is required when multiple hosts are configured.") + + +def create_database_url_from_config(config: HostConfig, database_name: str) -> str: + return f"postgresql://{config.username}:{config.password}@{config.host}:{config.port}/{database_name}" + + +async def get_service(database_name: str, host: Optional[str] = None) -> DatabaseService: + config = resolve_host_config(host) + service_key = f"{config.host}:{config.port}/{database_name}" + if service_key not in db_services: + database_url = create_database_url_from_config(config, database_name) + db_services[service_key] = DatabaseService( + database_url, current_access_mode, query_timeout) + return db_services[service_key] @mcp.custom_route("/health", methods=["GET"]) @@ -58,8 +114,12 @@ async def health_check(request: Request) -> Response: @mcp.tool(description="List all schemas in the database") -async def list_schemas(database_name: str = Field(description="Database name")) -> ResponseType: - service = await get_service(database_name) +async def list_schemas( + database_name: str = Field(description="Database name"), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), +) -> ResponseType: + service = await get_service(database_name, host) return await service.list_schemas() @@ -67,9 +127,12 @@ async def list_schemas(database_name: str = Field(description="Database name")) async def list_objects( database_name: str = Field(description="Database name"), schema_name: str = Field(description="Schema name"), - object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), + object_type: str = Field( + description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.list_objects(schema_name, object_type) @@ -78,9 +141,12 @@ async def get_object_details( database_name: str = Field(description="Database name"), schema_name: str = Field(description="Schema name"), object_name: str = Field(description="Object name"), - object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), + object_type: str = Field( + description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.get_object_details(schema_name, object_name, object_type) @@ -106,8 +172,10 @@ async def explain_query( If there is no hypothetical index, you can pass an empty list.""", default=[], ), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.explain_query(sql, analyze, hypothetical_indexes) @@ -115,8 +183,10 @@ async def explain_query( async def execute_sql( database_name: str = Field(description="Database name"), sql: str = Field(description="SQL to run", default="all"), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.execute_sql(sql) @@ -124,10 +194,14 @@ async def execute_sql( @validate_call async def analyze_workload_indexes( database_name: str = Field(description="Database name"), - max_index_size_mb: int = Field(description="Max index size in MB", default=10000), - method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), + max_index_size_mb: int = Field( + description="Max index size in MB", default=10000), + method: Literal["dta", "llm"] = Field( + description="Method to use for analysis", default="dta"), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.analyze_workload_indexes(max_index_size_mb, method) @@ -136,10 +210,14 @@ async def analyze_workload_indexes( async def analyze_query_indexes( database_name: str = Field(description="Database name"), queries: list[str] = Field(description="List of Query strings to analyze"), - max_index_size_mb: int = Field(description="Max index size in MB", default=10000), - method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), + max_index_size_mb: int = Field( + description="Max index size in MB", default=10000), + method: Literal["dta", "llm"] = Field( + description="Method to use for analysis", default="dta"), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.analyze_query_indexes(queries, max_index_size_mb, method) @@ -161,8 +239,10 @@ async def analyze_db_health( description=f"Optional. Valid values are: {', '.join(sorted([t.value for t in HealthType]))}.", default="all", ), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.analyze_db_health(health_type) @@ -177,16 +257,20 @@ async def get_top_queries( "for resource-intensive queries", default="resources", ), - limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10), + limit: int = Field( + description="Number of queries to return when ranking based on mean_time or total_time", default=10), + host: Optional[str] = Field( + description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.get_top_queries(sort_by, limit) async def main(): # Parse command line arguments parser = argparse.ArgumentParser(description="PostgreSQL MCP Server") - parser.add_argument("database_host", help="Database host: e.g database.example.com", nargs="?") + parser.add_argument( + "database_host", help="Database host: e.g database.example.com", nargs="?") parser.add_argument( "--database-port", type=int, @@ -242,24 +326,39 @@ async def main(): # Store the access mode in the global variable global current_access_mode - global database_host - global database_port - global database_username - global database_password + global host_configs global query_timeout current_access_mode = AccessMode(args.access_mode) - database_host = os.environ.get("DATABASE_HOST", args.database_host) - database_port = os.environ.get("DATABASE_PORT", args.database_port) raw_query_timeout = os.environ.get("QUERY_TIMEOUT") - query_timeout = float(raw_query_timeout) if raw_query_timeout is not None else None - - if args.database_creds_file: - database_username, database_password = read_database_creds(args.database_creds_file) - else: - database_username = os.environ.get("DATABASE_USERNAME", args.database_username) - database_password = os.environ.get("DATABASE_PASSWORD", args.database_password) + query_timeout = float( + raw_query_timeout) if raw_query_timeout is not None else None + + # Build host configs from multi-host env vars (DATABASES__N__*) + host_configs = parse_host_configs_from_env() + + # Also support the legacy single-host configuration (CLI args + single env vars) + legacy_host = os.environ.get("DATABASE_HOST", args.database_host) + if legacy_host: + legacy_port = int(os.environ.get( + "DATABASE_PORT", str(args.database_port))) + if args.database_creds_file: + legacy_username, legacy_password = read_database_creds( + args.database_creds_file) + else: + legacy_username = os.environ.get( + "DATABASE_USERNAME", args.database_username) + legacy_password = os.environ.get( + "DATABASE_PASSWORD", args.database_password) + + if legacy_username and legacy_password and legacy_host not in host_configs: + host_configs[legacy_host] = HostConfig( + host=legacy_host, + port=legacy_port, + username=legacy_username, + password=legacy_password, + ) # Add the query tool with a description appropriate to the access mode if current_access_mode == AccessMode.UNRESTRICTED: @@ -267,9 +366,11 @@ async def main(): else: mcp.add_tool(execute_sql, description="Execute a read-only SQL query") + configured_hosts = ", ".join(sorted(host_configs.keys())) or "(none)" logger.info( - f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode. transport={args.transport}, database_host={database_host}, " - f"database_port={database_port}, database_username={database_username}, sse_host={args.sse_host}, sse_port={args.sse_port}" + f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode. " + f"transport={args.transport}, configured_hosts=[{configured_hosts}], " + f"sse_host={args.sse_host}, sse_port={args.sse_port}" ) # Set up proper shutdown handling @@ -277,7 +378,8 @@ async def main(): loop = asyncio.get_running_loop() signals = (signal.SIGTERM, signal.SIGINT) for s in signals: - loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s))) + loop.add_signal_handler( + s, lambda s=s: asyncio.create_task(shutdown(s))) except NotImplementedError: # Windows doesn't support signals properly logger.warning("Signal handling not supported on Windows") @@ -299,28 +401,14 @@ def read_database_creds(creds_file: str) -> tuple[str, str]: with open(creds_file) as f: lines = f.read().splitlines() if len(lines) < 2: - raise ValueError("Credentials file must contain at least two lines: username and password") + raise ValueError( + "Credentials file must contain at least two lines: username and password") return lines[0], lines[1] except Exception as e: logger.error(f"Error reading database credentials from file: {e}") raise -def create_database_url(database_name: str) -> str: - if not database_host: - raise ValueError("Database host must be specified via command line argument or DATABASE_HOST environment variable") - if not database_username: - raise ValueError( - "Database username must be specified via command line argument, DATABASE_USERNAME environment variable or a database credentials file" - ) - if not database_password: - raise ValueError( - "Database password must be specified via command line argument, DATABASE_PASSWORD environment variable or a database credentials file" - ) - - return f"postgresql://{database_username}:{database_password}@{database_host}:{database_port}/{database_name}" - - async def shutdown(sig=None): """Clean shutdown of the server.""" global shutdown_in_progress diff --git a/tests/unit/test_multi_host_test.py b/tests/unit/test_multi_host_test.py new file mode 100644 index 0000000..3bcfa6d --- /dev/null +++ b/tests/unit/test_multi_host_test.py @@ -0,0 +1,216 @@ +import os +from unittest.mock import patch + +import pytest + +import postgres_mcp.server as server_module +from postgres_mcp.models import AccessMode +from postgres_mcp.models import HostConfig +from postgres_mcp.server import create_database_url_from_config +from postgres_mcp.server import get_service +from postgres_mcp.server import parse_host_configs_from_env +from postgres_mcp.server import resolve_host_config + +# --- parse_host_configs_from_env --- + + +class TestParseHostConfigsFromEnv: + def test_no_env_vars_returns_empty(self): + with patch.dict(os.environ, {}, clear=True): + assert parse_host_configs_from_env() == {} + + def test_single_host_parsed(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__PORT": "5433", + "DATABASES__1__USERNAME": "user1", + "DATABASES__1__PASSWORD": "pass1", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert len(configs) == 1 + config = configs["db1.example.com"] + assert config.host == "db1.example.com" + assert config.port == 5433 + assert config.username == "user1" + assert config.password == "pass1" + + def test_multiple_hosts_parsed(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__PORT": "5432", + "DATABASES__1__USERNAME": "user1", + "DATABASES__1__PASSWORD": "pass1", + "DATABASES__2__HOST": "db2.example.com", + "DATABASES__2__PORT": "5433", + "DATABASES__2__USERNAME": "user2", + "DATABASES__2__PASSWORD": "pass2", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert len(configs) == 2 + assert "db1.example.com" in configs + assert "db2.example.com" in configs + assert configs["db2.example.com"].port == 5433 + + def test_default_port_when_omitted(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__USERNAME": "user1", + "DATABASES__1__PASSWORD": "pass1", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert configs["db1.example.com"].port == 5432 + + def test_missing_username_raises(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__PASSWORD": "pass1", + } + with patch.dict(os.environ, env, clear=True): + with pytest.raises(ValueError, match="DATABASES__1__USERNAME and DATABASES__1__PASSWORD must both be set"): + parse_host_configs_from_env() + + def test_missing_password_raises(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__USERNAME": "user1", + } + with patch.dict(os.environ, env, clear=True): + with pytest.raises(ValueError, match="DATABASES__1__USERNAME and DATABASES__1__PASSWORD must both be set"): + parse_host_configs_from_env() + + def test_non_numeric_indices_supported(self): + env = { + "DATABASES__prod__HOST": "prod.example.com", + "DATABASES__prod__USERNAME": "admin", + "DATABASES__prod__PASSWORD": "secret", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert "prod.example.com" in configs + + def test_unrelated_env_vars_ignored(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__USERNAME": "user1", + "DATABASES__1__PASSWORD": "pass1", + "SOME_OTHER_VAR": "value", + "DATABASES__1__EXTRA": "ignored", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert len(configs) == 1 + + +# --- resolve_host_config --- + + +class TestResolveHostConfig: + CONFIG_A = HostConfig(host="a.example.com", port=5432, + username="ua", password="pa") + CONFIG_B = HostConfig(host="b.example.com", port=5433, + username="ub", password="pb") + + def test_single_host_no_host_param_returns_it(self): + with patch.object(server_module, "host_configs", {"a.example.com": self.CONFIG_A}): + assert resolve_host_config(None) == self.CONFIG_A + + def test_single_host_with_matching_host_param(self): + with patch.object(server_module, "host_configs", {"a.example.com": self.CONFIG_A}): + assert resolve_host_config("a.example.com") == self.CONFIG_A + + def test_single_host_with_wrong_host_param_raises(self): + with patch.object(server_module, "host_configs", {"a.example.com": self.CONFIG_A}): + with pytest.raises(ValueError, match="No configuration found for host 'unknown'"): + resolve_host_config("unknown") + + def test_multiple_hosts_no_host_param_raises(self): + configs = {"a.example.com": self.CONFIG_A, + "b.example.com": self.CONFIG_B} + with patch.object(server_module, "host_configs", configs): + with pytest.raises(ValueError, match="'host' parameter is required when multiple hosts are configured"): + resolve_host_config(None) + + def test_multiple_hosts_with_host_param(self): + configs = {"a.example.com": self.CONFIG_A, + "b.example.com": self.CONFIG_B} + with patch.object(server_module, "host_configs", configs): + assert resolve_host_config("b.example.com") == self.CONFIG_B + + def test_no_hosts_configured_raises(self): + with patch.object(server_module, "host_configs", {}): + with pytest.raises(ValueError, match="No database host configured"): + resolve_host_config(None) + + +# --- create_database_url_from_config --- + + +class TestCreateDatabaseUrlFromConfig: + def test_url_format(self): + config = HostConfig(host="myhost", port=5432, + username="myuser", password="mypass") + url = create_database_url_from_config(config, "mydb") + assert url == "postgresql://myuser:mypass@myhost:5432/mydb" + + def test_url_with_custom_port(self): + config = HostConfig(host="myhost", port=5433, + username="u", password="p") + url = create_database_url_from_config(config, "testdb") + assert url == "postgresql://u:p@myhost:5433/testdb" + + +# --- get_service --- + + +class TestGetService: + CONFIG = HostConfig(host="db.example.com", port=5432, + username="user", password="pass") + + @pytest.mark.asyncio + async def test_creates_service_for_new_database(self): + with ( + patch.object(server_module, "host_configs", { + "db.example.com": self.CONFIG}), + patch.object(server_module, "db_services", {}), + patch.object(server_module, "current_access_mode", + AccessMode.UNRESTRICTED), + patch.object(server_module, "query_timeout", None), + ): + service = await get_service("mydb") + assert service.database_url == "postgresql://user:pass@db.example.com:5432/mydb" + + @pytest.mark.asyncio + async def test_reuses_service_for_same_host_and_database(self): + with ( + patch.object(server_module, "host_configs", { + "db.example.com": self.CONFIG}), + patch.object(server_module, "db_services", {}), + patch.object(server_module, "current_access_mode", + AccessMode.UNRESTRICTED), + patch.object(server_module, "query_timeout", None), + ): + service1 = await get_service("mydb") + service2 = await get_service("mydb") + assert service1 is service2 + + @pytest.mark.asyncio + async def test_different_hosts_get_different_services(self): + config_b = HostConfig(host="other.example.com", + port=5433, username="u2", password="p2") + configs = {"db.example.com": self.CONFIG, + "other.example.com": config_b} + with ( + patch.object(server_module, "host_configs", configs), + patch.object(server_module, "db_services", {}), + patch.object(server_module, "current_access_mode", + AccessMode.UNRESTRICTED), + patch.object(server_module, "query_timeout", None), + ): + service_a = await get_service("mydb", "db.example.com") + service_b = await get_service("mydb", "other.example.com") + assert service_a is not service_b + assert "db.example.com" in service_a.database_url + assert "other.example.com" in service_b.database_url From 4e7d388977c34094f9bb7196e65f1aff249d1506 Mon Sep 17 00:00:00 2001 From: David Gilady Date: Wed, 25 Feb 2026 01:15:46 +0200 Subject: [PATCH 2/2] fix formatting errors --- src/postgres_mcp/server.py | 93 ++++++++++-------------------- tests/unit/test_multi_host_test.py | 42 +++++--------- 2 files changed, 45 insertions(+), 90 deletions(-) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 3aa10ca..fc55db6 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -30,8 +30,7 @@ PG_STAT_STATEMENTS = "pg_stat_statements" HYPOPG_EXTENSION = "hypopg" -ResponseType = List[types.TextContent | - types.ImageContent | types.EmbeddedResource] +ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] logger = logging.getLogger(__name__) @@ -62,10 +61,8 @@ def parse_host_configs_from_env() -> dict[str, HostConfig]: username = os.environ.get(f"DATABASES__{idx}__USERNAME") password = os.environ.get(f"DATABASES__{idx}__PASSWORD") if not username or not password: - raise ValueError( - f"DATABASES__{idx}__USERNAME and DATABASES__{idx}__PASSWORD must both be set when DATABASES__{idx}__HOST is configured") - configs[host] = HostConfig( - host=host, port=port, username=username, password=password) + raise ValueError(f"DATABASES__{idx}__USERNAME and DATABASES__{idx}__PASSWORD must both be set when DATABASES__{idx}__HOST is configured") + configs[host] = HostConfig(host=host, port=port, username=username, password=password) return configs @@ -78,20 +75,17 @@ def resolve_host_config(host: Optional[str] = None) -> HostConfig: if host is not None: if host not in host_configs: available = ", ".join(sorted(host_configs.keys())) - raise ValueError( - f"No configuration found for host '{host}'. Available hosts: {available}") + raise ValueError(f"No configuration found for host '{host}'. Available hosts: {available}") return host_configs[host] if len(host_configs) == 1: return next(iter(host_configs.values())) if len(host_configs) == 0: - raise ValueError( - "No database host configured. Set DATABASE_HOST or DATABASES__N__HOST environment variables.") + raise ValueError("No database host configured. Set DATABASE_HOST or DATABASES__N__HOST environment variables.") available = ", ".join(sorted(host_configs.keys())) - raise ValueError( - f"Multiple database hosts are configured ({available}). The 'host' parameter is required when multiple hosts are configured.") + raise ValueError(f"Multiple database hosts are configured ({available}). The 'host' parameter is required when multiple hosts are configured.") def create_database_url_from_config(config: HostConfig, database_name: str) -> str: @@ -103,8 +97,7 @@ async def get_service(database_name: str, host: Optional[str] = None) -> Databas service_key = f"{config.host}:{config.port}/{database_name}" if service_key not in db_services: database_url = create_database_url_from_config(config, database_name) - db_services[service_key] = DatabaseService( - database_url, current_access_mode, query_timeout) + db_services[service_key] = DatabaseService(database_url, current_access_mode, query_timeout) return db_services[service_key] @@ -116,8 +109,7 @@ async def health_check(request: Request) -> Response: @mcp.tool(description="List all schemas in the database") async def list_schemas( database_name: str = Field(description="Database name"), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.list_schemas() @@ -127,10 +119,8 @@ async def list_schemas( async def list_objects( database_name: str = Field(description="Database name"), schema_name: str = Field(description="Schema name"), - object_type: str = Field( - description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.list_objects(schema_name, object_type) @@ -141,10 +131,8 @@ async def get_object_details( database_name: str = Field(description="Database name"), schema_name: str = Field(description="Schema name"), object_name: str = Field(description="Object name"), - object_type: str = Field( - description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.get_object_details(schema_name, object_name, object_type) @@ -172,8 +160,7 @@ async def explain_query( If there is no hypothetical index, you can pass an empty list.""", default=[], ), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.explain_query(sql, analyze, hypothetical_indexes) @@ -183,8 +170,7 @@ async def explain_query( async def execute_sql( database_name: str = Field(description="Database name"), sql: str = Field(description="SQL to run", default="all"), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.execute_sql(sql) @@ -194,12 +180,9 @@ async def execute_sql( @validate_call async def analyze_workload_indexes( database_name: str = Field(description="Database name"), - max_index_size_mb: int = Field( - description="Max index size in MB", default=10000), - method: Literal["dta", "llm"] = Field( - description="Method to use for analysis", default="dta"), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + max_index_size_mb: int = Field(description="Max index size in MB", default=10000), + method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.analyze_workload_indexes(max_index_size_mb, method) @@ -210,12 +193,9 @@ async def analyze_workload_indexes( async def analyze_query_indexes( database_name: str = Field(description="Database name"), queries: list[str] = Field(description="List of Query strings to analyze"), - max_index_size_mb: int = Field( - description="Max index size in MB", default=10000), - method: Literal["dta", "llm"] = Field( - description="Method to use for analysis", default="dta"), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + max_index_size_mb: int = Field(description="Max index size in MB", default=10000), + method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.analyze_query_indexes(queries, max_index_size_mb, method) @@ -239,8 +219,7 @@ async def analyze_db_health( description=f"Optional. Valid values are: {', '.join(sorted([t.value for t in HealthType]))}.", default="all", ), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.analyze_db_health(health_type) @@ -257,10 +236,8 @@ async def get_top_queries( "for resource-intensive queries", default="resources", ), - limit: int = Field( - description="Number of queries to return when ranking based on mean_time or total_time", default=10), - host: Optional[str] = Field( - description="Database host name or address. Only provide when explicitly requested.", default=None), + limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: service = await get_service(database_name, host) return await service.get_top_queries(sort_by, limit) @@ -269,8 +246,7 @@ async def get_top_queries( async def main(): # Parse command line arguments parser = argparse.ArgumentParser(description="PostgreSQL MCP Server") - parser.add_argument( - "database_host", help="Database host: e.g database.example.com", nargs="?") + parser.add_argument("database_host", help="Database host: e.g database.example.com", nargs="?") parser.add_argument( "--database-port", type=int, @@ -332,8 +308,7 @@ async def main(): current_access_mode = AccessMode(args.access_mode) raw_query_timeout = os.environ.get("QUERY_TIMEOUT") - query_timeout = float( - raw_query_timeout) if raw_query_timeout is not None else None + query_timeout = float(raw_query_timeout) if raw_query_timeout is not None else None # Build host configs from multi-host env vars (DATABASES__N__*) host_configs = parse_host_configs_from_env() @@ -341,16 +316,12 @@ async def main(): # Also support the legacy single-host configuration (CLI args + single env vars) legacy_host = os.environ.get("DATABASE_HOST", args.database_host) if legacy_host: - legacy_port = int(os.environ.get( - "DATABASE_PORT", str(args.database_port))) + legacy_port = int(os.environ.get("DATABASE_PORT", str(args.database_port))) if args.database_creds_file: - legacy_username, legacy_password = read_database_creds( - args.database_creds_file) + legacy_username, legacy_password = read_database_creds(args.database_creds_file) else: - legacy_username = os.environ.get( - "DATABASE_USERNAME", args.database_username) - legacy_password = os.environ.get( - "DATABASE_PASSWORD", args.database_password) + legacy_username = os.environ.get("DATABASE_USERNAME", args.database_username) + legacy_password = os.environ.get("DATABASE_PASSWORD", args.database_password) if legacy_username and legacy_password and legacy_host not in host_configs: host_configs[legacy_host] = HostConfig( @@ -378,8 +349,7 @@ async def main(): loop = asyncio.get_running_loop() signals = (signal.SIGTERM, signal.SIGINT) for s in signals: - loop.add_signal_handler( - s, lambda s=s: asyncio.create_task(shutdown(s))) + loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s))) except NotImplementedError: # Windows doesn't support signals properly logger.warning("Signal handling not supported on Windows") @@ -401,8 +371,7 @@ def read_database_creds(creds_file: str) -> tuple[str, str]: with open(creds_file) as f: lines = f.read().splitlines() if len(lines) < 2: - raise ValueError( - "Credentials file must contain at least two lines: username and password") + raise ValueError("Credentials file must contain at least two lines: username and password") return lines[0], lines[1] except Exception as e: logger.error(f"Error reading database credentials from file: {e}") diff --git a/tests/unit/test_multi_host_test.py b/tests/unit/test_multi_host_test.py index 3bcfa6d..9ee9508 100644 --- a/tests/unit/test_multi_host_test.py +++ b/tests/unit/test_multi_host_test.py @@ -108,10 +108,8 @@ def test_unrelated_env_vars_ignored(self): class TestResolveHostConfig: - CONFIG_A = HostConfig(host="a.example.com", port=5432, - username="ua", password="pa") - CONFIG_B = HostConfig(host="b.example.com", port=5433, - username="ub", password="pb") + CONFIG_A = HostConfig(host="a.example.com", port=5432, username="ua", password="pa") + CONFIG_B = HostConfig(host="b.example.com", port=5433, username="ub", password="pb") def test_single_host_no_host_param_returns_it(self): with patch.object(server_module, "host_configs", {"a.example.com": self.CONFIG_A}): @@ -127,15 +125,13 @@ def test_single_host_with_wrong_host_param_raises(self): resolve_host_config("unknown") def test_multiple_hosts_no_host_param_raises(self): - configs = {"a.example.com": self.CONFIG_A, - "b.example.com": self.CONFIG_B} + configs = {"a.example.com": self.CONFIG_A, "b.example.com": self.CONFIG_B} with patch.object(server_module, "host_configs", configs): with pytest.raises(ValueError, match="'host' parameter is required when multiple hosts are configured"): resolve_host_config(None) def test_multiple_hosts_with_host_param(self): - configs = {"a.example.com": self.CONFIG_A, - "b.example.com": self.CONFIG_B} + configs = {"a.example.com": self.CONFIG_A, "b.example.com": self.CONFIG_B} with patch.object(server_module, "host_configs", configs): assert resolve_host_config("b.example.com") == self.CONFIG_B @@ -150,14 +146,12 @@ def test_no_hosts_configured_raises(self): class TestCreateDatabaseUrlFromConfig: def test_url_format(self): - config = HostConfig(host="myhost", port=5432, - username="myuser", password="mypass") + config = HostConfig(host="myhost", port=5432, username="myuser", password="mypass") url = create_database_url_from_config(config, "mydb") assert url == "postgresql://myuser:mypass@myhost:5432/mydb" def test_url_with_custom_port(self): - config = HostConfig(host="myhost", port=5433, - username="u", password="p") + config = HostConfig(host="myhost", port=5433, username="u", password="p") url = create_database_url_from_config(config, "testdb") assert url == "postgresql://u:p@myhost:5433/testdb" @@ -166,17 +160,14 @@ def test_url_with_custom_port(self): class TestGetService: - CONFIG = HostConfig(host="db.example.com", port=5432, - username="user", password="pass") + CONFIG = HostConfig(host="db.example.com", port=5432, username="user", password="pass") @pytest.mark.asyncio async def test_creates_service_for_new_database(self): with ( - patch.object(server_module, "host_configs", { - "db.example.com": self.CONFIG}), + patch.object(server_module, "host_configs", {"db.example.com": self.CONFIG}), patch.object(server_module, "db_services", {}), - patch.object(server_module, "current_access_mode", - AccessMode.UNRESTRICTED), + patch.object(server_module, "current_access_mode", AccessMode.UNRESTRICTED), patch.object(server_module, "query_timeout", None), ): service = await get_service("mydb") @@ -185,11 +176,9 @@ async def test_creates_service_for_new_database(self): @pytest.mark.asyncio async def test_reuses_service_for_same_host_and_database(self): with ( - patch.object(server_module, "host_configs", { - "db.example.com": self.CONFIG}), + patch.object(server_module, "host_configs", {"db.example.com": self.CONFIG}), patch.object(server_module, "db_services", {}), - patch.object(server_module, "current_access_mode", - AccessMode.UNRESTRICTED), + patch.object(server_module, "current_access_mode", AccessMode.UNRESTRICTED), patch.object(server_module, "query_timeout", None), ): service1 = await get_service("mydb") @@ -198,15 +187,12 @@ async def test_reuses_service_for_same_host_and_database(self): @pytest.mark.asyncio async def test_different_hosts_get_different_services(self): - config_b = HostConfig(host="other.example.com", - port=5433, username="u2", password="p2") - configs = {"db.example.com": self.CONFIG, - "other.example.com": config_b} + config_b = HostConfig(host="other.example.com", port=5433, username="u2", password="p2") + configs = {"db.example.com": self.CONFIG, "other.example.com": config_b} with ( patch.object(server_module, "host_configs", configs), patch.object(server_module, "db_services", {}), - patch.object(server_module, "current_access_mode", - AccessMode.UNRESTRICTED), + patch.object(server_module, "current_access_mode", AccessMode.UNRESTRICTED), patch.object(server_module, "query_timeout", None), ): service_a = await get_service("mydb", "db.example.com")