Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/postgres_mcp/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from enum import Enum


Expand All @@ -6,3 +7,13 @@ class AccessMode(str, Enum):

UNRESTRICTED = "unrestricted" # Unrestricted access
RESTRICTED = "restricted" # Read-only with safety features


@dataclass(frozen=True)
class HostConfig:
"""Connection details for a single database host."""

host: str
port: int
username: str
password: str
151 changes: 104 additions & 47 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from starlette.responses import Response

from postgres_mcp.models import AccessMode
from postgres_mcp.models import HostConfig

from .database_health import HealthType
from .database_service import DatabaseService
Expand All @@ -37,19 +38,67 @@
# Global variables
db_services: dict[str, DatabaseService] = {}
current_access_mode = AccessMode.UNRESTRICTED
database_host: Optional[str] = None
database_port: Optional[str] = None
database_username: Optional[str] = None
database_password: Optional[str] = None
host_configs: dict[str, HostConfig] = {}
query_timeout: Optional[float] = None
shutdown_in_progress = False


async def get_service(database_name: str) -> DatabaseService:
if database_name not in db_services:
database_url = create_database_url(database_name)
db_services[database_name] = DatabaseService(database_url, current_access_mode, query_timeout)
return db_services[database_name]
def parse_host_configs_from_env() -> dict[str, HostConfig]:
"""Parse DATABASES__N__HOST/PORT/USERNAME/PASSWORD environment variables into HostConfigs."""
configs: dict[str, HostConfig] = {}
indices: set[str] = set()
for key in os.environ:
if key.startswith("DATABASES__") and key.endswith("__HOST"):
parts = key.split("__")
if len(parts) == 3:
indices.add(parts[1])

for idx in sorted(indices):
host = os.environ.get(f"DATABASES__{idx}__HOST")
if not host:
continue
port = int(os.environ.get(f"DATABASES__{idx}__PORT", "5432"))
username = os.environ.get(f"DATABASES__{idx}__USERNAME")
password = os.environ.get(f"DATABASES__{idx}__PASSWORD")
if not username or not password:
raise ValueError(f"DATABASES__{idx}__USERNAME and DATABASES__{idx}__PASSWORD must both be set when DATABASES__{idx}__HOST is configured")
configs[host] = HostConfig(host=host, port=port, username=username, password=password)

return configs


def resolve_host_config(host: Optional[str] = None) -> HostConfig:
"""Resolve the HostConfig for the given host name.

When host is None, returns the single configured host or raises if multiple are configured.
"""
if host is not None:
if host not in host_configs:
available = ", ".join(sorted(host_configs.keys()))
raise ValueError(f"No configuration found for host '{host}'. Available hosts: {available}")
return host_configs[host]

if len(host_configs) == 1:
return next(iter(host_configs.values()))

if len(host_configs) == 0:
raise ValueError("No database host configured. Set DATABASE_HOST or DATABASES__N__HOST environment variables.")

available = ", ".join(sorted(host_configs.keys()))
raise ValueError(f"Multiple database hosts are configured ({available}). The 'host' parameter is required when multiple hosts are configured.")


def create_database_url_from_config(config: HostConfig, database_name: str) -> str:
return f"postgresql://{config.username}:{config.password}@{config.host}:{config.port}/{database_name}"


async def get_service(database_name: str, host: Optional[str] = None) -> DatabaseService:
config = resolve_host_config(host)
service_key = f"{config.host}:{config.port}/{database_name}"
if service_key not in db_services:
database_url = create_database_url_from_config(config, database_name)
db_services[service_key] = DatabaseService(database_url, current_access_mode, query_timeout)
return db_services[service_key]


@mcp.custom_route("/health", methods=["GET"])
Expand All @@ -58,8 +107,11 @@ async def health_check(request: Request) -> Response:


@mcp.tool(description="List all schemas in the database")
async def list_schemas(database_name: str = Field(description="Database name")) -> ResponseType:
service = await get_service(database_name)
async def list_schemas(
database_name: str = Field(description="Database name"),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name, host)
return await service.list_schemas()


Expand All @@ -68,8 +120,9 @@ async def list_objects(
database_name: str = Field(description="Database name"),
schema_name: str = Field(description="Schema name"),
object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name)
service = await get_service(database_name, host)
return await service.list_objects(schema_name, object_type)


Expand All @@ -79,8 +132,9 @@ async def get_object_details(
schema_name: str = Field(description="Schema name"),
object_name: str = Field(description="Object name"),
object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name)
service = await get_service(database_name, host)
return await service.get_object_details(schema_name, object_name, object_type)


