diff --git a/mcp/streamable.go b/mcp/streamable.go index 5bc31771..f3f6527d 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1550,7 +1550,30 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // Set pendingJSONMessages to a non-nil value to signal that this is an // application/json stream. stream.pendingJSONMessages = []json.RawMessage{} - } else { + } + + c.mu.Lock() + for reqID := range calls { + if _, ok := c.requestStreams[reqID]; ok { + c.mu.Unlock() + writeJSONRPCError(w, http.StatusBadRequest, reqID, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidRequest, + Message: fmt.Sprintf("duplicate in-flight request ID %v", reqID.Raw()), + }) + return + } + } + c.streams[stream.id] = stream + for reqID := range calls { + c.requestStreams[reqID] = stream.id + } + c.mu.Unlock() + + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + defer stream.release() + + if !c.jsonResponse { // SSE mode: write a priming event if supported. // // SEP-2575 removes Last-Event-ID-based resumable streams for protocol @@ -1573,20 +1596,6 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } - // TODO(rfindley): if we have no event store, we should really cancel all - // remaining requests here, since the client will never get the results. - defer stream.release() - - // The stream is now set up to deliver messages. - // - // Register it before publishing incoming messages. - c.mu.Lock() - c.streams[stream.id] = stream - for reqID := range calls { - c.requestStreams[reqID] = stream.id - } - c.mu.Unlock() - // Publish incoming messages. for _, msg := range incoming { select { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2efee20e..d44c824a 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -288,6 +288,66 @@ func TestStreamableConcurrentHandling(t *testing.T) { wg.Wait() } +func TestStreamableServerRejectsDuplicateInFlightRequestID(t *testing.T) { + id := jsonrpc2.Int64ID(1) + conn := &streamableServerConn{ + logger: ensureLogger(nil), + incoming: make(chan jsonrpc.Message, 1), + done: make(chan struct{}), + streams: map[string]*stream{ + "existing": { + id: "existing", + logger: ensureLogger(nil), + requests: map[jsonrpc.ID]struct{}{id: {}}, + }, + }, + requestStreams: map[jsonrpc.ID]string{id: "existing"}, + } + + data, err := jsonrpc2.EncodeMessage(req(1, methodPing, &PingParams{})) + if err != nil { + t.Fatalf("EncodeMessage() error = %v", err) + } + ctx, cancel := context.WithTimeout(t.Context(), 20*time.Millisecond) + defer cancel() + httpReq := httptest.NewRequestWithContext(ctx, http.MethodPost, "/", bytes.NewReader(data)) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json, text/event-stream") + + rec := httptest.NewRecorder() + conn.servePOST(rec, httpReq) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } + msg, err := jsonrpc2.DecodeMessage(rec.Body.Bytes()) + if err != nil { + t.Fatalf("DecodeMessage() error = %v; body = %s", err, rec.Body.String()) + } + resp, ok := msg.(*jsonrpc.Response) + if !ok { + t.Fatalf("response type = %T, want *jsonrpc.Response", msg) + } + if got := resp.ID.Raw(); got != int64(1) { + t.Fatalf("response ID = %v, want 1", got) + } + var jerr *jsonrpc.Error + if !errors.As(resp.Error, &jerr) { + t.Fatalf("response error = %v, want JSON-RPC error", resp.Error) + } + if jerr.Code != jsonrpc.CodeInvalidRequest { + t.Fatalf("error code = %d, want %d", jerr.Code, jsonrpc.CodeInvalidRequest) + } + if got := conn.requestStreams[id]; got != "existing" { + t.Fatalf("requestStreams[%v] = %q, want existing", id, got) + } + select { + case msg := <-conn.incoming: + t.Fatalf("duplicate request was published: %v", msg) + default: + } +} + func TestStreamableServerShutdown(t *testing.T) { ctx := context.Background()