Skip to content

Commit 9bdd153

Browse files
committed
fix: ServerRunner validates spec methods against ClientRequest before lookup
Parity with BaseSession._receive_loop: a spec method with malformed params surfaces as INVALID_PARAMS via the dispatcher's ValidationError boundary even when no handler is registered (the existing server validates against the discriminated union before any handler lookup). Gated on the set of spec method names (derived from the ClientRequest union discriminator) so custom methods registered via add_request_handler still route. The existing server rejects those too, but nothing pins that and routing them is strictly better. DirectDispatcher gains the same ValidationError -> INVALID_PARAMS mapping JSONRPCDispatcher has, so runner-over-direct unit tests see the same shape.
1 parent 130e160 commit 9bdd153

3 files changed

Lines changed: 54 additions & 4 deletions

File tree

src/mcp/server/runner.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from collections.abc import Mapping
2020
from dataclasses import dataclass, field
2121
from functools import partial, reduce
22-
from typing import Any, Generic, cast
22+
from typing import Any, Generic, cast, get_args
2323

2424
import anyio.abc
2525
from opentelemetry.trace import SpanKind, StatusCode
@@ -37,9 +37,11 @@
3737
INVALID_PARAMS,
3838
LATEST_PROTOCOL_VERSION,
3939
METHOD_NOT_FOUND,
40+
ClientRequest,
4041
Implementation,
4142
InitializeRequestParams,
4243
InitializeResult,
44+
client_request_adapter,
4345
)
4446

4547
__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"]
@@ -51,6 +53,13 @@
5153

5254
_INIT_EXEMPT: frozenset[str] = frozenset({"ping"})
5355

56+
_SPEC_CLIENT_METHODS: frozenset[str] = frozenset(
57+
cast(type[BaseModel], arm).model_fields["method"].default for arm in get_args(ClientRequest)
58+
)
59+
"""Method names in the spec `ClientRequest` union, derived from the
60+
discriminator literal on each arm. Used to gate upfront validation so custom
61+
methods registered via `add_request_handler` are not rejected."""
62+
5463

5564
def otel_middleware(next_on_request: OnRequest) -> OnRequest:
5665
"""Dispatch-tier middleware that wraps each request in an OpenTelemetry span.
@@ -161,6 +170,20 @@ async def _on_request(
161170
method: str,
162171
params: Mapping[str, Any] | None,
163172
) -> dict[str, Any]:
173+
# TODO(maxisbey): pinned compat. `BaseSession._receive_loop` validates
174+
# every inbound request against the spec `ClientRequest` discriminated
175+
# union *before* handler lookup, so a spec method with malformed params
176+
# surfaces as INVALID_PARAMS via the dispatcher's ValidationError
177+
# boundary even when no handler is registered. v2 wanted to decouple
178+
# the runner from the spec union; revisit once the suite's divergence
179+
# entry is resolved. Gated on spec methods so custom methods registered
180+
# via `add_request_handler` still route (the existing server rejects
181+
# those too, but nothing pins that and routing them is strictly better).
182+
if method in _SPEC_CLIENT_METHODS:
183+
payload: dict[str, Any] = {"method": method}
184+
if params is not None:
185+
payload["params"] = dict(params)
186+
client_request_adapter.validate_python(payload)
164187
if method == "initialize":
165188
return self._handle_initialize(params)
166189
if not self._initialized and method not in _INIT_EXEMPT:

src/mcp/shared/direct_dispatcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121

2222
import anyio
2323
import anyio.abc
24+
from pydantic import ValidationError
2425

2526
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT
2627
from mcp.shared.exceptions import MCPError, NoBackChannelError
2728
from mcp.shared.message import MessageMetadata
2829
from mcp.shared.transport_context import TransportContext
29-
from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT, RequestId
30+
from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, REQUEST_TIMEOUT, RequestId
3031

3132
__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"]
3233

@@ -149,6 +150,10 @@ async def _dispatch_request(
149150
return await self._on_request(dctx, method, params)
150151
except MCPError:
151152
raise
153+
except ValidationError as e:
154+
# Same shape JSONRPCDispatcher writes, so runner-over-direct
155+
# tests see what runner-over-JSONRPC would.
156+
raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") from e
152157
except Exception as e:
153158
raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e
154159
except TimeoutError:

tests/server/test_runner.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,32 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT):
161161

162162

163163
@pytest.mark.anyio
164-
async def test_runner_unknown_method_raises_method_not_found(server: SrvT):
164+
async def test_runner_spec_method_with_no_handler_raises_method_not_found(server: SrvT):
165+
async with connected_runner(server) as (client, _):
166+
with pytest.raises(MCPError) as exc:
167+
await client.send_raw_request("resources/list", None)
168+
assert exc.value.error.code == METHOD_NOT_FOUND
169+
170+
171+
@pytest.mark.anyio
172+
async def test_runner_non_spec_method_with_no_handler_raises_method_not_found(server: SrvT):
173+
"""Upfront validation is gated to spec methods, so a non-spec method
174+
skips it and reaches handler lookup."""
165175
async with connected_runner(server) as (client, _):
166176
with pytest.raises(MCPError) as exc:
167177
await client.send_raw_request("nonexistent/method", None)
168178
assert exc.value.error.code == METHOD_NOT_FOUND
169179

170180

181+
@pytest.mark.anyio
182+
async def test_runner_malformed_params_for_unregistered_spec_method_raises_invalid_params(server: SrvT):
183+
"""A spec method with malformed params is INVALID_PARAMS even with no handler."""
184+
async with connected_runner(server) as (client, _):
185+
with pytest.raises(MCPError) as exc:
186+
await client.send_raw_request("tools/call", {"name": 123})
187+
assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")
188+
189+
171190
@pytest.mark.anyio
172191
async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT):
173192
async with connected_runner(server, initialized=False) as (client, runner):
@@ -287,6 +306,9 @@ async def test_runner_stateless_skips_init_gate(server: SrvT):
287306

288307
@pytest.mark.anyio
289308
async def test_server_add_request_handler_routes_custom_method_with_validated_params(server: SrvT):
309+
"""Custom methods outside the spec `ClientRequest` union skip upfront
310+
validation and route to the registered handler."""
311+
290312
class GreetParams(RequestParams):
291313
name: str
292314

@@ -358,7 +380,7 @@ async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, s
358380
async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _):
359381
spans.clear()
360382
with pytest.raises(MCPError) as exc:
361-
await client.send_raw_request("nonexistent/method", None)
383+
await client.send_raw_request("resources/list", None)
362384
assert exc.value.error.code == METHOD_NOT_FOUND
363385
[span] = spans.finished()
364386
assert span.status.status_code == StatusCode.ERROR

0 commit comments

Comments
 (0)