Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions pkg/dispatcher/internal/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,11 +759,13 @@ func (p *mcpProcessor) forwardResponses(ctx context.Context, conn mcpclient.Forw

response, ok := msg.(*jsonrpc.Response)
if !ok {
err := fmt.Errorf("received non-response message from MCP server: %T", msg)
logger.ErrorContext(
ctx,
"received non-response message from MCP server",
append(attrsToArgs(messageSummaryAttrs(msg)), slog.String("type", fmt.Sprintf("%T", msg)))...,
)
postTerminalErrorResponse(err)
return
}

Expand All @@ -783,11 +785,19 @@ func (p *mcpProcessor) forwardResponses(ctx context.Context, conn mcpclient.Forw
// Responses MUST include the same ID as the request they correspond to.
// Notifications MUST NOT include an ID.
// streamableClientConn.processStream has similar heuristics comparing req/resp IDs and breaking out
finalResponse := response.ID.IsValid() && response.ID == req.ID
if !finalResponse {
logger.ErrorContext(ctx, "Received response without valid ID")
if !response.ID.IsValid() {
err := errors.New("received response without valid ID from MCP server")
logger.ErrorContext(ctx, "received response without valid ID")
postTerminalErrorResponse(err)
return
}
if response.ID != req.ID {
err := errors.New("received response with mismatched ID from MCP server")
logger.ErrorContext(ctx, "received response with mismatched ID", attrsToArgs(jsonRPCResponseCorrelationAttrs(req, response))...)
postTerminalErrorResponse(err)
return
}
finalResponse := true

// Ensure final JSON-RPC responses present as application/json to the control plane,
// even if the upstream server labeled them differently, unless the upstream
Expand Down
54 changes: 42 additions & 12 deletions pkg/dispatcher/internal/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,7 @@ func TestProcessorForwardResponsesClosesConnectionWhenNotificationForwardingFail
require.True(t, conn.closed)
}

func TestProcessorForwardResponsesStopsOnNonResponseMessage(t *testing.T) {
func TestProcessorForwardResponsesPostsTerminalErrorOnNonResponseMessage(t *testing.T) {
t.Parallel()

logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelDebug}))
Expand Down Expand Up @@ -2257,15 +2257,50 @@ func TestProcessorForwardResponsesStopsOnNonResponseMessage(t *testing.T) {
}

require.NoError(t, processor.Process(context.Background(), cmd))
assertTerminalJSONRPCErrorResponse(t, responder, cmd, http.StatusBadGateway, "Bad Gateway", "received non-response message from MCP server")
}

select {
case resp := <-responder.responses:
t.Fatalf("unexpected response posted: %+v", resp)
default:
func TestProcessorForwardResponsesPostsTerminalErrorOnInvalidID(t *testing.T) {
t.Parallel()

logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelDebug}))
responder := newRecordingResponder()

callID, err := jsonrpc.MakeID("invalid-id-call")
require.NoError(t, err)

transport := &stubForwardingTransport{conn: &scriptedForwardingConnection{
statusCode: http.StatusOK,
readSteps: []readStep{
{msg: &jsonrpc.Response{Result: json.RawMessage(`{"ok":true}`)}, err: nil},
},
}}

meterProvider := newTestMeterProvider(t)
processor, err := NewProcessor(processorParams{
Logger: logger,
ChannelBindings: newTestChannelBindings(transport),
TunnelResponder: responder,
MCPConfig: newTestMCPConfig(t, time.Second),
OAuthHTTPClient: &http.Client{},
ControlPlaneCfg: newTestControlPlaneConfig(t),
MeterProvider: meterProvider,
})
require.NoError(t, err)

cmd := &fakePolledCommand{
id: types.RequestID("invalid-id-request"),
message: &jsonrpc.Request{ID: callID, Method: "ping"},
enqueuedAt: time.Now(),
polledAt: time.Now(),
shardToken: "shard-invalid-id",
}

require.NoError(t, processor.Process(context.Background(), cmd))
assertTerminalJSONRPCErrorResponse(t, responder, cmd, http.StatusBadGateway, "Bad Gateway", "received response without valid ID from MCP server")
}

func TestProcessorForwardResponsesStopsOnIDMismatch(t *testing.T) {
func TestProcessorForwardResponsesPostsTerminalErrorOnIDMismatch(t *testing.T) {
t.Parallel()

logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelDebug}))
Expand Down Expand Up @@ -2304,12 +2339,7 @@ func TestProcessorForwardResponsesStopsOnIDMismatch(t *testing.T) {
}

require.NoError(t, processor.Process(context.Background(), cmd))

select {
case resp := <-responder.responses:
t.Fatalf("unexpected response posted: %+v", resp)
default:
}
assertTerminalJSONRPCErrorResponse(t, responder, cmd, http.StatusBadGateway, "Bad Gateway", "received response with mismatched ID from MCP server")
}

func TestProcessorForwardResponsesStopsWhenConnectionTTLReached(t *testing.T) {
Expand Down