From aeb802d3a375e20e71c17a9e7e79fb54d0906446 Mon Sep 17 00:00:00 2001 From: Prince Roshan Date: Fri, 9 Jan 2026 00:00:37 +0530 Subject: [PATCH 1/4] Add optional client ping --- README.md | 14 +++++ src/fast_agent/config.py | 16 ++++++ .../mcp/mcp_agent_client_session.py | 11 ++++ src/fast_agent/mcp/mcp_connection_manager.py | 51 ++++++++++++++++++- 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9a4837fc8..d1701568a 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,20 @@ mcp: - To disable OAuth for a specific server , set `auth.oauth: false` for that server. +## MCP Ping (optional) + +The MCP ping utility is optional and can be enabled by either peer (client or server). See the [Ping overview](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/ping#overview). + +Client-side pinging is opt-in and configured per server: + +```yaml +mcp: + servers: + myserver: + ping_interval_seconds: 30 # optional; <=0 or unset disables + max_missed_pings: 2 # optional; consecutive timeouts before marking failed +``` + ## Workflows ### Chain diff --git a/src/fast_agent/config.py b/src/fast_agent/config.py index cd2c3f078..36498f930 100644 --- a/src/fast_agent/config.py +++ b/src/fast_agent/config.py @@ -213,6 +213,12 @@ class MCPServerSettings(BaseModel): read_timeout_seconds: int | None = None """The timeout in seconds for the session.""" + ping_interval_seconds: int | None = None + """Optional interval for MCP ping requests. When unset or <=0, pinging is disabled.""" + + max_missed_pings: int = 1 + """Number of consecutive missed ping responses before treating the connection as failed.""" + http_timeout_seconds: int | None = None """Overall HTTP timeout (seconds) for StreamableHTTP transport. Defaults to MCP SDK.""" @@ -262,6 +268,16 @@ class MCPServerSettings(BaseModel): implementation: Implementation | None = None + @field_validator("max_missed_pings", mode="before") + @classmethod + def _coerce_max_missed_pings(cls, value: Any) -> int: + if isinstance(value, str): + value = int(value.strip()) + value = int(value) + if value <= 0: + raise ValueError("max_missed_pings must be greater than zero.") + return value + @model_validator(mode="before") @classmethod def validate_transport_inference(cls, values): diff --git a/src/fast_agent/mcp/mcp_agent_client_session.py b/src/fast_agent/mcp/mcp_agent_client_session.py index 60ab85739..9b76312d3 100644 --- a/src/fast_agent/mcp/mcp_agent_client_session.py +++ b/src/fast_agent/mcp/mcp_agent_client_session.py @@ -18,11 +18,13 @@ CallToolRequestParams, CallToolResult, ClientRequest, + EmptyResult, GetPromptRequest, GetPromptRequestParams, GetPromptResult, Implementation, ListRootsResult, + PingRequest, ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, @@ -365,6 +367,15 @@ async def call_tool( progress_callback=progress_callback, ) + async def ping(self, read_timeout_seconds: timedelta | None = None) -> EmptyResult: + """Send a ping request to check server liveness.""" + request = PingRequest(method="ping") + return await self.send_request( + ClientRequest(request), + EmptyResult, + request_read_timeout_seconds=read_timeout_seconds, + ) + async def read_resource( self, uri: AnyUrl | str, _meta: dict | None = None, **kwargs ) -> ReadResourceResult: diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index 211b8daf1..b1a838f9a 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -6,7 +6,7 @@ import traceback from contextlib import AbstractAsyncContextManager from datetime import timedelta -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Callable, Union, cast import httpx from anyio import Event, Lock, create_task_group @@ -245,6 +245,45 @@ def create_session( return session +async def _run_ping_loop(server_conn: ServerConnection) -> None: + interval = server_conn.server_config.ping_interval_seconds + if not interval or interval <= 0: + return + + max_missed = server_conn.server_config.max_missed_pings + missed = 0 + read_timeout = ( + timedelta(seconds=server_conn.server_config.read_timeout_seconds) + if server_conn.server_config.read_timeout_seconds + else None + ) + + while not server_conn._shutdown_event.is_set(): + await asyncio.sleep(interval) + if server_conn._shutdown_event.is_set(): + break + session = server_conn.session + if session is None: + break + if not hasattr(session, "ping"): + return + try: + await cast("MCPAgentClientSession", session).ping(read_timeout_seconds=read_timeout) + missed = 0 + except Exception as exc: + missed += 1 + logger.warning( + f"{server_conn.server_name}: Ping failed ({missed}/{max_missed}): {exc}" + ) + if missed >= max_missed: + server_conn._error_occurred = True + server_conn._error_message = ( + f"Ping failed {missed} time(s); last error: {exc}" + ) + server_conn.request_shutdown() + break + + async def _server_lifecycle_task(server_conn: ServerConnection) -> None: """ Manage the lifecycle of a single server connection. @@ -285,7 +324,15 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None: elif server_conn.server_config.transport == "stdio": server_conn.session_id = "local" - await server_conn.wait_for_shutdown_request() + if ( + server_conn.server_config.ping_interval_seconds + and server_conn.server_config.ping_interval_seconds > 0 + ): + async with create_task_group() as ping_group: + ping_group.start_soon(_run_ping_loop, server_conn) + await server_conn.wait_for_shutdown_request() + else: + await server_conn.wait_for_shutdown_request() except Exception as session_exit_exc: # Catch exceptions during session cleanup (e.g., when session was terminated) # This prevents cleanup errors from propagating to the task group From 1fe45792fdd81136678cba594daa97125f59419c Mon Sep 17 00:00:00 2001 From: Prince Roshan Date: Fri, 9 Jan 2026 15:57:09 +0530 Subject: [PATCH 2/4] Add test --- src/fast_agent/mcp/mcp_connection_manager.py | 34 +++++++++---- .../mcp/test_mcp_connection_manager.py | 48 ++++++++++++++++++- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index b1a838f9a..f8ac86e2e 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -334,17 +334,31 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None: else: await server_conn.wait_for_shutdown_request() except Exception as session_exit_exc: - # Catch exceptions during session cleanup (e.g., when session was terminated) - # This prevents cleanup errors from propagating to the task group - logger.debug( - f"{server_name}: Exception during session cleanup (expected during reconnect): {session_exit_exc}" - ) + if server_conn._shutdown_event.is_set(): + # Cleanup errors can happen when disconnecting a session that was already + # terminated; treat as expected during shutdown. + logger.debug( + f"{server_name}: Exception during session cleanup (expected during shutdown): {session_exit_exc}" + ) + if not server_conn._initialized_event.is_set(): + server_conn._error_occurred = True + server_conn._error_message = "Shutdown requested before initialization" + server_conn._initialized_event.set() + else: + raise except Exception as transport_exit_exc: - # Catch exceptions during transport cleanup - # This can happen when disconnecting a session that was already terminated - logger.debug( - f"{server_name}: Exception during transport cleanup (expected during reconnect): {transport_exit_exc}" - ) + if server_conn._shutdown_event.is_set(): + # Cleanup errors can happen when disconnecting a transport that was already + # terminated; treat as expected during shutdown. + logger.debug( + f"{server_name}: Exception during transport cleanup (expected during shutdown): {transport_exit_exc}" + ) + if not server_conn._initialized_event.is_set(): + server_conn._error_occurred = True + server_conn._error_message = "Shutdown requested before initialization" + server_conn._initialized_event.set() + else: + raise except HTTPStatusError as http_exc: logger.error( diff --git a/tests/unit/fast_agent/mcp/test_mcp_connection_manager.py b/tests/unit/fast_agent/mcp/test_mcp_connection_manager.py index bbad00843..013c2da9b 100644 --- a/tests/unit/fast_agent/mcp/test_mcp_connection_manager.py +++ b/tests/unit/fast_agent/mcp/test_mcp_connection_manager.py @@ -1,6 +1,14 @@ +import asyncio + +import pytest + from fast_agent.config import MCPServerSettings -from fast_agent.mcp.mcp_connection_manager import _prepare_headers_and_auth +from fast_agent.mcp.mcp_connection_manager import ( + ServerConnection, + _prepare_headers_and_auth, + _server_lifecycle_task, +) def test_prepare_headers_respects_user_authorization(monkeypatch): @@ -76,3 +84,41 @@ def _builder(received_config: MCPServerSettings): assert auth is sentinel assert user_keys == set() assert calls == [config] + + +@pytest.mark.asyncio +async def test_server_lifecycle_sets_initialized_on_startup_failure(): + class DummyTransportContext: + async def __aenter__(self): + return object(), object(), None + + async def __aexit__(self, exc_type, exc, tb): + return None + + class DummySession: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def initialize(self): + raise RuntimeError("boom") + + def session_factory(*_args, **_kwargs): + return DummySession() + + server_conn = ServerConnection( + server_name="test-server", + server_config=MCPServerSettings(name="test-server", url="http://example.com/mcp"), + transport_context_factory=DummyTransportContext, + client_session_factory=session_factory, + ) + + lifecycle_task = asyncio.create_task(_server_lifecycle_task(server_conn)) + try: + await asyncio.wait_for(server_conn.wait_for_initialized(), timeout=1.0) + finally: + await lifecycle_task + + assert server_conn._error_occurred is True From 961b114e35544e768f65e59c3405e2fad89fdd54 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Sat, 10 Jan 2026 12:50:32 +0100 Subject: [PATCH 3/4] - shut down ping task group on close (was delaying shutdown) - add client side health data to `/mcp` display, deduplicate response counting - make /mcp non-blocking, tie health information - make client ping on by default --- README.md | 8 +- src/fast_agent/config.py | 6 +- .../mcp/mcp_agent_client_session.py | 28 +++ src/fast_agent/mcp/mcp_aggregator.py | 137 +++++++++++---- src/fast_agent/mcp/mcp_connection_manager.py | 68 +++++++- src/fast_agent/mcp/transport_tracking.py | 26 ++- src/fast_agent/ui/mcp_display.py | 165 ++++++++++++++++++ 7 files changed, 389 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index d1701568a..b6f5877d0 100644 --- a/README.md +++ b/README.md @@ -268,16 +268,16 @@ mcp: ## MCP Ping (optional) -The MCP ping utility is optional and can be enabled by either peer (client or server). See the [Ping overview](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/ping#overview). +The MCP ping utility can be enabled by either peer (client or server). See the [Ping overview](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/ping#overview). -Client-side pinging is opt-in and configured per server: +Client-side pinging is configured per server (default: 30s interval, 3 missed pings): ```yaml mcp: servers: myserver: - ping_interval_seconds: 30 # optional; <=0 or unset disables - max_missed_pings: 2 # optional; consecutive timeouts before marking failed + ping_interval_seconds: 30 # optional; <=0 disables + max_missed_pings: 3 # optional; consecutive timeouts before marking failed ``` ## Workflows diff --git a/src/fast_agent/config.py b/src/fast_agent/config.py index 42f29158b..1e99572f8 100644 --- a/src/fast_agent/config.py +++ b/src/fast_agent/config.py @@ -213,10 +213,10 @@ class MCPServerSettings(BaseModel): read_timeout_seconds: int | None = None """The timeout in seconds for the session.""" - ping_interval_seconds: int | None = None - """Optional interval for MCP ping requests. When unset or <=0, pinging is disabled.""" + ping_interval_seconds: int = 30 + """Interval for MCP ping requests. Set <=0 to disable pinging.""" - max_missed_pings: int = 1 + max_missed_pings: int = 3 """Number of consecutive missed ping responses before treating the connection as failed.""" http_timeout_seconds: int | None = None diff --git a/src/fast_agent/mcp/mcp_agent_client_session.py b/src/fast_agent/mcp/mcp_agent_client_session.py index 9b76312d3..4be316b22 100644 --- a/src/fast_agent/mcp/mcp_agent_client_session.py +++ b/src/fast_agent/mcp/mcp_agent_client_session.py @@ -226,6 +226,13 @@ async def send_request( ) -> ReceiveResultT: logger.debug("send_request: request=", data=request.model_dump()) request_id = getattr(self, "_request_id", None) + is_ping_request = self._is_ping_request(request) + if ( + is_ping_request + and request_id is not None + and self._transport_metrics is not None + ): + self._transport_metrics.register_ping_request(request_id) try: result = await super().send_request( request=request, @@ -239,8 +246,20 @@ async def send_request( data=result.model_dump() if result is not None else "no response returned", ) self._attach_transport_channel(request_id, result) + if ( + is_ping_request + and request_id is not None + and self._transport_metrics is not None + ): + self._transport_metrics.discard_ping_request(request_id) return result except Exception as e: + if ( + is_ping_request + and request_id is not None + and self._transport_metrics is not None + ): + self._transport_metrics.discard_ping_request(request_id) from anyio import ClosedResourceError from fast_agent.core.exceptions import ServerSessionTerminatedError @@ -264,6 +283,15 @@ async def send_request( logger.error(f"send_request failed: {str(e)}") raise + @staticmethod + def _is_ping_request(request: ClientRequest) -> bool: + root = getattr(request, "root", None) + method = getattr(root, "method", None) + if not isinstance(method, str): + return False + method_lower = method.lower() + return method_lower == "ping" or method_lower.endswith("/ping") or method_lower.endswith(".ping") + def _is_session_terminated_error(self, exc: Exception) -> bool: """Check if exception is a session terminated error (code 32600 from 404).""" from mcp.shared.exceptions import McpError diff --git a/src/fast_agent/mcp/mcp_aggregator.py b/src/fast_agent/mcp/mcp_aggregator.py index b7672b925..38934acd0 100644 --- a/src/fast_agent/mcp/mcp_aggregator.py +++ b/src/fast_agent/mcp/mcp_aggregator.py @@ -115,6 +115,17 @@ class ServerStatus(BaseModel): transport_channels: TransportSnapshot | None = None skybridge: SkybridgeServerConfig | None = None reconnect_count: int = 0 + ping_interval_seconds: int | None = None + ping_max_missed: int | None = None + ping_ok_count: int | None = None + ping_fail_count: int | None = None + ping_consecutive_failures: int | None = None + ping_last_ok_at: datetime | None = None + ping_last_fail_at: datetime | None = None + ping_last_error: str | None = None + ping_activity_buckets: list[str] | None = None + ping_activity_bucket_seconds: int | None = None + ping_activity_bucket_count: int | None = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -997,49 +1008,88 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: server_conn = None transport: str | None = None transport_snapshot: TransportSnapshot | None = None + ping_interval_seconds: int | None = None + ping_max_missed: int | None = None + ping_ok_count: int | None = None + ping_fail_count: int | None = None + ping_consecutive_failures: int | None = None + ping_last_ok_at: datetime | None = None + ping_last_fail_at: datetime | None = None + ping_last_error: str | None = None + ping_activity_buckets: list[str] | None = None + ping_activity_bucket_seconds: int | None = None + ping_activity_bucket_count: int | None = None manager = getattr(self, "_persistent_connection_manager", None) if self.connection_persistence and manager is not None: try: - server_conn = await manager.get_server( - server_name, - client_session_factory=self._create_session_factory(server_name), - ) - implementation = server_conn.server_implementation - if implementation is not None: - implementation_name = implementation.name - implementation_version = implementation.version - capabilities = server_conn.server_capabilities - client_capabilities = server_conn.client_capabilities - session = server_conn.session - client_info = getattr(session, "client_info", None) if session else None - if client_info: - client_info_name = getattr(client_info, "name", None) - client_info_version = getattr(client_info, "version", None) - is_connected = server_conn.is_healthy() - error_message = server_conn._error_message - instructions_available = server_conn.server_instructions_available - instructions_enabled = server_conn.server_instructions_enabled - instructions_included = bool(server_conn.server_instructions) - server_cfg = server_conn.server_config - if session: - elicitation_mode = session.effective_elicitation_mode - session_id = server_conn.session_id - if not session_id and server_conn._get_session_id_cb: + async with manager._lock: + server_conn = manager.running_servers.get(server_name) + if server_conn is None: + is_connected = False + else: + implementation = server_conn.server_implementation + if implementation is not None: + implementation_name = implementation.name + implementation_version = implementation.version + capabilities = server_conn.server_capabilities + client_capabilities = server_conn.client_capabilities + session = server_conn.session + client_info = getattr(session, "client_info", None) if session else None + if client_info: + client_info_name = getattr(client_info, "name", None) + client_info_version = getattr(client_info, "version", None) + if server_conn._initialized_event.is_set(): + is_connected = server_conn.is_healthy() + else: + is_connected = False + error_message = error_message or "initializing..." + error_message = error_message or server_conn._error_message + instructions_available = server_conn.server_instructions_available + instructions_enabled = server_conn.server_instructions_enabled + instructions_included = bool(server_conn.server_instructions) + server_cfg = server_conn.server_config + ping_interval_seconds = server_cfg.ping_interval_seconds + ping_max_missed = server_cfg.max_missed_pings + ping_ok_count = server_conn._ping_ok_count + ping_fail_count = server_conn._ping_fail_count + ping_consecutive_failures = server_conn._ping_consecutive_failures + ping_last_ok_at = server_conn._ping_last_ok_at + ping_last_fail_at = server_conn._ping_last_fail_at + ping_last_error = server_conn._ping_last_error + if session: + elicitation_mode = session.effective_elicitation_mode + session_id = server_conn.session_id + if not session_id and server_conn._get_session_id_cb: + try: + session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined] + except Exception: + session_id = None + metrics = server_conn.transport_metrics + if metrics is not None: try: - session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined] + transport_snapshot = metrics.snapshot() except Exception: - session_id = None - metrics = server_conn.transport_metrics - if metrics is not None: - try: - transport_snapshot = metrics.snapshot() - except Exception: - logger.debug( - "Failed to snapshot transport metrics for server '%s'", - server_name, - exc_info=True, - ) + logger.debug( + "Failed to snapshot transport metrics for server '%s'", + server_name, + exc_info=True, + ) + bucket_seconds = ( + transport_snapshot.activity_bucket_seconds + if transport_snapshot and transport_snapshot.activity_bucket_seconds + else 30 + ) + bucket_count = ( + transport_snapshot.activity_bucket_count + if transport_snapshot and transport_snapshot.activity_bucket_count + else 20 + ) + ping_activity_buckets = server_conn.build_ping_activity_buckets( + bucket_seconds, bucket_count + ) + ping_activity_bucket_seconds = bucket_seconds + ping_activity_bucket_count = bucket_count except Exception as exc: logger.debug( f"Failed to collect status for server '{server_name}'", @@ -1068,6 +1118,8 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: elicitation_mode = ( getattr(elicitation, "mode", None) if elicitation else elicitation_mode ) + ping_interval_seconds = ping_interval_seconds or server_cfg.ping_interval_seconds + ping_max_missed = ping_max_missed or server_cfg.max_missed_pings sampling_cfg = server_cfg.sampling spoofing_enabled = server_cfg.implementation is not None if implementation_name is None and server_cfg.implementation is not None: @@ -1123,6 +1175,17 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: transport_channels=transport_snapshot, skybridge=self._skybridge_configs.get(server_name), reconnect_count=reconnect_count, + ping_interval_seconds=ping_interval_seconds, + ping_max_missed=ping_max_missed, + ping_ok_count=ping_ok_count, + ping_fail_count=ping_fail_count, + ping_consecutive_failures=ping_consecutive_failures, + ping_last_ok_at=ping_last_ok_at, + ping_last_fail_at=ping_last_fail_at, + ping_last_error=ping_last_error, + ping_activity_buckets=ping_activity_buckets, + ping_activity_bucket_seconds=ping_activity_bucket_seconds, + ping_activity_bucket_count=ping_activity_bucket_count, ) return status_map diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index f8ac86e2e..b7fa21a7e 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -4,8 +4,9 @@ import asyncio import traceback -from contextlib import AbstractAsyncContextManager -from datetime import timedelta +from collections import deque +from contextlib import AbstractAsyncContextManager, suppress +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Callable, Union, cast import httpx @@ -150,6 +151,13 @@ def __init__( self.session_id: str | None = None self._get_session_id_cb: GetSessionIdCallback | None = None self.transport_metrics: TransportChannelMetrics | None = None + self._ping_ok_count = 0 + self._ping_fail_count = 0 + self._ping_consecutive_failures = 0 + self._ping_last_ok_at: datetime | None = None + self._ping_last_fail_at: datetime | None = None + self._ping_last_error: str | None = None + self._ping_history: deque[tuple[datetime, str]] = deque(maxlen=200) def is_healthy(self) -> bool: """Check if the server connection is healthy and ready to use.""" @@ -216,6 +224,43 @@ async def wait_for_initialized(self) -> None: """ await self._initialized_event.wait() + def record_ping_event(self, state: str) -> None: + self._ping_history.append((datetime.now(timezone.utc), state)) + + def build_ping_activity_buckets(self, bucket_seconds: int, bucket_count: int) -> list[str]: + try: + seconds = int(bucket_seconds) + except (TypeError, ValueError): + seconds = 30 + if seconds <= 0: + seconds = 30 + + try: + count = int(bucket_count) + except (TypeError, ValueError): + count = 20 + if count <= 0: + count = 20 + + if not self._ping_history: + return ["none"] * count + + priority = {"error": 2, "ping": 1, "none": 0} + history_map: dict[int, str] = {} + for timestamp, state in self._ping_history: + bucket = int(timestamp.timestamp() // seconds) + existing = history_map.get(bucket) + if existing is None or priority.get(state, 0) >= priority.get(existing, 0): + history_map[bucket] = state + + current_bucket = int(datetime.now(timezone.utc).timestamp() // seconds) + buckets: list[str] = [] + for offset in range(count - 1, -1, -1): + bucket_index = current_bucket - offset + buckets.append(history_map.get(bucket_index, "none")) + + return buckets + def create_session( self, read_stream: MemoryObjectReceiveStream, @@ -270,8 +315,18 @@ async def _run_ping_loop(server_conn: ServerConnection) -> None: try: await cast("MCPAgentClientSession", session).ping(read_timeout_seconds=read_timeout) missed = 0 + server_conn._ping_ok_count += 1 + server_conn._ping_consecutive_failures = 0 + server_conn._ping_last_ok_at = datetime.now(timezone.utc) + server_conn._ping_last_error = None + server_conn.record_ping_event("ping") except Exception as exc: missed += 1 + server_conn._ping_fail_count += 1 + server_conn._ping_consecutive_failures = missed + server_conn._ping_last_fail_at = datetime.now(timezone.utc) + server_conn._ping_last_error = str(exc) + server_conn.record_ping_event("error") logger.warning( f"{server_conn.server_name}: Ping failed ({missed}/{max_missed}): {exc}" ) @@ -328,9 +383,14 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None: server_conn.server_config.ping_interval_seconds and server_conn.server_config.ping_interval_seconds > 0 ): - async with create_task_group() as ping_group: - ping_group.start_soon(_run_ping_loop, server_conn) + ping_task = asyncio.create_task(_run_ping_loop(server_conn)) + try: await server_conn.wait_for_shutdown_request() + finally: + if not ping_task.done(): + ping_task.cancel() + with suppress(asyncio.CancelledError): + await ping_task else: await server_conn.wait_for_shutdown_request() except Exception as session_exit_exc: diff --git a/src/fast_agent/mcp/transport_tracking.py b/src/fast_agent/mcp/transport_tracking.py index 933f25fd8..9f3c6e126 100644 --- a/src/fast_agent/mcp/transport_tracking.py +++ b/src/fast_agent/mcp/transport_tracking.py @@ -162,6 +162,7 @@ def __init__( self._stdio_notification_count = 0 self._response_channel_by_id: dict[RequestId, ChannelName] = {} + self._ping_request_ids: set[RequestId] = set() try: seconds = 30 if bucket_seconds is None else int(bucket_seconds) @@ -208,6 +209,14 @@ def record_event(self, event: ChannelEvent) -> None: elif event.channel == "stdio": self._handle_stdio_event(event, now) + def register_ping_request(self, request_id: RequestId) -> None: + with self._lock: + self._ping_request_ids.add(request_id) + + def discard_ping_request(self, request_id: RequestId) -> None: + with self._lock: + self._ping_request_ids.discard(request_id) + def _handle_post_event(self, event: ChannelEvent, now: datetime) -> None: mode = "json" if event.channel == "post-json" else "sse" if event.event_type == "message" and event.message is not None: @@ -226,7 +235,8 @@ def _handle_post_event(self, event: ChannelEvent, now: datetime) -> None: self._post_last_at = now self._record_response_channel(event) - self._record_history(event.channel, classification, now) + if classification != "ping": + self._record_history(event.channel, classification, now) elif event.event_type == "error": self._record_history(event.channel, "error", now) @@ -347,6 +357,20 @@ def _tally_message_counts( sub_mode: str | None = None, ) -> str: classification = self._classify_message(message) + root = message.root + request_id: RequestId | None = None + if isinstance(root, (JSONRPCRequest, JSONRPCResponse, JSONRPCError)): + request_id = getattr(root, "id", None) + + if classification == "ping" and request_id is not None and isinstance(root, JSONRPCRequest): + self._ping_request_ids.add(request_id) + elif ( + classification == "response" + and request_id is not None + and request_id in self._ping_request_ids + ): + self._ping_request_ids.discard(request_id) + classification = "ping" if channel_key == "post": if classification == "request": diff --git a/src/fast_agent/ui/mcp_display.py b/src/fast_agent/ui/mcp_display.py index 17490f54d..1ff0691bd 100644 --- a/src/fast_agent/ui/mcp_display.py +++ b/src/fast_agent/ui/mcp_display.py @@ -308,6 +308,95 @@ def _format_relative_time(dt: datetime | None) -> str: return _format_compact_duration(seconds) or "<1s" +def _truncate_detail(value: str, max_len: int = 48) -> str: + if len(value) <= max_len: + return value + return value[: max_len - 3] + "..." + + +def _build_health_text(status: ServerStatus) -> Text | None: + interval = status.ping_interval_seconds + if interval is None: + return None + + health = Text() + state_label, state_style = _get_health_state(status) + if interval <= 0: + health.append(state_label, style=state_style) + return health + + max_missed = status.ping_max_missed or 0 + misses = status.ping_consecutive_failures or 0 + + health.append(state_label, style=state_style) + health.append(f" | interval: {interval}s", style=Colours.TEXT_DIM) + + misses_text = f"{misses}/{max_missed}" if max_missed else str(misses) + misses_style = Colours.TEXT_WARNING if misses > 0 else Colours.TEXT_DIM + health.append(f" | misses: {misses_text}", style=misses_style) + + last_ok = _format_relative_time(status.ping_last_ok_at) + health.append(f" | last ok: {last_ok}", style=Colours.TEXT_DIM) + + if misses > 0: + last_fail = _format_relative_time(status.ping_last_fail_at) + health.append(f" | last fail: {last_fail}", style=Colours.TEXT_DIM) + if status.ping_last_error: + err = _truncate_detail(status.ping_last_error) + health.append(f" | last err: {err}", style=Colours.TEXT_ERROR) + + return health + + +def _get_health_state(status: ServerStatus) -> tuple[str, str]: + interval = status.ping_interval_seconds + if interval is None: + return ("unknown", Colours.TEXT_DIM) + if interval <= 0: + return ("disabled", Colours.TEXT_DIM) + + if _has_transport_error(status): + return ("error", Colours.TEXT_ERROR) + + max_missed = status.ping_max_missed or 0 + misses = status.ping_consecutive_failures or 0 + has_activity = bool(status.ping_last_ok_at or status.ping_last_fail_at) + + if not has_activity: + return ("ok", Colours.TEXT_SUCCESS) + if max_missed and misses >= max_missed: + return ("failed", Colours.TEXT_ERROR) + if misses > 0: + return ("missed", Colours.TEXT_WARNING) + return ("ok", Colours.TEXT_SUCCESS) + + +def _has_transport_error(status: ServerStatus) -> bool: + snapshot = status.transport_channels + if snapshot is None: + return False + channels = [ + getattr(snapshot, "get", None), + getattr(snapshot, "post_json", None), + getattr(snapshot, "post_sse", None), + getattr(snapshot, "post", None), + getattr(snapshot, "resumption", None), + getattr(snapshot, "stdio", None), + ] + for channel in channels: + if channel is None: + continue + if channel.state == "error" and channel.last_status_code != 405: + return True + return False + + +def _get_ping_attempts(status: ServerStatus) -> int: + ok = status.ping_ok_count or 0 + fail = status.ping_fail_count or 0 + return ok + fail + + def _format_label(label: str, width: int = 10) -> str: return f"{label:<{width}}" if len(label) < width else label @@ -440,6 +529,72 @@ def _render_channel_summary(status: ServerStatus, indent: str, total_width: int) # Get appropriate timeline color map timeline_color_map = TIMELINE_COLORS_STDIO if is_stdio else TIMELINE_COLORS + health_insert_label = None + if status.ping_interval_seconds is not None: + label_names = [entry[0] for entry in entries] + if "POST (JSON)" in label_names: + health_insert_label = "POST (JSON)" + elif "POST (SSE)" in label_names: + health_insert_label = "POST (SSE)" + elif "STDIO" in label_names: + health_insert_label = "STDIO" + elif label_names: + health_insert_label = label_names[-1] + + def render_health_row() -> None: + line = Text(indent) + line.append("│ ", style="dim") + _, state_style = _get_health_state(status) + line.append(SYMBOL_PING, style=state_style) + line.append(f" {'HEALTH':<13}", style=state_style) + + bucket_seconds = status.ping_activity_bucket_seconds or default_bucket_seconds + bucket_count = status.ping_activity_bucket_count or default_bucket_count + timeline_label = _format_timeline_label(bucket_seconds * bucket_count) + line.append(f"{timeline_label} ", style="dim") + + bucket_states = status.ping_activity_buckets or [] + if len(bucket_states) < bucket_count: + bucket_states = list(bucket_states) + ["none"] * (bucket_count - len(bucket_states)) + elif len(bucket_states) > bucket_count: + bucket_states = bucket_states[-bucket_count:] + + for bucket_state in bucket_states: + color = timeline_color_map.get(bucket_state, "dim") + if bucket_state in {"idle", "none"}: + symbol = SYMBOL_IDLE + elif bucket_state == "error": + symbol = SYMBOL_ERROR + elif bucket_state == "ping": + symbol = SYMBOL_PING + else: + symbol = SYMBOL_RESPONSE + line.append(symbol, style=f"bold {color}") + + line.append(" now", style="dim") + ping_attempts = _get_ping_attempts(status) + if is_stdio: + activity = str(ping_attempts).rjust(8) if ping_attempts > 0 else "-".rjust(8) + activity_style = Colours.TEXT_DEFAULT if ping_attempts > 0 else Colours.TEXT_DIM + line.append(f" {activity}", style=activity_style) + else: + req = "-".rjust(5) + resp = "-".rjust(5) + notif = "-".rjust(5) + ping = str(ping_attempts).rjust(5) if ping_attempts > 0 else "-".rjust(5) + ping_style = Colours.TEXT_DEFAULT if ping_attempts > 0 else Colours.TEXT_DIM + line.append(" ", style="dim") + line.append(req, style=Colours.TEXT_DIM) + line.append(" ", style="dim") + line.append(resp, style=Colours.TEXT_DIM) + line.append(" ", style="dim") + line.append(notif, style=Colours.TEXT_DIM) + line.append(" ", style="dim") + line.append(ping, style=ping_style) + console.console.print(line) + + health_inserted = False + for label, arrow, channel in entries: line = Text(indent) line.append("│ ", style="dim") @@ -596,6 +751,10 @@ def _render_channel_summary(status: ServerStatus, indent: str, total_width: int) console.console.print(line) + if health_insert_label == label and not health_inserted: + render_health_row() + health_inserted = True + # Debug: print the raw line length # import sys # print(f"Line length: {len(line.plain)}", file=sys.stderr) @@ -771,6 +930,12 @@ def render_header(label: Text, right: Text | None = None) -> None: session_text = _format_session_id(status.session_id) session_line.append_text(_build_aligned_field("session", session_text)) console.console.print(session_line) + + health_text = _build_health_text(status) + if health_text is not None: + health_line = Text(indent + " ") + health_line.append_text(_build_aligned_field("health", health_text)) + console.console.print(health_line) console.console.print() # Build status segments From a8e7acc6d27e2ad10ed8f4ab0f7dc26ce9465861 Mon Sep 17 00:00:00 2001 From: evalstate <1936278+evalstate@users.noreply.github.com> Date: Sat, 10 Jan 2026 13:35:13 +0100 Subject: [PATCH 4/4] pathological cases/timeout handling --- src/fast_agent/mcp/mcp_connection_manager.py | 2 + src/fast_agent/ui/mcp_display.py | 46 ++++++++++++++++++-- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index b7fa21a7e..6b8c4fb87 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -302,6 +302,8 @@ async def _run_ping_loop(server_conn: ServerConnection) -> None: if server_conn.server_config.read_timeout_seconds else None ) + if read_timeout is None: + read_timeout = timedelta(seconds=interval) while not server_conn._shutdown_event.is_set(): await asyncio.sleep(interval) diff --git a/src/fast_agent/ui/mcp_display.py b/src/fast_agent/ui/mcp_display.py index 1ff0691bd..2987c21f3 100644 --- a/src/fast_agent/ui/mcp_display.py +++ b/src/fast_agent/ui/mcp_display.py @@ -326,7 +326,7 @@ def _build_health_text(status: ServerStatus) -> Text | None: return health max_missed = status.ping_max_missed or 0 - misses = status.ping_consecutive_failures or 0 + misses = _compute_display_misses(status) health.append(state_label, style=state_style) health.append(f" | interval: {interval}s", style=Colours.TEXT_DIM) @@ -355,15 +355,28 @@ def _get_health_state(status: ServerStatus) -> tuple[str, str]: if interval <= 0: return ("disabled", Colours.TEXT_DIM) + if status.is_connected is False: + if status.error_message and "initializing" in status.error_message: + return ("pending", Colours.TEXT_DIM) + return ("offline", Colours.TEXT_ERROR) + if _has_transport_error(status): return ("error", Colours.TEXT_ERROR) max_missed = status.ping_max_missed or 0 - misses = status.ping_consecutive_failures or 0 + misses = _compute_display_misses(status) has_activity = bool(status.ping_last_ok_at or status.ping_last_fail_at) + last_ping_at = status.ping_last_ok_at or status.ping_last_fail_at + if last_ping_at is not None and max_missed > 0: + if last_ping_at.tzinfo is None: + last_ping_at = last_ping_at.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + if (now - last_ping_at).total_seconds() > interval * max_missed: + return ("stale", Colours.TEXT_ERROR) + if not has_activity: - return ("ok", Colours.TEXT_SUCCESS) + return ("pending", Colours.TEXT_DIM) if max_missed and misses >= max_missed: return ("failed", Colours.TEXT_ERROR) if misses > 0: @@ -386,11 +399,36 @@ def _has_transport_error(status: ServerStatus) -> bool: for channel in channels: if channel is None: continue - if channel.state == "error" and channel.last_status_code != 405: + if channel.last_status_code == 405 or channel.state == "disabled": + continue + if channel.last_error and "405" in channel.last_error: + continue + if channel.state == "error": return True return False +def _compute_display_misses(status: ServerStatus) -> int: + interval = status.ping_interval_seconds + if interval is None or interval <= 0: + return status.ping_consecutive_failures or 0 + + last_ping_at = status.ping_last_ok_at or status.ping_last_fail_at + if last_ping_at is None: + return status.ping_consecutive_failures or 0 + + if last_ping_at.tzinfo is None: + last_ping_at = last_ping_at.replace(tzinfo=timezone.utc) + + elapsed = (datetime.now(timezone.utc) - last_ping_at).total_seconds() + if elapsed <= 0: + return status.ping_consecutive_failures or 0 + + derived = int(elapsed // interval) + recorded = status.ping_consecutive_failures or 0 + return max(recorded, derived) + + def _get_ping_attempts(status: ServerStatus) -> int: ok = status.ping_ok_count or 0 fail = status.ping_fail_count or 0