diff --git a/python/packages/kagent-adk/src/kagent/adk/_a2a.py b/python/packages/kagent-adk/src/kagent/adk/_a2a.py index 4329d94d1..9e275a583 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_a2a.py +++ b/python/packages/kagent-adk/src/kagent/adk/_a2a.py @@ -1,4 +1,5 @@ #! /usr/bin/env python3 +import asyncio import faulthandler import logging import os @@ -20,7 +21,6 @@ from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService from google.genai import types - from kagent.core.a2a import ( KAgentRequestContextBuilder, KAgentTaskStore, @@ -170,6 +170,7 @@ def create_runner() -> Runner: # Health check/readiness probe app.add_route("/health", methods=["GET"], route=health_check) app.add_route("/thread_dump", methods=["GET"], route=thread_dump) + a2a_app.add_routes_to_app(app) return app diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 26c4c6df7..d1eea8637 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -5,10 +5,90 @@ from typing import Optional from google.adk.tools import BaseTool +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager from google.adk.tools.mcp_tool.mcp_toolset import McpToolset, ReadonlyContext +from mcp import ClientSession +from mcp.shared.exceptions import McpError logger = logging.getLogger("kagent_adk." + __name__) +# Short timeouts to fail fast on the request path; avoid adding latency when session is valid. +_PING_TIMEOUT_SECONDS = 1.0 +_SESSION_REVALIDATE_TIMEOUT_SECONDS = 2.0 +_JSONRPC_METHOD_NOT_FOUND = -32601 + + +def _is_server_alive_error(exc: Exception) -> bool: + """Return True if the error proves the server is reachable. + + Some MCP servers don't implement the optional ``ping`` method and + reply with JSON-RPC "Method not found" (-32601). This still means + the session is valid and the server is responding. + """ + if isinstance(exc, McpError): + return exc.error.code == _JSONRPC_METHOD_NOT_FOUND + return False + + +def _is_session_invalid_error(exc: Exception) -> bool: + """Return True if the error indicates the MCP session is no longer valid (e.g. 404).""" + parts = [str(exc)] + if isinstance(exc, McpError) and exc.error.message: + parts.append(exc.error.message) + msg = " ".join(parts).lower() + return "404" in msg or "session terminated" in msg + + +class KAgentMCPSessionManager(MCPSessionManager): + """Session manager that validates cached sessions via ping and list_tools before reuse. + + The upstream ``MCPSessionManager`` checks ``_read_stream._closed`` / + ``_write_stream._closed`` to decide whether a cached session is still + usable. Those are in-memory anyio channels that stay open even when + the remote MCP server restarts, so the check always passes and the + stale ``mcp-session-id`` is sent to the new server, which replies + with HTTP 404 → ``"Session terminated"``. + + This subclass: (1) runs ``send_ping()`` after the parent returns a cached + session; (2) then revalidates the session with ``list_tools()`` so that + if the server restarted and the session id is invalid (404), we prune + the session from the cache and create a new one. + """ + + async def _close_and_recreate_session(self, headers: dict[str, str] | None, reason: str) -> ClientSession: + """Close the cached session (best-effort) and create a new one.""" + logger.warning("%s", reason) + try: + await self.close() + except Exception as close_exc: + logger.debug("Non-fatal error while closing stale session: %s", close_exc) + return await super().create_session(headers) + + async def create_session(self, headers: dict[str, str] | None = None) -> ClientSession: + session = await super().create_session(headers) + + try: + await asyncio.wait_for(session.send_ping(), timeout=_PING_TIMEOUT_SECONDS) + except Exception as exc: + if _is_server_alive_error(exc): + pass + else: + return await self._close_and_recreate_session( + headers, + "MCP session failed ping validation, invalidating cached session and creating a fresh one", + ) + + try: + await asyncio.wait_for(session.list_tools(), timeout=_SESSION_REVALIDATE_TIMEOUT_SECONDS) + return session + except Exception as exc: + if _is_session_invalid_error(exc): + return await self._close_and_recreate_session( + headers, + "MCP session invalid (e.g. 404), pruning from cache and creating a fresh one", + ) + raise + def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError: message = "Failed to create MCP session: operation cancelled" @@ -25,6 +105,13 @@ class KAgentMcpToolset(McpToolset): implementation may not catch and propagate without enough context. """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._mcp_session_manager = KAgentMCPSessionManager( + connection_params=self._connection_params, + errlog=self._errlog, + ) + async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]: try: return await super().get_tools(readonly_context) diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py b/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py new file mode 100644 index 000000000..43124c766 --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py @@ -0,0 +1,287 @@ +"""Tests for KAgentMCPSessionManager ping-validated session caching.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData + +from kagent.adk._mcp_toolset import _PING_TIMEOUT_SECONDS, KAgentMCPSessionManager + + +def _make_manager(**overrides): + """Create a KAgentMCPSessionManager with a mocked connection_params.""" + params = MagicMock() + params.url = "http://mcp.example.com/mcp" + params.timeout = 5.0 + params.sse_read_timeout = 300.0 + params.headers = None + return KAgentMCPSessionManager(connection_params=params, **overrides) + + +@pytest.mark.asyncio +async def test_create_session_returns_session_when_ping_succeeds(): + mgr = _make_manager() + + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock(return_value=None) + mock_session.list_tools = AsyncMock(return_value=[]) + + with patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + new_callable=AsyncMock, + return_value=mock_session, + ) as parent_create: + result = await mgr.create_session() + + assert result is mock_session + mock_session.send_ping.assert_awaited_once() + mock_session.list_tools.assert_awaited_once() + parent_create.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_session_invalidates_and_retries_when_ping_fails(): + mgr = _make_manager() + + stale_session = AsyncMock() + stale_session.send_ping = AsyncMock(side_effect=Exception("Session terminated")) + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return stale_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object(mgr, "close", new_callable=AsyncMock) as mock_close, + ): + result = await mgr.create_session() + + assert result is fresh_session + mock_close.assert_awaited_once() + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_create_session_propagates_error_when_server_truly_down(): + mgr = _make_manager() + + with patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + new_callable=AsyncMock, + side_effect=ConnectionError("Failed to create MCP session"), + ): + with pytest.raises(ConnectionError, match="Failed to create MCP session"): + await mgr.create_session() + + +@pytest.mark.asyncio +async def test_create_session_ping_respects_timeout(): + mgr = _make_manager() + + async def _slow_ping(): + await asyncio.sleep(10) + + slow_session = AsyncMock() + slow_session.send_ping = _slow_ping + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return slow_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object(mgr, "close", new_callable=AsyncMock), + ): + result = await mgr.create_session() + + assert result is fresh_session + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_create_session_accepts_method_not_found_as_alive(): + """Servers that don't implement ping reply with -32601; session is still valid.""" + mgr = _make_manager() + + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock( + side_effect=McpError(error=ErrorData(code=-32601, message="Method not found: ping")) + ) + mock_session.list_tools = AsyncMock(return_value=[]) + + with patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + new_callable=AsyncMock, + return_value=mock_session, + ) as parent_create: + result = await mgr.create_session() + + assert result is mock_session + mock_session.list_tools.assert_awaited_once() + parent_create.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_session_invalidates_when_list_tools_returns_session_terminated(): + """After ping passes, list_tools is used to revalidate; 404/session terminated → prune and recreate.""" + mgr = _make_manager() + + stale_session = AsyncMock() + stale_session.send_ping = AsyncMock(return_value=None) + stale_session.list_tools = AsyncMock(side_effect=Exception("Session terminated")) + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return stale_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object(mgr, "close", new_callable=AsyncMock) as mock_close, + ): + result = await mgr.create_session() + + assert result is fresh_session + mock_close.assert_awaited_once() + assert call_count == 2 + stale_session.list_tools.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_session_invalidates_when_list_tools_returns_404(): + """list_tools returning 404 (session invalid) triggers prune and recreate.""" + mgr = _make_manager() + + stale_session = AsyncMock() + stale_session.send_ping = AsyncMock(return_value=None) + stale_session.list_tools = AsyncMock(side_effect=McpError(error=ErrorData(code=-32000, message="404 Not Found"))) + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return stale_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object(mgr, "close", new_callable=AsyncMock) as mock_close, + ): + result = await mgr.create_session() + + assert result is fresh_session + mock_close.assert_awaited_once() + assert call_count == 2 + stale_session.list_tools.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_session_propagates_non_session_error_from_list_tools(): + """Transient errors (e.g. timeout) from list_tools are propagated, not treated as session invalid.""" + mgr = _make_manager() + + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock(return_value=None) + mock_session.list_tools = AsyncMock(side_effect=asyncio.TimeoutError()) + + with patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + new_callable=AsyncMock, + return_value=mock_session, + ): + with pytest.raises(asyncio.TimeoutError): + await mgr.create_session() + + +@pytest.mark.asyncio +async def test_create_session_recovers_even_when_close_raises(): + """Recovery must succeed even if close() raises during stale session teardown.""" + mgr = _make_manager() + + stale_session = AsyncMock() + stale_session.send_ping = AsyncMock(side_effect=Exception("Session terminated")) + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return stale_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object( + mgr, + "close", + new_callable=AsyncMock, + side_effect=RuntimeError("cancel scope in different task"), + ), + ): + result = await mgr.create_session() + + assert result is fresh_session + assert call_count == 2