From 661086939f775ad7c45fc0dd4c8ab88df74f7d71 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 16 Jun 2026 11:44:22 +0000 Subject: [PATCH 1/2] fix: prevent duplicate in-flight request IDs by validating registrations before stream initialization --- mcp/streamable.go | 41 ++++++++++++++++++--------- mcp/streamable_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 14 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 8ff9cd1f..2c95bcac 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1553,6 +1553,33 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques done := make(chan struct{}) stream.done = done stream.protocolVersion = effectiveVersion + + // Register the stream before publishing incoming messages so the server + // can route responses back to this HTTP request. Reject any call whose ID + // is already in flight on this session, atomically and without partial + // registration: pass 1 checks all IDs, pass 2 mutates only if all are + // fresh. + c.mu.Lock() + for reqID := range calls { + if _, ok := c.requestStreams[reqID]; ok { + c.mu.Unlock() + writeJSONRPCError(w, http.StatusBadRequest, reqID, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidRequest, + Message: fmt.Sprintf("duplicate in-flight request ID %v", reqID.Raw()), + }) + return + } + } + c.streams[stream.id] = stream + for reqID := range calls { + c.requestStreams[reqID] = stream.id + } + c.mu.Unlock() + + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + defer stream.release() + if c.jsonResponse { // JSON mode: collect messages in pendingJSONMessages until done. // Set pendingJSONMessages to a non-nil value to signal that this is an @@ -1581,20 +1608,6 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } - // TODO(rfindley): if we have no event store, we should really cancel all - // remaining requests here, since the client will never get the results. - defer stream.release() - - // The stream is now set up to deliver messages. - // - // Register it before publishing incoming messages. - c.mu.Lock() - c.streams[stream.id] = stream - for reqID := range calls { - c.requestStreams[reqID] = stream.id - } - c.mu.Unlock() - // Publish incoming messages. for _, msg := range incoming { select { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2efee20e..f3bc98db 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3716,3 +3716,67 @@ func TestStreamableHTTP_E2E_DiscoverSuccess(t *testing.T) { t.Errorf("CallTool result[0] = %+v, want TextContent{Text:\"hello\"}", res.Content[0]) } } + +// A POST whose call ID is already in flight on the same session must be rejected +// with a JSON-RPC error, must not overwrite the existing request-to-stream +// mapping, and must not publish the duplicate message to the session's incoming +// channel. +func TestStreamableServerRejectsDuplicateInFlightRequestID(t *testing.T) { + id := jsonrpc2.Int64ID(1) + conn := &streamableServerConn{ + logger: ensureLogger(nil), + incoming: make(chan jsonrpc.Message, 1), + done: make(chan struct{}), + streams: map[string]*stream{ + "existing": { + id: "existing", + logger: ensureLogger(nil), + requests: map[jsonrpc.ID]struct{}{id: {}}, + }, + }, + requestStreams: map[jsonrpc.ID]string{id: "existing"}, + } + + data, err := jsonrpc2.EncodeMessage(req(1, methodPing, &PingParams{})) + if err != nil { + t.Fatalf("EncodeMessage() error = %v", err) + } + ctx, cancel := context.WithTimeout(t.Context(), 20*time.Millisecond) + defer cancel() + httpReq := httptest.NewRequestWithContext(ctx, http.MethodPost, "/", bytes.NewReader(data)) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json, text/event-stream") + + rec := httptest.NewRecorder() + conn.servePOST(rec, httpReq) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } + msg, err := jsonrpc2.DecodeMessage(rec.Body.Bytes()) + if err != nil { + t.Fatalf("DecodeMessage() error = %v; body = %s", err, rec.Body.String()) + } + resp, ok := msg.(*jsonrpc.Response) + if !ok { + t.Fatalf("response type = %T, want *jsonrpc.Response", msg) + } + if got := resp.ID.Raw(); got != int64(1) { + t.Fatalf("response ID = %v, want 1", got) + } + var jerr *jsonrpc.Error + if !errors.As(resp.Error, &jerr) { + t.Fatalf("response error = %v, want *jsonrpc.Error", resp.Error) + } + if jerr.Code != jsonrpc.CodeInvalidRequest { + t.Fatalf("error code = %d, want %d", jerr.Code, jsonrpc.CodeInvalidRequest) + } + if got := conn.requestStreams[id]; got != "existing" { + t.Fatalf("requestStreams[%v] = %q, want %q", id, got, "existing") + } + select { + case msg := <-conn.incoming: + t.Fatalf("duplicate request was published to incoming: %v", msg) + default: + } +} From 8d43a3b6627ad77581cdfeec7d2e91753ec37bc1 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 16 Jun 2026 12:44:01 +0000 Subject: [PATCH 2/2] refactor: simplify stream registration comment in streamable.go --- mcp/streamable.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 2c95bcac..470bc7f7 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1554,11 +1554,8 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques stream.done = done stream.protocolVersion = effectiveVersion - // Register the stream before publishing incoming messages so the server - // can route responses back to this HTTP request. Reject any call whose ID - // is already in flight on this session, atomically and without partial - // registration: pass 1 checks all IDs, pass 2 mutates only if all are - // fresh. + // Reject any call whose ID is already in flight on this session, + // atomically and without partial registration. c.mu.Lock() for reqID := range calls { if _, ok := c.requestStreams[reqID]; ok {