diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 45764f803..ea11723f1 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -185,7 +185,7 @@ impl<'de> Deserialize<'de> for ProtocolVersion { #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub enum NumberOrString { /// A numeric identifier - Number(u32), + Number(i64), /// A string identifier String(Arc), } @@ -227,10 +227,20 @@ impl<'de> Deserialize<'de> for NumberOrString { { let value: Value = Deserialize::deserialize(deserializer)?; match value { - Value::Number(n) => Ok(NumberOrString::Number( - n.as_u64() - .ok_or(serde::de::Error::custom("Expect an integer"))? as u32, - )), + Value::Number(n) => { + if let Some(i) = n.as_i64() { + Ok(NumberOrString::Number(i)) + } else if let Some(u) = n.as_u64() { + // Handle large unsigned numbers that fit in i64 + if u <= i64::MAX as u64 { + Ok(NumberOrString::Number(u as i64)) + } else { + Err(serde::de::Error::custom("Number too large for i64")) + } + } else { + Err(serde::de::Error::custom("Expected an integer")) + } + } Value::String(s) => Ok(NumberOrString::String(s.into())), _ => Err(serde::de::Error::custom("Expect number or string")), } @@ -1736,6 +1746,85 @@ mod tests { assert_eq!(server_response_json, raw_response_json); } + #[test] + fn test_negative_and_large_request_ids() { + // Test negative ID + let negative_id_json = json!({ + "jsonrpc": "2.0", + "id": -1, + "method": "test", + "params": {} + }); + + let message: JsonRpcMessage = + serde_json::from_value(negative_id_json.clone()).expect("Should parse negative ID"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(-1)); + } + _ => panic!("Expected Request"), + } + + // Test roundtrip serialization + let serialized = serde_json::to_value(&message).expect("Should serialize"); + assert_eq!(serialized, negative_id_json); + + // Test large negative ID + let large_negative_json = json!({ + "jsonrpc": "2.0", + "id": -9007199254740991i64, // JavaScript's MIN_SAFE_INTEGER + "method": "test", + "params": {} + }); + + let message: JsonRpcMessage = serde_json::from_value(large_negative_json.clone()) + .expect("Should parse large negative ID"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(-9007199254740991i64)); + } + _ => panic!("Expected Request"), + } + + // Test large positive ID (JavaScript's MAX_SAFE_INTEGER) + let large_positive_json = json!({ + "jsonrpc": "2.0", + "id": 9007199254740991i64, + "method": "test", + "params": {} + }); + + let message: JsonRpcMessage = serde_json::from_value(large_positive_json.clone()) + .expect("Should parse large positive ID"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(9007199254740991i64)); + } + _ => panic!("Expected Request"), + } + + // Test zero ID + let zero_id_json = json!({ + "jsonrpc": "2.0", + "id": 0, + "method": "test", + "params": {} + }); + + let message: JsonRpcMessage = + serde_json::from_value(zero_id_json.clone()).expect("Should parse zero ID"); + + match &message { + JsonRpcMessage::Request(r) => { + assert_eq!(r.id, RequestId::Number(0)); + } + _ => panic!("Expected Request"), + } + } + #[test] fn test_protocol_version_order() { let v1 = ProtocolVersion::V_2024_11_05; diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index 97baccc96..fd93362b7 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -116,9 +116,19 @@ impl Meta { pub fn get_progress_token(&self) -> Option { self.0.get(PROGRESS_TOKEN_FIELD).and_then(|v| match v { Value::String(s) => Some(ProgressToken(NumberOrString::String(s.to_string().into()))), - Value::Number(n) => n - .as_u64() - .map(|n| ProgressToken(NumberOrString::Number(n as u32))), + Value::Number(n) => { + if let Some(i) = n.as_i64() { + Some(ProgressToken(NumberOrString::Number(i))) + } else if let Some(u) = n.as_u64() { + if u <= i64::MAX as u64 { + Some(ProgressToken(NumberOrString::Number(u as i64))) + } else { + None + } + } else { + None + } + } _ => None, }) } diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index dbf23481e..5fc8934fa 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -193,7 +193,7 @@ impl> DynService for S { use std::{ collections::{HashMap, VecDeque}, ops::Deref, - sync::{Arc, atomic::AtomicU32}, + sync::{Arc, atomic::AtomicU64}, time::Duration, }; @@ -212,20 +212,21 @@ pub type AtomicU32ProgressTokenProvider = AtomicU32Provider; #[derive(Debug, Default)] pub struct AtomicU32Provider { - id: AtomicU32, + id: AtomicU64, } impl RequestIdProvider for AtomicU32Provider { fn next_request_id(&self) -> RequestId { - RequestId::Number(self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)) + let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + // Safe conversion: we start at 0 and increment by 1, so we won't overflow i64::MAX in practice + RequestId::Number(id as i64) } } impl ProgressTokenProvider for AtomicU32Provider { fn next_progress_token(&self) -> ProgressToken { - ProgressToken(NumberOrString::Number( - self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst), - )) + let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + ProgressToken(NumberOrString::Number(id as i64)) } }