diff --git a/README.md b/README.md index 9a4837fc8..b6f5877d0 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,20 @@ mcp: - To disable OAuth for a specific server , set `auth.oauth: false` for that server. +## MCP Ping (optional) + +The MCP ping utility can be enabled by either peer (client or server). See the [Ping overview](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/ping#overview). + +Client-side pinging is configured per server (default: 30s interval, 3 missed pings): + +```yaml +mcp: + servers: + myserver: + ping_interval_seconds: 30 # optional; <=0 disables + max_missed_pings: 3 # optional; consecutive timeouts before marking failed +``` + ## Workflows ### Chain diff --git a/src/fast_agent/config.py b/src/fast_agent/config.py index 1733213c6..1e99572f8 100644 --- a/src/fast_agent/config.py +++ b/src/fast_agent/config.py @@ -213,6 +213,12 @@ class MCPServerSettings(BaseModel): read_timeout_seconds: int | None = None """The timeout in seconds for the session.""" + ping_interval_seconds: int = 30 + """Interval for MCP ping requests. Set <=0 to disable pinging.""" + + max_missed_pings: int = 3 + """Number of consecutive missed ping responses before treating the connection as failed.""" + http_timeout_seconds: int | None = None """Overall HTTP timeout (seconds) for StreamableHTTP transport. Defaults to MCP SDK.""" @@ -262,6 +268,16 @@ class MCPServerSettings(BaseModel): implementation: Implementation | None = None + @field_validator("max_missed_pings", mode="before") + @classmethod + def _coerce_max_missed_pings(cls, value: Any) -> int: + if isinstance(value, str): + value = int(value.strip()) + value = int(value) + if value <= 0: + raise ValueError("max_missed_pings must be greater than zero.") + return value + @model_validator(mode="before") @classmethod def validate_transport_inference(cls, values): diff --git a/src/fast_agent/mcp/mcp_agent_client_session.py b/src/fast_agent/mcp/mcp_agent_client_session.py index 60ab85739..4be316b22 100644 --- a/src/fast_agent/mcp/mcp_agent_client_session.py +++ b/src/fast_agent/mcp/mcp_agent_client_session.py @@ -18,11 +18,13 @@ CallToolRequestParams, CallToolResult, ClientRequest, + EmptyResult, GetPromptRequest, GetPromptRequestParams, GetPromptResult, Implementation, ListRootsResult, + PingRequest, ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, @@ -224,6 +226,13 @@ async def send_request( ) -> ReceiveResultT: logger.debug("send_request: request=", data=request.model_dump()) request_id = getattr(self, "_request_id", None) + is_ping_request = self._is_ping_request(request) + if ( + is_ping_request + and request_id is not None + and self._transport_metrics is not None + ): + self._transport_metrics.register_ping_request(request_id) try: result = await super().send_request( request=request, @@ -237,8 +246,20 @@ async def send_request( data=result.model_dump() if result is not None else "no response returned", ) self._attach_transport_channel(request_id, result) + if ( + is_ping_request + and request_id is not None + and self._transport_metrics is not None + ): + self._transport_metrics.discard_ping_request(request_id) return result except Exception as e: + if ( + is_ping_request + and request_id is not None + and self._transport_metrics is not None + ): + self._transport_metrics.discard_ping_request(request_id) from anyio import ClosedResourceError from fast_agent.core.exceptions import ServerSessionTerminatedError @@ -262,6 +283,15 @@ async def send_request( logger.error(f"send_request failed: {str(e)}") raise + @staticmethod + def _is_ping_request(request: ClientRequest) -> bool: + root = getattr(request, "root", None) + method = getattr(root, "method", None) + if not isinstance(method, str): + return False + method_lower = method.lower() + return method_lower == "ping" or method_lower.endswith("/ping") or method_lower.endswith(".ping") + def _is_session_terminated_error(self, exc: Exception) -> bool: """Check if exception is a session terminated error (code 32600 from 404).""" from mcp.shared.exceptions import McpError @@ -365,6 +395,15 @@ async def call_tool( progress_callback=progress_callback, ) + async def ping(self, read_timeout_seconds: timedelta | None = None) -> EmptyResult: + """Send a ping request to check server liveness.""" + request = PingRequest(method="ping") + return await self.send_request( + ClientRequest(request), + EmptyResult, + request_read_timeout_seconds=read_timeout_seconds, + ) + async def read_resource( self, uri: AnyUrl | str, _meta: dict | None = None, **kwargs ) -> ReadResourceResult: diff --git a/src/fast_agent/mcp/mcp_aggregator.py b/src/fast_agent/mcp/mcp_aggregator.py index b7672b925..38934acd0 100644 --- a/src/fast_agent/mcp/mcp_aggregator.py +++ b/src/fast_agent/mcp/mcp_aggregator.py @@ -115,6 +115,17 @@ class ServerStatus(BaseModel): transport_channels: TransportSnapshot | None = None skybridge: SkybridgeServerConfig | None = None reconnect_count: int = 0 + ping_interval_seconds: int | None = None + ping_max_missed: int | None = None + ping_ok_count: int | None = None + ping_fail_count: int | None = None + ping_consecutive_failures: int | None = None + ping_last_ok_at: datetime | None = None + ping_last_fail_at: datetime | None = None + ping_last_error: str | None = None + ping_activity_buckets: list[str] | None = None + ping_activity_bucket_seconds: int | None = None + ping_activity_bucket_count: int | None = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -997,49 +1008,88 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: server_conn = None transport: str | None = None transport_snapshot: TransportSnapshot | None = None + ping_interval_seconds: int | None = None + ping_max_missed: int | None = None + ping_ok_count: int | None = None + ping_fail_count: int | None = None + ping_consecutive_failures: int | None = None + ping_last_ok_at: datetime | None = None + ping_last_fail_at: datetime | None = None + ping_last_error: str | None = None + ping_activity_buckets: list[str] | None = None + ping_activity_bucket_seconds: int | None = None + ping_activity_bucket_count: int | None = None manager = getattr(self, "_persistent_connection_manager", None) if self.connection_persistence and manager is not None: try: - server_conn = await manager.get_server( - server_name, - client_session_factory=self._create_session_factory(server_name), - ) - implementation = server_conn.server_implementation - if implementation is not None: - implementation_name = implementation.name - implementation_version = implementation.version - capabilities = server_conn.server_capabilities - client_capabilities = server_conn.client_capabilities - session = server_conn.session - client_info = getattr(session, "client_info", None) if session else None - if client_info: - client_info_name = getattr(client_info, "name", None) - client_info_version = getattr(client_info, "version", None) - is_connected = server_conn.is_healthy() - error_message = server_conn._error_message - instructions_available = server_conn.server_instructions_available - instructions_enabled = server_conn.server_instructions_enabled - instructions_included = bool(server_conn.server_instructions) - server_cfg = server_conn.server_config - if session: - elicitation_mode = session.effective_elicitation_mode - session_id = server_conn.session_id - if not session_id and server_conn._get_session_id_cb: + async with manager._lock: + server_conn = manager.running_servers.get(server_name) + if server_conn is None: + is_connected = False + else: + implementation = server_conn.server_implementation + if implementation is not None: + implementation_name = implementation.name + implementation_version = implementation.version + capabilities = server_conn.server_capabilities + client_capabilities = server_conn.client_capabilities + session = server_conn.session + client_info = getattr(session, "client_info", None) if session else None + if client_info: + client_info_name = getattr(client_info, "name", None) + client_info_version = getattr(client_info, "version", None) + if server_conn._initialized_event.is_set(): + is_connected = server_conn.is_healthy() + else: + is_connected = False + error_message = error_message or "initializing..." + error_message = error_message or server_conn._error_message + instructions_available = server_conn.server_instructions_available + instructions_enabled = server_conn.server_instructions_enabled + instructions_included = bool(server_conn.server_instructions) + server_cfg = server_conn.server_config + ping_interval_seconds = server_cfg.ping_interval_seconds + ping_max_missed = server_cfg.max_missed_pings + ping_ok_count = server_conn._ping_ok_count + ping_fail_count = server_conn._ping_fail_count + ping_consecutive_failures = server_conn._ping_consecutive_failures + ping_last_ok_at = server_conn._ping_last_ok_at + ping_last_fail_at = server_conn._ping_last_fail_at + ping_last_error = server_conn._ping_last_error + if session: + elicitation_mode = session.effective_elicitation_mode + session_id = server_conn.session_id + if not session_id and server_conn._get_session_id_cb: + try: + session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined] + except Exception: + session_id = None + metrics = server_conn.transport_metrics + if metrics is not None: try: - session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined] + transport_snapshot = metrics.snapshot() except Exception: - session_id = None - metrics = server_conn.transport_metrics - if metrics is not None: - try: - transport_snapshot = metrics.snapshot() - except Exception: - logger.debug( - "Failed to snapshot transport metrics for server '%s'", - server_name, - exc_info=True, - ) + logger.debug( + "Failed to snapshot transport metrics for server '%s'", + server_name, + exc_info=True, + ) + bucket_seconds = ( + transport_snapshot.activity_bucket_seconds + if transport_snapshot and transport_snapshot.activity_bucket_seconds + else 30 + ) + bucket_count = ( + transport_snapshot.activity_bucket_count + if transport_snapshot and transport_snapshot.activity_bucket_count + else 20 + ) + ping_activity_buckets = server_conn.build_ping_activity_buckets( + bucket_seconds, bucket_count + ) + ping_activity_bucket_seconds = bucket_seconds + ping_activity_bucket_count = bucket_count except Exception as exc: logger.debug( f"Failed to collect status for server '{server_name}'", @@ -1068,6 +1118,8 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: elicitation_mode = ( getattr(elicitation, "mode", None) if elicitation else elicitation_mode ) + ping_interval_seconds = ping_interval_seconds or server_cfg.ping_interval_seconds + ping_max_missed = ping_max_missed or server_cfg.max_missed_pings sampling_cfg = server_cfg.sampling spoofing_enabled = server_cfg.implementation is not None if implementation_name is None and server_cfg.implementation is not None: @@ -1123,6 +1175,17 @@ async def collect_server_status(self) -> dict[str, ServerStatus]: transport_channels=transport_snapshot, skybridge=self._skybridge_configs.get(server_name), reconnect_count=reconnect_count, + ping_interval_seconds=ping_interval_seconds, + ping_max_missed=ping_max_missed, + ping_ok_count=ping_ok_count, + ping_fail_count=ping_fail_count, + ping_consecutive_failures=ping_consecutive_failures, + ping_last_ok_at=ping_last_ok_at, + ping_last_fail_at=ping_last_fail_at, + ping_last_error=ping_last_error, + ping_activity_buckets=ping_activity_buckets, + ping_activity_bucket_seconds=ping_activity_bucket_seconds, + ping_activity_bucket_count=ping_activity_bucket_count, ) return status_map diff --git a/src/fast_agent/mcp/mcp_connection_manager.py b/src/fast_agent/mcp/mcp_connection_manager.py index 211b8daf1..6b8c4fb87 100644 --- a/src/fast_agent/mcp/mcp_connection_manager.py +++ b/src/fast_agent/mcp/mcp_connection_manager.py @@ -4,9 +4,10 @@ import asyncio import traceback -from contextlib import AbstractAsyncContextManager -from datetime import timedelta -from typing import TYPE_CHECKING, Callable, Union +from collections import deque +from contextlib import AbstractAsyncContextManager, suppress +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Callable, Union, cast import httpx from anyio import Event, Lock, create_task_group @@ -150,6 +151,13 @@ def __init__( self.session_id: str | None = None self._get_session_id_cb: GetSessionIdCallback | None = None self.transport_metrics: TransportChannelMetrics | None = None + self._ping_ok_count = 0 + self._ping_fail_count = 0 + self._ping_consecutive_failures = 0 + self._ping_last_ok_at: datetime | None = None + self._ping_last_fail_at: datetime | None = None + self._ping_last_error: str | None = None + self._ping_history: deque[tuple[datetime, str]] = deque(maxlen=200) def is_healthy(self) -> bool: """Check if the server connection is healthy and ready to use.""" @@ -216,6 +224,43 @@ async def wait_for_initialized(self) -> None: """ await self._initialized_event.wait() + def record_ping_event(self, state: str) -> None: + self._ping_history.append((datetime.now(timezone.utc), state)) + + def build_ping_activity_buckets(self, bucket_seconds: int, bucket_count: int) -> list[str]: + try: + seconds = int(bucket_seconds) + except (TypeError, ValueError): + seconds = 30 + if seconds <= 0: + seconds = 30 + + try: + count = int(bucket_count) + except (TypeError, ValueError): + count = 20 + if count <= 0: + count = 20 + + if not self._ping_history: + return ["none"] * count + + priority = {"error": 2, "ping": 1, "none": 0} + history_map: dict[int, str] = {} + for timestamp, state in self._ping_history: + bucket = int(timestamp.timestamp() // seconds) + existing = history_map.get(bucket) + if existing is None or priority.get(state, 0) >= priority.get(existing, 0): + history_map[bucket] = state + + current_bucket = int(datetime.now(timezone.utc).timestamp() // seconds) + buckets: list[str] = [] + for offset in range(count - 1, -1, -1): + bucket_index = current_bucket - offset + buckets.append(history_map.get(bucket_index, "none")) + + return buckets + def create_session( self, read_stream: MemoryObjectReceiveStream, @@ -245,6 +290,57 @@ def create_session( return session +async def _run_ping_loop(server_conn: ServerConnection) -> None: + interval = server_conn.server_config.ping_interval_seconds + if not interval or interval <= 0: + return + + max_missed = server_conn.server_config.max_missed_pings + missed = 0 + read_timeout = ( + timedelta(seconds=server_conn.server_config.read_timeout_seconds) + if server_conn.server_config.read_timeout_seconds + else None + ) + if read_timeout is None: + read_timeout = timedelta(seconds=interval) + + while not server_conn._shutdown_event.is_set(): + await asyncio.sleep(interval) + if server_conn._shutdown_event.is_set(): + break + session = server_conn.session + if session is None: + break + if not hasattr(session, "ping"): + return + try: + await cast("MCPAgentClientSession", session).ping(read_timeout_seconds=read_timeout) + missed = 0 + server_conn._ping_ok_count += 1 + server_conn._ping_consecutive_failures = 0 + server_conn._ping_last_ok_at = datetime.now(timezone.utc) + server_conn._ping_last_error = None + server_conn.record_ping_event("ping") + except Exception as exc: + missed += 1 + server_conn._ping_fail_count += 1 + server_conn._ping_consecutive_failures = missed + server_conn._ping_last_fail_at = datetime.now(timezone.utc) + server_conn._ping_last_error = str(exc) + server_conn.record_ping_event("error") + logger.warning( + f"{server_conn.server_name}: Ping failed ({missed}/{max_missed}): {exc}" + ) + if missed >= max_missed: + server_conn._error_occurred = True + server_conn._error_message = ( + f"Ping failed {missed} time(s); last error: {exc}" + ) + server_conn.request_shutdown() + break + + async def _server_lifecycle_task(server_conn: ServerConnection) -> None: """ Manage the lifecycle of a single server connection. @@ -285,19 +381,46 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None: elif server_conn.server_config.transport == "stdio": server_conn.session_id = "local" - await server_conn.wait_for_shutdown_request() + if ( + server_conn.server_config.ping_interval_seconds + and server_conn.server_config.ping_interval_seconds > 0 + ): + ping_task = asyncio.create_task(_run_ping_loop(server_conn)) + try: + await server_conn.wait_for_shutdown_request() + finally: + if not ping_task.done(): + ping_task.cancel() + with suppress(asyncio.CancelledError): + await ping_task + else: + await server_conn.wait_for_shutdown_request() except Exception as session_exit_exc: - # Catch exceptions during session cleanup (e.g., when session was terminated) - # This prevents cleanup errors from propagating to the task group - logger.debug( - f"{server_name}: Exception during session cleanup (expected during reconnect): {session_exit_exc}" - ) + if server_conn._shutdown_event.is_set(): + # Cleanup errors can happen when disconnecting a session that was already + # terminated; treat as expected during shutdown. + logger.debug( + f"{server_name}: Exception during session cleanup (expected during shutdown): {session_exit_exc}" + ) + if not server_conn._initialized_event.is_set(): + server_conn._error_occurred = True + server_conn._error_message = "Shutdown requested before initialization" + server_conn._initialized_event.set() + else: + raise except Exception as transport_exit_exc: - # Catch exceptions during transport cleanup - # This can happen when disconnecting a session that was already terminated - logger.debug( - f"{server_name}: Exception during transport cleanup (expected during reconnect): {transport_exit_exc}" - ) + if server_conn._shutdown_event.is_set(): + # Cleanup errors can happen when disconnecting a transport that was already + # terminated; treat as expected during shutdown. + logger.debug( + f"{server_name}: Exception during transport cleanup (expected during shutdown): {transport_exit_exc}" + ) + if not server_conn._initialized_event.is_set(): + server_conn._error_occurred = True + server_conn._error_message = "Shutdown requested before initialization" + server_conn._initialized_event.set() + else: + raise except HTTPStatusError as http_exc: logger.error( diff --git a/src/fast_agent/mcp/transport_tracking.py b/src/fast_agent/mcp/transport_tracking.py index 933f25fd8..9f3c6e126 100644 --- a/src/fast_agent/mcp/transport_tracking.py +++ b/src/fast_agent/mcp/transport_tracking.py @@ -162,6 +162,7 @@ def __init__( self._stdio_notification_count = 0 self._response_channel_by_id: dict[RequestId, ChannelName] = {} + self._ping_request_ids: set[RequestId] = set() try: seconds = 30 if bucket_seconds is None else int(bucket_seconds) @@ -208,6 +209,14 @@ def record_event(self, event: ChannelEvent) -> None: elif event.channel == "stdio": self._handle_stdio_event(event, now) + def register_ping_request(self, request_id: RequestId) -> None: + with self._lock: + self._ping_request_ids.add(request_id) + + def discard_ping_request(self, request_id: RequestId) -> None: + with self._lock: + self._ping_request_ids.discard(request_id) + def _handle_post_event(self, event: ChannelEvent, now: datetime) -> None: mode = "json" if event.channel == "post-json" else "sse" if event.event_type == "message" and event.message is not None: @@ -226,7 +235,8 @@ def _handle_post_event(self, event: ChannelEvent, now: datetime) -> None: self._post_last_at = now self._record_response_channel(event) - self._record_history(event.channel, classification, now) + if classification != "ping": + self._record_history(event.channel, classification, now) elif event.event_type == "error": self._record_history(event.channel, "error", now) @@ -347,6 +357,20 @@ def _tally_message_counts( sub_mode: str | None = None, ) -> str: classification = self._classify_message(message) + root = message.root + request_id: RequestId | None = None + if isinstance(root, (JSONRPCRequest, JSONRPCResponse, JSONRPCError)): + request_id = getattr(root, "id", None) + + if classification == "ping" and request_id is not None and isinstance(root, JSONRPCRequest): + self._ping_request_ids.add(request_id) + elif ( + classification == "response" + and request_id is not None + and request_id in self._ping_request_ids + ): + self._ping_request_ids.discard(request_id) + classification = "ping" if channel_key == "post": if classification == "request": diff --git a/src/fast_agent/ui/mcp_display.py b/src/fast_agent/ui/mcp_display.py index 17490f54d..2987c21f3 100644 --- a/src/fast_agent/ui/mcp_display.py +++ b/src/fast_agent/ui/mcp_display.py @@ -308,6 +308,133 @@ def _format_relative_time(dt: datetime | None) -> str: return _format_compact_duration(seconds) or "<1s" +def _truncate_detail(value: str, max_len: int = 48) -> str: + if len(value) <= max_len: + return value + return value[: max_len - 3] + "..." + + +def _build_health_text(status: ServerStatus) -> Text | None: + interval = status.ping_interval_seconds + if interval is None: + return None + + health = Text() + state_label, state_style = _get_health_state(status) + if interval <= 0: + health.append(state_label, style=state_style) + return health + + max_missed = status.ping_max_missed or 0 + misses = _compute_display_misses(status) + + health.append(state_label, style=state_style) + health.append(f" | interval: {interval}s", style=Colours.TEXT_DIM) + + misses_text = f"{misses}/{max_missed}" if max_missed else str(misses) + misses_style = Colours.TEXT_WARNING if misses > 0 else Colours.TEXT_DIM + health.append(f" | misses: {misses_text}", style=misses_style) + + last_ok = _format_relative_time(status.ping_last_ok_at) + health.append(f" | last ok: {last_ok}", style=Colours.TEXT_DIM) + + if misses > 0: + last_fail = _format_relative_time(status.ping_last_fail_at) + health.append(f" | last fail: {last_fail}", style=Colours.TEXT_DIM) + if status.ping_last_error: + err = _truncate_detail(status.ping_last_error) + health.append(f" | last err: {err}", style=Colours.TEXT_ERROR) + + return health + + +def _get_health_state(status: ServerStatus) -> tuple[str, str]: + interval = status.ping_interval_seconds + if interval is None: + return ("unknown", Colours.TEXT_DIM) + if interval <= 0: + return ("disabled", Colours.TEXT_DIM) + + if status.is_connected is False: + if status.error_message and "initializing" in status.error_message: + return ("pending", Colours.TEXT_DIM) + return ("offline", Colours.TEXT_ERROR) + + if _has_transport_error(status): + return ("error", Colours.TEXT_ERROR) + + max_missed = status.ping_max_missed or 0 + misses = _compute_display_misses(status) + has_activity = bool(status.ping_last_ok_at or status.ping_last_fail_at) + + last_ping_at = status.ping_last_ok_at or status.ping_last_fail_at + if last_ping_at is not None and max_missed > 0: + if last_ping_at.tzinfo is None: + last_ping_at = last_ping_at.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + if (now - last_ping_at).total_seconds() > interval * max_missed: + return ("stale", Colours.TEXT_ERROR) + + if not has_activity: + return ("pending", Colours.TEXT_DIM) + if max_missed and misses >= max_missed: + return ("failed", Colours.TEXT_ERROR) + if misses > 0: + return ("missed", Colours.TEXT_WARNING) + return ("ok", Colours.TEXT_SUCCESS) + + +def _has_transport_error(status: ServerStatus) -> bool: + snapshot = status.transport_channels + if snapshot is None: + return False + channels = [ + getattr(snapshot, "get", None), + getattr(snapshot, "post_json", None), + getattr(snapshot, "post_sse", None), + getattr(snapshot, "post", None), + getattr(snapshot, "resumption", None), + getattr(snapshot, "stdio", None), + ] + for channel in channels: + if channel is None: + continue + if channel.last_status_code == 405 or channel.state == "disabled": + continue + if channel.last_error and "405" in channel.last_error: + continue + if channel.state == "error": + return True + return False + + +def _compute_display_misses(status: ServerStatus) -> int: + interval = status.ping_interval_seconds + if interval is None or interval <= 0: + return status.ping_consecutive_failures or 0 + + last_ping_at = status.ping_last_ok_at or status.ping_last_fail_at + if last_ping_at is None: + return status.ping_consecutive_failures or 0 + + if last_ping_at.tzinfo is None: + last_ping_at = last_ping_at.replace(tzinfo=timezone.utc) + + elapsed = (datetime.now(timezone.utc) - last_ping_at).total_seconds() + if elapsed <= 0: + return status.ping_consecutive_failures or 0 + + derived = int(elapsed // interval) + recorded = status.ping_consecutive_failures or 0 + return max(recorded, derived) + + +def _get_ping_attempts(status: ServerStatus) -> int: + ok = status.ping_ok_count or 0 + fail = status.ping_fail_count or 0 + return ok + fail + + def _format_label(label: str, width: int = 10) -> str: return f"{label:<{width}}" if len(label) < width else label @@ -440,6 +567,72 @@ def _render_channel_summary(status: ServerStatus, indent: str, total_width: int) # Get appropriate timeline color map timeline_color_map = TIMELINE_COLORS_STDIO if is_stdio else TIMELINE_COLORS + health_insert_label = None + if status.ping_interval_seconds is not None: + label_names = [entry[0] for entry in entries] + if "POST (JSON)" in label_names: + health_insert_label = "POST (JSON)" + elif "POST (SSE)" in label_names: + health_insert_label = "POST (SSE)" + elif "STDIO" in label_names: + health_insert_label = "STDIO" + elif label_names: + health_insert_label = label_names[-1] + + def render_health_row() -> None: + line = Text(indent) + line.append("│ ", style="dim") + _, state_style = _get_health_state(status) + line.append(SYMBOL_PING, style=state_style) + line.append(f" {'HEALTH':<13}", style=state_style) + + bucket_seconds = status.ping_activity_bucket_seconds or default_bucket_seconds + bucket_count = status.ping_activity_bucket_count or default_bucket_count + timeline_label = _format_timeline_label(bucket_seconds * bucket_count) + line.append(f"{timeline_label} ", style="dim") + + bucket_states = status.ping_activity_buckets or [] + if len(bucket_states) < bucket_count: + bucket_states = list(bucket_states) + ["none"] * (bucket_count - len(bucket_states)) + elif len(bucket_states) > bucket_count: + bucket_states = bucket_states[-bucket_count:] + + for bucket_state in bucket_states: + color = timeline_color_map.get(bucket_state, "dim") + if bucket_state in {"idle", "none"}: + symbol = SYMBOL_IDLE + elif bucket_state == "error": + symbol = SYMBOL_ERROR + elif bucket_state == "ping": + symbol = SYMBOL_PING + else: + symbol = SYMBOL_RESPONSE + line.append(symbol, style=f"bold {color}") + + line.append(" now", style="dim") + ping_attempts = _get_ping_attempts(status) + if is_stdio: + activity = str(ping_attempts).rjust(8) if ping_attempts > 0 else "-".rjust(8) + activity_style = Colours.TEXT_DEFAULT if ping_attempts > 0 else Colours.TEXT_DIM + line.append(f" {activity}", style=activity_style) + else: + req = "-".rjust(5) + resp = "-".rjust(5) + notif = "-".rjust(5) + ping = str(ping_attempts).rjust(5) if ping_attempts > 0 else "-".rjust(5) + ping_style = Colours.TEXT_DEFAULT if ping_attempts > 0 else Colours.TEXT_DIM + line.append(" ", style="dim") + line.append(req, style=Colours.TEXT_DIM) + line.append(" ", style="dim") + line.append(resp, style=Colours.TEXT_DIM) + line.append(" ", style="dim") + line.append(notif, style=Colours.TEXT_DIM) + line.append(" ", style="dim") + line.append(ping, style=ping_style) + console.console.print(line) + + health_inserted = False + for label, arrow, channel in entries: line = Text(indent) line.append("│ ", style="dim") @@ -596,6 +789,10 @@ def _render_channel_summary(status: ServerStatus, indent: str, total_width: int) console.console.print(line) + if health_insert_label == label and not health_inserted: + render_health_row() + health_inserted = True + # Debug: print the raw line length # import sys # print(f"Line length: {len(line.plain)}", file=sys.stderr) @@ -771,6 +968,12 @@ def render_header(label: Text, right: Text | None = None) -> None: session_text = _format_session_id(status.session_id) session_line.append_text(_build_aligned_field("session", session_text)) console.console.print(session_line) + + health_text = _build_health_text(status) + if health_text is not None: + health_line = Text(indent + " ") + health_line.append_text(_build_aligned_field("health", health_text)) + console.console.print(health_line) console.console.print() # Build status segments diff --git a/tests/unit/fast_agent/mcp/test_mcp_connection_manager.py b/tests/unit/fast_agent/mcp/test_mcp_connection_manager.py index bbad00843..013c2da9b 100644 --- a/tests/unit/fast_agent/mcp/test_mcp_connection_manager.py +++ b/tests/unit/fast_agent/mcp/test_mcp_connection_manager.py @@ -1,6 +1,14 @@ +import asyncio + +import pytest + from fast_agent.config import MCPServerSettings -from fast_agent.mcp.mcp_connection_manager import _prepare_headers_and_auth +from fast_agent.mcp.mcp_connection_manager import ( + ServerConnection, + _prepare_headers_and_auth, + _server_lifecycle_task, +) def test_prepare_headers_respects_user_authorization(monkeypatch): @@ -76,3 +84,41 @@ def _builder(received_config: MCPServerSettings): assert auth is sentinel assert user_keys == set() assert calls == [config] + + +@pytest.mark.asyncio +async def test_server_lifecycle_sets_initialized_on_startup_failure(): + class DummyTransportContext: + async def __aenter__(self): + return object(), object(), None + + async def __aexit__(self, exc_type, exc, tb): + return None + + class DummySession: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def initialize(self): + raise RuntimeError("boom") + + def session_factory(*_args, **_kwargs): + return DummySession() + + server_conn = ServerConnection( + server_name="test-server", + server_config=MCPServerSettings(name="test-server", url="http://example.com/mcp"), + transport_context_factory=DummyTransportContext, + client_session_factory=session_factory, + ) + + lifecycle_task = asyncio.create_task(_server_lifecycle_task(server_conn)) + try: + await asyncio.wait_for(server_conn.wait_for_initialized(), timeout=1.0) + finally: + await lifecycle_task + + assert server_conn._error_occurred is True