Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fix-mcp-json-rpc-request-ids.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"effect": patch
---

Fix JSON-RPC serialization to preserve arbitrary request ids for MCP HTTP clients.
133 changes: 93 additions & 40 deletions packages/effect/src/unstable/rpc/RpcSerialization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,16 @@ export const jsonRpc = (options?: {
includesFraming: false,
makeUnsafe: () => {
const decoder = new TextDecoder()
const batches = new Map<string, {
readonly size: number
readonly responses: Map<string, RpcMessage.FromServerEncoded>
}>()
const state = makeJsonRpcState()
return {
decode: (bytes) => {
const decoded: JsonRpcMessage | Array<JsonRpcMessage> = JSON.parse(
typeof bytes === "string" ? bytes : decoder.decode(bytes)
)
return decodeJsonRpcRaw(decoded, batches)
return decodeJsonRpcRaw(decoded, state)
},
encode: (response) => {
const encoded = encodeJsonRpcResponse(response as any, batches)
const encoded = encodeJsonRpcResponse(response as any, state)
return encoded && JSON.stringify(encoded)
}
}
Expand All @@ -188,35 +185,46 @@ export const ndJsonRpc = (options?: {
includesFraming: true,
makeUnsafe: () => {
const parser = ndjson.makeUnsafe()
const batches = new Map<string, {
readonly size: number
readonly responses: Map<string, RpcMessage.FromServerEncoded>
}>()
const state = makeJsonRpcState()
return ({
decode: (bytes) => {
const frames = parser.decode(bytes)
if (frames.length === 0) return []
const messages: Array<RpcMessage.FromClientEncoded | RpcMessage.FromServerEncoded> = []
for (let i = 0; i < frames.length; i++) {
const frame = frames[i]
messages.push(...decodeJsonRpcRaw(frame as any, batches) as any)
messages.push(...decodeJsonRpcRaw(frame as any, state) as any)
}
return messages
},
encode: (response) => {
const encoded = encodeJsonRpcResponse(response as any, batches)
const encoded = encodeJsonRpcResponse(response as any, state)
return encoded && parser.encode(encoded)
}
})
}
})

function decodeJsonRpcRaw(
decoded: JsonRpcMessage | Array<JsonRpcMessage>,
batches: Map<string, {
type JsonRpcRequestId = string | number | null | undefined

type JsonRpcState = {
readonly batches: Map<string, {
readonly size: number
readonly responses: Map<string, RpcMessage.FromServerEncoded>
}>
readonly requestIds: Map<string, JsonRpcRequestId>
nextRequestId: bigint
}

const makeJsonRpcState = (): JsonRpcState => ({
batches: new Map(),
requestIds: new Map(),
nextRequestId: BigInt(-1)
})

function decodeJsonRpcRaw(
decoded: JsonRpcMessage | Array<JsonRpcMessage>,
state: JsonRpcState
) {
if (Array.isArray(decoded)) {
const batch = {
Expand All @@ -225,19 +233,22 @@ function decodeJsonRpcRaw(
}
const messages: Array<RpcMessage.FromClientEncoded | RpcMessage.FromServerEncoded> = []
for (let i = 0; i < decoded.length; i++) {
const message = decodeJsonRpcMessage(decoded[i])
const message = decodeJsonRpcMessage(decoded[i], state)
messages.push(message)
if (message._tag === "Request") {
batch.size++
batches.set(message.id, batch)
state.batches.set(message.id, batch)
}
}
return messages
}
return [decodeJsonRpcMessage(decoded)]
return [decodeJsonRpcMessage(decoded, state)]
}

function decodeJsonRpcMessage(decoded: JsonRpcMessage): RpcMessage.FromClientEncoded | RpcMessage.FromServerEncoded {
function decodeJsonRpcMessage(
decoded: JsonRpcMessage,
state: JsonRpcState
): RpcMessage.FromClientEncoded | RpcMessage.FromServerEncoded {
if ("method" in decoded) {
if (Predicate.isNullish(decoded.id) && decoded.method.startsWith("@effect/rpc/")) {
const tag = decoded.method.slice("@effect/rpc/".length) as
Expand All @@ -251,9 +262,10 @@ function decodeJsonRpcMessage(decoded: JsonRpcMessage): RpcMessage.FromClientEnc
} as any :
{ _tag: tag } as any
}
const id = decodeJsonRpcRequestId(decoded, state)
return {
_tag: "Request",
id: Predicate.isNotNullish(decoded.id) ? String(decoded.id) : "",
id,
tag: decoded.method,
payload: decoded.params ?? null,
headers: decoded.headers ?? [],
Expand Down Expand Up @@ -297,47 +309,82 @@ function decodeJsonRpcMessage(decoded: JsonRpcMessage): RpcMessage.FromClientEnc
}
}

const canonicalIntegerString = /^(?:0|-?[1-9]\d*)$/

function decodeJsonRpcRequestId(decoded: JsonRpcRequest, state: JsonRpcState): string {
const jsonRpcId = Predicate.hasProperty(decoded, "id") ? decoded.id : undefined
let requestId: string | undefined
if (typeof jsonRpcId === "number" && Number.isSafeInteger(jsonRpcId)) {
requestId = String(jsonRpcId)
} else if (typeof jsonRpcId === "string" && canonicalIntegerString.test(jsonRpcId)) {
requestId = jsonRpcId
}
if (requestId === undefined || state.requestIds.has(requestId)) {
requestId = nextJsonRpcRequestId(state)
}
state.requestIds.set(requestId, jsonRpcId)
return requestId
}

function nextJsonRpcRequestId(state: JsonRpcState): string {
let requestId: string
do {
requestId = state.nextRequestId.toString()
state.nextRequestId--
} while (state.requestIds.has(requestId))
return requestId
}

function encodeJsonRpcResponseId(
requestId: string,
state: JsonRpcState,
isExit: boolean
): JsonRpcRequestId {
if (state.requestIds.has(requestId)) {
const jsonRpcId = state.requestIds.get(requestId)
if (isExit) {
state.requestIds.delete(requestId)
}
return jsonRpcId
}
return requestId !== "" ? Number(requestId) : undefined
}

function encodeJsonRpcRaw(
response: RpcMessage.FromServerEncoded | RpcMessage.FromClientEncoded,
batches: Map<string, {
readonly size: number
readonly responses: Map<string, RpcMessage.FromServerEncoded>
}>
state: JsonRpcState
) {
if (!("requestId" in response)) {
return encodeJsonRpcMessage(response)
return encodeJsonRpcMessage(response, state)
}
const batch = batches.get(response.requestId)
const batch = state.batches.get(response.requestId)
if (batch) {
batches.delete(response.requestId)
state.batches.delete(response.requestId)
batch.responses.set(response.requestId, response as any)
if (batch.size === batch.responses.size) {
return Array.from(batch.responses.values(), encodeJsonRpcMessage)
return Array.from(batch.responses.values(), (message) => encodeJsonRpcMessage(message, state))
}
return undefined
}
return encodeJsonRpcMessage(response)
return encodeJsonRpcMessage(response, state)
}

function encodeJsonRpcResponse(
response:
| RpcMessage.FromServerEncoded
| RpcMessage.FromClientEncoded
| Array<RpcMessage.FromServerEncoded | RpcMessage.FromClientEncoded>,
batches: Map<string, {
readonly size: number
readonly responses: Map<string, RpcMessage.FromServerEncoded>
}>
state: JsonRpcState
) {
if (Array.isArray(response) === false) {
return encodeJsonRpcRaw(response, batches)
return encodeJsonRpcRaw(response, state)
}
if (response.length === 0) {
return undefined
}
const encoded: Array<JsonRpcMessage | Array<JsonRpcMessage>> = []
for (let i = 0; i < response.length; i++) {
const current = encodeJsonRpcRaw(response[i], batches)
const current = encodeJsonRpcRaw(response[i], state)
if (current !== undefined) {
encoded.push(current)
}
Expand All @@ -360,7 +407,10 @@ function encodeJsonRpcResponse(
return messages
}

function encodeJsonRpcMessage(response: RpcMessage.FromServerEncoded | RpcMessage.FromClientEncoded): JsonRpcMessage {
function encodeJsonRpcMessage(
response: RpcMessage.FromServerEncoded | RpcMessage.FromClientEncoded,
state: JsonRpcState
): JsonRpcMessage {
switch (response._tag) {
case "Request":
return {
Expand All @@ -383,25 +433,28 @@ function encodeJsonRpcMessage(response: RpcMessage.FromServerEncoded | RpcMessag
method: `@effect/rpc/${response._tag}`,
params: "requestId" in response ? { requestId: response.requestId } : undefined
}
case "Chunk":
case "Chunk": {
const chunkId = encodeJsonRpcResponseId(response.requestId, state, false)
return {
jsonrpc: "2.0",
chunk: true,
id: Number(response.requestId),
...(chunkId === undefined ? {} : { id: chunkId }),
result: response.values
}
}
case "Exit": {
const id = encodeJsonRpcResponseId(response.requestId, state, true)
if (response.exit._tag === "Success") {
return {
jsonrpc: "2.0",
id: response.requestId !== "" ? Number(response.requestId) : undefined,
id,
result: response.exit.value
} as any
}
const error = response.exit.cause.find((failure) => failure._tag === "Fail")
return {
jsonrpc: "2.0",
id: response.requestId !== "" ? Number(response.requestId) : undefined,
id,
error: response.exit._tag === "Failure" ?
{
_tag: "Cause",
Expand Down
75 changes: 67 additions & 8 deletions packages/effect/test/rpc/RpcSerialization.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,87 @@ describe("RpcSerialization", () => {
)
})

it("jsonRpc maps null id to internal notification sentinel", () => {
it("jsonRpc preserves arbitrary string ids across decode and response encode", () => {
const parser = RpcSerialization.jsonRpc().makeUnsafe()
const decoded = parser.decode("{\"jsonrpc\":\"2.0\",\"id\":\"14f40ee1b859ee70\",\"method\":\"users.get\"}")
assert.strictEqual(decoded.length, 1)
const request = decoded[0] as any
assert.strictEqual(request._tag, "Request")

const encoded = parser.encode(responseExitSuccess(request.id, "ok"))
assert(encoded !== undefined)
assert.deepStrictEqual(JSON.parse(encoded as string), {
jsonrpc: "2.0",
id: "14f40ee1b859ee70",
result: "ok"
})
})

it("jsonRpc preserves null id across decode and response encode", () => {
const parser = RpcSerialization.jsonRpc().makeUnsafe()
const decoded = parser.decode("{\"jsonrpc\":\"2.0\",\"id\":null,\"method\":\"users.get\"}")
assert.strictEqual(decoded.length, 1)
const request = decoded[0] as any
assert.strictEqual(request._tag, "Request")

const encoded = parser.encode(responseExitSuccess(request.id, "ok"))
assert(encoded !== undefined)
assert.deepStrictEqual(JSON.parse(encoded as string), {
jsonrpc: "2.0",
id: null,
result: "ok"
})
})

it("jsonRpc preserves omitted id across decode and response encode", () => {
const parser = RpcSerialization.jsonRpc().makeUnsafe()
const decoded = parser.decode("{\"jsonrpc\":\"2.0\",\"method\":\"users.get\"}")
assert.strictEqual(decoded.length, 1)
const request = decoded[0] as any
assert.strictEqual(request._tag, "Request")

const encoded = parser.encode(responseExitSuccess(request.id, "ok"))
assert(encoded !== undefined)
assert.deepStrictEqual(JSON.parse(encoded as string), {
jsonrpc: "2.0",
result: "ok"
})
})

it("jsonRpc preserves numeric string ids across decode and response encode", () => {
const parser = RpcSerialization.jsonRpc().makeUnsafe()
const decoded = parser.decode("{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"method\":\"users.get\"}")
assert.deepStrictEqual(decoded, [{
_tag: "Request",
id: "",
id: "1",
tag: "users.get",
payload: null,
headers: []
}])

const encoded = parser.encode(responseExitSuccess("1", "ok"))
assert(encoded !== undefined)
assert.deepStrictEqual(JSON.parse(encoded as string), {
jsonrpc: "2.0",
id: "1",
result: "ok"
})
})

it("jsonRpc preserves empty string id across decode and encode", () => {
const parser = RpcSerialization.jsonRpc().makeUnsafe()
const decoded = parser.decode("{\"jsonrpc\":\"2.0\",\"id\":\"\",\"method\":\"users.get\"}")
assert.deepStrictEqual(decoded, [{
_tag: "Request",
assert.strictEqual(decoded.length, 1)
const request = decoded[0] as any
assert.strictEqual(request._tag, "Request")

const response = parser.encode(responseExitSuccess(request.id, "ok"))
assert(response !== undefined)
assert.deepStrictEqual(JSON.parse(response as string), {
jsonrpc: "2.0",
id: "",
tag: "users.get",
payload: null,
headers: []
}])
result: "ok"
})

const encoded = parser.encode({
_tag: "Request",
Expand Down
Loading