Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion internal/agent/extensions_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ func (r *Runner) syncExtensionTools(ctx context.Context, force bool) error {

resolvedTools, resolveErr := resolver.ResolveAllTools(ctx)
if resolveErr != nil && (errors.Is(resolveErr, context.Canceled) || errors.Is(resolveErr, context.DeadlineExceeded)) {
return resolveErr
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
resolveErr = nil
}

nextKeys := map[string]map[string]struct{}{}
Expand Down
28 changes: 27 additions & 1 deletion internal/agent/extensions_runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"sync"
"testing"
"time"
Expand All @@ -14,6 +15,7 @@ import (

type runtimeSyncStubManager struct {
items []extensionspkg.ExtensionTool
resolveErr error
resolveCount int
invalidateCount int
resolveStarted chan struct{}
Expand Down Expand Up @@ -49,7 +51,7 @@ func (m *runtimeSyncStubManager) ResolveAllTools(context.Context) ([]extensionsp
}
out := make([]extensionspkg.ExtensionTool, len(m.items))
copy(out, m.items)
return out, nil
return out, m.resolveErr
}

func (m *runtimeSyncStubManager) Invalidate(string) {
Expand Down Expand Up @@ -197,3 +199,27 @@ func TestSyncExtensionToolsDoesNotHoldLockDuringResolve(t *testing.T) {
t.Fatalf("sync failed: %v", err)
}
}

func TestSyncExtensionToolsTreatsExtensionTimeoutAsNonFatalWhenCallerActive(t *testing.T) {
registry := toolspkg.DefaultRegistry()
manager := &runtimeSyncStubManager{
resolveErr: context.DeadlineExceeded,
}
runner := &Runner{
registry: registry,
extensions: manager,
extensionSyncTTL: time.Minute,
extensionSyncDirty: true,
extensionToolKeys: map[string]map[string]struct{}{},
}

if err := runner.syncExtensionTools(context.Background(), true); err != nil {
t.Fatalf("expected extension timeout to be non-fatal while caller context is active, got %v", err)
}

ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := runner.syncExtensionTools(ctx, true); !errors.Is(err, context.Canceled) {
t.Fatalf("expected canceled caller context to remain fatal, got %v", err)
}
}
10 changes: 3 additions & 7 deletions internal/extensions/mcp/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func FromMCPServer(cfg ServerConfig, opts ...Option) (extensionspkg.Extension, e
}

if options.eagerDiscover {
startupCtx, cancel := withTimeoutIfMissing(context.Background(), cfg.StartupTimeout)
startupCtx, cancel := withTimeoutLimit(context.Background(), cfg.StartupTimeout)
defer cancel()
_ = adapter.maybeRefresh(startupCtx, true)
}
Expand All @@ -133,7 +133,7 @@ func (a *Adapter) ResolveTools(ctx context.Context) ([]extensionspkg.ExtensionTo
if a == nil {
return nil, newExtensionError(extensionspkg.ErrCodeInvalidExtension, "mcp adapter is nil", nil)
}
if err := a.maybeRefresh(ctx, false); err != nil && contextError(err) != nil {
if err := a.maybeRefresh(ctx, false); err != nil && ctx.Err() != nil && contextError(err) != nil {
return nil, err
}

Expand Down Expand Up @@ -408,11 +408,7 @@ func (t mcpTool) Run(ctx context.Context, raw json.RawMessage, _ *toolspkg.Execu
}
}
defer release()
callCtx := ctx
cancel := func() {}
if _, has := ctx.Deadline(); !has && t.server.CallTimeout > 0 {
callCtx, cancel = context.WithTimeout(ctx, t.server.CallTimeout)
}
callCtx, cancel := withTimeoutLimit(ctx, t.server.CallTimeout)
defer cancel()

output, err := t.client.CallTool(callCtx, t.server, t.descriptor.Name, raw)
Expand Down
15 changes: 13 additions & 2 deletions internal/extensions/mcp/adapter_extra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,20 @@ func TestResolveToolsReturnsContextError(t *testing.T) {
t.Fatalf("expected *Adapter type, got %T", ext)
}
adapter.Invalidate()
_, err = ext.ResolveTools(context.Background())
tools, err := ext.ResolveTools(context.Background())
if err != nil {
t.Fatalf("ResolveTools should degrade and continue when only MCP discovery times out, got %v", err)
}
if len(tools) != 0 {
t.Fatalf("expected no tools after discovery timeout, got %#v", tools)
}

ctx, cancel := context.WithCancel(context.Background())
cancel()
adapter.Invalidate()
_, err = ext.ResolveTools(ctx)
if err == nil {
t.Fatal("expected context error from ResolveTools")
t.Fatal("expected caller context error from ResolveTools")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
Expand Down
43 changes: 36 additions & 7 deletions internal/extensions/mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (c *StdioClient) Discover(ctx context.Context, cfg ServerConfig) (ServerSna
return ServerSnapshot{}, err
}

callCtx, cancel := withTimeoutIfMissing(ctx, cfg.StartupTimeout)
callCtx, cancel := withTimeoutLimit(ctx, cfg.StartupTimeout)
defer cancel()

responses, err := c.runWithProtocolFallback(callCtx, cfg, func(protocolVersion string) []rpcRequest {
Expand Down Expand Up @@ -228,7 +228,7 @@ func (c *StdioClient) CallTool(ctx context.Context, cfg ServerConfig, toolName s
}
}

callCtx, cancel := withTimeoutIfMissing(ctx, cfg.CallTimeout)
callCtx, cancel := withTimeoutLimit(ctx, cfg.CallTimeout)
defer cancel()

responses, err := c.runWithProtocolFallback(callCtx, cfg, func(protocolVersion string) []rpcRequest {
Expand Down Expand Up @@ -284,6 +284,7 @@ func (c *StdioClient) runRPC(ctx context.Context, cfg ServerConfig, requests []r
}

cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
configureCommandCancellation(cmd)
if strings.TrimSpace(cfg.CWD) != "" {
cmd.Dir = cfg.CWD
}
Expand Down Expand Up @@ -322,8 +323,11 @@ func (c *StdioClient) runRPC(ctx context.Context, cfg ServerConfig, requests []r
continue
}
for {
response, err := readRPCResponse(reader)
response, err := readRPCResponseContext(ctx, reader)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) {
return nil, err
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
return nil, newClientError(ClientErrorTimeout, "mcp request timed out", err)
}
Expand Down Expand Up @@ -433,11 +437,11 @@ func validateServerConfig(cfg ServerConfig, requireCommand bool) error {
return nil
}

func withTimeoutIfMissing(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
func withTimeoutLimit(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if timeout <= 0 {
return context.WithCancel(ctx)
}
if _, has := ctx.Deadline(); has {
if deadline, has := ctx.Deadline(); has && time.Until(deadline) <= timeout {
return context.WithCancel(ctx)
}
return context.WithTimeout(ctx, timeout)
Expand Down Expand Up @@ -484,9 +488,12 @@ func stopCommand(cmd *exec.Cmd, stdin io.WriteCloser) {
case <-time.After(200 * time.Millisecond):
}
if cmd.Process != nil {
_ = cmd.Process.Kill()
_ = terminateCommand(cmd)
}
select {
case <-done:
case <-time.After(time.Second):
}
<-done
}

type rpcRequest struct {
Expand Down Expand Up @@ -639,6 +646,28 @@ func readRPCResponse(reader *bufio.Reader) (rpcResponse, error) {
return response, nil
}

type rpcReadResult struct {
response rpcResponse
err error
}

func readRPCResponseContext(ctx context.Context, reader *bufio.Reader) (rpcResponse, error) {
if err := ctx.Err(); err != nil {
return rpcResponse{}, err
}
results := make(chan rpcReadResult, 1)
go func() {
response, err := readRPCResponse(reader)
results <- rpcReadResult{response: response, err: err}
}()
select {
case result := <-results:
return result.response, result.err
case <-ctx.Done():
return rpcResponse{}, ctx.Err()
}
}

func writeRPCResponse(writer *bufio.Writer, response rpcResponse) error {
data, err := json.Marshal(response)
if err != nil {
Expand Down
55 changes: 51 additions & 4 deletions internal/extensions/mcp/client_extra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,17 @@ func TestValidateServerConfig(t *testing.T) {
}
}

func TestWithTimeoutIfMissing(t *testing.T) {
func TestWithTimeoutLimit(t *testing.T) {
ctx := context.Background()
withoutTimeout, cancel := withTimeoutIfMissing(ctx, 0)
withoutTimeout, cancel := withTimeoutLimit(ctx, 0)
defer cancel()
if _, has := withoutTimeout.Deadline(); has {
t.Fatal("did not expect deadline when timeout <= 0")
}

parent, parentCancel := context.WithTimeout(context.Background(), time.Second)
defer parentCancel()
inherited, inheritedCancel := withTimeoutIfMissing(parent, 5*time.Second)
inherited, inheritedCancel := withTimeoutLimit(parent, 5*time.Second)
defer inheritedCancel()
parentDeadline, _ := parent.Deadline()
inheritedDeadline, hasDeadline := inherited.Deadline()
Expand All @@ -111,7 +111,19 @@ func TestWithTimeoutIfMissing(t *testing.T) {
t.Fatalf("expected inherited deadline %v, got %v", parentDeadline, inheritedDeadline)
}

withTimeout, withTimeoutCancel := withTimeoutIfMissing(context.Background(), 50*time.Millisecond)
longParent, longParentCancel := context.WithTimeout(context.Background(), time.Hour)
defer longParentCancel()
capped, cappedCancel := withTimeoutLimit(longParent, 50*time.Millisecond)
defer cappedCancel()
cappedDeadline, hasCappedDeadline := capped.Deadline()
if !hasCappedDeadline {
t.Fatal("expected capped deadline")
}
if remaining := time.Until(cappedDeadline); remaining <= 0 || remaining > time.Second {
t.Fatalf("expected deadline near timeout, got remaining=%v", remaining)
}

withTimeout, withTimeoutCancel := withTimeoutLimit(context.Background(), 50*time.Millisecond)
defer withTimeoutCancel()
if _, has := withTimeout.Deadline(); !has {
t.Fatal("expected deadline when timeout is set")
Expand Down Expand Up @@ -364,6 +376,41 @@ func TestCallToolErrorAndFallbackPaths(t *testing.T) {
assertClientErrorCode(t, err, ClientErrorTimeout)
}

func TestDiscoverUsesStartupTimeoutWhenParentDeadlineIsLonger(t *testing.T) {
client := NewStdioClient()
cfg := helperServerConfig(t, "sleep")
cfg.StartupTimeout = 20 * time.Millisecond

parent, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
started := time.Now()
_, err := client.Discover(parent, cfg)
elapsed := time.Since(started)

assertClientErrorCode(t, err, ClientErrorTimeout)
if elapsed > 2*time.Second {
t.Fatalf("expected startup timeout to cap long parent deadline, elapsed=%v", elapsed)
}
}

func TestRunRPCTimeoutDoesNotWaitForStdoutEOF(t *testing.T) {
client := NewStdioClient()
cfg := helperServerConfig(t, "hold_stdout_child")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()

started := time.Now()
_, err := client.runRPC(ctx, cfg, []rpcRequest{
newRPCRequest(1, "initialize", initializeParams(defaultProtocolVersions[0])),
})
elapsed := time.Since(started)

assertClientErrorCode(t, err, ClientErrorTimeout)
if elapsed > time.Second {
t.Fatalf("expected context timeout without waiting for stdout EOF, elapsed=%v", elapsed)
}
}

func TestRunRPCZeroRequestAndStopCommandBranches(t *testing.T) {
client := NewStdioClient()
responses, err := client.runRPC(context.Background(), ServerConfig{}, nil)
Expand Down
17 changes: 17 additions & 0 deletions internal/extensions/mcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"io"
"os"
"os/exec"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -123,10 +124,26 @@ func TestStdioClientCallToolRequiresInitializedNotification(t *testing.T) {
}

func TestMCPHelperProcess(t *testing.T) {
if os.Getenv("BYTEMIND_MCP_HOLD_STDOUT_CHILD") == "1" {
time.Sleep(2 * time.Second)
os.Exit(0)
}
if os.Getenv("BYTEMIND_MCP_HELPER") != "1" {
return
}
scenario := strings.TrimSpace(os.Getenv("BYTEMIND_MCP_SCENARIO"))
if scenario == "hold_stdout_child" {
exe, err := os.Executable()
if err == nil {
child := exec.Command(exe, "-test.run=^TestMCPHelperProcess$")
child.Env = append(os.Environ(), "BYTEMIND_MCP_HOLD_STDOUT_CHILD=1")
child.Stdout = os.Stdout
child.Stderr = os.Stderr
_ = child.Start()
}
time.Sleep(10 * time.Second)
os.Exit(0)
}
if scenario == "eof_with_stderr" {
_, _ = os.Stderr.WriteString("helper exited early")
os.Exit(0)
Expand Down
18 changes: 18 additions & 0 deletions internal/extensions/mcp/process_tree.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package mcp

import (
"os/exec"
"time"
)

const commandWaitDelay = 500 * time.Millisecond

func configureCommandCancellation(cmd *exec.Cmd) {
if cmd == nil {
return
}
cmd.Cancel = func() error {
return terminateCommand(cmd)
}
cmd.WaitDelay = commandWaitDelay
}
15 changes: 15 additions & 0 deletions internal/extensions/mcp/process_tree_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//go:build !windows

package mcp

import (
"os"
"os/exec"
)

func terminateCommand(cmd *exec.Cmd) error {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Medium: on non-Windows we still only kill the direct child process here. If an MCP launcher spawns a helper that inherits stdout (the same shape of failure this PR is addressing), readRPCResponseContext will return on ctx.Done(), but the blocked reader goroutine and the descendant process will both stay alive until that helper exits on its own. Repeated timeouts can therefore accumulate leaked goroutines/processes on Unix-like hosts. Consider launching the command in its own process group and terminating the whole group here, not just cmd.Process.

if cmd == nil || cmd.Process == nil {
return os.ErrProcessDone
}
return cmd.Process.Kill()
}
22 changes: 22 additions & 0 deletions internal/extensions/mcp/process_tree_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//go:build windows

package mcp

import (
"os"
"os/exec"
"strconv"
)

func terminateCommand(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil {
return os.ErrProcessDone
}
if cmd.Process.Pid <= 0 {
return cmd.Process.Kill()
}
if err := exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid)).Run(); err != nil {
return cmd.Process.Kill()
}
return nil
}
Loading