Skip to content

Commit 67116fb

Browse files
committed
Wire per-version validation into ClientSession; fix elicit schema generation
ClientSession now uses parse_server_request/notification for inbound dispatch (KeyError -> METHOD_NOT_FOUND), validates callback results against the surface schema before sending (ValidationError -> INTERNAL_ERROR), and gates server results through validate_server_result before parsing. Adds a protocol_version property. Surface codegen: SCHEMA_PATCHES targets ["integer", "number"] instead of bare "number" so generated types use int | float (pydantic smart-union preserves int through the sieve; previously coerced 37 -> 37.0). Regenerated both packages. elicitation.py: Optional[T] fields now emit {"type": ...} with the field omitted from required (was non-spec anyOf). Gate tightened to reject list[...] and multi-primitive unions; only T or T | None accepted. The interaction-suite divergence for nested requestedSchema is flipped (client now rejects with INVALID_PARAMS). Documented the elicit gate change in migration.md.
1 parent 44a1e83 commit 67116fb

10 files changed

Lines changed: 285 additions & 163 deletions

File tree

docs/migration.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,10 @@ await ctx.log(level="info", data="hello")
535535

536536
Positional calls (`await ctx.info("hello")`) are unaffected.
537537

538+
### `Context.elicit()` schema gate tightened
539+
540+
`Context.elicit()` (and `elicit_with_validation()`) now accept only schemas whose fields are a single primitive type (`str`, `int`, `float`, `bool`) or `Optional[primitive]`. `list[...]` fields and unions of multiple primitives (e.g. `int | str`) raise `TypeError` at the call site; previously `list[str]` and arbitrary primitive unions were allowed but produced a `requestedSchema` outside the spec's restricted subset. `Optional[T]` fields now render as `{"type": ...}` with the field omitted from `required` instead of the non-spec `anyOf` shape.
541+
538542
### Replace `RootModel` by union types with `TypeAdapter` validation
539543

540544
The following union types are no longer `RootModel` subclasses:

scripts/gen_surface_types.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,30 @@
2727

2828
# schema.ts -> schema.json renders TypeScript ``number`` as JSON Schema
2929
# ``integer`` at these sites; patch the JSON before codegen so floats validate.
30+
# Patched to ``["integer", "number"]`` (not bare ``"number"``) so codegen emits
31+
# ``int | float`` and pydantic's smart-union preserves ints on round-trip.
3032
# TODO: drop once modelcontextprotocol/modelcontextprotocol fixes the schema.ts -> schema.json number rendering.
3133
SCHEMA_PATCHES: dict[str, list[tuple[str, Any, Any]]] = {
3234
"2025-11-25": [
33-
("$defs/NumberSchema/properties/default/type", "integer", "number"),
34-
("$defs/NumberSchema/properties/maximum/type", "integer", "number"),
35-
("$defs/NumberSchema/properties/minimum/type", "integer", "number"),
35+
("$defs/NumberSchema/properties/default/type", "integer", ["integer", "number"]),
36+
("$defs/NumberSchema/properties/maximum/type", "integer", ["integer", "number"]),
37+
("$defs/NumberSchema/properties/minimum/type", "integer", ["integer", "number"]),
3638
(
3739
"$defs/ElicitResult/properties/content/additionalProperties/anyOf/1/type",
3840
["string", "integer", "boolean"],
39-
["string", "number", "boolean"],
41+
["string", "integer", "number", "boolean"],
4042
),
4143
],
4244
"2026-07-28": [
45+
("$defs/NumberSchema/properties/default/type", "number", ["integer", "number"]),
46+
("$defs/NumberSchema/properties/maximum/type", "number", ["integer", "number"]),
47+
("$defs/NumberSchema/properties/minimum/type", "number", ["integer", "number"]),
4348
(
4449
"$defs/ElicitResult/properties/content/additionalProperties/anyOf/1/type",
4550
["string", "integer", "boolean"],
46-
["string", "number", "boolean"],
51+
["string", "integer", "number", "boolean"],
4752
),
48-
("$defs/JSONValue/anyOf/2/type", ["string", "integer", "boolean"], ["string", "number", "boolean"]),
53+
("$defs/JSONValue/anyOf/2/type", ["string", "integer", "boolean"], ["string", "integer", "number", "boolean"]),
4954
],
5055
}
5156

