diff --git a/go/README.md b/go/README.md index b89a76318..ac91e2944 100644 --- a/go/README.md +++ b/go/README.md @@ -355,6 +355,65 @@ session, _ := client.CreateSession(context.Background(), &copilot.SessionConfig{ When the model selects a tool, the SDK automatically runs your handler (in parallel with other calls) and responds to the CLI's `tool.call` with the handler's result. +#### Cooperative Cancellation via session.Abort + +`ToolInvocation.Context` is a `context.Context` that is cancelled when `session.Abort` is called. Pass it to any cancellable operation (HTTP requests, DB queries, sleeps) so the handler stops promptly when the session is aborted: + +```go +lookupIssue := copilot.DefineTool("lookup_issue", "Fetch issue details from our tracker", + func(params LookupIssueParams, inv copilot.ToolInvocation) (any, error) { + // Pass inv.Context so the HTTP request is cancelled on session.Abort. + req, err := http.NewRequestWithContext(inv.Context, "GET", + "https://api.example.com/issues/"+params.ID, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err // returns context.Canceled when aborted + } + defer resp.Body.Close() + // ... + return summary, nil + }) +``` + +Handlers that don't use `inv.Context` are unaffected; they run to completion as before. + +#### Cancelling a single tool call + +Use `session.CancelToolCall(toolCallID)` to cancel one specific in-flight handler without aborting the session or any other concurrent handlers. It returns `true` if the tool call was found and cancelled, `false` if it was not in flight. + +```go +// Cancel a specific tool call by its ID. +if cancelled := session.CancelToolCall(toolCallID); !cancelled { + log.Println("tool call was not in flight") +} +``` + +`toolCallID` is available inside the handler as `inv.ToolCallID`. You can capture it to enable external cancellation of a specific operation: + +```go +var mu sync.Mutex +activeCalls := map[string]string{} // label → toolCallID + +slowTool := copilot.DefineTool("slow_op", "A long-running operation", + func(params SlowOpParams, inv copilot.ToolInvocation) (any, error) { + mu.Lock() + activeCalls[params.Label] = inv.ToolCallID + mu.Unlock() + + // ... do work, checking inv.Context.Done() ... + return result, nil + }) + +// Elsewhere, cancel by label: +mu.Lock() +id := activeCalls["my-label"] +mu.Unlock() +session.CancelToolCall(id) +``` + #### Overriding Built-in Tools If you register a tool with the same name as a built-in CLI tool (e.g. `edit_file`, `read_file`), the SDK will throw an error unless you explicitly opt in by setting `OverridesBuiltInTool = true`. This flag signals that you intend to replace the built-in tool with your custom implementation. diff --git a/go/session.go b/go/session.go index acd698677..c35cc367e 100644 --- a/go/session.go +++ b/go/session.go @@ -82,6 +82,11 @@ type Session struct { capabilities SessionCapabilities capabilitiesMu sync.RWMutex + // toolCallCancels tracks cancel functions for in-flight tool calls so that + // Abort can propagate cancellation into handler contexts. + toolCallCancels map[string]context.CancelFunc + toolCallCancelsMu sync.Mutex + // eventCh serializes user event handler dispatch. dispatchEvent enqueues; // a single goroutine (processEvents) dequeues and invokes handlers in FIFO order. eventCh chan SessionEvent @@ -1337,11 +1342,35 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { // executeToolAndRespond executes a tool handler and sends the result back via RPC. func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, arguments any, handler ToolHandler, traceparent, tracestate string) { - ctx := contextWithTraceParent(context.Background(), traceparent, tracestate) + // traceCtx carries OTel trace propagation but is not subject to abort cancellation. + // It is used for administrative RPC calls that must complete regardless of abort. + traceCtx := contextWithTraceParent(context.Background(), traceparent, tracestate) + // ctx is passed to the tool handler and is cancelled when session.Abort is called, + // giving handlers a cooperative cancellation signal. + ctx, cancel := context.WithCancel(traceCtx) + + s.toolCallCancelsMu.Lock() + if s.toolCallCancels == nil { + s.toolCallCancels = make(map[string]context.CancelFunc) + } + s.toolCallCancels[toolCallID] = cancel + s.toolCallCancelsMu.Unlock() + + // Cleanup runs last (registered first). Removes the cancel from the in-flight map + // and releases context resources. + defer func() { + s.toolCallCancelsMu.Lock() + delete(s.toolCallCancels, toolCallID) + s.toolCallCancelsMu.Unlock() + cancel() + }() + + // Panic recovery runs first (registered second, LIFO). Uses traceCtx to ensure + // the error response is sent even if ctx was already cancelled by Abort. defer func() { if r := recover(); r != nil { errMsg := fmt.Sprintf("tool panic: %v", r) - s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{ + s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{ RequestID: requestID, Error: &errMsg, }) @@ -1353,13 +1382,14 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, ToolCallID: toolCallID, ToolName: toolName, Arguments: arguments, - TraceContext: ctx, + Context: ctx, + TraceContext: traceCtx, } result, err := handler(invocation) if err != nil { errMsg := err.Error() - s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{ + s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{ RequestID: requestID, Error: &errMsg, }) @@ -1389,7 +1419,7 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, if result.Error != "" { rpcResult.Error = &result.Error } - s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{ + s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{ RequestID: requestID, Result: rpcResult, }) @@ -1555,9 +1585,44 @@ func (s *Session) Abort(ctx context.Context) error { return fmt.Errorf("failed to abort session: %w", err) } + s.toolCallCancelsMu.Lock() + for id, cancel := range s.toolCallCancels { + cancel() + delete(s.toolCallCancels, id) + } + s.toolCallCancelsMu.Unlock() + return nil } +// CancelToolCall cancels a single in-flight tool handler identified by toolCallID +// without aborting the agentic loop or any other concurrent tool handlers. +// +// It looks up the cancel func registered when the handler was dispatched, calls it +// (cancelling the context passed to that handler via ToolInvocation.Context), removes +// the entry from the in-flight map, and returns true. If no handler with the given +// toolCallID is currently executing, CancelToolCall is a no-op and returns false. +// +// Example: +// +// // Start a session with a long-running tool registered. +// // Later, cancel only a specific tool call without aborting the session: +// if cancelled := session.CancelToolCall("tool-call-id-123"); !cancelled { +// log.Println("tool call was not in flight") +// } +func (s *Session) CancelToolCall(toolCallID string) bool { + s.toolCallCancelsMu.Lock() + defer s.toolCallCancelsMu.Unlock() + + cancel, ok := s.toolCallCancels[toolCallID] + if !ok { + return false + } + cancel() + delete(s.toolCallCancels, toolCallID) + return true +} + // SetModelOptions configures optional parameters for SetModel. type SetModelOptions struct { // ReasoningEffort sets the reasoning effort level for the new model (e.g., "low", "medium", "high", "xhigh"). diff --git a/go/session_test.go b/go/session_test.go index 15cfbcf57..e2356f3d3 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -1031,3 +1031,278 @@ func TestSession_ElicitationRequestSchema(t *testing.T) { } }) } + +// TestToolInvocation_ContextCancelledOnAbort verifies that the context passed to a +// tool handler is cancelled when the in-flight cancel func (as used by Abort) fires. +func TestToolInvocation_ContextCancelledOnAbort(t *testing.T) { + session, cleanup := newRPCDrainSession(t, "session-abort-test") + defer cleanup() + + // Channel to receive the invocation context from the handler. + ctxCh := make(chan context.Context, 1) + + // The handler blocks until its context is cancelled, then reports. + handler := ToolHandler(func(inv ToolInvocation) (ToolResult, error) { + ctxCh <- inv.Context + <-inv.Context.Done() + return ToolResult{TextResultForLLM: "cancelled"}, nil + }) + + done := make(chan struct{}) + go func() { + defer close(done) + session.executeToolAndRespond("req-1", "my_tool", "tc-1", nil, handler, "", "") + }() + + // Wait for the handler to start and capture its context. + var handlerCtx context.Context + select { + case handlerCtx = <-ctxCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler to start") + } + + // Verify the context is not yet cancelled. + if handlerCtx.Err() != nil { + t.Fatalf("expected context to be active, got %v", handlerCtx.Err()) + } + + // Simulate what Abort() does: cancel all in-flight tool call contexts. + session.toolCallCancelsMu.Lock() + for _, cancel := range session.toolCallCancels { + cancel() + } + session.toolCallCancelsMu.Unlock() + + // Wait for the handler to finish. + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler to finish after cancellation") + } + + // The handler context must be cancelled. + if handlerCtx.Err() == nil { + t.Fatal("expected handler context to be cancelled after abort") + } + + // The cancel func must have been removed from the map. + session.toolCallCancelsMu.Lock() + remaining := len(session.toolCallCancels) + session.toolCallCancelsMu.Unlock() + if remaining != 0 { + t.Fatalf("expected toolCallCancels to be empty after execution, got %d entries", remaining) + } +} + +// TestToolInvocation_ContextPopulated verifies that executeToolAndRespond sets +// both Context and TraceContext on the ToolInvocation passed to the handler. +func TestToolInvocation_ContextPopulated(t *testing.T) { + session, cleanup := newRPCDrainSession(t, "session-ctx-test") + defer cleanup() + + invCh := make(chan ToolInvocation, 1) + handler := ToolHandler(func(inv ToolInvocation) (ToolResult, error) { + invCh <- inv + return ToolResult{TextResultForLLM: "ok"}, nil + }) + + done := make(chan struct{}) + go func() { + defer close(done) + session.executeToolAndRespond("req-2", "check_tool", "tc-2", map[string]any{"x": 1}, handler, "", "") + }() + + var inv ToolInvocation + select { + case inv = <-invCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler invocation") + } + + if inv.Context == nil { + t.Fatal("expected ToolInvocation.Context to be set") + } + if inv.TraceContext == nil { + t.Fatal("expected ToolInvocation.TraceContext to be set") + } + // Context is a cancellable child of TraceContext; they are different instances. + // Both must be non-nil and Context must be cancellable independently. + if inv.Context == inv.TraceContext { + t.Error("expected Context and TraceContext to be different instances (Context is a cancellable child)") + } + if inv.SessionID != "session-ctx-test" { + t.Errorf("expected SessionID session-ctx-test, got %q", inv.SessionID) + } + if inv.ToolCallID != "tc-2" { + t.Errorf("expected ToolCallID tc-2, got %q", inv.ToolCallID) + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for executeToolAndRespond to complete") + } +} + +// newRPCDrainSession creates a Session backed by a pipe-based JSON-RPC client +// whose server side drains requests and returns empty success responses. +// The caller must close stdinW and stdoutW when done. +func newRPCDrainSession(t *testing.T, sessionID string) (*Session, func()) { + t.Helper() + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + + client := jsonrpc2.NewClient(stdinW, stdoutR) + client.Start() + + session := &Session{ + SessionID: sessionID, + client: client, + RPC: rpc.NewSessionRPC(client, sessionID), + } + + // Drain goroutine: read every RPC request and send an empty success response. + // Uses a single bufio.Reader so that header parsing and body reads share the + // same read buffer — mixing bufio.Scanner with io.ReadFull on the same reader + // causes data corruption because Scanner may buffer-ahead bytes that + // io.ReadFull then misses. + go func() { + br := bufio.NewReader(stdinR) + for { + // Read headers until blank line. + var contentLen int + for { + line, err := br.ReadString('\n') + if err != nil { + return + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + break // end of headers + } + fmt.Sscanf(line, "Content-Length: %d", &contentLen) + } + if contentLen == 0 { + continue + } + + body := make([]byte, contentLen) + if _, err := io.ReadFull(br, body); err != nil { + return + } + + var req struct { + ID json.RawMessage `json:"id"` + } + if err := json.Unmarshal(body, &req); err != nil || req.ID == nil { + continue + } + resp, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(req.ID), + "result": map[string]any{}, + }) + fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(resp), resp) //nolint:errcheck + } + }() + + cleanup := func() { + client.Stop() + stdinR.Close() + stdinW.Close() + stdoutR.Close() + stdoutW.Close() + } + return session, cleanup +} + +// TestCancelToolCall_cancelsTargetedHandlerOnly verifies that CancelToolCall +// cancels only the specified handler's context while leaving concurrent +// handlers unaffected, and returns false for an unknown tool call ID. +func TestCancelToolCall_cancelsTargetedHandlerOnly(t *testing.T) { + session, cleanup := newRPCDrainSession(t, "session-cancel-test") + defer cleanup() + + type handlerState struct { + ctx context.Context + done chan struct{} + } + + makeBlockingHandler := func(id string) (ToolHandler, *handlerState) { + state := &handlerState{done: make(chan struct{})} + h := ToolHandler(func(inv ToolInvocation) (ToolResult, error) { + state.ctx = inv.Context + <-inv.Context.Done() // block until cancelled + return ToolResult{TextResultForLLM: "cancelled"}, nil + }) + return h, state + } + + h1, s1 := makeBlockingHandler("tc-a") + h2, s2 := makeBlockingHandler("tc-b") + + // Start both handlers concurrently. + go func() { + defer close(s1.done) + session.executeToolAndRespond("req-a", "tool_a", "tc-a", nil, h1, "", "") + }() + go func() { + defer close(s2.done) + session.executeToolAndRespond("req-b", "tool_b", "tc-b", nil, h2, "", "") + }() + + // Wait for both handlers to start (ctx will be set once they block). + deadline := time.After(2 * time.Second) + for s1.ctx == nil || s2.ctx == nil { + select { + case <-deadline: + t.Fatal("timed out waiting for handlers to start") + default: + time.Sleep(5 * time.Millisecond) + } + } + + // Unknown ID returns false and leaves both handlers running. + if got := session.CancelToolCall("nonexistent"); got { + t.Fatal("expected CancelToolCall(unknown) to return false") + } + if s1.ctx.Err() != nil { + t.Fatal("handler 1 should still be running after unknown CancelToolCall") + } + if s2.ctx.Err() != nil { + t.Fatal("handler 2 should still be running after unknown CancelToolCall") + } + + // Cancel only handler 1. + if got := session.CancelToolCall("tc-a"); !got { + t.Fatal("expected CancelToolCall(tc-a) to return true") + } + + // Handler 1 should finish; handler 2 should remain live. + select { + case <-s1.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler 1 to finish after CancelToolCall") + } + + if s1.ctx.Err() == nil { + t.Fatal("expected handler 1 context to be cancelled") + } + if s2.ctx.Err() != nil { + t.Fatal("handler 2 context should still be live") + } + + // CancelToolCall on the same ID again returns false (already removed). + if got := session.CancelToolCall("tc-a"); got { + t.Fatal("expected second CancelToolCall(tc-a) to return false") + } + + // Cancel handler 2 to let it finish and avoid goroutine leak. + session.CancelToolCall("tc-b") + select { + case <-s2.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler 2 to finish") + } +} diff --git a/go/types.go b/go/types.go index 5ed0b6931..749b44a7e 100644 --- a/go/types.go +++ b/go/types.go @@ -1154,10 +1154,25 @@ type ToolInvocation struct { ToolName string Arguments any + // Context is the primary context for this tool invocation. It carries + // W3C Trace Context propagation (for OpenTelemetry) and is cancelled + // when session.Abort is called, allowing handlers to cooperatively stop + // in-flight work (e.g. pass to http.NewRequestWithContext, sql.QueryContext). + // + // Handlers that do not inspect the context continue to work unchanged. + Context context.Context + + // TraceContext is deprecated: use Context instead. // TraceContext carries the W3C Trace Context propagated from the CLI's - // execute_tool span. Pass this to OpenTelemetry-aware code so that + // execute_tool span. Pass this to OpenTelemetry-aware code so that // child spans created inside the handler are parented to the CLI span. // When no trace context is available this will be context.Background(). + // Unlike Context, TraceContext is never cancelled — it remains valid for + // the lifetime of the RPC call regardless of session.Abort. + // + // Deprecated: Use Context, which carries the same trace information and + // is additionally cancelled when session.Abort or session.CancelToolCall + // is called. TraceContext context.Context }