Skip to content

Commit e729d1b

Browse files
committed
fix: respect session group capabilities
1 parent cf110e3 commit e729d1b

2 files changed

Lines changed: 83 additions & 25 deletions

File tree

src/mcp/client/session_group.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -344,36 +344,42 @@ async def _aggregate_components(self, server_info: types.Implementation, session
344344
tools_temp: dict[str, types.Tool] = {}
345345
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
346346

347+
initialize_result = session.initialize_result
348+
capabilities = initialize_result.capabilities if initialize_result is not None else None
349+
347350
# Query the server for its prompts and aggregate to list.
348-
try:
349-
prompts = (await session.list_prompts()).prompts
350-
for prompt in prompts:
351-
name = self._component_name(prompt.name, server_info)
352-
prompts_temp[name] = prompt
353-
component_names.prompts.add(name)
354-
except MCPError as err: # pragma: no cover
355-
logging.warning(f"Could not fetch prompts: {err}")
351+
if capabilities is None or capabilities.prompts is not None:
352+
try:
353+
prompts = (await session.list_prompts()).prompts
354+
for prompt in prompts:
355+
name = self._component_name(prompt.name, server_info)
356+
prompts_temp[name] = prompt
357+
component_names.prompts.add(name)
358+
except MCPError as err: # pragma: no cover
359+
logging.warning(f"Could not fetch prompts: {err}")
356360

357361
# Query the server for its resources and aggregate to list.
358-
try:
359-
resources = (await session.list_resources()).resources
360-
for resource in resources:
361-
name = self._component_name(resource.name, server_info)
362-
resources_temp[name] = resource
363-
component_names.resources.add(name)
364-
except MCPError as err: # pragma: no cover
365-
logging.warning(f"Could not fetch resources: {err}")
362+
if capabilities is None or capabilities.resources is not None:
363+
try:
364+
resources = (await session.list_resources()).resources
365+
for resource in resources:
366+
name = self._component_name(resource.name, server_info)
367+
resources_temp[name] = resource
368+
component_names.resources.add(name)
369+
except MCPError as err: # pragma: no cover
370+
logging.warning(f"Could not fetch resources: {err}")
366371

367372
# Query the server for its tools and aggregate to list.
368-
try:
369-
tools = (await session.list_tools()).tools
370-
for tool in tools:
371-
name = self._component_name(tool.name, server_info)
372-
tools_temp[name] = tool
373-
tool_to_session_temp[name] = session
374-
component_names.tools.add(name)
375-
except MCPError as err: # pragma: no cover
376-
logging.warning(f"Could not fetch tools: {err}")
373+
if capabilities is None or capabilities.tools is not None:
374+
try:
375+
tools = (await session.list_tools()).tools
376+
for tool in tools:
377+
name = self._component_name(tool.name, server_info)
378+
tools_temp[name] = tool
379+
tool_to_session_temp[name] = session
380+
component_names.tools.add(name)
381+
except MCPError as err: # pragma: no cover
382+
logging.warning(f"Could not fetch tools: {err}")
377383

378384
# Clean up exit stack for session if we couldn't retrieve anything
379385
# from the server.

tests/client/test_session_group.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,58 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
125125
mock_session.list_prompts.assert_awaited_once()
126126

127127

128+
@pytest.mark.anyio
129+
async def test_client_session_group_skips_unadvertised_capabilities(mock_exit_stack: contextlib.AsyncExitStack):
130+
server_info = types.Implementation(name="ToolsOnlyServer", version="1")
131+
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
132+
mock_tool = types.Tool(name="ping", input_schema={})
133+
mock_session.initialize_result = types.InitializeResult(
134+
protocol_version="2025-03-26",
135+
capabilities=types.ServerCapabilities(tools=types.ToolsCapability(list_changed=False)),
136+
server_info=server_info,
137+
)
138+
mock_session.list_tools.return_value = types.ListToolsResult(tools=[mock_tool])
139+
mock_session.list_resources.return_value = types.ListResourcesResult(resources=[])
140+
mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[])
141+
142+
group = ClientSessionGroup(exit_stack=mock_exit_stack)
143+
with mock.patch.object(group, "_establish_session", return_value=(server_info, mock_session)):
144+
await group.connect_to_server(StdioServerParameters(command="test"))
145+
146+
assert group.tools == {"ping": mock_tool}
147+
assert group.resources == {}
148+
assert group.prompts == {}
149+
mock_session.list_tools.assert_awaited_once()
150+
mock_session.list_resources.assert_not_awaited()
151+
mock_session.list_prompts.assert_not_awaited()
152+
153+
154+
@pytest.mark.anyio
155+
async def test_client_session_group_skips_unadvertised_tools(mock_exit_stack: contextlib.AsyncExitStack):
156+
server_info = types.Implementation(name="PromptServer", version="1")
157+
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
158+
mock_prompt = types.Prompt(name="explain")
159+
mock_session.initialize_result = types.InitializeResult(
160+
protocol_version="2025-03-26",
161+
capabilities=types.ServerCapabilities(prompts=types.PromptsCapability(list_changed=False)),
162+
server_info=server_info,
163+
)
164+
mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[mock_prompt])
165+
mock_session.list_resources.return_value = types.ListResourcesResult(resources=[])
166+
mock_session.list_tools.return_value = types.ListToolsResult(tools=[])
167+
168+
group = ClientSessionGroup(exit_stack=mock_exit_stack)
169+
with mock.patch.object(group, "_establish_session", return_value=(server_info, mock_session)):
170+
await group.connect_to_server(StdioServerParameters(command="test"))
171+
172+
assert group.prompts == {"explain": mock_prompt}
173+
assert group.resources == {}
174+
assert group.tools == {}
175+
mock_session.list_prompts.assert_awaited_once()
176+
mock_session.list_resources.assert_not_awaited()
177+
mock_session.list_tools.assert_not_awaited()
178+
179+
128180
@pytest.mark.anyio
129181
async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack):
130182
"""Test connecting with a component name hook."""

0 commit comments

Comments
 (0)