diff --git a/src/fast_agent/mcp/ping_tracker.py b/src/fast_agent/mcp/ping_tracker.py new file mode 100644 index 000000000..b5582156b --- /dev/null +++ b/src/fast_agent/mcp/ping_tracker.py @@ -0,0 +1,65 @@ +"""Shared ping failure tracking for MCP transport layers.""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +DEFAULT_PING_FAILURE_RESET_THRESHOLD = 3 + + +class PingFailureTracker: + """Tracks consecutive ping failures and determines when to reset connection state.""" + + def __init__( + self, + url: str, + threshold: int = DEFAULT_PING_FAILURE_RESET_THRESHOLD, + ) -> None: + """Initialize ping failure tracker. + + Args: + url: URL being tracked (for logging context) + threshold: Number of consecutive failures before reset is recommended + """ + self.url = url + self.threshold = threshold + self._count = 0 + + def record_failure(self) -> tuple[int, bool]: + """Record a ping failure and return count and reset recommendation. + + Returns: + Tuple of (failure_count, should_reset) where should_reset is True + if threshold has been reached. + """ + self._count += 1 + logger.warning( + "Ping timeout waiting for keepalive on %s (%s/%s)", + self.url, + self._count, + self.threshold, + ) + should_reset = self._count >= self.threshold + if should_reset: + logger.warning("Multiple ping timeouts on %s; clearing resumption state", self.url) + return self._count, should_reset + + def reset(self) -> None: + """Reset the failure count (called on successful ping or non-timeout error).""" + if self._count > 0: + logger.debug("Resetting ping failure count for %s", self.url) + self._count = 0 + + @property + def count(self) -> int: + """Current failure count.""" + return self._count + + def format_detail(self) -> str: + """Format error detail string with current failure count.""" + detail = f"Ping timeout waiting for keepalive ({self._count}/{self.threshold})" + if self._count >= self.threshold: + detail += "; clearing resumption state" + return detail diff --git a/src/fast_agent/mcp/sse_tracking.py b/src/fast_agent/mcp/sse_tracking.py index 119cfc28d..02fec8c74 100644 --- a/src/fast_agent/mcp/sse_tracking.py +++ b/src/fast_agent/mcp/sse_tracking.py @@ -15,6 +15,7 @@ from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage +from fast_agent.mcp.ping_tracker import PingFailureTracker from fast_agent.mcp.transport_tracking import ChannelEvent, ChannelName if TYPE_CHECKING: @@ -105,6 +106,7 @@ async def tracking_sse_client( write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) session_id: str | None = None + ping_tracker = PingFailureTracker(url) def get_session_id() -> str | None: return session_id @@ -124,6 +126,7 @@ async def sse_reader( task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, ): try: + ping_tracker.reset() async for sse in event_source.aiter_sse(): if sse.event == "endpoint": endpoint_url = urljoin(url, sse.data) @@ -167,7 +170,9 @@ async def sse_reader( _emit_channel_event(channel_hook, "get", "message", message=message) await read_stream_writer.send(SessionMessage(message)) + ping_tracker.reset() else: + ping_tracker.reset() _emit_channel_event( channel_hook, "get", @@ -176,6 +181,7 @@ async def sse_reader( ) except SSEError as sse_exc: logger.exception("Encountered SSE exception") + ping_tracker.reset() _emit_channel_event( channel_hook, "get", @@ -184,14 +190,24 @@ async def sse_reader( ) raise except Exception as exc: - logger.exception("Error in sse_reader") + if isinstance(exc, (httpx.ReadTimeout, httpx.TimeoutException)): + _, should_reset = ping_tracker.record_failure() + detail = ping_tracker.format_detail() + if should_reset: + logger.warning("SSE ping timeout on %s", url) + error = ConnectionError(detail) + else: + ping_tracker.reset() + detail = str(exc) + logger.exception("Error in sse_reader") + error = exc _emit_channel_event( channel_hook, "get", "error", - detail=str(exc), + detail=detail, ) - await read_stream_writer.send(exc) + await read_stream_writer.send(error) finally: await read_stream_writer.aclose() diff --git a/src/fast_agent/mcp/streamable_http_tracking.py b/src/fast_agent/mcp/streamable_http_tracking.py index b09cf2522..fa56b558f 100644 --- a/src/fast_agent/mcp/streamable_http_tracking.py +++ b/src/fast_agent/mcp/streamable_http_tracking.py @@ -21,10 +21,10 @@ from mcp.shared.message import SessionMessage from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +from fast_agent.mcp.ping_tracker import PingFailureTracker from fast_agent.mcp.transport_tracking import ChannelEvent, ChannelName if TYPE_CHECKING: - from anyio.abc import ObjectReceiveStream, ObjectSendStream logger = logging.getLogger(__name__) @@ -43,6 +43,7 @@ def __init__( ) -> None: super().__init__(url) self._channel_hook = channel_hook + self._ping_tracker = PingFailureTracker(url) def _emit_channel_event( self, @@ -70,6 +71,9 @@ def _emit_channel_event( except Exception: # pragma: no cover - hook errors must not break transport logger.exception("Channel hook raised an exception") + def _reset_ping_failures(self) -> None: + self._ping_tracker.reset() + async def _handle_json_response( # type: ignore[override] self, response: httpx.Response, @@ -101,6 +105,7 @@ async def _handle_sse_event_with_channel( ) -> bool: if sse.event != "message": # Treat non-message events (e.g. ping) as keepalive notifications + self._reset_ping_failures() self._emit_channel_event(channel, "keepalive", raw_event=sse.event or "keepalive") return False @@ -121,6 +126,7 @@ async def _handle_sse_event_with_channel( ): message.root.id = original_request_id + self._reset_ping_failures() self._emit_channel_event(channel, "message", message=message) await read_stream_writer.send(SessionMessage(message)) @@ -163,6 +169,7 @@ async def handle_get_stream( # type: ignore[override] event_source.response.raise_for_status() self._emit_channel_event("get", "connect") connected = True + self._reset_ping_failures() async for sse in event_source.aiter_sse(): if sse.id: @@ -179,10 +186,22 @@ async def handle_get_stream( # type: ignore[override] attempt = 0 except Exception as exc: # pragma: no cover - non fatal stream errors + is_ping_timeout = isinstance(exc, (httpx.ReadTimeout, httpx.TimeoutException)) + reset_connection = False + if is_ping_timeout: + _, reset_connection = self._ping_tracker.record_failure() + if reset_connection: + last_event_id = None + retry_interval_ms = None + else: + self._ping_tracker.reset() logger.debug("GET stream error: %s", exc) attempt += 1 status_code = None - detail = str(exc) + if is_ping_timeout: + detail = self._ping_tracker.format_detail() + else: + detail = str(exc) if isinstance(exc, httpx.HTTPStatusError) and exc.response is not None: status_code = exc.response.status_code reason = exc.response.reason_phrase or "" @@ -201,7 +220,9 @@ async def handle_get_stream( # type: ignore[override] return delay_ms = ( - retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS + retry_interval_ms + if retry_interval_ms is not None + else DEFAULT_RECONNECTION_DELAY_MS ) logger.info("GET stream disconnected, reconnecting in %sms...", delay_ms) await anyio.sleep(delay_ms / 1000.0) @@ -290,7 +311,9 @@ async def _handle_reconnection( # type: ignore[override] ) # pragma: no cover return - delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS + delay_ms = ( + retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS + ) await anyio.sleep(delay_ms / 1000.0) headers = self._prepare_headers()