Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,20 @@ mcp:

- To disable OAuth for a specific server , set `auth.oauth: false` for that server.

## MCP Ping (optional)

The MCP ping utility can be enabled by either peer (client or server). See the [Ping overview](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/ping#overview).

Client-side pinging is configured per server (default: 30s interval, 3 missed pings):

```yaml
mcp:
servers:
myserver:
ping_interval_seconds: 30 # optional; <=0 disables
max_missed_pings: 3 # optional; consecutive timeouts before marking failed
```

## Workflows

### Chain
Expand Down
16 changes: 16 additions & 0 deletions src/fast_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ class MCPServerSettings(BaseModel):
read_timeout_seconds: int | None = None
"""The timeout in seconds for the session."""

ping_interval_seconds: int = 30
"""Interval for MCP ping requests. Set <=0 to disable pinging."""

max_missed_pings: int = 3
"""Number of consecutive missed ping responses before treating the connection as failed."""

http_timeout_seconds: int | None = None
"""Overall HTTP timeout (seconds) for StreamableHTTP transport. Defaults to MCP SDK."""

Expand Down Expand Up @@ -262,6 +268,16 @@ class MCPServerSettings(BaseModel):

implementation: Implementation | None = None

@field_validator("max_missed_pings", mode="before")
@classmethod
def _coerce_max_missed_pings(cls, value: Any) -> int:
if isinstance(value, str):
value = int(value.strip())
value = int(value)
if value <= 0:
raise ValueError("max_missed_pings must be greater than zero.")
return value

@model_validator(mode="before")
@classmethod
def validate_transport_inference(cls, values):
Expand Down
39 changes: 39 additions & 0 deletions src/fast_agent/mcp/mcp_agent_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
CallToolRequestParams,
CallToolResult,
ClientRequest,
EmptyResult,
GetPromptRequest,
GetPromptRequestParams,
GetPromptResult,
Implementation,
ListRootsResult,
PingRequest,
ReadResourceRequest,
ReadResourceRequestParams,
ReadResourceResult,
Expand Down Expand Up @@ -224,6 +226,13 @@ async def send_request(
) -> ReceiveResultT:
logger.debug("send_request: request=", data=request.model_dump())
request_id = getattr(self, "_request_id", None)
is_ping_request = self._is_ping_request(request)
if (
is_ping_request
and request_id is not None
and self._transport_metrics is not None
):
self._transport_metrics.register_ping_request(request_id)
try:
result = await super().send_request(
request=request,
Expand All @@ -237,8 +246,20 @@ async def send_request(
data=result.model_dump() if result is not None else "no response returned",
)
self._attach_transport_channel(request_id, result)
if (
is_ping_request
and request_id is not None
and self._transport_metrics is not None
):
self._transport_metrics.discard_ping_request(request_id)
return result
except Exception as e:
if (
is_ping_request
and request_id is not None
and self._transport_metrics is not None
):
self._transport_metrics.discard_ping_request(request_id)
from anyio import ClosedResourceError

from fast_agent.core.exceptions import ServerSessionTerminatedError
Expand All @@ -262,6 +283,15 @@ async def send_request(
logger.error(f"send_request failed: {str(e)}")
raise

@staticmethod
def _is_ping_request(request: ClientRequest) -> bool:
root = getattr(request, "root", None)
method = getattr(root, "method", None)
if not isinstance(method, str):
return False
method_lower = method.lower()
return method_lower == "ping" or method_lower.endswith("/ping") or method_lower.endswith(".ping")

def _is_session_terminated_error(self, exc: Exception) -> bool:
"""Check if exception is a session terminated error (code 32600 from 404)."""
from mcp.shared.exceptions import McpError
Expand Down Expand Up @@ -365,6 +395,15 @@ async def call_tool(
progress_callback=progress_callback,
)

async def ping(self, read_timeout_seconds: timedelta | None = None) -> EmptyResult:
"""Send a ping request to check server liveness."""
request = PingRequest(method="ping")
return await self.send_request(
ClientRequest(request),
EmptyResult,
request_read_timeout_seconds=read_timeout_seconds,
)

async def read_resource(
self, uri: AnyUrl | str, _meta: dict | None = None, **kwargs
) -> ReadResourceResult:
Expand Down
137 changes: 100 additions & 37 deletions src/fast_agent/mcp/mcp_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ class ServerStatus(BaseModel):
transport_channels: TransportSnapshot | None = None
skybridge: SkybridgeServerConfig | None = None
reconnect_count: int = 0
ping_interval_seconds: int | None = None
ping_max_missed: int | None = None
ping_ok_count: int | None = None
ping_fail_count: int | None = None
ping_consecutive_failures: int | None = None
ping_last_ok_at: datetime | None = None
ping_last_fail_at: datetime | None = None
ping_last_error: str | None = None
ping_activity_buckets: list[str] | None = None
ping_activity_bucket_seconds: int | None = None
ping_activity_bucket_count: int | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -997,49 +1008,88 @@ async def collect_server_status(self) -> dict[str, ServerStatus]:
server_conn = None
transport: str | None = None
transport_snapshot: TransportSnapshot | None = None
ping_interval_seconds: int | None = None
ping_max_missed: int | None = None
ping_ok_count: int | None = None
ping_fail_count: int | None = None
ping_consecutive_failures: int | None = None
ping_last_ok_at: datetime | None = None
ping_last_fail_at: datetime | None = None
ping_last_error: str | None = None
ping_activity_buckets: list[str] | None = None
ping_activity_bucket_seconds: int | None = None
ping_activity_bucket_count: int | None = None

manager = getattr(self, "_persistent_connection_manager", None)
if self.connection_persistence and manager is not None:
try:
server_conn = await manager.get_server(
server_name,
client_session_factory=self._create_session_factory(server_name),
)
implementation = server_conn.server_implementation
if implementation is not None:
implementation_name = implementation.name
implementation_version = implementation.version
capabilities = server_conn.server_capabilities
client_capabilities = server_conn.client_capabilities
session = server_conn.session
client_info = getattr(session, "client_info", None) if session else None
if client_info:
client_info_name = getattr(client_info, "name", None)
client_info_version = getattr(client_info, "version", None)
is_connected = server_conn.is_healthy()
error_message = server_conn._error_message
instructions_available = server_conn.server_instructions_available
instructions_enabled = server_conn.server_instructions_enabled
instructions_included = bool(server_conn.server_instructions)
server_cfg = server_conn.server_config
if session:
elicitation_mode = session.effective_elicitation_mode
session_id = server_conn.session_id
if not session_id and server_conn._get_session_id_cb:
async with manager._lock:
server_conn = manager.running_servers.get(server_name)
if server_conn is None:
is_connected = False
else:
implementation = server_conn.server_implementation
if implementation is not None:
implementation_name = implementation.name
implementation_version = implementation.version
capabilities = server_conn.server_capabilities
client_capabilities = server_conn.client_capabilities
session = server_conn.session
client_info = getattr(session, "client_info", None) if session else None
if client_info:
client_info_name = getattr(client_info, "name", None)
client_info_version = getattr(client_info, "version", None)
if server_conn._initialized_event.is_set():
is_connected = server_conn.is_healthy()
else:
is_connected = False
error_message = error_message or "initializing..."
error_message = error_message or server_conn._error_message
instructions_available = server_conn.server_instructions_available
instructions_enabled = server_conn.server_instructions_enabled
instructions_included = bool(server_conn.server_instructions)
server_cfg = server_conn.server_config
ping_interval_seconds = server_cfg.ping_interval_seconds
ping_max_missed = server_cfg.max_missed_pings
ping_ok_count = server_conn._ping_ok_count
ping_fail_count = server_conn._ping_fail_count
ping_consecutive_failures = server_conn._ping_consecutive_failures
ping_last_ok_at = server_conn._ping_last_ok_at
ping_last_fail_at = server_conn._ping_last_fail_at
ping_last_error = server_conn._ping_last_error
if session:
elicitation_mode = session.effective_elicitation_mode
session_id = server_conn.session_id
if not session_id and server_conn._get_session_id_cb:
try:
session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined]
except Exception:
session_id = None
metrics = server_conn.transport_metrics
if metrics is not None:
try:
session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined]
transport_snapshot = metrics.snapshot()
except Exception:
session_id = None
metrics = server_conn.transport_metrics
if metrics is not None:
try:
transport_snapshot = metrics.snapshot()
except Exception:
logger.debug(
"Failed to snapshot transport metrics for server '%s'",
server_name,
exc_info=True,
)
logger.debug(
"Failed to snapshot transport metrics for server '%s'",
server_name,
exc_info=True,
)
bucket_seconds = (
transport_snapshot.activity_bucket_seconds
if transport_snapshot and transport_snapshot.activity_bucket_seconds
else 30
)
bucket_count = (
transport_snapshot.activity_bucket_count
if transport_snapshot and transport_snapshot.activity_bucket_count
else 20
)
ping_activity_buckets = server_conn.build_ping_activity_buckets(
bucket_seconds, bucket_count
)
ping_activity_bucket_seconds = bucket_seconds
ping_activity_bucket_count = bucket_count
except Exception as exc:
logger.debug(
f"Failed to collect status for server '{server_name}'",
Expand Down Expand Up @@ -1068,6 +1118,8 @@ async def collect_server_status(self) -> dict[str, ServerStatus]:
elicitation_mode = (
getattr(elicitation, "mode", None) if elicitation else elicitation_mode
)
ping_interval_seconds = ping_interval_seconds or server_cfg.ping_interval_seconds
ping_max_missed = ping_max_missed or server_cfg.max_missed_pings
sampling_cfg = server_cfg.sampling
spoofing_enabled = server_cfg.implementation is not None
if implementation_name is None and server_cfg.implementation is not None:
Expand Down Expand Up @@ -1123,6 +1175,17 @@ async def collect_server_status(self) -> dict[str, ServerStatus]:
transport_channels=transport_snapshot,
skybridge=self._skybridge_configs.get(server_name),
reconnect_count=reconnect_count,
ping_interval_seconds=ping_interval_seconds,
ping_max_missed=ping_max_missed,
ping_ok_count=ping_ok_count,
ping_fail_count=ping_fail_count,
ping_consecutive_failures=ping_consecutive_failures,
ping_last_ok_at=ping_last_ok_at,
ping_last_fail_at=ping_last_fail_at,
ping_last_error=ping_last_error,
ping_activity_buckets=ping_activity_buckets,
ping_activity_bucket_seconds=ping_activity_bucket_seconds,
ping_activity_bucket_count=ping_activity_bucket_count,
)

return status_map
Expand Down
Loading
Loading