diff --git a/src/postgres_mcp/models.py b/src/postgres_mcp/models.py index ed8e110..2195bb5 100644 --- a/src/postgres_mcp/models.py +++ b/src/postgres_mcp/models.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from enum import Enum @@ -6,3 +7,13 @@ class AccessMode(str, Enum): UNRESTRICTED = "unrestricted" # Unrestricted access RESTRICTED = "restricted" # Read-only with safety features + + +@dataclass(frozen=True) +class HostConfig: + """Connection details for a single database host.""" + + host: str + port: int + username: str + password: str diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 086ae0b..fc55db6 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -18,6 +18,7 @@ from starlette.responses import Response from postgres_mcp.models import AccessMode +from postgres_mcp.models import HostConfig from .database_health import HealthType from .database_service import DatabaseService @@ -37,19 +38,67 @@ # Global variables db_services: dict[str, DatabaseService] = {} current_access_mode = AccessMode.UNRESTRICTED -database_host: Optional[str] = None -database_port: Optional[str] = None -database_username: Optional[str] = None -database_password: Optional[str] = None +host_configs: dict[str, HostConfig] = {} query_timeout: Optional[float] = None shutdown_in_progress = False -async def get_service(database_name: str) -> DatabaseService: - if database_name not in db_services: - database_url = create_database_url(database_name) - db_services[database_name] = DatabaseService(database_url, current_access_mode, query_timeout) - return db_services[database_name] +def parse_host_configs_from_env() -> dict[str, HostConfig]: + """Parse DATABASES__N__HOST/PORT/USERNAME/PASSWORD environment variables into HostConfigs.""" + configs: dict[str, HostConfig] = {} + indices: set[str] = set() + for key in os.environ: + if key.startswith("DATABASES__") and key.endswith("__HOST"): + parts = key.split("__") + if len(parts) == 3: + indices.add(parts[1]) + + for idx in sorted(indices): + host = os.environ.get(f"DATABASES__{idx}__HOST") + if not host: + continue + port = int(os.environ.get(f"DATABASES__{idx}__PORT", "5432")) + username = os.environ.get(f"DATABASES__{idx}__USERNAME") + password = os.environ.get(f"DATABASES__{idx}__PASSWORD") + if not username or not password: + raise ValueError(f"DATABASES__{idx}__USERNAME and DATABASES__{idx}__PASSWORD must both be set when DATABASES__{idx}__HOST is configured") + configs[host] = HostConfig(host=host, port=port, username=username, password=password) + + return configs + + +def resolve_host_config(host: Optional[str] = None) -> HostConfig: + """Resolve the HostConfig for the given host name. + + When host is None, returns the single configured host or raises if multiple are configured. + """ + if host is not None: + if host not in host_configs: + available = ", ".join(sorted(host_configs.keys())) + raise ValueError(f"No configuration found for host '{host}'. Available hosts: {available}") + return host_configs[host] + + if len(host_configs) == 1: + return next(iter(host_configs.values())) + + if len(host_configs) == 0: + raise ValueError("No database host configured. Set DATABASE_HOST or DATABASES__N__HOST environment variables.") + + available = ", ".join(sorted(host_configs.keys())) + raise ValueError(f"Multiple database hosts are configured ({available}). The 'host' parameter is required when multiple hosts are configured.") + + +def create_database_url_from_config(config: HostConfig, database_name: str) -> str: + return f"postgresql://{config.username}:{config.password}@{config.host}:{config.port}/{database_name}" + + +async def get_service(database_name: str, host: Optional[str] = None) -> DatabaseService: + config = resolve_host_config(host) + service_key = f"{config.host}:{config.port}/{database_name}" + if service_key not in db_services: + database_url = create_database_url_from_config(config, database_name) + db_services[service_key] = DatabaseService(database_url, current_access_mode, query_timeout) + return db_services[service_key] @mcp.custom_route("/health", methods=["GET"]) @@ -58,8 +107,11 @@ async def health_check(request: Request) -> Response: @mcp.tool(description="List all schemas in the database") -async def list_schemas(database_name: str = Field(description="Database name")) -> ResponseType: - service = await get_service(database_name) +async def list_schemas( + database_name: str = Field(description="Database name"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), +) -> ResponseType: + service = await get_service(database_name, host) return await service.list_schemas() @@ -68,8 +120,9 @@ async def list_objects( database_name: str = Field(description="Database name"), schema_name: str = Field(description="Schema name"), object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.list_objects(schema_name, object_type) @@ -79,8 +132,9 @@ async def get_object_details( schema_name: str = Field(description="Schema name"), object_name: str = Field(description="Object name"), object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.get_object_details(schema_name, object_name, object_type) @@ -106,8 +160,9 @@ async def explain_query( If there is no hypothetical index, you can pass an empty list.""", default=[], ), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.explain_query(sql, analyze, hypothetical_indexes) @@ -115,8 +170,9 @@ async def explain_query( async def execute_sql( database_name: str = Field(description="Database name"), sql: str = Field(description="SQL to run", default="all"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.execute_sql(sql) @@ -126,8 +182,9 @@ async def analyze_workload_indexes( database_name: str = Field(description="Database name"), max_index_size_mb: int = Field(description="Max index size in MB", default=10000), method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.analyze_workload_indexes(max_index_size_mb, method) @@ -138,8 +195,9 @@ async def analyze_query_indexes( queries: list[str] = Field(description="List of Query strings to analyze"), max_index_size_mb: int = Field(description="Max index size in MB", default=10000), method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.analyze_query_indexes(queries, max_index_size_mb, method) @@ -161,8 +219,9 @@ async def analyze_db_health( description=f"Optional. Valid values are: {', '.join(sorted([t.value for t in HealthType]))}.", default="all", ), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.analyze_db_health(health_type) @@ -178,8 +237,9 @@ async def get_top_queries( default="resources", ), limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10), + host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None), ) -> ResponseType: - service = await get_service(database_name) + service = await get_service(database_name, host) return await service.get_top_queries(sort_by, limit) @@ -242,24 +302,34 @@ async def main(): # Store the access mode in the global variable global current_access_mode - global database_host - global database_port - global database_username - global database_password + global host_configs global query_timeout current_access_mode = AccessMode(args.access_mode) - database_host = os.environ.get("DATABASE_HOST", args.database_host) - database_port = os.environ.get("DATABASE_PORT", args.database_port) raw_query_timeout = os.environ.get("QUERY_TIMEOUT") query_timeout = float(raw_query_timeout) if raw_query_timeout is not None else None - if args.database_creds_file: - database_username, database_password = read_database_creds(args.database_creds_file) - else: - database_username = os.environ.get("DATABASE_USERNAME", args.database_username) - database_password = os.environ.get("DATABASE_PASSWORD", args.database_password) + # Build host configs from multi-host env vars (DATABASES__N__*) + host_configs = parse_host_configs_from_env() + + # Also support the legacy single-host configuration (CLI args + single env vars) + legacy_host = os.environ.get("DATABASE_HOST", args.database_host) + if legacy_host: + legacy_port = int(os.environ.get("DATABASE_PORT", str(args.database_port))) + if args.database_creds_file: + legacy_username, legacy_password = read_database_creds(args.database_creds_file) + else: + legacy_username = os.environ.get("DATABASE_USERNAME", args.database_username) + legacy_password = os.environ.get("DATABASE_PASSWORD", args.database_password) + + if legacy_username and legacy_password and legacy_host not in host_configs: + host_configs[legacy_host] = HostConfig( + host=legacy_host, + port=legacy_port, + username=legacy_username, + password=legacy_password, + ) # Add the query tool with a description appropriate to the access mode if current_access_mode == AccessMode.UNRESTRICTED: @@ -267,9 +337,11 @@ async def main(): else: mcp.add_tool(execute_sql, description="Execute a read-only SQL query") + configured_hosts = ", ".join(sorted(host_configs.keys())) or "(none)" logger.info( - f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode. transport={args.transport}, database_host={database_host}, " - f"database_port={database_port}, database_username={database_username}, sse_host={args.sse_host}, sse_port={args.sse_port}" + f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode. " + f"transport={args.transport}, configured_hosts=[{configured_hosts}], " + f"sse_host={args.sse_host}, sse_port={args.sse_port}" ) # Set up proper shutdown handling @@ -306,21 +378,6 @@ def read_database_creds(creds_file: str) -> tuple[str, str]: raise -def create_database_url(database_name: str) -> str: - if not database_host: - raise ValueError("Database host must be specified via command line argument or DATABASE_HOST environment variable") - if not database_username: - raise ValueError( - "Database username must be specified via command line argument, DATABASE_USERNAME environment variable or a database credentials file" - ) - if not database_password: - raise ValueError( - "Database password must be specified via command line argument, DATABASE_PASSWORD environment variable or a database credentials file" - ) - - return f"postgresql://{database_username}:{database_password}@{database_host}:{database_port}/{database_name}" - - async def shutdown(sig=None): """Clean shutdown of the server.""" global shutdown_in_progress diff --git a/tests/unit/test_multi_host_test.py b/tests/unit/test_multi_host_test.py new file mode 100644 index 0000000..9ee9508 --- /dev/null +++ b/tests/unit/test_multi_host_test.py @@ -0,0 +1,202 @@ +import os +from unittest.mock import patch + +import pytest + +import postgres_mcp.server as server_module +from postgres_mcp.models import AccessMode +from postgres_mcp.models import HostConfig +from postgres_mcp.server import create_database_url_from_config +from postgres_mcp.server import get_service +from postgres_mcp.server import parse_host_configs_from_env +from postgres_mcp.server import resolve_host_config + +# --- parse_host_configs_from_env --- + + +class TestParseHostConfigsFromEnv: + def test_no_env_vars_returns_empty(self): + with patch.dict(os.environ, {}, clear=True): + assert parse_host_configs_from_env() == {} + + def test_single_host_parsed(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__PORT": "5433", + "DATABASES__1__USERNAME": "user1", + "DATABASES__1__PASSWORD": "pass1", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert len(configs) == 1 + config = configs["db1.example.com"] + assert config.host == "db1.example.com" + assert config.port == 5433 + assert config.username == "user1" + assert config.password == "pass1" + + def test_multiple_hosts_parsed(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__PORT": "5432", + "DATABASES__1__USERNAME": "user1", + "DATABASES__1__PASSWORD": "pass1", + "DATABASES__2__HOST": "db2.example.com", + "DATABASES__2__PORT": "5433", + "DATABASES__2__USERNAME": "user2", + "DATABASES__2__PASSWORD": "pass2", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert len(configs) == 2 + assert "db1.example.com" in configs + assert "db2.example.com" in configs + assert configs["db2.example.com"].port == 5433 + + def test_default_port_when_omitted(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__USERNAME": "user1", + "DATABASES__1__PASSWORD": "pass1", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert configs["db1.example.com"].port == 5432 + + def test_missing_username_raises(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__PASSWORD": "pass1", + } + with patch.dict(os.environ, env, clear=True): + with pytest.raises(ValueError, match="DATABASES__1__USERNAME and DATABASES__1__PASSWORD must both be set"): + parse_host_configs_from_env() + + def test_missing_password_raises(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__USERNAME": "user1", + } + with patch.dict(os.environ, env, clear=True): + with pytest.raises(ValueError, match="DATABASES__1__USERNAME and DATABASES__1__PASSWORD must both be set"): + parse_host_configs_from_env() + + def test_non_numeric_indices_supported(self): + env = { + "DATABASES__prod__HOST": "prod.example.com", + "DATABASES__prod__USERNAME": "admin", + "DATABASES__prod__PASSWORD": "secret", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert "prod.example.com" in configs + + def test_unrelated_env_vars_ignored(self): + env = { + "DATABASES__1__HOST": "db1.example.com", + "DATABASES__1__USERNAME": "user1", + "DATABASES__1__PASSWORD": "pass1", + "SOME_OTHER_VAR": "value", + "DATABASES__1__EXTRA": "ignored", + } + with patch.dict(os.environ, env, clear=True): + configs = parse_host_configs_from_env() + assert len(configs) == 1 + + +# --- resolve_host_config --- + + +class TestResolveHostConfig: + CONFIG_A = HostConfig(host="a.example.com", port=5432, username="ua", password="pa") + CONFIG_B = HostConfig(host="b.example.com", port=5433, username="ub", password="pb") + + def test_single_host_no_host_param_returns_it(self): + with patch.object(server_module, "host_configs", {"a.example.com": self.CONFIG_A}): + assert resolve_host_config(None) == self.CONFIG_A + + def test_single_host_with_matching_host_param(self): + with patch.object(server_module, "host_configs", {"a.example.com": self.CONFIG_A}): + assert resolve_host_config("a.example.com") == self.CONFIG_A + + def test_single_host_with_wrong_host_param_raises(self): + with patch.object(server_module, "host_configs", {"a.example.com": self.CONFIG_A}): + with pytest.raises(ValueError, match="No configuration found for host 'unknown'"): + resolve_host_config("unknown") + + def test_multiple_hosts_no_host_param_raises(self): + configs = {"a.example.com": self.CONFIG_A, "b.example.com": self.CONFIG_B} + with patch.object(server_module, "host_configs", configs): + with pytest.raises(ValueError, match="'host' parameter is required when multiple hosts are configured"): + resolve_host_config(None) + + def test_multiple_hosts_with_host_param(self): + configs = {"a.example.com": self.CONFIG_A, "b.example.com": self.CONFIG_B} + with patch.object(server_module, "host_configs", configs): + assert resolve_host_config("b.example.com") == self.CONFIG_B + + def test_no_hosts_configured_raises(self): + with patch.object(server_module, "host_configs", {}): + with pytest.raises(ValueError, match="No database host configured"): + resolve_host_config(None) + + +# --- create_database_url_from_config --- + + +class TestCreateDatabaseUrlFromConfig: + def test_url_format(self): + config = HostConfig(host="myhost", port=5432, username="myuser", password="mypass") + url = create_database_url_from_config(config, "mydb") + assert url == "postgresql://myuser:mypass@myhost:5432/mydb" + + def test_url_with_custom_port(self): + config = HostConfig(host="myhost", port=5433, username="u", password="p") + url = create_database_url_from_config(config, "testdb") + assert url == "postgresql://u:p@myhost:5433/testdb" + + +# --- get_service --- + + +class TestGetService: + CONFIG = HostConfig(host="db.example.com", port=5432, username="user", password="pass") + + @pytest.mark.asyncio + async def test_creates_service_for_new_database(self): + with ( + patch.object(server_module, "host_configs", {"db.example.com": self.CONFIG}), + patch.object(server_module, "db_services", {}), + patch.object(server_module, "current_access_mode", AccessMode.UNRESTRICTED), + patch.object(server_module, "query_timeout", None), + ): + service = await get_service("mydb") + assert service.database_url == "postgresql://user:pass@db.example.com:5432/mydb" + + @pytest.mark.asyncio + async def test_reuses_service_for_same_host_and_database(self): + with ( + patch.object(server_module, "host_configs", {"db.example.com": self.CONFIG}), + patch.object(server_module, "db_services", {}), + patch.object(server_module, "current_access_mode", AccessMode.UNRESTRICTED), + patch.object(server_module, "query_timeout", None), + ): + service1 = await get_service("mydb") + service2 = await get_service("mydb") + assert service1 is service2 + + @pytest.mark.asyncio + async def test_different_hosts_get_different_services(self): + config_b = HostConfig(host="other.example.com", port=5433, username="u2", password="p2") + configs = {"db.example.com": self.CONFIG, "other.example.com": config_b} + with ( + patch.object(server_module, "host_configs", configs), + patch.object(server_module, "db_services", {}), + patch.object(server_module, "current_access_mode", AccessMode.UNRESTRICTED), + patch.object(server_module, "query_timeout", None), + ): + service_a = await get_service("mydb", "db.example.com") + service_b = await get_service("mydb", "other.example.com") + assert service_a is not service_b + assert "db.example.com" in service_a.database_url + assert "other.example.com" in service_b.database_url