Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/fast_agent/mcp/ping_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Shared ping failure tracking for MCP transport layers."""

from __future__ import annotations

import logging

logger = logging.getLogger(__name__)

DEFAULT_PING_FAILURE_RESET_THRESHOLD = 3


class PingFailureTracker:
"""Tracks consecutive ping failures and determines when to reset connection state."""

def __init__(
self,
url: str,
threshold: int = DEFAULT_PING_FAILURE_RESET_THRESHOLD,
) -> None:
"""Initialize ping failure tracker.

Args:
url: URL being tracked (for logging context)
threshold: Number of consecutive failures before reset is recommended
"""
self.url = url
self.threshold = threshold
self._count = 0

def record_failure(self) -> tuple[int, bool]:
"""Record a ping failure and return count and reset recommendation.

Returns:
Tuple of (failure_count, should_reset) where should_reset is True
if threshold has been reached.
"""
self._count += 1
logger.warning(
"Ping timeout waiting for keepalive on %s (%s/%s)",
self.url,
self._count,
self.threshold,
)
should_reset = self._count >= self.threshold
if should_reset:
logger.warning("Multiple ping timeouts on %s; clearing resumption state", self.url)
return self._count, should_reset

def reset(self) -> None:
"""Reset the failure count (called on successful ping or non-timeout error)."""
if self._count > 0:
logger.debug("Resetting ping failure count for %s", self.url)
self._count = 0

@property
def count(self) -> int:
"""Current failure count."""
return self._count

def format_detail(self) -> str:
"""Format error detail string with current failure count."""
detail = f"Ping timeout waiting for keepalive ({self._count}/{self.threshold})"
if self._count >= self.threshold:
detail += "; clearing resumption state"
return detail
22 changes: 19 additions & 3 deletions src/fast_agent/mcp/sse_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage

from fast_agent.mcp.ping_tracker import PingFailureTracker
from fast_agent.mcp.transport_tracking import ChannelEvent, ChannelName

if TYPE_CHECKING:
Expand Down Expand Up @@ -105,6 +106,7 @@ async def tracking_sse_client(
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

session_id: str | None = None
ping_tracker = PingFailureTracker(url)

def get_session_id() -> str | None:
return session_id
Expand All @@ -124,6 +126,7 @@ async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
ping_tracker.reset()
async for sse in event_source.aiter_sse():
if sse.event == "endpoint":
endpoint_url = urljoin(url, sse.data)
Expand Down Expand Up @@ -167,7 +170,9 @@ async def sse_reader(

_emit_channel_event(channel_hook, "get", "message", message=message)
await read_stream_writer.send(SessionMessage(message))
ping_tracker.reset()
else:
ping_tracker.reset()
_emit_channel_event(
channel_hook,
"get",
Expand All @@ -176,6 +181,7 @@ async def sse_reader(
)
except SSEError as sse_exc:
logger.exception("Encountered SSE exception")
ping_tracker.reset()
_emit_channel_event(
channel_hook,
"get",
Expand All @@ -184,14 +190,24 @@ async def sse_reader(
)
raise
except Exception as exc:
logger.exception("Error in sse_reader")
if isinstance(exc, (httpx.ReadTimeout, httpx.TimeoutException)):
_, should_reset = ping_tracker.record_failure()
detail = ping_tracker.format_detail()
if should_reset:
logger.warning("SSE ping timeout on %s", url)
error = ConnectionError(detail)
else:
ping_tracker.reset()
detail = str(exc)
logger.exception("Error in sse_reader")
error = exc
_emit_channel_event(
channel_hook,
"get",
"error",
detail=str(exc),
detail=detail,
)
await read_stream_writer.send(exc)
await read_stream_writer.send(error)
finally:
await read_stream_writer.aclose()

Expand Down
31 changes: 27 additions & 4 deletions src/fast_agent/mcp/streamable_http_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse

from fast_agent.mcp.ping_tracker import PingFailureTracker
from fast_agent.mcp.transport_tracking import ChannelEvent, ChannelName

if TYPE_CHECKING:

from anyio.abc import ObjectReceiveStream, ObjectSendStream

logger = logging.getLogger(__name__)
Expand All @@ -43,6 +43,7 @@ def __init__(
) -> None:
super().__init__(url)
self._channel_hook = channel_hook
self._ping_tracker = PingFailureTracker(url)

def _emit_channel_event(
self,
Expand Down Expand Up @@ -70,6 +71,9 @@ def _emit_channel_event(
except Exception: # pragma: no cover - hook errors must not break transport
logger.exception("Channel hook raised an exception")

def _reset_ping_failures(self) -> None:
self._ping_tracker.reset()

async def _handle_json_response( # type: ignore[override]
self,
response: httpx.Response,
Expand Down Expand Up @@ -101,6 +105,7 @@ async def _handle_sse_event_with_channel(
) -> bool:
if sse.event != "message":
# Treat non-message events (e.g. ping) as keepalive notifications
self._reset_ping_failures()
self._emit_channel_event(channel, "keepalive", raw_event=sse.event or "keepalive")
return False

Expand All @@ -121,6 +126,7 @@ async def _handle_sse_event_with_channel(
):
message.root.id = original_request_id

self._reset_ping_failures()
self._emit_channel_event(channel, "message", message=message)
await read_stream_writer.send(SessionMessage(message))

Expand Down Expand Up @@ -163,6 +169,7 @@ async def handle_get_stream( # type: ignore[override]
event_source.response.raise_for_status()
self._emit_channel_event("get", "connect")
connected = True
self._reset_ping_failures()

async for sse in event_source.aiter_sse():
if sse.id:
Expand All @@ -179,10 +186,22 @@ async def handle_get_stream( # type: ignore[override]
attempt = 0

except Exception as exc: # pragma: no cover - non fatal stream errors
is_ping_timeout = isinstance(exc, (httpx.ReadTimeout, httpx.TimeoutException))
reset_connection = False
if is_ping_timeout:
_, reset_connection = self._ping_tracker.record_failure()
if reset_connection:
last_event_id = None
retry_interval_ms = None
else:
self._ping_tracker.reset()
logger.debug("GET stream error: %s", exc)
attempt += 1
status_code = None
detail = str(exc)
if is_ping_timeout:
detail = self._ping_tracker.format_detail()
else:
detail = str(exc)
if isinstance(exc, httpx.HTTPStatusError) and exc.response is not None:
status_code = exc.response.status_code
reason = exc.response.reason_phrase or ""
Expand All @@ -201,7 +220,9 @@ async def handle_get_stream( # type: ignore[override]
return

delay_ms = (
retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
retry_interval_ms
if retry_interval_ms is not None
else DEFAULT_RECONNECTION_DELAY_MS
)
logger.info("GET stream disconnected, reconnecting in %sms...", delay_ms)
await anyio.sleep(delay_ms / 1000.0)
Expand Down Expand Up @@ -290,7 +311,9 @@ async def _handle_reconnection( # type: ignore[override]
) # pragma: no cover
return

delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
delay_ms = (
retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
)
await anyio.sleep(delay_ms / 1000.0)

headers = self._prepare_headers()
Expand Down
Loading