src/mcp/client/session.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Mapping
55
from dataclasses import dataclass
66
from types import TracebackType
7-
from typing import Any, Protocol, cast, get_args
7+
from typing import Any, Protocol, cast
88

99
import anyio
1010
import anyio.abc
@@ -22,7 +22,8 @@
2222
from mcp.shared.session import RequestResponder
2323
from mcp.shared.transport_context import TransportContext
2424
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
25-
from mcp.types import RequestId, RequestParamsMeta
25+
from mcp.types import INTERNAL_ERROR, METHOD_NOT_FOUND, RequestId, RequestParamsMeta
26+
from mcp.types import methods as _methods
2627

2728
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
2829

@@ -116,14 +117,6 @@ async def _default_logging_callback(
116117

117118
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
118119

119-
_SERVER_REQUEST_METHODS: frozenset[str] = frozenset(
120-
cast(type[BaseModel], arm).model_fields["method"].default for arm in get_args(types.ServerRequest)
121-
)
122-
"""Method names in the SDK's `ServerRequest` union, derived from the
123-
discriminator literal on each arm. Requests for any other method — including
124-
spec methods this SDK deliberately doesn't model, like `tasks/*` — are
125-
answered with METHOD_NOT_FOUND instead of failing union validation."""
126-
127120

128121
class ClientSession:
129122
"""Client half of an MCP connection, running on a `Dispatcher`.
@@ -244,6 +237,12 @@ async def send_request(
244237
# The spec forbids cancelling initialize.
245238
opts["cancel_on_abandon"] = False
246239
raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts)
240+
# Literal fallback covers pre-handshake and stateless; matches runner.py.
241+
version = self.protocol_version or "2025-11-25"
242+
try:
243+
_methods.validate_server_result(method, version, raw)
244+
except KeyError:
245+
pass
247246
return result_type.model_validate(raw, by_name=False)
248247

249248
async def send_notification(self, notification: types.ClientNotification) -> None:
@@ -308,6 +307,11 @@ def initialize_result(self) -> types.InitializeResult | None:
308307
"""
309308
return self._initialize_result
310309

310+
@property
311+
def protocol_version(self) -> str | None:
312+
"""The negotiated protocol version. None until `initialize()` has completed."""
313+
return self._initialize_result.protocol_version if self._initialize_result else None
314+
311315
async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
312316
"""Send a ping request."""
313317
return await self.send_request(types.PingRequest(params=types.RequestParams(_meta=meta)), types.EmptyResult)
@@ -506,12 +510,14 @@ async def _on_request(
506510
self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None
507511
) -> dict[str, Any]:
508512
"""Answer a server-initiated request via the registered callbacks."""
509-
if method not in _SERVER_REQUEST_METHODS:
510-
raise MCPError(code=types.METHOD_NOT_FOUND, message="Method not found", data=method)
511-
payload: dict[str, Any] = {"method": method}
512-
if params is not None:
513-
payload["params"] = dict(params)
514-
request = types.server_request_adapter.validate_python(payload, by_name=False)
513+
# Literal, not LATEST_PROTOCOL_VERSION: the fallback covers the initialize
514+
# handshake (which only exists at <=2025) and stateless until the header
515+
# is plumbed; its meaning is fixed regardless of LATEST bumps.
516+
version = self.protocol_version or "2025-11-25"
517+
try:
518+
request = cast(types.ServerRequest, _methods.parse_server_request(method, version, params))
519+
except KeyError:
520+
raise MCPError(code=METHOD_NOT_FOUND, message="Method not found", data=method) from None
515521

516522
response: types.ClientResult | types.ErrorData
517523
if isinstance(request, types.PingRequest):
@@ -532,19 +538,24 @@ async def _on_request(
532538
client_response = ClientResponse.validate_python(response)
533539
if isinstance(client_response, types.ErrorData):
534540
raise MCPError.from_error_data(client_response)
535-
return client_response.model_dump(by_alias=True, mode="json", exclude_none=True)
541+
dumped = client_response.model_dump(by_alias=True, mode="json", exclude_none=True)
542+
try:
543+
_methods.validate_client_result(method, version, dumped)
544+
except ValidationError:
545+
logger.exception("client callback for %r returned an invalid result", method)
546+
raise MCPError(code=INTERNAL_ERROR, message="Client callback returned an invalid result") from None
547+
return dumped
536548

537549
async def _on_notify(
538550
self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None
539551
) -> None:
540552
"""Route a server notification: validate, run the typed callback, tee to message_handler."""
541-
payload: dict[str, Any] = {"method": method}
542-
if params is not None:
543-
payload["params"] = dict(params)
553+
# Same fallback as `_on_request`: covers pre-handshake and stateless.
554+
version = self.protocol_version or "2025-11-25"
544555
try:
545-
notification = types.server_notification_adapter.validate_python(payload, by_name=False)
546-
except ValidationError:
547-
logger.warning("Failed to validate notification: %s", payload, exc_info=True)
556+
notification = cast(types.ServerNotification, _methods.parse_server_notification(method, version, params))
557+
except (KeyError, ValidationError):
558+
logger.warning("Failed to validate notification: %s", method, exc_info=True)
548559
return
549560
if isinstance(notification, types.CancelledNotification):
550561
# The dispatcher already applied the cancellation; not surfaced to message_handler.

src/mcp/server/elicitation.py

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from __future__ import annotations
44

55
import types
6-
from collections.abc import Sequence
7-
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
6+
from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin
87

98
from pydantic import BaseModel
9+
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
10+
from pydantic_core import core_schema
1011

1112
from mcp.server.session import ServerSession
1213
from mcp.types import RequestId
@@ -49,56 +50,43 @@ class AcceptedUrlElicitation(BaseModel):
4950
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
5051

5152

53+
class _ElicitationJsonSchema(GenerateJsonSchema):
54+
"""JSON-Schema generator that flattens `T | None` to `T` and drops `None` defaults.
55+
56+
The spec's `PrimitiveSchemaDefinition` admits no `anyOf` or null type; an
57+
optional field is expressed by leaving it out of `required`, which pydantic
58+
already does for any field with a default.
59+
"""
60+
61+
def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue:
62+
return self.generate_inner(schema["schema"])
63+
64+
def default_schema(self, schema: core_schema.WithDefaultSchema) -> JsonSchemaValue:
65+
result = super().default_schema(schema)
66+
if result.get("default") is None:
67+
result.pop("default", None)
68+
return result
69+
70+
5271
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
5372
"""Validate that a Pydantic model only contains primitive field types."""
5473
for field_name, field_info in schema.model_fields.items():
55-
annotation = field_info.annotation
56-
57-
if annotation is None or annotation is types.NoneType: # pragma: no cover
58-
continue
59-
elif _is_primitive_field(annotation):
60-
continue
61-
elif _is_string_sequence(annotation):
62-
continue
63-
else:
74+
if not _is_primitive_field(field_info.annotation):
6475
raise TypeError(
6576
f"Elicitation schema field '{field_name}' must be a primitive type "
66-
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
67-
f"or Optional of these types. Nested models and complex types are not allowed."
77+
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of one. Unions of multiple "
78+
f"primitives, lists, and nested models are not allowed."
6879
)
6980

7081

71-
def _is_string_sequence(annotation: type) -> bool:
72-
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
73-
origin = get_origin(annotation)
74-
# Check if it's a sequence-like type with str elements
75-
if origin:
76-
try:
77-
if issubclass(origin, Sequence):
78-
args = get_args(annotation)
79-
# Should have single str type arg
80-
return len(args) == 1 and args[0] is str
81-
except TypeError: # pragma: no cover
82-
# origin is not a class, so it can't be a subclass of Sequence
83-
pass
84-
return False
85-
86-
87-
def _is_primitive_field(annotation: type) -> bool:
88-
"""Check if a field is a primitive type allowed in elicitation schemas."""
89-
# Handle basic primitive types
82+
def _is_primitive_field(annotation: Any) -> bool:
83+
"""True if `annotation` is a single primitive type, optionally wrapped in `Optional`."""
9084
if annotation in _ELICITATION_PRIMITIVE_TYPES:
9185
return True
92-
93-
# Handle Union types
9486
origin = get_origin(annotation)
9587
if origin is Union or origin is types.UnionType:
96-
args = get_args(annotation)
97-
# All args must be primitive types, None, or string sequences
98-
return all(
99-
arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES or _is_string_sequence(arg) for arg in args
100-
)
101-
88+
non_none = [a for a in get_args(annotation) if a is not types.NoneType]
89+
return len(non_none) == 1 and non_none[0] in _ELICITATION_PRIMITIVE_TYPES
10290
return False
10391

10492

@@ -121,7 +109,7 @@ async def elicit_with_validation(
121109
# Validate that schema only contains primitive types and fail loudly if not
122110
_validate_elicitation_schema(schema)
123111

124-
json_schema = schema.model_json_schema()
112+
json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema)
125113

126114
result = await session.elicit_form(
127115
message=message,

src/mcp/types/v2025_11_25/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ class ElicitResult(WireModel):
311311
- "decline": User explicitly decline the action
312312
- "cancel": User dismissed without making an explicit choice
313313
"""
314-
content: dict[str, list[str] | str | float | bool] | None = None
314+
content: dict[str, list[str] | str | int | float | bool] | None = None
315315
"""
316316
The submitted form data, only present when action is "accept" and mode was "form".
317317
Contains values matching the requested schema.
@@ -685,10 +685,10 @@ class NumberSchema(WireModel):
685685
model_config = ConfigDict(
686686
extra="ignore",
687687
)
688-
default: float | None = None
688+
default: int | float | None = None
689689
description: str | None = None
690-
maximum: float | None = None
691-
minimum: float | None = None
690+
maximum: int | float | None = None
691+
minimum: int | float | None = None
692692
title: str | None = None
693693
type: Literal["integer", "number"]
694694

src/mcp/types/v2026_07_28/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class ElicitResult(WireModel):
144144
- `"decline"`: User explicitly declined the action
145145
- `"cancel"`: User dismissed without making an explicit choice
146146
"""
147-
content: dict[str, list[str] | str | float | bool] | None = None
147+
content: dict[str, list[str] | str | int | float | bool] | None = None
148148
"""
149149
The submitted form data, only present when action is `"accept"` and mode was `"form"`.
150150
Contains values matching the requested schema.
@@ -585,10 +585,10 @@ class NumberSchema(WireModel):
585585
model_config = ConfigDict(
586586
extra="ignore",
587587
)
588-
default: float | None = None
588+
default: int | float | None = None
589589
description: str | None = None
590-
maximum: float | None = None
591-
minimum: float | None = None
590+
maximum: int | float | None = None
591+
minimum: int | float | None = None
592592
title: str | None = None
593593
type: Literal["integer", "number"]
594594

@@ -3536,8 +3536,8 @@ class JSONObject(RootModel[dict[str, "JSONValue"]]):
35363536
root: dict[str, "JSONValue"]
35373537

35383538

3539-
class JSONValue(RootModel[Union[JSONObject, list["JSONValue"], str | float | bool]]):
3540-
root: Union[JSONObject, list["JSONValue"], str | float | bool]
3539+
class JSONValue(RootModel[Union[JSONObject, list["JSONValue"], str | int | float | bool]]):
3540+
root: Union[JSONObject, list["JSONValue"], str | int | float | bool]
35413541

35423542

35433543
AnyCallToolResult = CallToolResult | InputRequiredResult

0 commit comments

Comments
 (0)