Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions agent/loop/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,13 @@ func (ex *executor) handleResponse(ctx context.Context, resp *llm.CompletionResp
}

if len(resp.ToolCalls) > 0 {
for _, tc := range resp.ToolCalls {
for i, tc := range resp.ToolCalls {
if err := ex.executeToolCall(ctx, tc); err != nil {
if errors.Is(err, context.Canceled) {
if backfillErr := ex.addInterruptedToolResults(resp.ToolCalls[i+1:]); backfillErr != nil {
return false, backfillErr
}
}
return false, err
}
}
Expand Down Expand Up @@ -384,6 +389,9 @@ func (ex *executor) executeToolCall(ctx context.Context, tc llm.ToolCall) error
path := extractPathArg(tc.Function.Arguments)
granted, err := ex.engine.permission.Request(ctx, toolName, action, path)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) {
return ex.addInterruptedToolResult(tc.ID)
}
return err
}
if !granted {
Expand All @@ -402,7 +410,7 @@ func (ex *executor) executeToolCall(ctx context.Context, tc llm.ToolCall) error
result, err := tool.Execute(ctx, tc.Function.Arguments)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) {
return context.Canceled
return ex.addInterruptedToolResult(tc.ID)
}
errMsg := fmt.Sprintf("Tool execution error: %v", err)
if err := ex.addToolResultWithFallback(tc.ID, errMsg); err != nil {
Expand All @@ -417,7 +425,7 @@ func (ex *executor) executeToolCall(ctx context.Context, tc llm.ToolCall) error

if result.Error != nil {
if errors.Is(result.Error, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) {
return context.Canceled
return ex.addInterruptedToolResult(tc.ID)
}
errMsg := result.Error.Error()
if err := ex.addToolResultWithFallback(tc.ID, errMsg); err != nil {
Expand Down Expand Up @@ -447,6 +455,38 @@ func (ex *executor) executeToolCall(ctx context.Context, tc llm.ToolCall) error
return nil
}

func (ex *executor) addInterruptedToolResult(callID string) error {
const interruptedToolResult = "tool execution interrupted by user"
if err := ex.addToolResultWithFallback(callID, interruptedToolResult); err != nil {
return err
}
if err := ex.persistSnapshot(); err != nil {
return err
}
ex.addEvent(NewEvent(EventToolError, interruptedToolResult))
return context.Canceled
}

func (ex *executor) addInterruptedToolResults(calls []llm.ToolCall) error {
if len(calls) == 0 {
return nil
}

const interruptedToolResult = "tool execution interrupted by user before execution"
for _, tc := range calls {
if strings.TrimSpace(tc.ID) == "" {
continue
}
if err := ex.addToolResultWithFallback(tc.ID, interruptedToolResult); err != nil {
return err
}
}
if err := ex.persistSnapshot(); err != nil {
return err
}
return nil
}

var toolEventMap = map[string]string{
"read": EventToolRead,
"grep": EventToolGrep,
Expand Down
200 changes: 200 additions & 0 deletions agent/loop/engine_persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package loop
import (
"context"
"encoding/json"
"errors"
"io"
"strings"
"sync"
"testing"
"time"

"github.com/vigo999/ms-cli/integrations/llm"
"github.com/vigo999/ms-cli/tools"
Expand Down Expand Up @@ -278,3 +281,200 @@ func TestRunPersistsToolErrorBeforeErrorRender(t *testing.T) {
requireOrder(t, log, "tool_call:missing_tool", "snapshot:tool_call:missing_tool", "ui:ToolCallStart")
requireOrder(t, log, "tool_result:missing_tool", "snapshot:tool_result:missing_tool", "ui:ToolError")
}

type blockingCancelTool struct {
name string
started chan struct{}
}

func (t blockingCancelTool) Name() string {
return t.name
}

func (t blockingCancelTool) Description() string {
return "blocking cancel tool"
}

func (t blockingCancelTool) Schema() llm.ToolSchema {
return llm.ToolSchema{Type: "object"}
}

func (t blockingCancelTool) Execute(ctx context.Context, _ json.RawMessage) (*tools.Result, error) {
select {
case <-t.started:
default:
close(t.started)
}
<-ctx.Done()
return nil, ctx.Err()
}

func TestRunAddsInterruptedToolResultOnCanceledToolCall(t *testing.T) {
args, err := json.Marshal(map[string]string{"path": "README.md"})
if err != nil {
t.Fatalf("marshal tool args: %v", err)
}

provider := &scriptedStreamProvider{
responses: []*llm.CompletionResponse{{
ToolCalls: []llm.ToolCall{{
ID: "call-cancel-1",
Type: "function",
Function: llm.ToolCallFunc{
Name: "read",
Arguments: args,
},
}},
FinishReason: llm.FinishToolCalls,
}},
}

started := make(chan struct{})
registry := tools.NewRegistry()
registry.MustRegister(blockingCancelTool{name: "read", started: started})

engine := NewEngine(EngineConfig{
MaxIterations: 2,
ContextWindow: 4096,
}, provider, registry)

ctx, cancel := context.WithCancel(context.Background())
var events []Event
done := make(chan error, 1)
go func() {
done <- engine.RunWithContextStream(ctx, Task{
ID: "cancel-tool",
Description: "read file",
}, func(ev Event) {
events = append(events, ev)
})
}()

select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for tool execution to start")
}
cancel()

select {
case runErr := <-done:
if !errors.Is(runErr, context.Canceled) {
t.Fatalf("RunWithContextStream error = %v, want context canceled", runErr)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for cancellation")
}

msgs := engine.ctxManager.GetNonSystemMessages()
if len(msgs) < 3 {
t.Fatalf("expected at least 3 non-system messages, got %d", len(msgs))
}

last := msgs[len(msgs)-1]
if last.Role != "tool" {
t.Fatalf("last role = %q, want tool", last.Role)
}
if last.ToolCallID != "call-cancel-1" {
t.Fatalf("last tool_call_id = %q, want %q", last.ToolCallID, "call-cancel-1")
}
if !strings.Contains(strings.ToLower(last.Content), "interrupted") {
t.Fatalf("last tool result = %q, want interrupted marker", last.Content)
}

foundInterruptedEvent := false
for _, ev := range events {
if ev.Type == EventToolError && strings.Contains(strings.ToLower(ev.Message), "interrupted by user") {
foundInterruptedEvent = true
break
}
}
if !foundInterruptedEvent {
t.Fatalf("expected interrupted tool error event, got events: %#v", events)
}
}

func TestRunBackfillsInterruptedResultsForUnexecutedToolCalls(t *testing.T) {
args, err := json.Marshal(map[string]string{"path": "README.md"})
if err != nil {
t.Fatalf("marshal tool args: %v", err)
}

provider := &scriptedStreamProvider{
responses: []*llm.CompletionResponse{{
ToolCalls: []llm.ToolCall{
{
ID: "call-cancel-1",
Type: "function",
Function: llm.ToolCallFunc{
Name: "read",
Arguments: args,
},
},
{
ID: "call-cancel-2",
Type: "function",
Function: llm.ToolCallFunc{
Name: "grep",
Arguments: args,
},
},
},
FinishReason: llm.FinishToolCalls,
}},
}

started := make(chan struct{})
registry := tools.NewRegistry()
registry.MustRegister(blockingCancelTool{name: "read", started: started})
registry.MustRegister(stubTool{name: "grep", content: "should-not-run", summary: "1 line"})

engine := NewEngine(EngineConfig{
MaxIterations: 2,
ContextWindow: 4096,
}, provider, registry)

ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- engine.RunWithContextStream(ctx, Task{
ID: "cancel-two-tools",
Description: "read and grep",
}, nil)
}()

select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first tool execution to start")
}
cancel()

select {
case runErr := <-done:
if !errors.Is(runErr, context.Canceled) {
t.Fatalf("RunWithContextStream error = %v, want context canceled", runErr)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for cancellation")
}

msgs := engine.ctxManager.GetNonSystemMessages()
toolResultsByID := map[string]string{}
for _, msg := range msgs {
if msg.Role != "tool" {
continue
}
toolResultsByID[msg.ToolCallID] = msg.Content
}

for _, id := range []string{"call-cancel-1", "call-cancel-2"} {
content, ok := toolResultsByID[id]
if !ok {
t.Fatalf("missing tool result for %s in messages: %#v", id, msgs)
}
if !strings.Contains(strings.ToLower(content), "interrupted") {
t.Fatalf("tool result for %s = %q, want interrupted marker", id, content)
}
}
}
5 changes: 5 additions & 0 deletions runtime/shell/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ func (r *Runner) Run(ctx context.Context, command string) (*Result, error) {
result.Error = err
}
}
if result.Error == nil {
if ctxErr := ctx.Err(); ctxErr != nil {
result.Error = ctxErr
}
}

return result, nil
}
Expand Down
4 changes: 4 additions & 0 deletions tools/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package shell
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -99,6 +100,9 @@ func (t *ShellTool) Execute(ctx context.Context, params json.RawMessage) (*tools
summary = fmt.Sprintf("exit %d", result.ExitCode)
}
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, context.DeadlineExceeded) {
return nil, result.Error
}
summary = fmt.Sprintf("error: %s", result.Error.Error())
}

Expand Down
23 changes: 23 additions & 0 deletions tools/shell/shell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package shell

import (
"context"
"errors"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -34,3 +35,25 @@ func TestShellToolExecute_DoesNotDuplicateCommandOrExit0InContent(t *testing.T)
t.Fatalf("expected summary not to be 'exit 0'")
}
}

func TestShellToolExecute_ReturnsContextCanceledOnInterrupt(t *testing.T) {
runner := rshell.NewRunner(rshell.Config{
WorkDir: ".",
Timeout: 10 * time.Second,
})
tool := NewShellTool(runner)

ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()

result, err := tool.Execute(ctx, []byte(`{"command":"sleep 100","timeout":120}`))
if !errors.Is(err, context.Canceled) {
t.Fatalf("execute shell tool error = %v, want context canceled", err)
}
if result != nil {
t.Fatalf("expected nil result on cancel, got %#v", result)
}
}