Skip to content

Commit 5a51056

Browse files
committed
Validate full sampling tool result history
1 parent 47bbab3 commit 5a51056

3 files changed

Lines changed: 222 additions & 22 deletions

File tree

src/mcp/server/validation.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from mcp.shared.exceptions import MCPError
7-
from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, Tool, ToolChoice
7+
from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, SamplingMessageContentBlock, Tool, ToolChoice
88

99

1010
def check_sampling_tools_capability(client_caps: ClientCapabilities | None) -> bool:
@@ -52,6 +52,7 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None:
5252
1. Messages with tool_result content contain ONLY tool_result content
5353
2. tool_result messages are preceded by a message with tool_use
5454
3. tool_result IDs match the tool_use IDs from the previous message
55+
4. Every tool_use message in the history is followed by matching tool_result content
5556
5657
See: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577
5758
@@ -64,24 +65,26 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None:
6465
if not messages:
6566
return
6667

67-
last_content = messages[-1].content_as_list
68-
has_tool_results = any(c.type == "tool_result" for c in last_content)
69-
70-
previous_content = messages[-2].content_as_list if len(messages) >= 2 else None
71-
has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content)
72-
73-
if has_tool_results:
74-
# Per spec: "SamplingMessage with tool result content blocks
75-
# MUST NOT contain other content types."
76-
if any(c.type != "tool_result" for c in last_content):
77-
raise ValueError("The last message must contain only tool_result content if any is present")
78-
if previous_content is None:
79-
raise ValueError("tool_result requires a previous message containing tool_use")
80-
if not has_previous_tool_use:
81-
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
82-
83-
if has_previous_tool_use and previous_content:
84-
tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
85-
tool_result_ids = {c.tool_use_id for c in last_content if c.type == "tool_result"}
86-
if tool_use_ids != tool_result_ids:
87-
raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match")
68+
previous_content: list[SamplingMessageContentBlock] | None = None
69+
for content in (message.content_as_list for message in messages):
70+
has_tool_results = any(c.type == "tool_result" for c in content)
71+
previous_tool_use_ids: set[str] = set()
72+
if previous_content is not None:
73+
previous_tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
74+
75+
if has_tool_results:
76+
# Per spec: "SamplingMessage with tool result content blocks
77+
# MUST NOT contain other content types."
78+
if any(c.type != "tool_result" for c in content):
79+
raise ValueError("A message must contain only tool_result content if any is present")
80+
if previous_content is None:
81+
raise ValueError("tool_result requires a previous message containing tool_use")
82+
if not previous_tool_use_ids:
83+
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
84+
85+
if previous_tool_use_ids:
86+
tool_result_ids = {c.tool_use_id for c in content if c.type == "tool_result"}
87+
if previous_tool_use_ids != tool_result_ids:
88+
raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match")
89+
90+
previous_content = content

tests/server/test_session.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,118 @@ async def test_send_request_skips_the_surface_gate_when_method_absent_at_version
158158
assert isinstance(result, types.EmptyResult)
159159

160160

