From b6ba435765880e86b2bcd992aea3afc01ae930da Mon Sep 17 00:00:00 2001 From: Denis Tu Date: Mon, 2 Mar 2026 08:41:01 -0800 Subject: [PATCH 1/6] fix: recover stale MCP sessions after server restart Signed-off-by: Denis Tu --- .../translator/agent/adk_api_translator.go | 9 + .../kagent-adk/src/kagent/adk/_a2a.py | 66 +++++++ .../kagent-adk/src/kagent/adk/_mcp_toolset.py | 68 ++++++- .../unittests/test_mcp_health_endpoint.py | 134 +++++++++++++ .../unittests/test_mcp_session_recovery.py | 182 ++++++++++++++++++ 5 files changed, 458 insertions(+), 1 deletion(-) create mode 100644 python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py create mode 100644 python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py diff --git a/go/core/internal/controller/translator/agent/adk_api_translator.go b/go/core/internal/controller/translator/agent/adk_api_translator.go index b3f7b36d5..81c59b246 100644 --- a/go/core/internal/controller/translator/agent/adk_api_translator.go +++ b/go/core/internal/controller/translator/agent/adk_api_translator.go @@ -568,6 +568,15 @@ func (a *adkApiTranslator) buildManifest( TimeoutSeconds: probeConf.TimeoutSeconds, PeriodSeconds: probeConf.PeriodSeconds, }, + LivenessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{Path: "/healthz/mcp", Port: intstr.FromString("http")}, + }, + InitialDelaySeconds: 30, + PeriodSeconds: 30, + TimeoutSeconds: 10, + FailureThreshold: 3, + }, SecurityContext: securityContext, VolumeMounts: volumeMounts, }}, diff --git a/python/packages/kagent-adk/src/kagent/adk/_a2a.py b/python/packages/kagent-adk/src/kagent/adk/_a2a.py index 4329d94d1..fc3e26399 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_a2a.py +++ b/python/packages/kagent-adk/src/kagent/adk/_a2a.py @@ -1,4 +1,5 @@ #! /usr/bin/env python3 +import asyncio import faulthandler import logging import os @@ -20,6 +21,7 @@ from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService from google.genai import types +from starlette.responses import JSONResponse from kagent.core.a2a import ( KAgentRequestContextBuilder, @@ -29,6 +31,7 @@ from ._agent_executor import A2aAgentExecutor, A2aAgentExecutorConfig from ._lifespan import LifespanManager +from ._mcp_toolset import KAgentMCPSessionManager, _is_server_alive_error from ._memory_service import KagentMemoryService from ._session_service import KAgentSessionService from ._token import KAgentTokenService @@ -36,6 +39,8 @@ logger = logging.getLogger(__name__) +_MCP_HEALTH_TIMEOUT_SECONDS = 5.0 + def health_check(request: Request) -> PlainTextResponse: return PlainTextResponse("OK") @@ -50,6 +55,63 @@ def thread_dump(request: Request) -> PlainTextResponse: return PlainTextResponse(tmp.read()) +def _build_mcp_health_check(agent_config: Optional[AgentConfig]): + """Return a request handler that pings every configured MCP server. + + Returns 200 with ``{"status": "ok"}`` when all servers respond to ping, + or 503 with per-server error details when any fail. When the agent has + no MCP tools configured the endpoint always returns 200. + + Checks run concurrently so that one slow server does not block others + or cause the liveness probe to time out. + """ + connection_params_list: list[Any] = [] + if agent_config: + for cfg in agent_config.http_tools or []: + connection_params_list.append(cfg.params) + for cfg in agent_config.sse_tools or []: + connection_params_list.append(cfg.params) + + async def _check_one(params: Any) -> tuple[str, Optional[str]]: + """Return (url, error_string | None) for a single MCP server.""" + url = getattr(params, "url", "unknown") + mgr = KAgentMCPSessionManager(connection_params=params) + try: + await asyncio.wait_for( + mgr.create_session(), timeout=_MCP_HEALTH_TIMEOUT_SECONDS + ) + return url, None + except Exception as exc: + if _is_server_alive_error(exc): + return url, None + return url, str(exc) + finally: + try: + await mgr.close() + except Exception: + pass + + async def mcp_health(request: Request): + if not connection_params_list: + return JSONResponse({"status": "ok", "servers": 0}) + + results = await asyncio.gather( + *(_check_one(p) for p in connection_params_list) + ) + errors = {url: err for url, err in results if err is not None} + + if errors: + return JSONResponse( + {"status": "error", "errors": errors}, + status_code=503, + ) + return JSONResponse( + {"status": "ok", "servers": len(connection_params_list)} + ) + + return mcp_health + + kagent_url_override = os.getenv("KAGENT_URL") @@ -170,6 +232,10 @@ def create_runner() -> Runner: # Health check/readiness probe app.add_route("/health", methods=["GET"], route=health_check) app.add_route("/thread_dump", methods=["GET"], route=thread_dump) + + mcp_health = _build_mcp_health_check(self.agent_config) + app.add_route("/healthz/mcp", methods=["GET"], route=mcp_health) + a2a_app.add_routes_to_app(app) return app diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 26c4c6df7..5774db14c 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -2,13 +2,72 @@ import asyncio import logging -from typing import Optional +from typing import Dict, Optional from google.adk.tools import BaseTool +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager from google.adk.tools.mcp_tool.mcp_toolset import McpToolset, ReadonlyContext +from mcp import ClientSession +from mcp.shared.exceptions import McpError logger = logging.getLogger("kagent_adk." + __name__) +_PING_TIMEOUT_SECONDS = 2.0 +_JSONRPC_METHOD_NOT_FOUND = -32601 + + +def _is_server_alive_error(exc: Exception) -> bool: + """Return True if the error proves the server is reachable. + + Some MCP servers don't implement the optional ``ping`` method and + reply with JSON-RPC "Method not found" (-32601). This still means + the session is valid and the server is responding. + """ + if isinstance(exc, McpError): + return exc.error.code == _JSONRPC_METHOD_NOT_FOUND + return False + + +class KAgentMCPSessionManager(MCPSessionManager): + """Session manager that validates cached sessions via ping before reuse. + + The upstream ``MCPSessionManager`` checks ``_read_stream._closed`` / + ``_write_stream._closed`` to decide whether a cached session is still + usable. Those are in-memory anyio channels that stay open even when + the remote MCP server restarts, so the check always passes and the + stale ``mcp-session-id`` is sent to the new server, which replies + with HTTP 404 → ``"Session terminated"``. + + This subclass adds a lightweight ``send_ping()`` probe after the + parent returns a cached session. If the ping fails the cached + session is torn down and a brand-new one is created transparently. + """ + + async def create_session( + self, headers: Optional[Dict[str, str]] = None + ) -> ClientSession: + session = await super().create_session(headers) + + try: + await asyncio.wait_for( + session.send_ping(), timeout=_PING_TIMEOUT_SECONDS + ) + return session + except Exception as exc: + if _is_server_alive_error(exc): + return session + logger.warning( + "MCP session failed ping validation, " + "invalidating cached session and creating a fresh one" + ) + try: + await self.close() + except Exception as close_exc: + logger.debug( + "Non-fatal error while closing stale session: %s", close_exc + ) + return await super().create_session(headers) + def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError: message = "Failed to create MCP session: operation cancelled" @@ -25,6 +84,13 @@ class KAgentMcpToolset(McpToolset): implementation may not catch and propagate without enough context. """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._mcp_session_manager = KAgentMCPSessionManager( + connection_params=self._connection_params, + errlog=self._errlog, + ) + async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]: try: return await super().get_tools(readonly_context) diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py b/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py new file mode 100644 index 000000000..795c57938 --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py @@ -0,0 +1,134 @@ +"""Tests for the /healthz/mcp endpoint.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from starlette.testclient import TestClient + +from kagent.adk._a2a import _build_mcp_health_check +from kagent.adk._mcp_toolset import KAgentMCPSessionManager +from kagent.adk.types import HttpMcpServerConfig + + +def _make_app(handler): + from fastapi import FastAPI + + app = FastAPI() + app.add_route("/healthz/mcp", methods=["GET"], route=handler) + return app + + +def _make_http_tool(url="http://mcp1.example.com/mcp"): + params = StreamableHTTPConnectionParams(url=url) + return HttpMcpServerConfig(params=params, tools=[]) + + +def test_no_mcp_tools_returns_ok(): + config = MagicMock() + config.http_tools = None + config.sse_tools = None + + handler = _build_mcp_health_check(config) + app = _make_app(handler) + + with TestClient(app) as client: + resp = client.get("/healthz/mcp") + + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert body["servers"] == 0 + + +def test_none_config_returns_ok(): + handler = _build_mcp_health_check(None) + app = _make_app(handler) + + with TestClient(app) as client: + resp = client.get("/healthz/mcp") + + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + +def test_healthy_mcp_returns_ok(): + config = MagicMock() + config.http_tools = [_make_http_tool()] + config.sse_tools = None + + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock(return_value=None) + + with patch.object( + KAgentMCPSessionManager, + "create_session", + new_callable=AsyncMock, + return_value=mock_session, + ), patch.object( + KAgentMCPSessionManager, "close", new_callable=AsyncMock + ): + handler = _build_mcp_health_check(config) + app = _make_app(handler) + + with TestClient(app) as client: + resp = client.get("/healthz/mcp") + + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert body["servers"] == 1 + + +def test_unhealthy_mcp_returns_503(): + config = MagicMock() + config.http_tools = [_make_http_tool("http://dead-mcp.example.com/mcp")] + config.sse_tools = None + + with patch.object( + KAgentMCPSessionManager, + "create_session", + new_callable=AsyncMock, + side_effect=ConnectionError("connection refused"), + ), patch.object( + KAgentMCPSessionManager, "close", new_callable=AsyncMock + ): + handler = _build_mcp_health_check(config) + app = _make_app(handler) + + with TestClient(app) as client: + resp = client.get("/healthz/mcp") + + assert resp.status_code == 503 + body = resp.json() + assert body["status"] == "error" + assert "http://dead-mcp.example.com/mcp" in body["errors"] + + +def test_method_not_found_treated_as_healthy(): + """MCP servers that don't support ping (-32601) should be reported as ok.""" + from mcp.shared.exceptions import McpError + from mcp.types import ErrorData + + config = MagicMock() + config.http_tools = [_make_http_tool("http://no-ping-mcp.example.com/mcp")] + config.sse_tools = None + + with patch.object( + KAgentMCPSessionManager, + "create_session", + new_callable=AsyncMock, + side_effect=McpError(error=ErrorData(code=-32601, message="Method not found")), + ), patch.object( + KAgentMCPSessionManager, "close", new_callable=AsyncMock + ): + handler = _build_mcp_health_check(config) + app = _make_app(handler) + + with TestClient(app) as client: + resp = client.get("/healthz/mcp") + + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py b/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py new file mode 100644 index 000000000..e83e3b404 --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py @@ -0,0 +1,182 @@ +"""Tests for KAgentMCPSessionManager ping-validated session caching.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData + +from kagent.adk._mcp_toolset import _PING_TIMEOUT_SECONDS, KAgentMCPSessionManager + + +def _make_manager(**overrides): + """Create a KAgentMCPSessionManager with a mocked connection_params.""" + params = MagicMock() + params.url = "http://mcp.example.com/mcp" + params.timeout = 5.0 + params.sse_read_timeout = 300.0 + params.headers = None + return KAgentMCPSessionManager(connection_params=params, **overrides) + + +@pytest.mark.asyncio +async def test_create_session_returns_session_when_ping_succeeds(): + mgr = _make_manager() + + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock(return_value=None) + + with patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + new_callable=AsyncMock, + return_value=mock_session, + ) as parent_create: + result = await mgr.create_session() + + assert result is mock_session + mock_session.send_ping.assert_awaited_once() + parent_create.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_session_invalidates_and_retries_when_ping_fails(): + mgr = _make_manager() + + stale_session = AsyncMock() + stale_session.send_ping = AsyncMock(side_effect=Exception("Session terminated")) + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return stale_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object(mgr, "close", new_callable=AsyncMock) as mock_close, + ): + result = await mgr.create_session() + + assert result is fresh_session + mock_close.assert_awaited_once() + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_create_session_propagates_error_when_server_truly_down(): + mgr = _make_manager() + + with patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + new_callable=AsyncMock, + side_effect=ConnectionError("Failed to create MCP session"), + ): + with pytest.raises(ConnectionError, match="Failed to create MCP session"): + await mgr.create_session() + + +@pytest.mark.asyncio +async def test_create_session_ping_respects_timeout(): + mgr = _make_manager() + + async def _slow_ping(): + await asyncio.sleep(10) + + slow_session = AsyncMock() + slow_session.send_ping = _slow_ping + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return slow_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object(mgr, "close", new_callable=AsyncMock), + ): + result = await mgr.create_session() + + assert result is fresh_session + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_create_session_accepts_method_not_found_as_alive(): + """Servers that don't implement ping reply with -32601; session is still valid.""" + mgr = _make_manager() + + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock( + side_effect=McpError(error=ErrorData(code=-32601, message="Method not found: ping")) + ) + + with patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + new_callable=AsyncMock, + return_value=mock_session, + ) as parent_create: + result = await mgr.create_session() + + assert result is mock_session + parent_create.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_session_recovers_even_when_close_raises(): + """Recovery must succeed even if close() raises during stale session teardown.""" + mgr = _make_manager() + + stale_session = AsyncMock() + stale_session.send_ping = AsyncMock(side_effect=Exception("Session terminated")) + + fresh_session = AsyncMock() + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return stale_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object( + mgr, "close", new_callable=AsyncMock, + side_effect=RuntimeError("cancel scope in different task"), + ), + ): + result = await mgr.create_session() + + assert result is fresh_session + assert call_count == 2 From 9ab8f5ad3996fd4ec6742dae61d9682f16665155 Mon Sep 17 00:00:00 2001 From: Denis Tu Date: Tue, 3 Mar 2026 05:03:16 -0800 Subject: [PATCH 2/6] Pull request overview fixes Signed-off-by: Denis Tu --- .../kagent-adk/src/kagent/adk/_a2a.py | 16 +++---- .../kagent-adk/src/kagent/adk/_mcp_toolset.py | 19 ++------ .../unittests/test_mcp_health_endpoint.py | 47 ++++++++++--------- .../unittests/test_mcp_session_recovery.py | 4 +- 4 files changed, 39 insertions(+), 47 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_a2a.py b/python/packages/kagent-adk/src/kagent/adk/_a2a.py index fc3e26399..b1249743e 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_a2a.py +++ b/python/packages/kagent-adk/src/kagent/adk/_a2a.py @@ -77,9 +77,7 @@ async def _check_one(params: Any) -> tuple[str, Optional[str]]: url = getattr(params, "url", "unknown") mgr = KAgentMCPSessionManager(connection_params=params) try: - await asyncio.wait_for( - mgr.create_session(), timeout=_MCP_HEALTH_TIMEOUT_SECONDS - ) + await asyncio.wait_for(mgr.create_session(), timeout=_MCP_HEALTH_TIMEOUT_SECONDS) return url, None except Exception as exc: if _is_server_alive_error(exc): @@ -95,19 +93,17 @@ async def mcp_health(request: Request): if not connection_params_list: return JSONResponse({"status": "ok", "servers": 0}) - results = await asyncio.gather( - *(_check_one(p) for p in connection_params_list) - ) + results = await asyncio.gather(*(_check_one(p) for p in connection_params_list)) errors = {url: err for url, err in results if err is not None} if errors: + for url, err in errors.items(): + logger.warning("MCP health check failed: %s: %s", url, err) return JSONResponse( - {"status": "error", "errors": errors}, + {"status": "error", "unhealthy_count": len(errors)}, status_code=503, ) - return JSONResponse( - {"status": "ok", "servers": len(connection_params_list)} - ) + return JSONResponse({"status": "ok", "servers": len(connection_params_list)}) return mcp_health diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 5774db14c..88419b834 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -2,7 +2,7 @@ import asyncio import logging -from typing import Dict, Optional +from typing import Optional from google.adk.tools import BaseTool from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager @@ -43,29 +43,20 @@ class KAgentMCPSessionManager(MCPSessionManager): session is torn down and a brand-new one is created transparently. """ - async def create_session( - self, headers: Optional[Dict[str, str]] = None - ) -> ClientSession: + async def create_session(self, headers: dict[str, str] | None = None) -> ClientSession: session = await super().create_session(headers) try: - await asyncio.wait_for( - session.send_ping(), timeout=_PING_TIMEOUT_SECONDS - ) + await asyncio.wait_for(session.send_ping(), timeout=_PING_TIMEOUT_SECONDS) return session except Exception as exc: if _is_server_alive_error(exc): return session - logger.warning( - "MCP session failed ping validation, " - "invalidating cached session and creating a fresh one" - ) + logger.warning("MCP session failed ping validation, invalidating cached session and creating a fresh one") try: await self.close() except Exception as close_exc: - logger.debug( - "Non-fatal error while closing stale session: %s", close_exc - ) + logger.debug("Non-fatal error while closing stale session: %s", close_exc) return await super().create_session(headers) diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py b/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py index 795c57938..1c2277da5 100644 --- a/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py @@ -61,13 +61,14 @@ def test_healthy_mcp_returns_ok(): mock_session = AsyncMock() mock_session.send_ping = AsyncMock(return_value=None) - with patch.object( - KAgentMCPSessionManager, - "create_session", - new_callable=AsyncMock, - return_value=mock_session, - ), patch.object( - KAgentMCPSessionManager, "close", new_callable=AsyncMock + with ( + patch.object( + KAgentMCPSessionManager, + "create_session", + new_callable=AsyncMock, + return_value=mock_session, + ), + patch.object(KAgentMCPSessionManager, "close", new_callable=AsyncMock), ): handler = _build_mcp_health_check(config) app = _make_app(handler) @@ -86,13 +87,14 @@ def test_unhealthy_mcp_returns_503(): config.http_tools = [_make_http_tool("http://dead-mcp.example.com/mcp")] config.sse_tools = None - with patch.object( - KAgentMCPSessionManager, - "create_session", - new_callable=AsyncMock, - side_effect=ConnectionError("connection refused"), - ), patch.object( - KAgentMCPSessionManager, "close", new_callable=AsyncMock + with ( + patch.object( + KAgentMCPSessionManager, + "create_session", + new_callable=AsyncMock, + side_effect=ConnectionError("connection refused"), + ), + patch.object(KAgentMCPSessionManager, "close", new_callable=AsyncMock), ): handler = _build_mcp_health_check(config) app = _make_app(handler) @@ -103,7 +105,7 @@ def test_unhealthy_mcp_returns_503(): assert resp.status_code == 503 body = resp.json() assert body["status"] == "error" - assert "http://dead-mcp.example.com/mcp" in body["errors"] + assert body["unhealthy_count"] == 1 def test_method_not_found_treated_as_healthy(): @@ -115,13 +117,14 @@ def test_method_not_found_treated_as_healthy(): config.http_tools = [_make_http_tool("http://no-ping-mcp.example.com/mcp")] config.sse_tools = None - with patch.object( - KAgentMCPSessionManager, - "create_session", - new_callable=AsyncMock, - side_effect=McpError(error=ErrorData(code=-32601, message="Method not found")), - ), patch.object( - KAgentMCPSessionManager, "close", new_callable=AsyncMock + with ( + patch.object( + KAgentMCPSessionManager, + "create_session", + new_callable=AsyncMock, + side_effect=McpError(error=ErrorData(code=-32601, message="Method not found")), + ), + patch.object(KAgentMCPSessionManager, "close", new_callable=AsyncMock), ): handler = _build_mcp_health_check(config) app = _make_app(handler) diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py b/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py index e83e3b404..6c61cfc03 100644 --- a/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py @@ -172,7 +172,9 @@ async def _parent_create(headers=None): side_effect=_parent_create, ), patch.object( - mgr, "close", new_callable=AsyncMock, + mgr, + "close", + new_callable=AsyncMock, side_effect=RuntimeError("cancel scope in different task"), ), ): From 36d80acdd07e80c64f420caf4108c2f53acb0546 Mon Sep 17 00:00:00 2001 From: Denis Tu Date: Wed, 4 Mar 2026 00:24:20 -0800 Subject: [PATCH 3/6] remove adk LivenessProbe and update mcpSessionManager Signed-off-by: Denis Tu --- .../translator/agent/adk_api_translator.go | 9 -- .../kagent-adk/src/kagent/adk/_mcp_toolset.py | 50 +++++++-- .../unittests/test_mcp_session_recovery.py | 103 ++++++++++++++++++ 3 files changed, 141 insertions(+), 21 deletions(-) diff --git a/go/core/internal/controller/translator/agent/adk_api_translator.go b/go/core/internal/controller/translator/agent/adk_api_translator.go index 81c59b246..b3f7b36d5 100644 --- a/go/core/internal/controller/translator/agent/adk_api_translator.go +++ b/go/core/internal/controller/translator/agent/adk_api_translator.go @@ -568,15 +568,6 @@ func (a *adkApiTranslator) buildManifest( TimeoutSeconds: probeConf.TimeoutSeconds, PeriodSeconds: probeConf.PeriodSeconds, }, - LivenessProbe: &corev1.Probe{ - ProbeHandler: corev1.ProbeHandler{ - HTTPGet: &corev1.HTTPGetAction{Path: "/healthz/mcp", Port: intstr.FromString("http")}, - }, - InitialDelaySeconds: 30, - PeriodSeconds: 30, - TimeoutSeconds: 10, - FailureThreshold: 3, - }, SecurityContext: securityContext, VolumeMounts: volumeMounts, }}, diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 88419b834..6ad6ec0c7 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -13,6 +13,7 @@ logger = logging.getLogger("kagent_adk." + __name__) _PING_TIMEOUT_SECONDS = 2.0 +_SESSION_REVALIDATE_TIMEOUT_SECONDS = 5.0 _JSONRPC_METHOD_NOT_FOUND = -32601 @@ -28,8 +29,17 @@ def _is_server_alive_error(exc: Exception) -> bool: return False +def _is_session_invalid_error(exc: Exception) -> bool: + """Return True if the error indicates the MCP session is no longer valid (e.g. 404).""" + parts = [str(exc)] + if isinstance(exc, McpError) and exc.error.message: + parts.append(exc.error.message) + msg = " ".join(parts).lower() + return "404" in msg or "session terminated" in msg + + class KAgentMCPSessionManager(MCPSessionManager): - """Session manager that validates cached sessions via ping before reuse. + """Session manager that validates cached sessions via ping and list_tools before reuse. The upstream ``MCPSessionManager`` checks ``_read_stream._closed`` / ``_write_stream._closed`` to decide whether a cached session is still @@ -38,9 +48,10 @@ class KAgentMCPSessionManager(MCPSessionManager): stale ``mcp-session-id`` is sent to the new server, which replies with HTTP 404 → ``"Session terminated"``. - This subclass adds a lightweight ``send_ping()`` probe after the - parent returns a cached session. If the ping fails the cached - session is torn down and a brand-new one is created transparently. + This subclass: (1) runs ``send_ping()`` after the parent returns a cached + session; (2) then revalidates the session with ``list_tools()`` so that + if the server restarted and the session id is invalid (404), we prune + the session from the cache and create a new one. """ async def create_session(self, headers: dict[str, str] | None = None) -> ClientSession: @@ -48,16 +59,31 @@ async def create_session(self, headers: dict[str, str] | None = None) -> ClientS try: await asyncio.wait_for(session.send_ping(), timeout=_PING_TIMEOUT_SECONDS) - return session except Exception as exc: if _is_server_alive_error(exc): - return session - logger.warning("MCP session failed ping validation, invalidating cached session and creating a fresh one") - try: - await self.close() - except Exception as close_exc: - logger.debug("Non-fatal error while closing stale session: %s", close_exc) - return await super().create_session(headers) + pass + else: + logger.warning( + "MCP session failed ping validation, invalidating cached session and creating a fresh one" + ) + try: + await self.close() + except Exception as close_exc: + logger.debug("Non-fatal error while closing stale session: %s", close_exc) + return await super().create_session(headers) + + try: + await asyncio.wait_for(session.list_tools(), timeout=_SESSION_REVALIDATE_TIMEOUT_SECONDS) + return session + except Exception as exc: + if _is_session_invalid_error(exc): + logger.warning("MCP session invalid (e.g. 404), pruning from cache and creating a fresh one") + try: + await self.close() + except Exception as close_exc: + logger.debug("Non-fatal error while closing stale session: %s", close_exc) + return await super().create_session(headers) + raise def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError: diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py b/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py index 6c61cfc03..43124c766 100644 --- a/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_session_recovery.py @@ -26,6 +26,7 @@ async def test_create_session_returns_session_when_ping_succeeds(): mock_session = AsyncMock() mock_session.send_ping = AsyncMock(return_value=None) + mock_session.list_tools = AsyncMock(return_value=[]) with patch.object( KAgentMCPSessionManager.__bases__[0], @@ -37,6 +38,7 @@ async def test_create_session_returns_session_when_ping_succeeds(): assert result is mock_session mock_session.send_ping.assert_awaited_once() + mock_session.list_tools.assert_awaited_once() parent_create.assert_awaited_once() @@ -49,6 +51,7 @@ async def test_create_session_invalidates_and_retries_when_ping_fails(): fresh_session = AsyncMock() fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) call_count = 0 @@ -100,6 +103,7 @@ async def _slow_ping(): fresh_session = AsyncMock() fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) call_count = 0 @@ -133,6 +137,7 @@ async def test_create_session_accepts_method_not_found_as_alive(): mock_session.send_ping = AsyncMock( side_effect=McpError(error=ErrorData(code=-32601, message="Method not found: ping")) ) + mock_session.list_tools = AsyncMock(return_value=[]) with patch.object( KAgentMCPSessionManager.__bases__[0], @@ -143,9 +148,105 @@ async def test_create_session_accepts_method_not_found_as_alive(): result = await mgr.create_session() assert result is mock_session + mock_session.list_tools.assert_awaited_once() parent_create.assert_awaited_once() +@pytest.mark.asyncio +async def test_create_session_invalidates_when_list_tools_returns_session_terminated(): + """After ping passes, list_tools is used to revalidate; 404/session terminated → prune and recreate.""" + mgr = _make_manager() + + stale_session = AsyncMock() + stale_session.send_ping = AsyncMock(return_value=None) + stale_session.list_tools = AsyncMock(side_effect=Exception("Session terminated")) + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return stale_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object(mgr, "close", new_callable=AsyncMock) as mock_close, + ): + result = await mgr.create_session() + + assert result is fresh_session + mock_close.assert_awaited_once() + assert call_count == 2 + stale_session.list_tools.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_session_invalidates_when_list_tools_returns_404(): + """list_tools returning 404 (session invalid) triggers prune and recreate.""" + mgr = _make_manager() + + stale_session = AsyncMock() + stale_session.send_ping = AsyncMock(return_value=None) + stale_session.list_tools = AsyncMock(side_effect=McpError(error=ErrorData(code=-32000, message="404 Not Found"))) + + fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) + + call_count = 0 + + async def _parent_create(headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return stale_session + return fresh_session + + with ( + patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + side_effect=_parent_create, + ), + patch.object(mgr, "close", new_callable=AsyncMock) as mock_close, + ): + result = await mgr.create_session() + + assert result is fresh_session + mock_close.assert_awaited_once() + assert call_count == 2 + stale_session.list_tools.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_session_propagates_non_session_error_from_list_tools(): + """Transient errors (e.g. timeout) from list_tools are propagated, not treated as session invalid.""" + mgr = _make_manager() + + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock(return_value=None) + mock_session.list_tools = AsyncMock(side_effect=asyncio.TimeoutError()) + + with patch.object( + KAgentMCPSessionManager.__bases__[0], + "create_session", + new_callable=AsyncMock, + return_value=mock_session, + ): + with pytest.raises(asyncio.TimeoutError): + await mgr.create_session() + + @pytest.mark.asyncio async def test_create_session_recovers_even_when_close_raises(): """Recovery must succeed even if close() raises during stale session teardown.""" @@ -155,6 +256,8 @@ async def test_create_session_recovers_even_when_close_raises(): stale_session.send_ping = AsyncMock(side_effect=Exception("Session terminated")) fresh_session = AsyncMock() + fresh_session.send_ping = AsyncMock(return_value=None) + fresh_session.list_tools = AsyncMock(return_value=[]) call_count = 0 From 3a3baffdac80018d5fe0e2f42c44f692ee09cf0d Mon Sep 17 00:00:00 2001 From: Denis Tu Date: Wed, 4 Mar 2026 12:09:27 -0800 Subject: [PATCH 4/6] remove mcp health check Signed-off-by: Denis Tu --- .../kagent-adk/src/kagent/adk/_a2a.py | 61 -------- .../unittests/test_mcp_health_endpoint.py | 137 ------------------ 2 files changed, 198 deletions(-) delete mode 100644 python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py diff --git a/python/packages/kagent-adk/src/kagent/adk/_a2a.py b/python/packages/kagent-adk/src/kagent/adk/_a2a.py index b1249743e..9e275a583 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_a2a.py +++ b/python/packages/kagent-adk/src/kagent/adk/_a2a.py @@ -21,8 +21,6 @@ from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService from google.genai import types -from starlette.responses import JSONResponse - from kagent.core.a2a import ( KAgentRequestContextBuilder, KAgentTaskStore, @@ -31,7 +29,6 @@ from ._agent_executor import A2aAgentExecutor, A2aAgentExecutorConfig from ._lifespan import LifespanManager -from ._mcp_toolset import KAgentMCPSessionManager, _is_server_alive_error from ._memory_service import KagentMemoryService from ._session_service import KAgentSessionService from ._token import KAgentTokenService @@ -39,8 +36,6 @@ logger = logging.getLogger(__name__) -_MCP_HEALTH_TIMEOUT_SECONDS = 5.0 - def health_check(request: Request) -> PlainTextResponse: return PlainTextResponse("OK") @@ -55,59 +50,6 @@ def thread_dump(request: Request) -> PlainTextResponse: return PlainTextResponse(tmp.read()) -def _build_mcp_health_check(agent_config: Optional[AgentConfig]): - """Return a request handler that pings every configured MCP server. - - Returns 200 with ``{"status": "ok"}`` when all servers respond to ping, - or 503 with per-server error details when any fail. When the agent has - no MCP tools configured the endpoint always returns 200. - - Checks run concurrently so that one slow server does not block others - or cause the liveness probe to time out. - """ - connection_params_list: list[Any] = [] - if agent_config: - for cfg in agent_config.http_tools or []: - connection_params_list.append(cfg.params) - for cfg in agent_config.sse_tools or []: - connection_params_list.append(cfg.params) - - async def _check_one(params: Any) -> tuple[str, Optional[str]]: - """Return (url, error_string | None) for a single MCP server.""" - url = getattr(params, "url", "unknown") - mgr = KAgentMCPSessionManager(connection_params=params) - try: - await asyncio.wait_for(mgr.create_session(), timeout=_MCP_HEALTH_TIMEOUT_SECONDS) - return url, None - except Exception as exc: - if _is_server_alive_error(exc): - return url, None - return url, str(exc) - finally: - try: - await mgr.close() - except Exception: - pass - - async def mcp_health(request: Request): - if not connection_params_list: - return JSONResponse({"status": "ok", "servers": 0}) - - results = await asyncio.gather(*(_check_one(p) for p in connection_params_list)) - errors = {url: err for url, err in results if err is not None} - - if errors: - for url, err in errors.items(): - logger.warning("MCP health check failed: %s: %s", url, err) - return JSONResponse( - {"status": "error", "unhealthy_count": len(errors)}, - status_code=503, - ) - return JSONResponse({"status": "ok", "servers": len(connection_params_list)}) - - return mcp_health - - kagent_url_override = os.getenv("KAGENT_URL") @@ -229,9 +171,6 @@ def create_runner() -> Runner: app.add_route("/health", methods=["GET"], route=health_check) app.add_route("/thread_dump", methods=["GET"], route=thread_dump) - mcp_health = _build_mcp_health_check(self.agent_config) - app.add_route("/healthz/mcp", methods=["GET"], route=mcp_health) - a2a_app.add_routes_to_app(app) return app diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py b/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py deleted file mode 100644 index 1c2277da5..000000000 --- a/python/packages/kagent-adk/tests/unittests/test_mcp_health_endpoint.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Tests for the /healthz/mcp endpoint.""" - -import asyncio -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams -from starlette.testclient import TestClient - -from kagent.adk._a2a import _build_mcp_health_check -from kagent.adk._mcp_toolset import KAgentMCPSessionManager -from kagent.adk.types import HttpMcpServerConfig - - -def _make_app(handler): - from fastapi import FastAPI - - app = FastAPI() - app.add_route("/healthz/mcp", methods=["GET"], route=handler) - return app - - -def _make_http_tool(url="http://mcp1.example.com/mcp"): - params = StreamableHTTPConnectionParams(url=url) - return HttpMcpServerConfig(params=params, tools=[]) - - -def test_no_mcp_tools_returns_ok(): - config = MagicMock() - config.http_tools = None - config.sse_tools = None - - handler = _build_mcp_health_check(config) - app = _make_app(handler) - - with TestClient(app) as client: - resp = client.get("/healthz/mcp") - - assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "ok" - assert body["servers"] == 0 - - -def test_none_config_returns_ok(): - handler = _build_mcp_health_check(None) - app = _make_app(handler) - - with TestClient(app) as client: - resp = client.get("/healthz/mcp") - - assert resp.status_code == 200 - assert resp.json()["status"] == "ok" - - -def test_healthy_mcp_returns_ok(): - config = MagicMock() - config.http_tools = [_make_http_tool()] - config.sse_tools = None - - mock_session = AsyncMock() - mock_session.send_ping = AsyncMock(return_value=None) - - with ( - patch.object( - KAgentMCPSessionManager, - "create_session", - new_callable=AsyncMock, - return_value=mock_session, - ), - patch.object(KAgentMCPSessionManager, "close", new_callable=AsyncMock), - ): - handler = _build_mcp_health_check(config) - app = _make_app(handler) - - with TestClient(app) as client: - resp = client.get("/healthz/mcp") - - assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "ok" - assert body["servers"] == 1 - - -def test_unhealthy_mcp_returns_503(): - config = MagicMock() - config.http_tools = [_make_http_tool("http://dead-mcp.example.com/mcp")] - config.sse_tools = None - - with ( - patch.object( - KAgentMCPSessionManager, - "create_session", - new_callable=AsyncMock, - side_effect=ConnectionError("connection refused"), - ), - patch.object(KAgentMCPSessionManager, "close", new_callable=AsyncMock), - ): - handler = _build_mcp_health_check(config) - app = _make_app(handler) - - with TestClient(app) as client: - resp = client.get("/healthz/mcp") - - assert resp.status_code == 503 - body = resp.json() - assert body["status"] == "error" - assert body["unhealthy_count"] == 1 - - -def test_method_not_found_treated_as_healthy(): - """MCP servers that don't support ping (-32601) should be reported as ok.""" - from mcp.shared.exceptions import McpError - from mcp.types import ErrorData - - config = MagicMock() - config.http_tools = [_make_http_tool("http://no-ping-mcp.example.com/mcp")] - config.sse_tools = None - - with ( - patch.object( - KAgentMCPSessionManager, - "create_session", - new_callable=AsyncMock, - side_effect=McpError(error=ErrorData(code=-32601, message="Method not found")), - ), - patch.object(KAgentMCPSessionManager, "close", new_callable=AsyncMock), - ): - handler = _build_mcp_health_check(config) - app = _make_app(handler) - - with TestClient(app) as client: - resp = client.get("/healthz/mcp") - - assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "ok" From 695d2fa9975a301f6e74edf1dd9b4a4e7964eea1 Mon Sep 17 00:00:00 2001 From: Denis Tu Date: Thu, 5 Mar 2026 02:39:26 -0800 Subject: [PATCH 5/6] _close_and_recreate_session Signed-off-by: Denis Tu --- .../kagent-adk/src/kagent/adk/_mcp_toolset.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 6ad6ec0c7..f43554de9 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -12,8 +12,9 @@ logger = logging.getLogger("kagent_adk." + __name__) -_PING_TIMEOUT_SECONDS = 2.0 -_SESSION_REVALIDATE_TIMEOUT_SECONDS = 5.0 +# Short timeouts to fail fast on the request path; avoid adding latency when session is valid. +_PING_TIMEOUT_SECONDS = 1.0 +_SESSION_REVALIDATE_TIMEOUT_SECONDS = 2.0 _JSONRPC_METHOD_NOT_FOUND = -32601 @@ -54,6 +55,17 @@ class KAgentMCPSessionManager(MCPSessionManager): the session from the cache and create a new one. """ + async def _close_and_recreate_session( + self, headers: dict[str, str] | None, reason: str + ) -> ClientSession: + """Close the cached session (best-effort) and create a new one.""" + logger.warning("%s", reason) + try: + await self.close() + except Exception as close_exc: + logger.debug("Non-fatal error while closing stale session: %s", close_exc) + return await super().create_session(headers) + async def create_session(self, headers: dict[str, str] | None = None) -> ClientSession: session = await super().create_session(headers) @@ -63,26 +75,20 @@ async def create_session(self, headers: dict[str, str] | None = None) -> ClientS if _is_server_alive_error(exc): pass else: - logger.warning( - "MCP session failed ping validation, invalidating cached session and creating a fresh one" + return await self._close_and_recreate_session( + headers, + "MCP session failed ping validation, invalidating cached session and creating a fresh one", ) - try: - await self.close() - except Exception as close_exc: - logger.debug("Non-fatal error while closing stale session: %s", close_exc) - return await super().create_session(headers) try: await asyncio.wait_for(session.list_tools(), timeout=_SESSION_REVALIDATE_TIMEOUT_SECONDS) return session except Exception as exc: if _is_session_invalid_error(exc): - logger.warning("MCP session invalid (e.g. 404), pruning from cache and creating a fresh one") - try: - await self.close() - except Exception as close_exc: - logger.debug("Non-fatal error while closing stale session: %s", close_exc) - return await super().create_session(headers) + return await self._close_and_recreate_session( + headers, + "MCP session invalid (e.g. 404), pruning from cache and creating a fresh one", + ) raise From b283686ab16f9022dbd0adeeadecc8ba1e7f9c31 Mon Sep 17 00:00:00 2001 From: Denis Tu Date: Tue, 10 Mar 2026 15:51:14 -0700 Subject: [PATCH 6/6] apply ruff format Signed-off-by: Denis Tu --- python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index f43554de9..d1eea8637 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -55,9 +55,7 @@ class KAgentMCPSessionManager(MCPSessionManager): the session from the cache and create a new one. """ - async def _close_and_recreate_session( - self, headers: dict[str, str] | None, reason: str - ) -> ClientSession: + async def _close_and_recreate_session(self, headers: dict[str, str] | None, reason: str) -> ClientSession: """Close the cached session (best-effort) and create a new one.""" logger.warning("%s", reason) try: