diff --git a/mcp/streamable.go b/mcp/streamable.go index 8ff9cd1f..470bc7f7 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1553,6 +1553,30 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques done := make(chan struct{}) stream.done = done stream.protocolVersion = effectiveVersion + + // Reject any call whose ID is already in flight on this session, + // atomically and without partial registration. + c.mu.Lock() + for reqID := range calls { + if _, ok := c.requestStreams[reqID]; ok { + c.mu.Unlock() + writeJSONRPCError(w, http.StatusBadRequest, reqID, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidRequest, + Message: fmt.Sprintf("duplicate in-flight request ID %v", reqID.Raw()), + }) + return + } + } + c.streams[stream.id] = stream + for reqID := range calls { + c.requestStreams[reqID] = stream.id + } + c.mu.Unlock() + + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + defer stream.release() + if c.jsonResponse { // JSON mode: collect messages in pendingJSONMessages until done. // Set pendingJSONMessages to a non-nil value to signal that this is an @@ -1581,20 +1605,6 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } - // TODO(rfindley): if we have no event store, we should really cancel all - // remaining requests here, since the client will never get the results. - defer stream.release() - - // The stream is now set up to deliver messages. - // - // Register it before publishing incoming messages. - c.mu.Lock() - c.streams[stream.id] = stream - for reqID := range calls { - c.requestStreams[reqID] = stream.id - } - c.mu.Unlock() - // Publish incoming messages. for _, msg := range incoming { select { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2efee20e..f3bc98db 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3716,3 +3716,67 @@ func TestStreamableHTTP_E2E_DiscoverSuccess(t *testing.T) { t.Errorf("CallTool result[0] = %+v, want TextContent{Text:\"hello\"}", res.Content[0]) } } + +// A POST whose call ID is already in flight on the same session must be rejected +// with a JSON-RPC error, must not overwrite the existing request-to-stream +// mapping, and must not publish the duplicate message to the session's incoming +// channel. +func TestStreamableServerRejectsDuplicateInFlightRequestID(t *testing.T) { + id := jsonrpc2.Int64ID(1) + conn := &streamableServerConn{ + logger: ensureLogger(nil), + incoming: make(chan jsonrpc.Message, 1), + done: make(chan struct{}), + streams: map[string]*stream{ + "existing": { + id: "existing", + logger: ensureLogger(nil), + requests: map[jsonrpc.ID]struct{}{id: {}}, + }, + }, + requestStreams: map[jsonrpc.ID]string{id: "existing"}, + } + + data, err := jsonrpc2.EncodeMessage(req(1, methodPing, &PingParams{})) + if err != nil { + t.Fatalf("EncodeMessage() error = %v", err) + } + ctx, cancel := context.WithTimeout(t.Context(), 20*time.Millisecond) + defer cancel() + httpReq := httptest.NewRequestWithContext(ctx, http.MethodPost, "/", bytes.NewReader(data)) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json, text/event-stream") + + rec := httptest.NewRecorder() + conn.servePOST(rec, httpReq) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } + msg, err := jsonrpc2.DecodeMessage(rec.Body.Bytes()) + if err != nil { + t.Fatalf("DecodeMessage() error = %v; body = %s", err, rec.Body.String()) + } + resp, ok := msg.(*jsonrpc.Response) + if !ok { + t.Fatalf("response type = %T, want *jsonrpc.Response", msg) + } + if got := resp.ID.Raw(); got != int64(1) { + t.Fatalf("response ID = %v, want 1", got) + } + var jerr *jsonrpc.Error + if !errors.As(resp.Error, &jerr) { + t.Fatalf("response error = %v, want *jsonrpc.Error", resp.Error) + } + if jerr.Code != jsonrpc.CodeInvalidRequest { + t.Fatalf("error code = %d, want %d", jerr.Code, jsonrpc.CodeInvalidRequest) + } + if got := conn.requestStreams[id]; got != "existing" { + t.Fatalf("requestStreams[%v] = %q, want %q", id, got, "existing") + } + select { + case msg := <-conn.incoming: + t.Fatalf("duplicate request was published to incoming: %v", msg) + default: + } +}