diff --git a/AI_CHANGE_CHECKLIST.md b/AI_CHANGE_CHECKLIST.md index 4b2fd94..ebb109e 100644 --- a/AI_CHANGE_CHECKLIST.md +++ b/AI_CHANGE_CHECKLIST.md @@ -70,3 +70,11 @@ Run all of: - `make test` Do not conclude work until all pass, or explicitly report blockers. + +## 9) MCP Extension Changes + +When adding MCP-facing capabilities: +- Place MCP implementation under `app/domains//` (for this repo: `app/domains/mcp_server/`). +- Keep tool registration grouped by capability scopes so new routers can be mapped without editing core platform files. +- Ensure API client wrappers normalize downstream failures to RFC 7807-like error documents. +- Document standalone runtime scripts and required environment variables. diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index c7a568e..ce75383 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -301,3 +301,16 @@ Worker scaffold: - `property`: property-based invariants Coverage is enforced at 100% line + branch for all modules under `app/`. + +## MCP Extension Domain + +- Package: `app/domains/mcp_server/` +- Purpose: expose grouped MCP tools/resources (`auth`, `users`, `posts`, `vote`) mapped to existing `/api/v1/*` routes. +- Standalone runtime: `fastapi-template-mcp` (script entrypoint in `pyproject.toml`) starts a dedicated FastAPI process that serves: + - `GET /mcp/tools` + - `GET /mcp/resources` + - `POST /mcp/tools/call` + - `GET /mcp/config` + - `GET /mcp/health` +- API calls are delegated through a typed async HTTP client with normalized problem-document errors compatible with RFC 7807-like payloads. +- MCP tool execution logging includes action-level IDs, session IDs, source metadata, and outbound API call logs for audit/debug timelines. diff --git a/README.md b/README.md index 90218c6..a3ae1cd 100644 --- a/README.md +++ b/README.md @@ -144,3 +144,25 @@ Kubernetes examples: - Architecture: `ARCHITECTURE.md` - AI checklist: `AI_CHANGE_CHECKLIST.md` - AI agent instructions: `AGENTS.md` + +## MCP Server Extension (Standalone) + +A new extension package is available at `app/domains/mcp_server/` to expose template API capabilities as MCP-style tools/resources without modifying platform core files. + +Run standalone MCP server: + +```bash +fastapi-template-mcp +``` + +Environment variables (prefix `MCP_SERVER_`): +- `MCP_SERVER_HOST` +- `MCP_SERVER_PORT` +- `MCP_SERVER_BASE_URL` +- `MCP_SERVER_AUTH_MODE` (`none` or `bearer`) +- `MCP_SERVER_TIMEOUT_SECONDS` +- `MCP_SERVER_ALLOWED_TOOL_SCOPES` (CSV list, e.g. `auth,users,posts,vote`) + +Default tool groups are mapped from current `/api/v1` endpoints using existing router capability tags (`Authentication`, `Users`, `Post`, `Vote`). + +MCP interactions are logged with per-action IDs and session IDs to support traceability for debugging and security reviews. diff --git a/app/domains/mcp_server/__init__.py b/app/domains/mcp_server/__init__.py new file mode 100644 index 0000000..7a9f7c8 --- /dev/null +++ b/app/domains/mcp_server/__init__.py @@ -0,0 +1,5 @@ +"""MCP server domain package.""" + +from .server import MCPServer, create_mcp_app + +__all__ = ["MCPServer", "create_mcp_app"] diff --git a/app/domains/mcp_server/cli.py b/app/domains/mcp_server/cli.py new file mode 100644 index 0000000..76c62f1 --- /dev/null +++ b/app/domains/mcp_server/cli.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import uvicorn + +from .config import MCPServerSettings +from .server import create_mcp_app + + +def main() -> None: + settings = MCPServerSettings() + app = create_mcp_app(settings) + uvicorn.run(app, host=settings.host, port=settings.port) + + +if __name__ == "__main__": + main() diff --git a/app/domains/mcp_server/client.py b/app/domains/mcp_server/client.py new file mode 100644 index 0000000..7c93e90 --- /dev/null +++ b/app/domains/mcp_server/client.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import logging +from typing import Any, TypeVar + +import httpx +from pydantic import BaseModel + +from .problem import APIClientError, ProblemDocument, normalized_problem + +ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel) +logger = logging.getLogger(__name__) + + +class APIClient: + def __init__( + self, + *, + base_url: str, + timeout_seconds: float, + bearer_token: str | None = None, + ) -> None: + headers: dict[str, str] = {} + if bearer_token: + headers["Authorization"] = f"Bearer {bearer_token}" + self._client = httpx.AsyncClient( + base_url=base_url, + timeout=timeout_seconds, + headers=headers, + ) + + async def close(self) -> None: + await self._client.aclose() + + async def request( + self, + *, + method: str, + path: str, + response_model: type[ResponseModelT], + json_body: BaseModel | dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + action_id: str | None = None, + session_id: str | None = None, + ) -> ResponseModelT: + payload = json_body.model_dump() if isinstance(json_body, BaseModel) else json_body + logger.info( + "mcp_outbound_request", + extra={ + "action_id": action_id, + "session_id": session_id, + "method": method, + "path": path, + }, + ) + response = await self._client.request( + method=method, + url=path, + json=payload, + params=params, + headers={ + "X-MCP-Action-ID": action_id or "", + "X-MCP-Session-ID": session_id or "", + }, + ) + if response.is_error: + self._raise_api_error( + path=path, + response=response, + action_id=action_id, + session_id=session_id, + ) + logger.info( + "mcp_outbound_response", + extra={ + "action_id": action_id, + "session_id": session_id, + "status_code": response.status_code, + "path": path, + }, + ) + return response_model.model_validate(response.json()) + + @staticmethod + def _raise_api_error( + *, + path: str, + response: httpx.Response, + action_id: str | None = None, + session_id: str | None = None, + ) -> None: + data = response.json() if response.content else {} + if isinstance(data, dict) and all( + key in data for key in ("title", "status", "detail", "error_code") + ): + problem = ProblemDocument.model_validate(data) + else: + detail = str(data.get("detail", response.text or "Request failed")) if isinstance(data, dict) else ( + response.text or "Request failed" + ) + problem = normalized_problem( + status=response.status_code, + detail=detail, + instance=path, + error_code="request_failed", + ) + logger.warning( + "mcp_outbound_error", + extra={ + "action_id": action_id, + "session_id": session_id, + "status": problem.status, + "error_code": problem.error_code, + "instance": problem.instance, + }, + ) + raise APIClientError(problem) diff --git a/app/domains/mcp_server/config.py b/app/domains/mcp_server/config.py new file mode 100644 index 0000000..0ff2624 --- /dev/null +++ b/app/domains/mcp_server/config.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from enum import Enum + +from pydantic import AnyHttpUrl, Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class MCPAuthMode(str, Enum): + NONE = "none" + BEARER = "bearer" + + +class MCPServerSettings(BaseSettings): + host: str = "127.0.0.1" + port: int = 8765 + base_url: AnyHttpUrl = "http://127.0.0.1:8000" + auth_mode: MCPAuthMode = MCPAuthMode.BEARER + timeout_seconds: float = 10.0 + allowed_tool_scopes: list[str] = Field( + default_factory=lambda: ["auth", "users", "posts", "vote"] + ) + + model_config = SettingsConfigDict( + env_prefix="MCP_SERVER_", + env_file=".env", + case_sensitive=False, + extra="ignore", + ) + + @field_validator("timeout_seconds") + @classmethod + def validate_timeout(cls, value: float) -> float: + if value <= 0: + raise ValueError("MCP_SERVER_TIMEOUT_SECONDS must be > 0") + return value + + @field_validator("allowed_tool_scopes", mode="before") + @classmethod + def parse_scopes(cls, value: object) -> list[str]: + if isinstance(value, str): + return [scope.strip() for scope in value.split(",") if scope.strip()] + if isinstance(value, list): + return [str(scope).strip() for scope in value if str(scope).strip()] + raise ValueError("MCP_SERVER_ALLOWED_TOOL_SCOPES must be a list or csv string") + + +settings = MCPServerSettings() diff --git a/app/domains/mcp_server/problem.py b/app/domains/mcp_server/problem.py new file mode 100644 index 0000000..1951edd --- /dev/null +++ b/app/domains/mcp_server/problem.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from http import HTTPStatus + +from pydantic import BaseModel + + +class ProblemDocument(BaseModel): + type: str = "about:blank" + title: str + status: int + detail: str + instance: str = "" + error_code: str = "request_failed" + + +class APIClientError(Exception): + def __init__(self, problem: ProblemDocument): + self.problem = problem + super().__init__(problem.detail) + + +def normalized_problem(*, status: int, detail: str, instance: str, error_code: str) -> ProblemDocument: + try: + title = HTTPStatus(status).phrase + except ValueError: + title = "Error" + return ProblemDocument( + title=title, + status=status, + detail=detail, + instance=instance, + error_code=error_code, + ) diff --git a/app/domains/mcp_server/server.py b/app/domains/mcp_server/server.py new file mode 100644 index 0000000..e9b4ecf --- /dev/null +++ b/app/domains/mcp_server/server.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +from collections.abc import Iterable +from contextlib import asynccontextmanager +import logging +from typing import Any +from uuid import uuid4 + +from fastapi import FastAPI, Request +from pydantic import BaseModel, ConfigDict, Field + +from .client import APIClient +from .config import MCPAuthMode, MCPServerSettings +from .problem import APIClientError, ProblemDocument +from .tools import ResourceDefinition, build_resources, build_tool_registry + +logger = logging.getLogger(__name__) + + +class ToolCallRequest(BaseModel): + tool_name: str + args: dict[str, Any] = Field(default_factory=dict) + session_id: str | None = None + source: str | None = None + + +class ToolCallResponse(BaseModel): + ok: bool + data: dict[str, Any] | list[dict[str, Any]] | None = None + problem: ProblemDocument | None = None + action_id: str + session_id: str + + +class MCPTool(BaseModel): + model_config = ConfigDict(from_attributes=True) + + name: str + description: str + scope: str + method: str + path: str + + +class MCPResource(BaseModel): + model_config = ConfigDict(from_attributes=True) + + uri: str + name: str + description: str + + +class MCPServer: + def __init__(self, settings: MCPServerSettings) -> None: + registry = build_tool_registry() + self._allowed_scopes = set(settings.allowed_tool_scopes) + self._tools = { + tool.name: tool + for scope, tools in registry.items() + if scope in self._allowed_scopes + for tool in tools + } + self._resources = build_resources() + self._client = APIClient( + base_url=str(settings.base_url), + timeout_seconds=settings.timeout_seconds, + ) + + def list_tools(self) -> list[MCPTool]: + return [ + MCPTool( + name=tool.name, + description=tool.description, + scope=tool.scope.value, + method=tool.method, + path=tool.path, + ) + for tool in self._tools.values() + ] + + def list_resources(self) -> list[MCPResource]: + return [MCPResource.model_validate(resource) for resource in self._resources] + + async def close(self) -> None: + await self._client.close() + + async def call_tool(self, *, tool_name: str, args: dict[str, Any]) -> ToolCallResponse: + action_id = str(uuid4()) + session_id = str(args.pop("session_id", "") or uuid4()) + source = str(args.pop("source", "unknown")) + logger.info( + "mcp_tool_invocation_started", + extra={ + "action_id": action_id, + "session_id": session_id, + "tool_name": tool_name, + "source": source, + }, + ) + tool = self._tools.get(tool_name) + if not tool: + logger.warning( + "mcp_tool_not_found", + extra={ + "action_id": action_id, + "session_id": session_id, + "tool_name": tool_name, + "source": source, + }, + ) + return ToolCallResponse( + ok=False, + problem=ProblemDocument( + title="Not Found", + status=404, + detail=f"Tool '{tool_name}' is not registered", + instance=f"tool:{tool_name}", + error_code="mcp_tool_not_found", + ), + action_id=action_id, + session_id=session_id, + ) + + try: + path = tool.path.format(**args) + except KeyError as exc: + logger.warning( + "mcp_tool_missing_argument", + extra={ + "action_id": action_id, + "session_id": session_id, + "tool_name": tool_name, + "source": source, + "argument": exc.args[0], + }, + ) + return ToolCallResponse( + ok=False, + problem=ProblemDocument( + title="Bad Request", + status=400, + detail=f"Missing tool argument: {exc.args[0]}", + instance=f"tool:{tool_name}", + error_code="mcp_missing_argument", + ), + action_id=action_id, + session_id=session_id, + ) + + query = args.get("query") if isinstance(args.get("query"), dict) else None + body = args.get("body") if isinstance(args.get("body"), dict) else None + + try: + result = await self._client.request( + method=tool.method, + path=path, + params=query, + json_body=body, + response_model=FreeFormResponse, + action_id=action_id, + session_id=session_id, + ) + logger.info( + "mcp_tool_invocation_succeeded", + extra={ + "action_id": action_id, + "session_id": session_id, + "tool_name": tool_name, + "source": source, + }, + ) + return ToolCallResponse( + ok=True, + data=result.data, + action_id=action_id, + session_id=session_id, + ) + except APIClientError as exc: + logger.warning( + "mcp_tool_invocation_failed", + extra={ + "action_id": action_id, + "session_id": session_id, + "tool_name": tool_name, + "source": source, + "error_code": exc.problem.error_code, + "status": exc.problem.status, + }, + ) + return ToolCallResponse( + ok=False, + problem=exc.problem, + action_id=action_id, + session_id=session_id, + ) + + +class FreeFormResponse(BaseModel): + data: dict[str, Any] | list[dict[str, Any]] | None = None + + @classmethod + def model_validate(cls, obj: Any, *args: Any, **kwargs: Any) -> "FreeFormResponse": + return super().model_validate({"data": obj}, *args, **kwargs) + + +def _serialize_tools(server: MCPServer) -> list[dict[str, Any]]: + return [tool.model_dump() for tool in server.list_tools()] + + +def _serialize_resources(resources: Iterable[ResourceDefinition]) -> list[dict[str, Any]]: + return [MCPResource.model_validate(resource).model_dump() for resource in resources] + + +def create_mcp_app(settings: MCPServerSettings | None = None) -> FastAPI: + cfg = settings or MCPServerSettings() + mcp_server = MCPServer(cfg) + + @asynccontextmanager + async def lifespan(_: FastAPI): + yield + await mcp_server.close() + + app = FastAPI(title="fastapi-template-mcp", lifespan=lifespan) + + @app.get("/mcp/tools") + async def list_tools(request: Request) -> list[dict[str, Any]]: + logger.info( + "mcp_tools_listed", + extra={ + "source_ip": request.client.host if request.client else "unknown", + "path": request.url.path, + }, + ) + return _serialize_tools(mcp_server) + + @app.get("/mcp/resources") + async def list_resources(request: Request) -> list[dict[str, Any]]: + logger.info( + "mcp_resources_listed", + extra={ + "source_ip": request.client.host if request.client else "unknown", + "path": request.url.path, + }, + ) + return [resource.model_dump() for resource in mcp_server.list_resources()] + + @app.post("/mcp/tools/call", response_model=ToolCallResponse) + async def call_tool(payload: ToolCallRequest, request: Request) -> ToolCallResponse: + args = { + **payload.args, + "session_id": payload.session_id, + "source": payload.source or (request.client.host if request.client else "unknown"), + } + return await mcp_server.call_tool(tool_name=payload.tool_name, args=args) + + @app.get("/mcp/config") + async def show_config() -> dict[str, Any]: + auth_required = cfg.auth_mode == MCPAuthMode.BEARER + return { + "base_url": str(cfg.base_url), + "auth_mode": cfg.auth_mode.value, + "auth_required": auth_required, + "timeout_seconds": cfg.timeout_seconds, + "allowed_tool_scopes": sorted(cfg.allowed_tool_scopes), + } + + @app.get("/mcp/health") + async def mcp_health() -> dict[str, str]: + return {"status": "ok"} + + return app diff --git a/app/domains/mcp_server/tools.py b/app/domains/mcp_server/tools.py new file mode 100644 index 0000000..1bfb6c3 --- /dev/null +++ b/app/domains/mcp_server/tools.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + + +class ToolScope(str, Enum): + AUTH = "auth" + USERS = "users" + POSTS = "posts" + VOTE = "vote" + + +@dataclass(frozen=True) +class ToolDefinition: + name: str + description: str + scope: ToolScope + method: str + path: str + + +@dataclass(frozen=True) +class ResourceDefinition: + uri: str + name: str + description: str + + +def build_tool_registry() -> dict[str, list[ToolDefinition]]: + """Grouped registry, anchored to /api/v1 and existing router tags.""" + return { + ToolScope.AUTH.value: [ + ToolDefinition( + name="auth.login", + description="Authentication: exchange username/password for token pair.", + scope=ToolScope.AUTH, + method="POST", + path="/api/v1/login", + ), + ToolDefinition( + name="auth.providers", + description="Authentication: list configured OAuth providers.", + scope=ToolScope.AUTH, + method="GET", + path="/api/v1/auth/oauth/providers", + ), + ], + ToolScope.USERS.value: [ + ToolDefinition( + name="users.list", + description="Users: list users from /api/v1/users endpoint.", + scope=ToolScope.USERS, + method="GET", + path="/api/v1/users", + ), + ToolDefinition( + name="users.get", + description="Users: fetch a user by id from /api/v1/users/{user_id}.", + scope=ToolScope.USERS, + method="GET", + path="/api/v1/users/{user_id}", + ), + ], + ToolScope.POSTS.value: [ + ToolDefinition( + name="posts.list", + description="Posts: list posts using /api/v1/posts.", + scope=ToolScope.POSTS, + method="GET", + path="/api/v1/posts", + ), + ToolDefinition( + name="posts.get", + description="Posts: fetch a post by id using /api/v1/posts/{post_id}.", + scope=ToolScope.POSTS, + method="GET", + path="/api/v1/posts/{post_id}", + ), + ], + ToolScope.VOTE.value: [ + ToolDefinition( + name="vote.create", + description="Vote: submit up/down vote through /api/v1/vote.", + scope=ToolScope.VOTE, + method="POST", + path="/api/v1/vote/", + ) + ], + } + + +def build_resources() -> list[ResourceDefinition]: + return [ + ResourceDefinition( + uri="resource://fastapi/openapi", + name="openapi", + description="OpenAPI schema at /openapi.json.", + ), + ResourceDefinition( + uri="resource://fastapi/health", + name="health", + description="Health check summary from /health.", + ), + ] diff --git a/pyproject.toml b/pyproject.toml index bea7ef0..59bf9bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,3 +57,6 @@ tests_dir = [ ] do_not_mutate = ["*/__init__.py"] also_copy = ["app"] + +[project.scripts] +fastapi-template-mcp = "app.domains.mcp_server.cli:main" diff --git a/tests/unit/test_mcp_client_unit.py b/tests/unit/test_mcp_client_unit.py new file mode 100644 index 0000000..42adcf7 --- /dev/null +++ b/tests/unit/test_mcp_client_unit.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import pytest +from httpx import Request, Response + +from app.domains.mcp_server.client import APIClient +from app.domains.mcp_server.problem import APIClientError + +pytestmark = pytest.mark.unit + + +def test_raise_api_error_uses_problem_document_payload(): + response = Response( + 400, + json={ + "title": "Bad Request", + "status": 400, + "detail": "broken", + "instance": "/api/v1/users", + "error_code": "bad_request", + }, + request=Request("GET", "http://testserver/api/v1/users"), + ) + + with pytest.raises(APIClientError) as exc_info: + APIClient._raise_api_error(path="/api/v1/users", response=response) + + assert exc_info.value.problem.error_code == "bad_request" + + +def test_raise_api_error_normalizes_non_problem_payload(): + response = Response( + 500, + json={"detail": "boom"}, + request=Request("GET", "http://testserver/api/v1/users"), + ) + + with pytest.raises(APIClientError) as exc_info: + APIClient._raise_api_error(path="/api/v1/users", response=response) + + assert exc_info.value.problem.error_code == "request_failed" + assert exc_info.value.problem.status == 500 diff --git a/tests/unit/test_mcp_server_unit.py b/tests/unit/test_mcp_server_unit.py new file mode 100644 index 0000000..0661f54 --- /dev/null +++ b/tests/unit/test_mcp_server_unit.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import asyncio + +from fastapi.testclient import TestClient + +from app.domains.mcp_server.client import APIClient +from app.domains.mcp_server.config import MCPServerSettings +from app.domains.mcp_server.problem import APIClientError, ProblemDocument +from app.domains.mcp_server.server import MCPServer, create_mcp_app +from app.domains.mcp_server.tools import build_resources, build_tool_registry + + +def test_mcp_server_settings_parse_csv_scopes(): + settings = MCPServerSettings(allowed_tool_scopes="auth, users") + assert settings.allowed_tool_scopes == ["auth", "users"] + + +def test_mcp_server_list_tools_filters_by_scope(): + settings = MCPServerSettings(allowed_tool_scopes=["auth"]) + server = MCPServer(settings) + tools = server.list_tools() + assert {tool.scope for tool in tools} == {"auth"} + asyncio.run(server.close()) + + +def test_tool_registry_contains_api_v1_anchors(): + registry = build_tool_registry() + assert registry["auth"][0].path.startswith("/api/v1") + assert registry["users"][0].path.startswith("/api/v1") + + +def test_resources_include_openapi_and_health(): + resources = build_resources() + assert {resource.name for resource in resources} == {"openapi", "health"} + + +def test_create_mcp_app_has_core_endpoints(): + app = create_mcp_app(MCPServerSettings(allowed_tool_scopes=["auth"])) + client = TestClient(app) + + tools_response = client.get("/mcp/tools") + config_response = client.get("/mcp/config") + + assert tools_response.status_code == 200 + assert config_response.status_code == 200 + assert config_response.json()["allowed_tool_scopes"] == ["auth"] + + +def test_call_tool_returns_problem_when_tool_missing(): + settings = MCPServerSettings(allowed_tool_scopes=["auth"]) + server = MCPServer(settings) + + result = asyncio.run(server.call_tool(tool_name="users.list", args={})) + + assert result.ok is False + assert result.problem is not None + assert result.problem.error_code == "mcp_tool_not_found" + assert result.action_id + assert result.session_id + asyncio.run(server.close()) + + +def test_call_tool_surfaces_normalized_problem(monkeypatch): + settings = MCPServerSettings(allowed_tool_scopes=["auth"]) + server = MCPServer(settings) + + async def _raise_error(*args, **kwargs): + raise APIClientError( + ProblemDocument( + title="Unauthorized", + status=401, + detail="Invalid token", + instance="/api/v1/login", + error_code="auth_invalid", + ) + ) + + monkeypatch.setattr(APIClient, "request", _raise_error) + result = asyncio.run( + server.call_tool( + tool_name="auth.login", + args={"body": {}, "session_id": "session-123", "source": "test-suite"}, + ) + ) + + assert result.ok is False + assert result.problem is not None + assert result.problem.status == 401 + assert result.session_id == "session-123" + assert result.action_id + asyncio.run(server.close())