Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 94 additions & 5 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl<'de> Deserialize<'de> for ProtocolVersion {
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub enum NumberOrString {
/// A numeric identifier
Number(u32),
Number(i64),
/// A string identifier
String(Arc<str>),
}
Expand Down Expand Up @@ -227,10 +227,20 @@ impl<'de> Deserialize<'de> for NumberOrString {
{
let value: Value = Deserialize::deserialize(deserializer)?;
match value {
Value::Number(n) => Ok(NumberOrString::Number(
n.as_u64()
.ok_or(serde::de::Error::custom("Expect an integer"))? as u32,
)),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
Ok(NumberOrString::Number(i))
} else if let Some(u) = n.as_u64() {
// Handle large unsigned numbers that fit in i64
if u <= i64::MAX as u64 {
Ok(NumberOrString::Number(u as i64))
} else {
Err(serde::de::Error::custom("Number too large for i64"))
}
} else {
Err(serde::de::Error::custom("Expected an integer"))
}
}
Value::String(s) => Ok(NumberOrString::String(s.into())),
_ => Err(serde::de::Error::custom("Expect number or string")),
}
Expand Down Expand Up @@ -1736,6 +1746,85 @@ mod tests {
assert_eq!(server_response_json, raw_response_json);
}

#[test]
fn test_negative_and_large_request_ids() {
// Test negative ID
let negative_id_json = json!({
"jsonrpc": "2.0",
"id": -1,
"method": "test",
"params": {}
});

let message: JsonRpcMessage =
serde_json::from_value(negative_id_json.clone()).expect("Should parse negative ID");

match &message {
JsonRpcMessage::Request(r) => {
assert_eq!(r.id, RequestId::Number(-1));
}
_ => panic!("Expected Request"),
}

// Test roundtrip serialization
let serialized = serde_json::to_value(&message).expect("Should serialize");
assert_eq!(serialized, negative_id_json);

// Test large negative ID
let large_negative_json = json!({
"jsonrpc": "2.0",
"id": -9007199254740991i64, // JavaScript's MIN_SAFE_INTEGER
"method": "test",
"params": {}
});

let message: JsonRpcMessage = serde_json::from_value(large_negative_json.clone())
.expect("Should parse large negative ID");

match &message {
JsonRpcMessage::Request(r) => {
assert_eq!(r.id, RequestId::Number(-9007199254740991i64));
}
_ => panic!("Expected Request"),
}

// Test large positive ID (JavaScript's MAX_SAFE_INTEGER)
let large_positive_json = json!({
"jsonrpc": "2.0",
"id": 9007199254740991i64,
"method": "test",
"params": {}
});

let message: JsonRpcMessage = serde_json::from_value(large_positive_json.clone())
.expect("Should parse large positive ID");

match &message {
JsonRpcMessage::Request(r) => {
assert_eq!(r.id, RequestId::Number(9007199254740991i64));
}
_ => panic!("Expected Request"),
}

// Test zero ID
let zero_id_json = json!({
"jsonrpc": "2.0",
"id": 0,
"method": "test",
"params": {}
});

let message: JsonRpcMessage =
serde_json::from_value(zero_id_json.clone()).expect("Should parse zero ID");

match &message {
JsonRpcMessage::Request(r) => {
assert_eq!(r.id, RequestId::Number(0));
}
_ => panic!("Expected Request"),
}
}

#[test]
fn test_protocol_version_order() {
let v1 = ProtocolVersion::V_2024_11_05;
Expand Down
16 changes: 13 additions & 3 deletions crates/rmcp/src/model/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,19 @@ impl Meta {
pub fn get_progress_token(&self) -> Option<ProgressToken> {
self.0.get(PROGRESS_TOKEN_FIELD).and_then(|v| match v {
Value::String(s) => Some(ProgressToken(NumberOrString::String(s.to_string().into()))),
Value::Number(n) => n
.as_u64()
.map(|n| ProgressToken(NumberOrString::Number(n as u32))),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
Some(ProgressToken(NumberOrString::Number(i)))
} else if let Some(u) = n.as_u64() {
if u <= i64::MAX as u64 {
Some(ProgressToken(NumberOrString::Number(u as i64)))
} else {
None
}
} else {
None
}
}
_ => None,
})
}
Expand Down
13 changes: 7 additions & 6 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl<R: ServiceRole, S: Service<R>> DynService<R> for S {
use std::{
collections::{HashMap, VecDeque},
ops::Deref,
sync::{Arc, atomic::AtomicU32},
sync::{Arc, atomic::AtomicU64},
time::Duration,
};

Expand All @@ -212,20 +212,21 @@ pub type AtomicU32ProgressTokenProvider = AtomicU32Provider;

#[derive(Debug, Default)]
pub struct AtomicU32Provider {
id: AtomicU32,
id: AtomicU64,
}

impl RequestIdProvider for AtomicU32Provider {
fn next_request_id(&self) -> RequestId {
RequestId::Number(self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst))
let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
// Safe conversion: we start at 0 and increment by 1, so we won't overflow i64::MAX in practice
RequestId::Number(id as i64)
}
}

impl ProgressTokenProvider for AtomicU32Provider {
fn next_progress_token(&self) -> ProgressToken {
ProgressToken(NumberOrString::Number(
self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
))
let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
ProgressToken(NumberOrString::Number(id as i64))
}
}

Expand Down