Expand All @@ -106,17 +160,19 @@ async def explain_query(
If there is no hypothetical index, you can pass an empty list.""",
default=[],
),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name)
service = await get_service(database_name, host)
return await service.explain_query(sql, analyze, hypothetical_indexes)


# Query function declaration without the decorator - we'll add it dynamically based on access mode
async def execute_sql(
database_name: str = Field(description="Database name"),
sql: str = Field(description="SQL to run", default="all"),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name)
service = await get_service(database_name, host)
return await service.execute_sql(sql)


Expand All @@ -126,8 +182,9 @@ async def analyze_workload_indexes(
database_name: str = Field(description="Database name"),
max_index_size_mb: int = Field(description="Max index size in MB", default=10000),
method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name)
service = await get_service(database_name, host)
return await service.analyze_workload_indexes(max_index_size_mb, method)


Expand All @@ -138,8 +195,9 @@ async def analyze_query_indexes(
queries: list[str] = Field(description="List of Query strings to analyze"),
max_index_size_mb: int = Field(description="Max index size in MB", default=10000),
method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name)
service = await get_service(database_name, host)
return await service.analyze_query_indexes(queries, max_index_size_mb, method)


Expand All @@ -161,8 +219,9 @@ async def analyze_db_health(
description=f"Optional. Valid values are: {', '.join(sorted([t.value for t in HealthType]))}.",
default="all",
),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name)
service = await get_service(database_name, host)
return await service.analyze_db_health(health_type)


Expand All @@ -178,8 +237,9 @@ async def get_top_queries(
default="resources",
),
limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10),
host: Optional[str] = Field(description="Database host name or address. Only provide when explicitly requested.", default=None),
) -> ResponseType:
service = await get_service(database_name)
service = await get_service(database_name, host)
return await service.get_top_queries(sort_by, limit)


Expand Down Expand Up @@ -242,34 +302,46 @@ async def main():

# Store the access mode in the global variable
global current_access_mode
global database_host
global database_port
global database_username
global database_password
global host_configs
global query_timeout

current_access_mode = AccessMode(args.access_mode)
database_host = os.environ.get("DATABASE_HOST", args.database_host)
database_port = os.environ.get("DATABASE_PORT", args.database_port)

raw_query_timeout = os.environ.get("QUERY_TIMEOUT")
query_timeout = float(raw_query_timeout) if raw_query_timeout is not None else None

if args.database_creds_file:
database_username, database_password = read_database_creds(args.database_creds_file)
else:
database_username = os.environ.get("DATABASE_USERNAME", args.database_username)
database_password = os.environ.get("DATABASE_PASSWORD", args.database_password)
# Build host configs from multi-host env vars (DATABASES__N__*)
host_configs = parse_host_configs_from_env()

# Also support the legacy single-host configuration (CLI args + single env vars)
legacy_host = os.environ.get("DATABASE_HOST", args.database_host)
if legacy_host:
legacy_port = int(os.environ.get("DATABASE_PORT", str(args.database_port)))
if args.database_creds_file:
legacy_username, legacy_password = read_database_creds(args.database_creds_file)
else:
legacy_username = os.environ.get("DATABASE_USERNAME", args.database_username)
legacy_password = os.environ.get("DATABASE_PASSWORD", args.database_password)

if legacy_username and legacy_password and legacy_host not in host_configs:
host_configs[legacy_host] = HostConfig(
host=legacy_host,
port=legacy_port,
username=legacy_username,
password=legacy_password,
)

# Add the query tool with a description appropriate to the access mode
if current_access_mode == AccessMode.UNRESTRICTED:
mcp.add_tool(execute_sql, description="Execute any SQL query")
else:
mcp.add_tool(execute_sql, description="Execute a read-only SQL query")

configured_hosts = ", ".join(sorted(host_configs.keys())) or "(none)"
logger.info(
f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode. transport={args.transport}, database_host={database_host}, "
f"database_port={database_port}, database_username={database_username}, sse_host={args.sse_host}, sse_port={args.sse_port}"
f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode. "
f"transport={args.transport}, configured_hosts=[{configured_hosts}], "
f"sse_host={args.sse_host}, sse_port={args.sse_port}"
)

# Set up proper shutdown handling
Expand Down Expand Up @@ -306,21 +378,6 @@ def read_database_creds(creds_file: str) -> tuple[str, str]:
raise


def create_database_url(database_name: str) -> str:
if not database_host:
raise ValueError("Database host must be specified via command line argument or DATABASE_HOST environment variable")
if not database_username:
raise ValueError(
"Database username must be specified via command line argument, DATABASE_USERNAME environment variable or a database credentials file"
)
if not database_password:
raise ValueError(
"Database password must be specified via command line argument, DATABASE_PASSWORD environment variable or a database credentials file"
)

return f"postgresql://{database_username}:{database_password}@{database_host}:{database_port}/{database_name}"


async def shutdown(sig=None):
"""Clean shutdown of the server."""
global shutdown_in_progress
Expand Down
Loading