From 7f79ba5c34fd98ccd7b60f47a7f84196ee177b61 Mon Sep 17 00:00:00 2001 From: townwish Date: Fri, 27 Mar 2026 15:01:30 +0800 Subject: [PATCH] fix: backfill interrupted tool results on Ctrl+C --- agent/loop/engine.go | 46 +++++- agent/loop/engine_persistence_test.go | 200 ++++++++++++++++++++++++++ runtime/shell/runner.go | 5 + tools/shell/shell.go | 4 + tools/shell/shell_test.go | 23 +++ 5 files changed, 275 insertions(+), 3 deletions(-) diff --git a/agent/loop/engine.go b/agent/loop/engine.go index e91fbee..3a70ebc 100644 --- a/agent/loop/engine.go +++ b/agent/loop/engine.go @@ -349,8 +349,13 @@ func (ex *executor) handleResponse(ctx context.Context, resp *llm.CompletionResp } if len(resp.ToolCalls) > 0 { - for _, tc := range resp.ToolCalls { + for i, tc := range resp.ToolCalls { if err := ex.executeToolCall(ctx, tc); err != nil { + if errors.Is(err, context.Canceled) { + if backfillErr := ex.addInterruptedToolResults(resp.ToolCalls[i+1:]); backfillErr != nil { + return false, backfillErr + } + } return false, err } } @@ -384,6 +389,9 @@ func (ex *executor) executeToolCall(ctx context.Context, tc llm.ToolCall) error path := extractPathArg(tc.Function.Arguments) granted, err := ex.engine.permission.Request(ctx, toolName, action, path) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) { + return ex.addInterruptedToolResult(tc.ID) + } return err } if !granted { @@ -402,7 +410,7 @@ func (ex *executor) executeToolCall(ctx context.Context, tc llm.ToolCall) error result, err := tool.Execute(ctx, tc.Function.Arguments) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) { - return context.Canceled + return ex.addInterruptedToolResult(tc.ID) } errMsg := fmt.Sprintf("Tool execution error: %v", err) if err := ex.addToolResultWithFallback(tc.ID, errMsg); err != nil { @@ -417,7 +425,7 @@ func (ex *executor) executeToolCall(ctx context.Context, tc llm.ToolCall) error if result.Error != nil { if errors.Is(result.Error, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) { - return context.Canceled + return ex.addInterruptedToolResult(tc.ID) } errMsg := result.Error.Error() if err := ex.addToolResultWithFallback(tc.ID, errMsg); err != nil { @@ -447,6 +455,38 @@ func (ex *executor) executeToolCall(ctx context.Context, tc llm.ToolCall) error return nil } +func (ex *executor) addInterruptedToolResult(callID string) error { + const interruptedToolResult = "tool execution interrupted by user" + if err := ex.addToolResultWithFallback(callID, interruptedToolResult); err != nil { + return err + } + if err := ex.persistSnapshot(); err != nil { + return err + } + ex.addEvent(NewEvent(EventToolError, interruptedToolResult)) + return context.Canceled +} + +func (ex *executor) addInterruptedToolResults(calls []llm.ToolCall) error { + if len(calls) == 0 { + return nil + } + + const interruptedToolResult = "tool execution interrupted by user before execution" + for _, tc := range calls { + if strings.TrimSpace(tc.ID) == "" { + continue + } + if err := ex.addToolResultWithFallback(tc.ID, interruptedToolResult); err != nil { + return err + } + } + if err := ex.persistSnapshot(); err != nil { + return err + } + return nil +} + var toolEventMap = map[string]string{ "read": EventToolRead, "grep": EventToolGrep, diff --git a/agent/loop/engine_persistence_test.go b/agent/loop/engine_persistence_test.go index 0f4523a..8c22d05 100644 --- a/agent/loop/engine_persistence_test.go +++ b/agent/loop/engine_persistence_test.go @@ -3,9 +3,12 @@ package loop import ( "context" "encoding/json" + "errors" "io" + "strings" "sync" "testing" + "time" "github.com/vigo999/ms-cli/integrations/llm" "github.com/vigo999/ms-cli/tools" @@ -278,3 +281,200 @@ func TestRunPersistsToolErrorBeforeErrorRender(t *testing.T) { requireOrder(t, log, "tool_call:missing_tool", "snapshot:tool_call:missing_tool", "ui:ToolCallStart") requireOrder(t, log, "tool_result:missing_tool", "snapshot:tool_result:missing_tool", "ui:ToolError") } + +type blockingCancelTool struct { + name string + started chan struct{} +} + +func (t blockingCancelTool) Name() string { + return t.name +} + +func (t blockingCancelTool) Description() string { + return "blocking cancel tool" +} + +func (t blockingCancelTool) Schema() llm.ToolSchema { + return llm.ToolSchema{Type: "object"} +} + +func (t blockingCancelTool) Execute(ctx context.Context, _ json.RawMessage) (*tools.Result, error) { + select { + case <-t.started: + default: + close(t.started) + } + <-ctx.Done() + return nil, ctx.Err() +} + +func TestRunAddsInterruptedToolResultOnCanceledToolCall(t *testing.T) { + args, err := json.Marshal(map[string]string{"path": "README.md"}) + if err != nil { + t.Fatalf("marshal tool args: %v", err) + } + + provider := &scriptedStreamProvider{ + responses: []*llm.CompletionResponse{{ + ToolCalls: []llm.ToolCall{{ + ID: "call-cancel-1", + Type: "function", + Function: llm.ToolCallFunc{ + Name: "read", + Arguments: args, + }, + }}, + FinishReason: llm.FinishToolCalls, + }}, + } + + started := make(chan struct{}) + registry := tools.NewRegistry() + registry.MustRegister(blockingCancelTool{name: "read", started: started}) + + engine := NewEngine(EngineConfig{ + MaxIterations: 2, + ContextWindow: 4096, + }, provider, registry) + + ctx, cancel := context.WithCancel(context.Background()) + var events []Event + done := make(chan error, 1) + go func() { + done <- engine.RunWithContextStream(ctx, Task{ + ID: "cancel-tool", + Description: "read file", + }, func(ev Event) { + events = append(events, ev) + }) + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for tool execution to start") + } + cancel() + + select { + case runErr := <-done: + if !errors.Is(runErr, context.Canceled) { + t.Fatalf("RunWithContextStream error = %v, want context canceled", runErr) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for cancellation") + } + + msgs := engine.ctxManager.GetNonSystemMessages() + if len(msgs) < 3 { + t.Fatalf("expected at least 3 non-system messages, got %d", len(msgs)) + } + + last := msgs[len(msgs)-1] + if last.Role != "tool" { + t.Fatalf("last role = %q, want tool", last.Role) + } + if last.ToolCallID != "call-cancel-1" { + t.Fatalf("last tool_call_id = %q, want %q", last.ToolCallID, "call-cancel-1") + } + if !strings.Contains(strings.ToLower(last.Content), "interrupted") { + t.Fatalf("last tool result = %q, want interrupted marker", last.Content) + } + + foundInterruptedEvent := false + for _, ev := range events { + if ev.Type == EventToolError && strings.Contains(strings.ToLower(ev.Message), "interrupted by user") { + foundInterruptedEvent = true + break + } + } + if !foundInterruptedEvent { + t.Fatalf("expected interrupted tool error event, got events: %#v", events) + } +} + +func TestRunBackfillsInterruptedResultsForUnexecutedToolCalls(t *testing.T) { + args, err := json.Marshal(map[string]string{"path": "README.md"}) + if err != nil { + t.Fatalf("marshal tool args: %v", err) + } + + provider := &scriptedStreamProvider{ + responses: []*llm.CompletionResponse{{ + ToolCalls: []llm.ToolCall{ + { + ID: "call-cancel-1", + Type: "function", + Function: llm.ToolCallFunc{ + Name: "read", + Arguments: args, + }, + }, + { + ID: "call-cancel-2", + Type: "function", + Function: llm.ToolCallFunc{ + Name: "grep", + Arguments: args, + }, + }, + }, + FinishReason: llm.FinishToolCalls, + }}, + } + + started := make(chan struct{}) + registry := tools.NewRegistry() + registry.MustRegister(blockingCancelTool{name: "read", started: started}) + registry.MustRegister(stubTool{name: "grep", content: "should-not-run", summary: "1 line"}) + + engine := NewEngine(EngineConfig{ + MaxIterations: 2, + ContextWindow: 4096, + }, provider, registry) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + done <- engine.RunWithContextStream(ctx, Task{ + ID: "cancel-two-tools", + Description: "read and grep", + }, nil) + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first tool execution to start") + } + cancel() + + select { + case runErr := <-done: + if !errors.Is(runErr, context.Canceled) { + t.Fatalf("RunWithContextStream error = %v, want context canceled", runErr) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for cancellation") + } + + msgs := engine.ctxManager.GetNonSystemMessages() + toolResultsByID := map[string]string{} + for _, msg := range msgs { + if msg.Role != "tool" { + continue + } + toolResultsByID[msg.ToolCallID] = msg.Content + } + + for _, id := range []string{"call-cancel-1", "call-cancel-2"} { + content, ok := toolResultsByID[id] + if !ok { + t.Fatalf("missing tool result for %s in messages: %#v", id, msgs) + } + if !strings.Contains(strings.ToLower(content), "interrupted") { + t.Fatalf("tool result for %s = %q, want interrupted marker", id, content) + } + } +} diff --git a/runtime/shell/runner.go b/runtime/shell/runner.go index 651b065..bfe4548 100644 --- a/runtime/shell/runner.go +++ b/runtime/shell/runner.go @@ -132,6 +132,11 @@ func (r *Runner) Run(ctx context.Context, command string) (*Result, error) { result.Error = err } } + if result.Error == nil { + if ctxErr := ctx.Err(); ctxErr != nil { + result.Error = ctxErr + } + } return result, nil } diff --git a/tools/shell/shell.go b/tools/shell/shell.go index fb6e512..b29be34 100644 --- a/tools/shell/shell.go +++ b/tools/shell/shell.go @@ -5,6 +5,7 @@ package shell import ( "context" "encoding/json" + "errors" "fmt" "strings" "time" @@ -99,6 +100,9 @@ func (t *ShellTool) Execute(ctx context.Context, params json.RawMessage) (*tools summary = fmt.Sprintf("exit %d", result.ExitCode) } if result.Error != nil { + if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, context.DeadlineExceeded) { + return nil, result.Error + } summary = fmt.Sprintf("error: %s", result.Error.Error()) } diff --git a/tools/shell/shell_test.go b/tools/shell/shell_test.go index df15100..9cad74c 100644 --- a/tools/shell/shell_test.go +++ b/tools/shell/shell_test.go @@ -2,6 +2,7 @@ package shell import ( "context" + "errors" "strings" "testing" "time" @@ -34,3 +35,25 @@ func TestShellToolExecute_DoesNotDuplicateCommandOrExit0InContent(t *testing.T) t.Fatalf("expected summary not to be 'exit 0'") } } + +func TestShellToolExecute_ReturnsContextCanceledOnInterrupt(t *testing.T) { + runner := rshell.NewRunner(rshell.Config{ + WorkDir: ".", + Timeout: 10 * time.Second, + }) + tool := NewShellTool(runner) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + result, err := tool.Execute(ctx, []byte(`{"command":"sleep 100","timeout":120}`)) + if !errors.Is(err, context.Canceled) { + t.Fatalf("execute shell tool error = %v, want context canceled", err) + } + if result != nil { + t.Fatalf("expected nil result on cancel, got %#v", result) + } +}