161+
@pytest.mark.anyio
162+
async def test_create_message_tool_result_validation():
163+
"""Test tool_use/tool_result validation in create_message."""
164+
dispatcher = StubDispatcher(
165+
result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}
166+
)
167+
session = _make_session(
168+
dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability()))
169+
)
170+
tool = types.Tool(name="test_tool", input_schema={"type": "object"})
171+
text = types.TextContent(type="text", text="hello")
172+
tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={})
173+
tool_result = types.ToolResultContent(type="tool_result", tool_use_id="call_1", content=[])
174+
175+
# Case 1: tool_result mixed with other content
176+
with pytest.raises(ValueError, match="only tool_result content"):
177+
await session.create_message(
178+
messages=[
179+
types.SamplingMessage(role="user", content=text),
180+
types.SamplingMessage(role="assistant", content=tool_use),
181+
types.SamplingMessage(role="user", content=[tool_result, text]),
182+
],
183+
max_tokens=100,
184+
tools=[tool],
185+
)
186+
187+
# Case 2: tool_result without previous message
188+
with pytest.raises(ValueError, match="requires a previous message"):
189+
await session.create_message(
190+
messages=[types.SamplingMessage(role="user", content=tool_result)],
191+
max_tokens=100,
192+
tools=[tool],
193+
)
194+
195+
# Case 3: tool_result without previous tool_use
196+
with pytest.raises(ValueError, match="do not match any tool_use"):
197+
await session.create_message(
198+
messages=[
199+
types.SamplingMessage(role="user", content=text),
200+
types.SamplingMessage(role="user", content=tool_result),
201+
],
202+
max_tokens=100,
203+
tools=[tool],
204+
)
205+
206+
# Case 4: mismatched tool IDs
207+
with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"):
208+
await session.create_message(
209+
messages=[
210+
types.SamplingMessage(role="user", content=text),
211+
types.SamplingMessage(role="assistant", content=tool_use),
212+
types.SamplingMessage(
213+
role="user",
214+
content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]),
215+
),
216+
],
217+
max_tokens=100,
218+
tools=[tool],
219+
)
220+
221+
# Case 4b: earlier mismatched tool result with a later plain message
222+
with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"):
223+
await session.create_message(
224+
messages=[
225+
types.SamplingMessage(role="assistant", content=tool_use),
226+
types.SamplingMessage(
227+
role="user",
228+
content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]),
229+
),
230+
types.SamplingMessage(role="assistant", content=text),
231+
],
232+
max_tokens=100,
233+
tools=[tool],
234+
)
235+
236+
# Case 5: text-only message with tools (no tool_results) - passes validation
237+
await session.create_message(
238+
messages=[types.SamplingMessage(role="user", content=text)],
239+
max_tokens=100,
240+
tools=[tool],
241+
)
242+
243+
# Case 6: valid matching tool_result/tool_use IDs - passes validation
244+
await session.create_message(
245+
messages=[
246+
types.SamplingMessage(role="user", content=text),
247+
types.SamplingMessage(role="assistant", content=tool_use),
248+
types.SamplingMessage(role="user", content=tool_result),
249+
],
250+
max_tokens=100,
251+
tools=[tool],
252+
)
253+
254+
# Case 7: validation runs even without `tools` parameter
255+
# (tool loop continuation may omit tools while containing tool_result)
256+
with pytest.raises(ValueError, match="do not match any tool_use"):
257+
await session.create_message(
258+
messages=[
259+
types.SamplingMessage(role="user", content=text),
260+
types.SamplingMessage(role="user", content=tool_result),
261+
],
262+
max_tokens=100,
263+
)
264+
265+
# Case 8: empty messages list - skips validation entirely
266+
no_tools_session = _make_session(
267+
StubDispatcher(result={"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"}),
268+
capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())),
269+
)
270+
await no_tools_session.create_message(messages=[], max_tokens=100)
271+
272+
161273
@pytest.mark.anyio
162274
async def test_send_request_validates_result_alias_only():
163275
"""Peer results validate alias-only; a snake_case key from the wire is

tests/server/test_validation.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,27 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_mixed_with_ot
108108
validate_tool_use_result_messages(messages)
109109

110110

111+
def test_validate_tool_use_result_messages_raises_for_earlier_mixed_tool_result() -> None:
112+
"""Raises when an earlier message mixes tool_result with other content."""
113+
messages = [
114+
SamplingMessage(
115+
role="assistant",
116+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
117+
),
118+
SamplingMessage(
119+
role="user",
120+
content=[
121+
ToolResultContent(type="tool_result", tool_use_id="tool-1"),
122+
TextContent(type="text", text="also this"),
123+
],
124+
),
125+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
126+
]
127+
128+
with pytest.raises(ValueError, match="only tool_result content"):
129+
validate_tool_use_result_messages(messages)
130+
131+
111132
def test_validate_tool_use_result_messages_raises_when_tool_result_without_previous_tool_use() -> None:
112133
"""Raises when tool_result appears without preceding tool_use."""
113134
messages = [
@@ -146,6 +167,39 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_ids_dont_matc
146167
validate_tool_use_result_messages(messages)
147168

148169

170+
def test_validate_tool_use_result_messages_raises_when_earlier_tool_result_ids_dont_match_tool_use() -> None:
171+
"""Raises when an earlier tool_result does not match the previous tool_use."""
172+
messages = [
173+
SamplingMessage(
174+
role="assistant",
175+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
176+
),
177+
SamplingMessage(
178+
role="user",
179+
content=ToolResultContent(type="tool_result", tool_use_id="tool-2"),
180+
),
181+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
182+
]
183+
184+
with pytest.raises(ValueError, match="do not match"):
185+
validate_tool_use_result_messages(messages)
186+
187+
188+
def test_validate_tool_use_result_messages_raises_when_tool_use_is_not_answered() -> None:
189+
"""Raises when a tool_use is followed by a non-tool_result message."""
190+
messages = [
191+
SamplingMessage(
192+
role="assistant",
193+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
194+
),
195+
SamplingMessage(role="user", content=TextContent(type="text", text="not a result")),
196+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
197+
]
198+
199+
with pytest.raises(ValueError, match="do not match"):
200+
validate_tool_use_result_messages(messages)
201+
202+
149203
def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_tool_use() -> None:
150204
"""No error when tool_result IDs match tool_use IDs."""
151205
messages = [
@@ -159,3 +213,34 @@ def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_too
159213
),
160214
]
161215
validate_tool_use_result_messages(messages) # Should not raise
216+
217+
218+
def test_validate_tool_use_result_messages_no_error_for_multiple_tool_pairs() -> None:
219+
"""No error when every tool_use in the history has a matching tool_result."""
220+
messages = [
221+
SamplingMessage(role="user", content=TextContent(type="text", text="first")),
222+
SamplingMessage(
223+
role="assistant",
224+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
225+
),
226+
SamplingMessage(
227+
role="user",
228+
content=ToolResultContent(type="tool_result", tool_use_id="tool-1"),
229+
),
230+
SamplingMessage(
231+
role="assistant",
232+
content=[
233+
ToolUseContent(type="tool_use", id="tool-2", name="test", input={}),
234+
ToolUseContent(type="tool_use", id="tool-3", name="test", input={}),
235+
],
236+
),
237+
SamplingMessage(
238+
role="user",
239+
content=[
240+
ToolResultContent(type="tool_result", tool_use_id="tool-3"),
241+
ToolResultContent(type="tool_result", tool_use_id="tool-2"),
242+
],
243+
),
244+
]
245+
246+
validate_tool_use_result_messages(messages)

0 commit comments

Comments
 (0)