From 3985e55046418a4b72372822fd7e4e47c85220cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E7=BA=AA?= <3049035704@qq.com> Date: Thu, 14 May 2026 02:23:34 +0800 Subject: [PATCH] Fix MCP startup timeout handling --- internal/agent/extensions_runtime.go | 5 +- internal/agent/extensions_runtime_test.go | 28 +++++++++- internal/extensions/mcp/adapter.go | 10 +--- internal/extensions/mcp/adapter_extra_test.go | 15 ++++- internal/extensions/mcp/client.go | 43 ++++++++++++--- internal/extensions/mcp/client_extra_test.go | 55 +++++++++++++++++-- internal/extensions/mcp/client_test.go | 17 ++++++ internal/extensions/mcp/process_tree.go | 18 ++++++ internal/extensions/mcp/process_tree_other.go | 15 +++++ .../extensions/mcp/process_tree_windows.go | 22 ++++++++ 10 files changed, 206 insertions(+), 22 deletions(-) create mode 100644 internal/extensions/mcp/process_tree.go create mode 100644 internal/extensions/mcp/process_tree_other.go create mode 100644 internal/extensions/mcp/process_tree_windows.go diff --git a/internal/agent/extensions_runtime.go b/internal/agent/extensions_runtime.go index ec6c0ac3..0529e24b 100644 --- a/internal/agent/extensions_runtime.go +++ b/internal/agent/extensions_runtime.go @@ -36,7 +36,10 @@ func (r *Runner) syncExtensionTools(ctx context.Context, force bool) error { resolvedTools, resolveErr := resolver.ResolveAllTools(ctx) if resolveErr != nil && (errors.Is(resolveErr, context.Canceled) || errors.Is(resolveErr, context.DeadlineExceeded)) { - return resolveErr + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + resolveErr = nil } nextKeys := map[string]map[string]struct{}{} diff --git a/internal/agent/extensions_runtime_test.go b/internal/agent/extensions_runtime_test.go index 220e79ea..57d313bd 100644 --- a/internal/agent/extensions_runtime_test.go +++ b/internal/agent/extensions_runtime_test.go @@ -3,6 +3,7 @@ package agent import ( "context" "encoding/json" + "errors" "sync" "testing" "time" @@ -14,6 +15,7 @@ import ( type runtimeSyncStubManager struct { items []extensionspkg.ExtensionTool + resolveErr error resolveCount int invalidateCount int resolveStarted chan struct{} @@ -49,7 +51,7 @@ func (m *runtimeSyncStubManager) ResolveAllTools(context.Context) ([]extensionsp } out := make([]extensionspkg.ExtensionTool, len(m.items)) copy(out, m.items) - return out, nil + return out, m.resolveErr } func (m *runtimeSyncStubManager) Invalidate(string) { @@ -197,3 +199,27 @@ func TestSyncExtensionToolsDoesNotHoldLockDuringResolve(t *testing.T) { t.Fatalf("sync failed: %v", err) } } + +func TestSyncExtensionToolsTreatsExtensionTimeoutAsNonFatalWhenCallerActive(t *testing.T) { + registry := toolspkg.DefaultRegistry() + manager := &runtimeSyncStubManager{ + resolveErr: context.DeadlineExceeded, + } + runner := &Runner{ + registry: registry, + extensions: manager, + extensionSyncTTL: time.Minute, + extensionSyncDirty: true, + extensionToolKeys: map[string]map[string]struct{}{}, + } + + if err := runner.syncExtensionTools(context.Background(), true); err != nil { + t.Fatalf("expected extension timeout to be non-fatal while caller context is active, got %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := runner.syncExtensionTools(ctx, true); !errors.Is(err, context.Canceled) { + t.Fatalf("expected canceled caller context to remain fatal, got %v", err) + } +} diff --git a/internal/extensions/mcp/adapter.go b/internal/extensions/mcp/adapter.go index 2b320f14..e46c4d5e 100644 --- a/internal/extensions/mcp/adapter.go +++ b/internal/extensions/mcp/adapter.go @@ -113,7 +113,7 @@ func FromMCPServer(cfg ServerConfig, opts ...Option) (extensionspkg.Extension, e } if options.eagerDiscover { - startupCtx, cancel := withTimeoutIfMissing(context.Background(), cfg.StartupTimeout) + startupCtx, cancel := withTimeoutLimit(context.Background(), cfg.StartupTimeout) defer cancel() _ = adapter.maybeRefresh(startupCtx, true) } @@ -133,7 +133,7 @@ func (a *Adapter) ResolveTools(ctx context.Context) ([]extensionspkg.ExtensionTo if a == nil { return nil, newExtensionError(extensionspkg.ErrCodeInvalidExtension, "mcp adapter is nil", nil) } - if err := a.maybeRefresh(ctx, false); err != nil && contextError(err) != nil { + if err := a.maybeRefresh(ctx, false); err != nil && ctx.Err() != nil && contextError(err) != nil { return nil, err } @@ -408,11 +408,7 @@ func (t mcpTool) Run(ctx context.Context, raw json.RawMessage, _ *toolspkg.Execu } } defer release() - callCtx := ctx - cancel := func() {} - if _, has := ctx.Deadline(); !has && t.server.CallTimeout > 0 { - callCtx, cancel = context.WithTimeout(ctx, t.server.CallTimeout) - } + callCtx, cancel := withTimeoutLimit(ctx, t.server.CallTimeout) defer cancel() output, err := t.client.CallTool(callCtx, t.server, t.descriptor.Name, raw) diff --git a/internal/extensions/mcp/adapter_extra_test.go b/internal/extensions/mcp/adapter_extra_test.go index 67ed541b..9343b3a1 100644 --- a/internal/extensions/mcp/adapter_extra_test.go +++ b/internal/extensions/mcp/adapter_extra_test.go @@ -110,9 +110,20 @@ func TestResolveToolsReturnsContextError(t *testing.T) { t.Fatalf("expected *Adapter type, got %T", ext) } adapter.Invalidate() - _, err = ext.ResolveTools(context.Background()) + tools, err := ext.ResolveTools(context.Background()) + if err != nil { + t.Fatalf("ResolveTools should degrade and continue when only MCP discovery times out, got %v", err) + } + if len(tools) != 0 { + t.Fatalf("expected no tools after discovery timeout, got %#v", tools) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + adapter.Invalidate() + _, err = ext.ResolveTools(ctx) if err == nil { - t.Fatal("expected context error from ResolveTools") + t.Fatal("expected caller context error from ResolveTools") } if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected deadline exceeded, got %v", err) diff --git a/internal/extensions/mcp/client.go b/internal/extensions/mcp/client.go index 25cc1ff4..55e774c2 100644 --- a/internal/extensions/mcp/client.go +++ b/internal/extensions/mcp/client.go @@ -143,7 +143,7 @@ func (c *StdioClient) Discover(ctx context.Context, cfg ServerConfig) (ServerSna return ServerSnapshot{}, err } - callCtx, cancel := withTimeoutIfMissing(ctx, cfg.StartupTimeout) + callCtx, cancel := withTimeoutLimit(ctx, cfg.StartupTimeout) defer cancel() responses, err := c.runWithProtocolFallback(callCtx, cfg, func(protocolVersion string) []rpcRequest { @@ -228,7 +228,7 @@ func (c *StdioClient) CallTool(ctx context.Context, cfg ServerConfig, toolName s } } - callCtx, cancel := withTimeoutIfMissing(ctx, cfg.CallTimeout) + callCtx, cancel := withTimeoutLimit(ctx, cfg.CallTimeout) defer cancel() responses, err := c.runWithProtocolFallback(callCtx, cfg, func(protocolVersion string) []rpcRequest { @@ -284,6 +284,7 @@ func (c *StdioClient) runRPC(ctx context.Context, cfg ServerConfig, requests []r } cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...) + configureCommandCancellation(cmd) if strings.TrimSpace(cfg.CWD) != "" { cmd.Dir = cfg.CWD } @@ -322,8 +323,11 @@ func (c *StdioClient) runRPC(ctx context.Context, cfg ServerConfig, requests []r continue } for { - response, err := readRPCResponse(reader) + response, err := readRPCResponseContext(ctx, reader) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) { + return nil, err + } if errors.Is(err, context.DeadlineExceeded) || errors.Is(ctx.Err(), context.DeadlineExceeded) { return nil, newClientError(ClientErrorTimeout, "mcp request timed out", err) } @@ -433,11 +437,11 @@ func validateServerConfig(cfg ServerConfig, requireCommand bool) error { return nil } -func withTimeoutIfMissing(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { +func withTimeoutLimit(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { if timeout <= 0 { return context.WithCancel(ctx) } - if _, has := ctx.Deadline(); has { + if deadline, has := ctx.Deadline(); has && time.Until(deadline) <= timeout { return context.WithCancel(ctx) } return context.WithTimeout(ctx, timeout) @@ -484,9 +488,12 @@ func stopCommand(cmd *exec.Cmd, stdin io.WriteCloser) { case <-time.After(200 * time.Millisecond): } if cmd.Process != nil { - _ = cmd.Process.Kill() + _ = terminateCommand(cmd) + } + select { + case <-done: + case <-time.After(time.Second): } - <-done } type rpcRequest struct { @@ -639,6 +646,28 @@ func readRPCResponse(reader *bufio.Reader) (rpcResponse, error) { return response, nil } +type rpcReadResult struct { + response rpcResponse + err error +} + +func readRPCResponseContext(ctx context.Context, reader *bufio.Reader) (rpcResponse, error) { + if err := ctx.Err(); err != nil { + return rpcResponse{}, err + } + results := make(chan rpcReadResult, 1) + go func() { + response, err := readRPCResponse(reader) + results <- rpcReadResult{response: response, err: err} + }() + select { + case result := <-results: + return result.response, result.err + case <-ctx.Done(): + return rpcResponse{}, ctx.Err() + } +} + func writeRPCResponse(writer *bufio.Writer, response rpcResponse) error { data, err := json.Marshal(response) if err != nil { diff --git a/internal/extensions/mcp/client_extra_test.go b/internal/extensions/mcp/client_extra_test.go index 7fd2f049..4ca62b46 100644 --- a/internal/extensions/mcp/client_extra_test.go +++ b/internal/extensions/mcp/client_extra_test.go @@ -90,9 +90,9 @@ func TestValidateServerConfig(t *testing.T) { } } -func TestWithTimeoutIfMissing(t *testing.T) { +func TestWithTimeoutLimit(t *testing.T) { ctx := context.Background() - withoutTimeout, cancel := withTimeoutIfMissing(ctx, 0) + withoutTimeout, cancel := withTimeoutLimit(ctx, 0) defer cancel() if _, has := withoutTimeout.Deadline(); has { t.Fatal("did not expect deadline when timeout <= 0") @@ -100,7 +100,7 @@ func TestWithTimeoutIfMissing(t *testing.T) { parent, parentCancel := context.WithTimeout(context.Background(), time.Second) defer parentCancel() - inherited, inheritedCancel := withTimeoutIfMissing(parent, 5*time.Second) + inherited, inheritedCancel := withTimeoutLimit(parent, 5*time.Second) defer inheritedCancel() parentDeadline, _ := parent.Deadline() inheritedDeadline, hasDeadline := inherited.Deadline() @@ -111,7 +111,19 @@ func TestWithTimeoutIfMissing(t *testing.T) { t.Fatalf("expected inherited deadline %v, got %v", parentDeadline, inheritedDeadline) } - withTimeout, withTimeoutCancel := withTimeoutIfMissing(context.Background(), 50*time.Millisecond) + longParent, longParentCancel := context.WithTimeout(context.Background(), time.Hour) + defer longParentCancel() + capped, cappedCancel := withTimeoutLimit(longParent, 50*time.Millisecond) + defer cappedCancel() + cappedDeadline, hasCappedDeadline := capped.Deadline() + if !hasCappedDeadline { + t.Fatal("expected capped deadline") + } + if remaining := time.Until(cappedDeadline); remaining <= 0 || remaining > time.Second { + t.Fatalf("expected deadline near timeout, got remaining=%v", remaining) + } + + withTimeout, withTimeoutCancel := withTimeoutLimit(context.Background(), 50*time.Millisecond) defer withTimeoutCancel() if _, has := withTimeout.Deadline(); !has { t.Fatal("expected deadline when timeout is set") @@ -364,6 +376,41 @@ func TestCallToolErrorAndFallbackPaths(t *testing.T) { assertClientErrorCode(t, err, ClientErrorTimeout) } +func TestDiscoverUsesStartupTimeoutWhenParentDeadlineIsLonger(t *testing.T) { + client := NewStdioClient() + cfg := helperServerConfig(t, "sleep") + cfg.StartupTimeout = 20 * time.Millisecond + + parent, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + started := time.Now() + _, err := client.Discover(parent, cfg) + elapsed := time.Since(started) + + assertClientErrorCode(t, err, ClientErrorTimeout) + if elapsed > 2*time.Second { + t.Fatalf("expected startup timeout to cap long parent deadline, elapsed=%v", elapsed) + } +} + +func TestRunRPCTimeoutDoesNotWaitForStdoutEOF(t *testing.T) { + client := NewStdioClient() + cfg := helperServerConfig(t, "hold_stdout_child") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + + started := time.Now() + _, err := client.runRPC(ctx, cfg, []rpcRequest{ + newRPCRequest(1, "initialize", initializeParams(defaultProtocolVersions[0])), + }) + elapsed := time.Since(started) + + assertClientErrorCode(t, err, ClientErrorTimeout) + if elapsed > time.Second { + t.Fatalf("expected context timeout without waiting for stdout EOF, elapsed=%v", elapsed) + } +} + func TestRunRPCZeroRequestAndStopCommandBranches(t *testing.T) { client := NewStdioClient() responses, err := client.runRPC(context.Background(), ServerConfig{}, nil) diff --git a/internal/extensions/mcp/client_test.go b/internal/extensions/mcp/client_test.go index 948b4f99..916bda88 100644 --- a/internal/extensions/mcp/client_test.go +++ b/internal/extensions/mcp/client_test.go @@ -7,6 +7,7 @@ import ( "errors" "io" "os" + "os/exec" "strconv" "strings" "testing" @@ -123,10 +124,26 @@ func TestStdioClientCallToolRequiresInitializedNotification(t *testing.T) { } func TestMCPHelperProcess(t *testing.T) { + if os.Getenv("BYTEMIND_MCP_HOLD_STDOUT_CHILD") == "1" { + time.Sleep(2 * time.Second) + os.Exit(0) + } if os.Getenv("BYTEMIND_MCP_HELPER") != "1" { return } scenario := strings.TrimSpace(os.Getenv("BYTEMIND_MCP_SCENARIO")) + if scenario == "hold_stdout_child" { + exe, err := os.Executable() + if err == nil { + child := exec.Command(exe, "-test.run=^TestMCPHelperProcess$") + child.Env = append(os.Environ(), "BYTEMIND_MCP_HOLD_STDOUT_CHILD=1") + child.Stdout = os.Stdout + child.Stderr = os.Stderr + _ = child.Start() + } + time.Sleep(10 * time.Second) + os.Exit(0) + } if scenario == "eof_with_stderr" { _, _ = os.Stderr.WriteString("helper exited early") os.Exit(0) diff --git a/internal/extensions/mcp/process_tree.go b/internal/extensions/mcp/process_tree.go new file mode 100644 index 00000000..9cd3f75a --- /dev/null +++ b/internal/extensions/mcp/process_tree.go @@ -0,0 +1,18 @@ +package mcp + +import ( + "os/exec" + "time" +) + +const commandWaitDelay = 500 * time.Millisecond + +func configureCommandCancellation(cmd *exec.Cmd) { + if cmd == nil { + return + } + cmd.Cancel = func() error { + return terminateCommand(cmd) + } + cmd.WaitDelay = commandWaitDelay +} diff --git a/internal/extensions/mcp/process_tree_other.go b/internal/extensions/mcp/process_tree_other.go new file mode 100644 index 00000000..541b666a --- /dev/null +++ b/internal/extensions/mcp/process_tree_other.go @@ -0,0 +1,15 @@ +//go:build !windows + +package mcp + +import ( + "os" + "os/exec" +) + +func terminateCommand(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return os.ErrProcessDone + } + return cmd.Process.Kill() +} diff --git a/internal/extensions/mcp/process_tree_windows.go b/internal/extensions/mcp/process_tree_windows.go new file mode 100644 index 00000000..a0493233 --- /dev/null +++ b/internal/extensions/mcp/process_tree_windows.go @@ -0,0 +1,22 @@ +//go:build windows + +package mcp + +import ( + "os" + "os/exec" + "strconv" +) + +func terminateCommand(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return os.ErrProcessDone + } + if cmd.Process.Pid <= 0 { + return cmd.Process.Kill() + } + if err := exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid)).Run(); err != nil { + return cmd.Process.Kill() + } + return nil +}