diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 61003ac9f..41b8316a5 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,10 +1,13 @@ from collections.abc import Awaitable, Callable, Mapping -from dataclasses import dataclass -from typing import Any, Generic, Protocol +from dataclasses import dataclass, field +from typing import Any, Generic, Protocol, cast from pydantic import BaseModel from typing_extensions import TypeVar +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.connection import Connection from mcp.server.session import ServerSession from mcp.shared.context import BaseContext @@ -34,9 +37,49 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]): request_id: RequestId | None = None meta: RequestParamsMeta | None = None request: RequestT | None = None + transport: TransportContext = field( + default_factory=lambda: TransportContext(kind="unknown", can_send_request=False) + ) close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None + @property + def session_id(self) -> str | None: + """The transport's session id for this request, when one exists.""" + headers = self.headers + if headers is not None: + header_session_id = headers.get("mcp-session-id") + if header_session_id is not None: + return header_session_id + query_params = getattr(self.request, "query_params", None) + if query_params is None: + return None + return query_params.get("session_id") or query_params.get("sessionId") + + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + HTTP-based transports expose headers through their request object while + direct/in-memory transports may provide them directly on the transport. + """ + if self.transport.headers is not None: + return self.transport.headers + request_headers = getattr(self.request, "headers", None) + if request_headers is None: + return None + return request_headers + + @property + def access_token(self) -> AccessToken | None: + """The OAuth access token for the current request, if authentication ran.""" + scope = getattr(self.request, "scope", None) + typed_scope = cast("Mapping[str, object]", scope) if isinstance(scope, Mapping) else None + user = typed_scope.get("user") if typed_scope is not None else None + if isinstance(user, AuthenticatedUser): + return user.access_token + return get_access_token() + # Covariant: `lifespan` is exposed read-only, so a `Context[AppState]` passes as `Context[object]`. LifespanT_co = TypeVar("LifespanT_co", default=Any, covariant=True) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index d2536189d..f4311d271 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -64,7 +64,7 @@ async def main(): from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -from mcp.shared.message import SessionMessage +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext logger = logging.getLogger(__name__) @@ -80,6 +80,30 @@ async def main(): """A registered notification handler: `(ctx, params) -> None`.""" +def _make_transport_builder( + transport_kind: str | None, + transport_can_send_request: bool | None, +) -> Callable[[MessageMetadata], TransportContext]: + """Build per-message transport metadata from the transport's message wrapper.""" + + def build_transport_context(metadata: MessageMetadata) -> TransportContext: + request = metadata.request_context if isinstance(metadata, ServerMessageMetadata) else None + headers = getattr(request, "headers", None) + query_params = getattr(request, "query_params", None) + + kind = transport_kind + if kind is None and request is not None: + kind = "sse" if query_params is not None and query_params.get("session_id") else "streamable-http" + + return TransportContext( + kind=kind or "jsonrpc", + can_send_request=transport_can_send_request if transport_can_send_request is not None else True, + headers=headers, + ) + + return build_transport_context + + @dataclass(frozen=True, slots=True) class HandlerEntry(Generic[LifespanResultT]): """A registered handler and the params model to validate incoming params against. @@ -406,11 +430,14 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, + transport_kind: str | None = None, + transport_can_send_request: bool | None = None, ) -> None: async with self.lifespan(self) as lifespan_context: dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( read_stream, write_stream, + transport_builder=_make_transport_builder(transport_kind, transport_can_send_request), raise_handler_exceptions=raise_exceptions, # Handle `initialize` inline so a client that pipelines it with # the next request (spec says SHOULD NOT, not MUST NOT) sees diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 92de074d3..73b0374b1 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -1,10 +1,11 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any, Generic from pydantic import AnyUrl, BaseModel +from mcp.server.auth.provider import AccessToken from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext from mcp.server.elicitation import ( ElicitationResult, @@ -14,6 +15,7 @@ elicit_with_validation, ) from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.shared.transport_context import TransportContext from mcp.types import LoggingLevel if TYPE_CHECKING: @@ -228,6 +230,31 @@ def session(self): """Access to the underlying session for advanced usage.""" return self.request_context.session + @property + def transport(self) -> TransportContext: + """Transport-specific metadata for this request.""" + return self.request_context.transport + + @property + def session_id(self) -> str | None: + """The transport's session id for this connection, when one exists.""" + return self.request_context.session_id + + @property + def request(self) -> RequestT | None: + """The HTTP request object for this message, when the transport has one.""" + return self.request_context.request + + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them.""" + return self.request_context.headers + + @property + def access_token(self) -> AccessToken | None: + """The OAuth access token for the current request, if authentication ran.""" + return self.request_context.access_token + async def close_sse_stream(self) -> None: """Close the SSE stream to trigger client reconnection. diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index fdb69571d..c3ce2b89a 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -936,7 +936,10 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no async with sse.connect_sse(scope, receive, send) as streams: await self._lowlevel_server.run( - streams[0], streams[1], self._lowlevel_server.create_initialization_options() + streams[0], + streams[1], + self._lowlevel_server.create_initialization_options(), + transport_kind="sse", ) return Response() diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 9b1037322..a07dd9d9a 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -350,6 +350,7 @@ def _make_context( request_id=dctx.request_id, meta=meta, request=request, + transport=dctx.transport, close_sse_stream=close_sse_stream, close_standalone_sse_stream=close_standalone_sse_stream, ) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index cec170082..7c66eff63 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -178,6 +178,8 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA write_stream, self.app.create_initialization_options(), stateless=True, + transport_kind="streamable-http", + transport_can_send_request=False, ) except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") @@ -268,6 +270,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE write_stream, self.app.create_initialization_options(), stateless=False, + transport_kind="streamable-http", + transport_can_send_request=not self.json_response, ) if idle_scope.cancelled_caught: diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index caed8905d..ab7113d99 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -761,6 +761,22 @@ def __post_init__(self) -> None: source="sdk", behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", ), + "mcpserver:context:transport-metadata": Requirement( + source="issue:#2098", + behavior=( + "Context exposes the current transport metadata, session id, HTTP request, headers, and auth token " + "to tool handlers." + ), + issue="https://github.com/modelcontextprotocol/python-sdk/issues/2098", + ), + "lowlevel:context:transport-metadata": Requirement( + source="issue:#2098", + behavior=( + "ServerRequestContext exposes the current transport metadata, session id, HTTP request, headers, and " + "auth token to low-level handlers." + ), + issue="https://github.com/modelcontextprotocol/python-sdk/issues/2098", + ), # ═══════════════════════════════════════════════════════════════════════════ # Resources # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/interaction/lowlevel/test_context_metadata.py b/tests/interaction/lowlevel/test_context_metadata.py new file mode 100644 index 000000000..d19862203 --- /dev/null +++ b/tests/interaction/lowlevel/test_context_metadata.py @@ -0,0 +1,74 @@ +"""Context transport metadata exposed to low-level server handlers.""" + +import pytest +from starlette.requests import Request + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import auth_context_var, get_access_token +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken +from mcp.types import CallToolResult, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lowlevel:context:transport-metadata") +async def test_lowlevel_context_exposes_transport_metadata(connect: Connect) -> None: + """A low-level handler can read transport/session/auth metadata from context.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect_context", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect_context" + access_token = AccessToken(token="secret", client_id="client-1", scopes=["tools"]) + token = auth_context_var.set(AuthenticatedUser(access_token)) + try: + exposed_token = ctx.access_token + token_matches_helper = exposed_token == get_access_token() + finally: + auth_context_var.reset(token) + request = ctx.request + request_kind = type(request).__name__ if request is not None else "none" + request_path = str(request.url.path) if isinstance(request, Request) else "none" + has_headers = ctx.headers is not None + text = "|".join( + [ + ctx.transport.kind, + ctx.session_id or "none", + request_kind, + request_path, + str(has_headers), + str(token_matches_helper), + exposed_token.client_id if exposed_token is not None else "none", + ] + ) + return CallToolResult(content=[TextContent(text=text)]) + + server = Server("metadata", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("inspect_context", {}) + + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + transport_kind, session_id, request_kind, request_path, has_headers, token_matches_helper, token_client_id = ( + text.split("|") + ) + assert request_kind in {"Request", "none"} + if request_kind == "Request": + assert transport_kind == "sse" if request_path.startswith("/messages/") else "streamable-http" + assert session_id != "none" + assert has_headers == "True" + else: + assert transport_kind == "jsonrpc" + assert session_id == "none" + assert request_path == "none" + assert has_headers == "False" + assert token_matches_helper == "True" + assert token_client_id == "client-1" diff --git a/tests/interaction/mcpserver/test_context_metadata.py b/tests/interaction/mcpserver/test_context_metadata.py new file mode 100644 index 000000000..526f69ced --- /dev/null +++ b/tests/interaction/mcpserver/test_context_metadata.py @@ -0,0 +1,81 @@ +"""Context transport metadata exposed to MCPServer tools.""" + +import pytest +from starlette.requests import Request + +from mcp.server.auth.middleware.auth_context import auth_context_var, get_access_token +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:context:transport-metadata") +async def test_context_exposes_transport_metadata_to_a_tool(connect: Connect) -> None: + """A tool can read transport/session/auth metadata from its injected Context. + + The in-memory leg has no transport session id; HTTP/SSE legs expose the real HTTP request object + and headers. The handler installs an auth token to prove the Context property matches the shared + auth helper inside the same request scope. + """ + mcp = MCPServer("metadata") + + @mcp.tool() + async def inspect_context(ctx: Context) -> str: + access_token = AccessToken(token="secret", client_id="client-1", scopes=["tools"]) + token = auth_context_var.set(AuthenticatedUser(access_token)) + try: + exposed_token = ctx.access_token + token_matches_helper = exposed_token == get_access_token() + finally: + auth_context_var.reset(token) + request = ctx.request + request_kind = type(request).__name__ if request is not None else "none" + request_path = str(request.url.path) if isinstance(request, Request) else "none" + header_value = ctx.headers.get("mcp-protocol-version", "none") if ctx.headers is not None else "none" + has_headers = ctx.headers is not None + return "|".join( + [ + ctx.transport.kind, + ctx.session_id or "none", + request_kind, + request_path, + header_value, + str(has_headers), + str(token_matches_helper), + exposed_token.client_id if exposed_token is not None else "none", + ] + ) + + async with connect(mcp) as client: + result = await client.call_tool("inspect_context", {}) + + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + ( + transport_kind, + session_id, + request_kind, + request_path, + header_value, + has_headers, + token_matches_helper, + token_client_id, + ) = text.split("|") + assert request_kind in {"Request", "none"} + if request_kind == "Request": + assert transport_kind == "sse" if request_path.startswith("/messages/") else "streamable-http" + assert session_id != "none" + assert has_headers == "True" + else: + assert transport_kind == "jsonrpc" + assert session_id == "none" + assert request_path == "none" + assert has_headers == "False" + assert header_value == "none" + assert token_matches_helper == "True" + assert token_client_id == "client-1" diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index 01e96ff37..394cba548 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -7,13 +7,15 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import Any +from typing import Any, cast import anyio import pytest +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.connection import Connection -from mcp.server.context import Context +from mcp.server.context import Context, ServerRequestContext from mcp.shared.dispatcher import DispatchContext from mcp.shared.transport_context import TransportContext @@ -28,6 +30,40 @@ class _Lifespan: name: str +@dataclass +class _RequestWithHeaders: + headers: Mapping[str, str] + + +@dataclass +class _RequestWithUser: + scope: Mapping[str, Any] + + +def test_server_request_context_reads_headers_from_request_object(): + ctx = ServerRequestContext( + session=cast(Any, object()), + lifespan_context={}, + request=_RequestWithHeaders({"x-test": "present"}), + transport=TransportContext(kind="jsonrpc", can_send_request=True), + ) + + assert ctx.headers == {"x-test": "present"} + assert ctx.session_id is None + + +def test_server_request_context_reads_access_token_from_request_user(): + access_token = AccessToken(token="secret", client_id="client-1", scopes=["tools"]) + ctx = ServerRequestContext( + session=cast(Any, object()), + lifespan_context={}, + request=_RequestWithUser({"user": AuthenticatedUser(access_token)}), + transport=TransportContext(kind="streamable-http", can_send_request=True), + ) + + assert ctx.access_token == access_token + + @pytest.mark.anyio async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): captured: list[Context[_Lifespan]] = []