|
11 | 11 | import pytest |
12 | 12 |
|
13 | 13 | from mcp import types |
| 14 | +from mcp.server import Server, ServerRequestContext |
14 | 15 | from mcp.server.connection import Connection |
15 | 16 | from mcp.server.session import ServerSession |
16 | 17 | from mcp.shared.dispatcher import CallOptions |
|
26 | 27 | SamplingToolsCapability, |
27 | 28 | ) |
28 | 29 |
|
| 30 | +from .test_runner import connected_runner |
| 31 | + |
29 | 32 |
|
30 | 33 | class StubDispatcher: |
31 | 34 | """Records `send_raw_request` / `notify` calls and returns a canned result.""" |
@@ -158,3 +161,63 @@ def test_check_client_capability_delegates_to_connection(): |
158 | 161 | session = _make_session(dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability())) |
159 | 162 | assert session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())) is True |
160 | 163 | assert session.check_client_capability(ClientCapabilities(experimental={"x": {}})) is False |
| 164 | + |
| 165 | + |
| 166 | +def _runner_server(seen_versions: list[str | None]) -> Server[dict[str, Any]]: |
| 167 | + """A lowlevel Server whose tools/list handler records `ctx.session.protocol_version`.""" |
| 168 | + |
| 169 | + async def list_tools( |
| 170 | + ctx: ServerRequestContext[dict[str, Any], Any], params: types.PaginatedRequestParams | None |
| 171 | + ) -> types.ListToolsResult: |
| 172 | + seen_versions.append(ctx.session.protocol_version) |
| 173 | + return types.ListToolsResult(tools=[]) |
| 174 | + |
| 175 | + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) |
| 176 | + |
| 177 | + |
| 178 | +def _init_params(protocol_version: str) -> dict[str, Any]: |
| 179 | + return InitializeRequestParams( |
| 180 | + protocol_version=protocol_version, |
| 181 | + capabilities=ClientCapabilities(), |
| 182 | + client_info=Implementation(name="test-client", version="1.0"), |
| 183 | + ).model_dump(by_alias=True, exclude_none=True) |
| 184 | + |
| 185 | + |
| 186 | +@pytest.mark.anyio |
| 187 | +async def test_protocol_version_is_none_before_initialize(): |
| 188 | + async with connected_runner(_runner_server([]), initialized=False) as (_client, runner): |
| 189 | + assert runner.session.protocol_version is None |
| 190 | + |
| 191 | + |
| 192 | +@pytest.mark.anyio |
| 193 | +async def test_protocol_version_is_negotiated_version_after_initialize(): |
| 194 | + """A supported requested version is echoed back and readable on the session, |
| 195 | + both directly and from inside a handler via `ctx.session`.""" |
| 196 | + seen: list[str | None] = [] |
| 197 | + async with connected_runner(_runner_server(seen), initialized=False) as (client, runner): |
| 198 | + result = await client.send_raw_request("initialize", _init_params("2025-03-26")) |
| 199 | + assert result["protocolVersion"] == "2025-03-26" |
| 200 | + assert runner.session.protocol_version == "2025-03-26" |
| 201 | + await client.send_raw_request("tools/list", None) |
| 202 | + assert seen == ["2025-03-26"] |
| 203 | + |
| 204 | + |
| 205 | +@pytest.mark.anyio |
| 206 | +async def test_protocol_version_reads_latest_when_requested_version_unsupported(): |
| 207 | + """An unsupported requested version negotiates down to LATEST_PROTOCOL_VERSION.""" |
| 208 | + async with connected_runner(_runner_server([]), initialized=False) as (client, runner): |
| 209 | + result = await client.send_raw_request("initialize", _init_params("1999-01-01")) |
| 210 | + assert result["protocolVersion"] == LATEST_PROTOCOL_VERSION |
| 211 | + assert runner.session.protocol_version == LATEST_PROTOCOL_VERSION |
| 212 | + |
| 213 | + |
| 214 | +@pytest.mark.anyio |
| 215 | +async def test_protocol_version_is_none_on_stateless_connection(): |
| 216 | + """Stateless connections never see a handshake: requests flow, but the |
| 217 | + negotiated version legitimately stays None.""" |
| 218 | + seen: list[str | None] = [] |
| 219 | + async with connected_runner(_runner_server(seen), initialized=False, stateless=True) as (client, runner): |
| 220 | + result = await client.send_raw_request("tools/list", None) |
| 221 | + assert result == {"tools": []} |
| 222 | + assert seen == [None] |
| 223 | + assert runner.session.protocol_version is None |
0 commit comments