diff --git a/crates/rmcp/src/transport/common/http_header.rs b/crates/rmcp/src/transport/common/http_header.rs index b215ab12..283f0daa 100644 --- a/crates/rmcp/src/transport/common/http_header.rs +++ b/crates/rmcp/src/transport/common/http_header.rs @@ -65,8 +65,10 @@ pub(crate) fn extract_scope_from_header(header: &str) -> Option { #[cfg(test)] mod tests { + #[cfg(feature = "client-side-sse")] use super::*; + #[cfg(feature = "client-side-sse")] #[test] fn extract_scope_quoted() { let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#; @@ -76,6 +78,7 @@ mod tests { ); } + #[cfg(feature = "client-side-sse")] #[test] fn extract_scope_unquoted() { let header = r#"Bearer scope=read:data, error="insufficient_scope""#; @@ -85,12 +88,14 @@ mod tests { ); } + #[cfg(feature = "client-side-sse")] #[test] fn extract_scope_missing() { let header = r#"Bearer error="invalid_token""#; assert_eq!(extract_scope_from_header(header), None); } + #[cfg(feature = "client-side-sse")] #[test] fn extract_scope_empty_header() { assert_eq!(extract_scope_from_header("Bearer"), None); diff --git a/crates/rmcp/src/transport/common/server_side_http.rs b/crates/rmcp/src/transport/common/server_side_http.rs index d24b19af..39a321f9 100644 --- a/crates/rmcp/src/transport/common/server_side_http.rs +++ b/crates/rmcp/src/transport/common/server_side_http.rs @@ -57,7 +57,7 @@ impl sse_stream::Timer for TokioTimer { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] #[non_exhaustive] pub struct ServerSseMessage { /// The event ID for this message. When set, clients can use this ID @@ -71,6 +71,37 @@ pub struct ServerSseMessage { pub retry: Option, } +impl ServerSseMessage { + /// Create a message carrying a JSON-RPC response/notification with an event ID. + pub fn new(event_id: impl Into, message: ServerJsonRpcMessage) -> Self { + Self { + event_id: Some(event_id.into()), + message: Some(Arc::new(message)), + retry: None, + } + } + + /// Wrap a JSON-RPC message without an event ID or retry hint. + pub fn from_message(message: ServerJsonRpcMessage) -> Self { + Self { + event_id: None, + message: Some(Arc::new(message)), + retry: None, + } + } + + /// Create a priming event that tells the client to reconnect after `retry` + /// if the connection drops. + /// See [SEP-1699](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1699). + pub fn priming(event_id: impl Into, retry: Duration) -> Self { + Self { + event_id: Some(event_id.into()), + message: None, + retry: Some(retry), + } + } +} + pub(crate) fn sse_stream_response( stream: impl futures::Stream + Send + Sync + 'static, keep_alive: Option, @@ -169,3 +200,49 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::{EmptyResult, JsonRpcResponse, JsonRpcVersion2_0, RequestId, ServerResult}; + + fn dummy_message() -> ServerJsonRpcMessage { + ServerJsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: JsonRpcVersion2_0, + id: RequestId::Number(1), + result: ServerResult::EmptyResult(EmptyResult {}), + }) + } + + #[test] + fn default_has_all_none() { + let msg = ServerSseMessage::default(); + assert!(msg.event_id.is_none()); + assert!(msg.message.is_none()); + assert!(msg.retry.is_none()); + } + + #[test] + fn new_sets_event_id_and_message() { + let msg = ServerSseMessage::new("42", dummy_message()); + assert_eq!(msg.event_id.as_deref(), Some("42")); + assert!(msg.message.is_some()); + assert!(msg.retry.is_none()); + } + + #[test] + fn from_message_has_no_event_id() { + let msg = ServerSseMessage::from_message(dummy_message()); + assert!(msg.event_id.is_none()); + assert!(msg.message.is_some()); + assert!(msg.retry.is_none()); + } + + #[test] + fn priming_sets_event_id_and_retry() { + let msg = ServerSseMessage::priming("0", Duration::from_secs(5)); + assert_eq!(msg.event_id.as_deref(), Some("0")); + assert!(msg.message.is_none()); + assert_eq!(msg.retry, Some(Duration::from_secs(5))); + } +} diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index 2d2059c5..96d77ffc 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -1,7 +1,6 @@ use std::{ collections::{HashMap, HashSet, VecDeque}, num::ParseIntError, - sync::Arc, time::Duration, }; @@ -222,21 +221,13 @@ impl CachedTx { async fn send(&mut self, message: ServerJsonRpcMessage) { let event_id = self.next_event_id(); - let message = ServerSseMessage { - event_id: Some(event_id.to_string()), - message: Some(Arc::new(message)), - retry: None, - }; + let message = ServerSseMessage::new(event_id.to_string(), message); self.cache_and_send(message).await; } async fn send_priming(&mut self, retry: Duration) { let event_id = self.next_event_id(); - let message = ServerSseMessage { - event_id: Some(event_id.to_string()), - message: None, - retry: Some(retry), - }; + let message = ServerSseMessage::priming(event_id.to_string(), retry); self.cache_and_send(message).await; } diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 8f9c0a70..5dc7996c 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -499,11 +499,7 @@ where .map_err(internal_error_response("create standalone stream"))?; // Prepend priming event if sse_retry configured let stream = if let Some(retry) = self.config.sse_retry { - let priming = ServerSseMessage { - event_id: Some("0".into()), - message: None, - retry: Some(retry), - }; + let priming = ServerSseMessage::priming("0", retry); futures::stream::once(async move { priming }) .chain(stream) .left_stream() @@ -609,11 +605,7 @@ where .map_err(internal_error_response("get session"))?; // Prepend priming event if sse_retry configured let stream = if let Some(retry) = self.config.sse_retry { - let priming = ServerSseMessage { - event_id: Some("0".into()), - message: None, - retry: Some(retry), - }; + let priming = ServerSseMessage::priming("0", retry); futures::stream::once(async move { priming }) .chain(stream) .left_stream() @@ -687,20 +679,11 @@ where .initialize_session(&session_id, message) .await .map_err(internal_error_response("create stream"))?; - let stream = futures::stream::once(async move { - ServerSseMessage { - event_id: None, - message: Some(Arc::new(response)), - retry: None, - } - }); + let stream = + futures::stream::once(async move { ServerSseMessage::from_message(response) }); // Prepend priming event if sse_retry configured let stream = if let Some(retry) = self.config.sse_retry { - let priming = ServerSseMessage { - event_id: Some("0".into()), - message: None, - retry: Some(retry), - }; + let priming = ServerSseMessage::priming("0", retry); futures::stream::once(async move { priming }) .chain(stream) .left_stream() @@ -774,11 +757,7 @@ where // SSE mode (default): original behaviour preserved unchanged let stream = ReceiverStream::new(receiver).map(|message| { tracing::trace!(?message); - ServerSseMessage { - event_id: None, - message: Some(Arc::new(message)), - retry: None, - } + ServerSseMessage::from_message(message) }); Ok(sse_stream_response( stream, diff --git a/crates/rmcp/tests/test_inflight_response_drain.rs b/crates/rmcp/tests/test_inflight_response_drain.rs index b5fc160e..2381644d 100644 --- a/crates/rmcp/tests/test_inflight_response_drain.rs +++ b/crates/rmcp/tests/test_inflight_response_drain.rs @@ -1,4 +1,4 @@ -#![cfg(not(feature = "local"))] +#![cfg(all(feature = "client", feature = "server", not(feature = "local")))] // cargo test --test test_inflight_response_drain --features "client server" use std::{