From a69f0a51ee7248ef17f6f6904578ad6187448c74 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Thu, 25 Jun 2026 11:39:03 +0200 Subject: [PATCH] Fix MCP malformed response terminal errors --- pkg/dispatcher/internal/processor.go | 16 +++++-- pkg/dispatcher/internal/processor_test.go | 54 ++++++++++++++++++----- 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/pkg/dispatcher/internal/processor.go b/pkg/dispatcher/internal/processor.go index af61529..a2ea71f 100644 --- a/pkg/dispatcher/internal/processor.go +++ b/pkg/dispatcher/internal/processor.go @@ -759,11 +759,13 @@ func (p *mcpProcessor) forwardResponses(ctx context.Context, conn mcpclient.Forw response, ok := msg.(*jsonrpc.Response) if !ok { + err := fmt.Errorf("received non-response message from MCP server: %T", msg) logger.ErrorContext( ctx, "received non-response message from MCP server", append(attrsToArgs(messageSummaryAttrs(msg)), slog.String("type", fmt.Sprintf("%T", msg)))..., ) + postTerminalErrorResponse(err) return } @@ -783,11 +785,19 @@ func (p *mcpProcessor) forwardResponses(ctx context.Context, conn mcpclient.Forw // Responses MUST include the same ID as the request they correspond to. // Notifications MUST NOT include an ID. // streamableClientConn.processStream has similar heuristics comparing req/resp IDs and breaking out - finalResponse := response.ID.IsValid() && response.ID == req.ID - if !finalResponse { - logger.ErrorContext(ctx, "Received response without valid ID") + if !response.ID.IsValid() { + err := errors.New("received response without valid ID from MCP server") + logger.ErrorContext(ctx, "received response without valid ID") + postTerminalErrorResponse(err) + return + } + if response.ID != req.ID { + err := errors.New("received response with mismatched ID from MCP server") + logger.ErrorContext(ctx, "received response with mismatched ID", attrsToArgs(jsonRPCResponseCorrelationAttrs(req, response))...) + postTerminalErrorResponse(err) return } + finalResponse := true // Ensure final JSON-RPC responses present as application/json to the control plane, // even if the upstream server labeled them differently, unless the upstream diff --git a/pkg/dispatcher/internal/processor_test.go b/pkg/dispatcher/internal/processor_test.go index ae8cf8b..eb89611 100644 --- a/pkg/dispatcher/internal/processor_test.go +++ b/pkg/dispatcher/internal/processor_test.go @@ -2219,7 +2219,7 @@ func TestProcessorForwardResponsesClosesConnectionWhenNotificationForwardingFail require.True(t, conn.closed) } -func TestProcessorForwardResponsesStopsOnNonResponseMessage(t *testing.T) { +func TestProcessorForwardResponsesPostsTerminalErrorOnNonResponseMessage(t *testing.T) { t.Parallel() logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelDebug})) @@ -2257,15 +2257,50 @@ func TestProcessorForwardResponsesStopsOnNonResponseMessage(t *testing.T) { } require.NoError(t, processor.Process(context.Background(), cmd)) + assertTerminalJSONRPCErrorResponse(t, responder, cmd, http.StatusBadGateway, "Bad Gateway", "received non-response message from MCP server") +} - select { - case resp := <-responder.responses: - t.Fatalf("unexpected response posted: %+v", resp) - default: +func TestProcessorForwardResponsesPostsTerminalErrorOnInvalidID(t *testing.T) { + t.Parallel() + + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelDebug})) + responder := newRecordingResponder() + + callID, err := jsonrpc.MakeID("invalid-id-call") + require.NoError(t, err) + + transport := &stubForwardingTransport{conn: &scriptedForwardingConnection{ + statusCode: http.StatusOK, + readSteps: []readStep{ + {msg: &jsonrpc.Response{Result: json.RawMessage(`{"ok":true}`)}, err: nil}, + }, + }} + + meterProvider := newTestMeterProvider(t) + processor, err := NewProcessor(processorParams{ + Logger: logger, + ChannelBindings: newTestChannelBindings(transport), + TunnelResponder: responder, + MCPConfig: newTestMCPConfig(t, time.Second), + OAuthHTTPClient: &http.Client{}, + ControlPlaneCfg: newTestControlPlaneConfig(t), + MeterProvider: meterProvider, + }) + require.NoError(t, err) + + cmd := &fakePolledCommand{ + id: types.RequestID("invalid-id-request"), + message: &jsonrpc.Request{ID: callID, Method: "ping"}, + enqueuedAt: time.Now(), + polledAt: time.Now(), + shardToken: "shard-invalid-id", } + + require.NoError(t, processor.Process(context.Background(), cmd)) + assertTerminalJSONRPCErrorResponse(t, responder, cmd, http.StatusBadGateway, "Bad Gateway", "received response without valid ID from MCP server") } -func TestProcessorForwardResponsesStopsOnIDMismatch(t *testing.T) { +func TestProcessorForwardResponsesPostsTerminalErrorOnIDMismatch(t *testing.T) { t.Parallel() logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelDebug})) @@ -2304,12 +2339,7 @@ func TestProcessorForwardResponsesStopsOnIDMismatch(t *testing.T) { } require.NoError(t, processor.Process(context.Background(), cmd)) - - select { - case resp := <-responder.responses: - t.Fatalf("unexpected response posted: %+v", resp) - default: - } + assertTerminalJSONRPCErrorResponse(t, responder, cmd, http.StatusBadGateway, "Bad Gateway", "received response with mismatched ID from MCP server") } func TestProcessorForwardResponsesStopsWhenConnectionTTLReached(t *testing.T) {