diff --git a/java/README.md b/java/README.md index 7c7209bc8..12aaa3553 100644 --- a/java/README.md +++ b/java/README.md @@ -242,6 +242,49 @@ public class Consumer { The gate also applies to individual methods annotated with `@CopilotExperimental` on otherwise stable types. When a type-level annotation is present, all member accesses through that type are considered experimental. `@AllowCopilotExperimental` mirrors the same declaration-level boundary: annotating a class opts in that class and its enclosed declarations, while annotating a method or constructor opts in just that executable signature. +## Tool Handler Cancellation + +Tool handlers can observe when the session is aborted and stop in-flight work cooperatively using the `AbortSignal` available on every `ToolInvocation`. + +### Cancelling all in-flight handlers: `session.abort()` + +`session.abort()` stops the agentic loop **and** immediately fires the `AbortSignal` for every tool handler currently running in the session. + +```java +// Handler using isAborted() polling +var tool = ToolDefinition.create("long_task", "A long-running task", + Map.of("type", "object", "properties", Map.of()), + invocation -> CompletableFuture.supplyAsync(() -> { + AbortSignal signal = invocation.getAbortSignal(); + while (!signal.isAborted()) { + // perform incremental work + } + return "cancelled"; + }) +); + +// Handler using onAborted() callback +var tool2 = ToolDefinition.create("http_task", "Task with cleanup", + Map.of("type", "object", "properties", Map.of()), + invocation -> { + var future = new CompletableFuture(); + invocation.getAbortSignal().onAborted(() -> future.complete("aborted")); + // start async work that completes future normally ... + return future; + } +); +``` + +### Cancelling a single handler: `session.cancelToolCall(toolCallId)` + +`cancelToolCall(toolCallId)` fires the `AbortSignal` for **only** the named in-flight invocation, without aborting the agentic loop or affecting other running handlers. Returns `true` if the call was found and cancelled, `false` if not found (already completed or unknown id). + +```java +// Cancel a specific in-flight tool call by its ID +String toolCallId = invocation.getToolCallId(); // captured from ToolInvocation +boolean wasCancelled = session.cancelToolCall(toolCallId); +``` + ## Projects Using This SDK | Project | Description | diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index adfeac013..4c4b2bbc0 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -169,6 +169,7 @@ public final class CopilotSession implements AutoCloseable { private final Set> eventHandlers = ConcurrentHashMap.newKeySet(); private final Map toolHandlers = new ConcurrentHashMap<>(); private final Map commandHandlers = new ConcurrentHashMap<>(); + private final Map activeToolSignals = new ConcurrentHashMap<>(); private final AtomicReference permissionHandler = new AtomicReference<>(); private final AtomicReference userInputHandler = new AtomicReference<>(); private final AtomicReference elicitationHandler = new AtomicReference<>(); @@ -882,15 +883,20 @@ private void handleBroadcastEventAsync(SessionEvent event) { */ private void executeToolAndRespondAsync(String requestId, String toolName, String toolCallId, Object arguments, ToolDefinition tool) { + var signal = new com.github.copilot.rpc.AbortSignal(); + String signalKey = toolCallId != null ? toolCallId : requestId; + activeToolSignals.put(signalKey, signal); Runnable task = () -> { try { JsonNode argumentsNode = arguments instanceof JsonNode jn ? jn : (arguments != null ? MAPPER.valueToTree(arguments) : null); var invocation = new com.github.copilot.rpc.ToolInvocation().setSessionId(sessionId) - .setToolCallId(toolCallId).setToolName(toolName).setArguments(argumentsNode); + .setToolCallId(toolCallId).setToolName(toolName).setArguments(argumentsNode) + .setAbortSignal(signal); tool.handler().invoke(invocation).thenAccept(result -> { + activeToolSignals.remove(signalKey); try { ToolResultObject toolResult; if (result instanceof ToolResultObject tr) { @@ -905,6 +911,7 @@ private void executeToolAndRespondAsync(String requestId, String toolName, Strin LOG.log(Level.WARNING, "Error sending tool result for requestId=" + requestId, e); } }).exceptionally(ex -> { + activeToolSignals.remove(signalKey); try { getRpc().tools.handlePendingToolCall(new SessionToolsHandlePendingToolCallParams(sessionId, requestId, null, ex.getMessage() != null ? ex.getMessage() : ex.toString())); @@ -914,6 +921,7 @@ private void executeToolAndRespondAsync(String requestId, String toolName, Strin return null; }); } catch (Exception e) { + activeToolSignals.remove(signalKey); LOG.log(Level.WARNING, "Error executing tool for requestId=" + requestId, e); try { getRpc().tools.handlePendingToolCall(new SessionToolsHandlePendingToolCallParams(sessionId, @@ -1796,9 +1804,52 @@ public CompletableFuture> getMessages() { */ public CompletableFuture abort() { ensureNotTerminated(); + for (com.github.copilot.rpc.AbortSignal signal : activeToolSignals.values()) { + signal.abort(); + } return rpc.invoke("session.abort", Map.of("sessionId", sessionId), Void.class); } + /** + * Cancels a single in-flight tool handler by its tool call ID. + *

+ * Unlike {@link #abort()}, this method fires the + * {@link com.github.copilot.rpc.AbortSignal} for only the specified tool + * invocation and does not abort the agentic loop or affect any other in-flight + * handlers. + *

+ * The signal is fired and the entry is removed from the tracking map + * immediately. The handler is responsible for observing + * {@link com.github.copilot.rpc.AbortSignal#isAborted()} or registering an + * {@link com.github.copilot.rpc.AbortSignal#onAborted(Runnable)} callback. + * + *

{@code
+     * // Cancel a specific tool invocation
+     * boolean cancelled = session.cancelToolCall(toolCallId);
+     * if (!cancelled) {
+     * 	// tool call was already complete or id was not found
+     * }
+     * }
+ * + * @param toolCallId + * the tool call ID to cancel, as provided by + * {@link com.github.copilot.rpc.ToolInvocation#getToolCallId()} + * @return {@code true} if an in-flight handler was found and its signal was + * fired; {@code false} if no in-flight handler matched the given ID + * @throws IllegalStateException + * if this session has been terminated + * @since 1.6.0 + */ + public boolean cancelToolCall(String toolCallId) { + ensureNotTerminated(); + com.github.copilot.rpc.AbortSignal signal = activeToolSignals.remove(toolCallId); + if (signal != null) { + signal.abort(); + return true; + } + return false; + } + /** * Changes the model for this session with an optional reasoning effort level. *

@@ -2136,6 +2187,7 @@ public void close() { eventHandlers.clear(); toolHandlers.clear(); commandHandlers.clear(); + activeToolSignals.clear(); permissionHandler.set(null); userInputHandler.set(null); elicitationHandler.set(null); diff --git a/java/src/main/java/com/github/copilot/rpc/AbortSignal.java b/java/src/main/java/com/github/copilot/rpc/AbortSignal.java new file mode 100644 index 000000000..2c7626814 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/AbortSignal.java @@ -0,0 +1,123 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.Objects; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A signal that indicates whether a tool invocation has been aborted. + *

+ * An {@code AbortSignal} is passed to tool handlers via + * {@link ToolInvocation#getAbortSignal()} and is triggered when + * {@link com.github.copilot.CopilotSession#abort()} is called while the tool is + * executing. Tool handlers can use this to implement cooperative cancellation, + * allowing them to stop long-running work gracefully when the session is + * aborted. + * + *

Example Usage

+ * + *
{@code
+ * ToolHandler handler = invocation -> {
+ * 	AbortSignal signal = invocation.getAbortSignal();
+ * 	return CompletableFuture.supplyAsync(() -> {
+ * 		while (!signal.isAborted()) {
+ * 			// do incremental work here
+ * 		}
+ * 		throw new CancellationException("Tool aborted");
+ * 	});
+ * };
+ * }
+ * + *

Callback Registration

+ * + *
{@code
+ * ToolHandler handler = invocation -> {
+ * 	AbortSignal signal = invocation.getAbortSignal();
+ * 	signal.onAborted(() -> System.out.println("Aborting tool!"));
+ * 	// ... perform work ...
+ * 	return CompletableFuture.completedFuture("done");
+ * };
+ * }
+ * + * @see ToolInvocation#getAbortSignal() + * @see com.github.copilot.CopilotSession#abort() + * @since 1.6.0 + */ +public final class AbortSignal { + + private final AtomicBoolean aborted = new AtomicBoolean(false); + private final CopyOnWriteArrayList listeners = new CopyOnWriteArrayList<>(); + + /** + * Returns whether this signal has been aborted. + * + * @return {@code true} if {@link com.github.copilot.CopilotSession#abort()} was + * called while this tool invocation was in progress; {@code false} + * otherwise + */ + public boolean isAborted() { + return aborted.get(); + } + + /** + * Registers a callback to be invoked when this signal is aborted. + *

+ * If the signal is already aborted at the time of registration, the callback is + * invoked immediately on the calling thread. + *

+ * The callback is guaranteed to be invoked at most once, regardless of + * concurrent calls to {@link #abort()} and {@code onAborted}. Any + * {@link Throwable} thrown by the callback is silently ignored. + * + * @param listener + * the callback to invoke on abort + * @throws NullPointerException + * if listener is null + */ + public void onAborted(Runnable listener) { + Objects.requireNonNull(listener, "listener must not be null"); + // Wrap in an AtomicBoolean-guarded runnable so the callback fires at most once + // even if abort() races with this method between listeners.add() and the + // aborted.get() check below. + AtomicBoolean fired = new AtomicBoolean(false); + Runnable once = () -> { + if (fired.compareAndSet(false, true)) { + try { + listener.run(); + } catch (Throwable ignored) { + // Throwables from listeners are silently ignored + } + } + }; + listeners.add(once); + if (aborted.get()) { + once.run(); + } + } + + /** + * Triggers this abort signal, notifying all registered listeners. + *

+ * Note: This method is intended for internal SDK use only. It + * is called by the SDK when {@link com.github.copilot.CopilotSession#abort()} + * is invoked while this tool invocation is in progress. + *

+ * Calling this method more than once has no effect — the signal fires exactly + * once. Any {@link Throwable} thrown by a listener is silently ignored. + */ + public void abort() { + if (aborted.compareAndSet(false, true)) { + for (Runnable listener : listeners) { + try { + listener.run(); + } catch (Throwable ignored) { + // Throwables from listeners are silently ignored + } + } + } + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/ToolInvocation.java b/java/src/main/java/com/github/copilot/rpc/ToolInvocation.java index dddfdd06f..3ef38206e 100644 --- a/java/src/main/java/com/github/copilot/rpc/ToolInvocation.java +++ b/java/src/main/java/com/github/copilot/rpc/ToolInvocation.java @@ -6,6 +6,7 @@ import java.util.Map; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonSetter; import com.fasterxml.jackson.core.type.TypeReference; @@ -16,11 +17,28 @@ * Represents a tool invocation request from the AI assistant. *

* When the assistant invokes a tool, this object contains the context including - * the session ID, tool call ID, tool name, and arguments parsed from the - * assistant's request. + * the session ID, tool call ID, tool name, arguments parsed from the + * assistant's request, and an {@link AbortSignal} that is triggered when + * {@link com.github.copilot.CopilotSession#abort()} is called while the tool is + * executing. + * + *

Cooperative Cancellation

+ * + *
{@code
+ * ToolHandler handler = invocation -> {
+ * 	AbortSignal signal = invocation.getAbortSignal();
+ * 	return CompletableFuture.supplyAsync(() -> {
+ * 		while (!signal.isAborted()) {
+ * 			// do incremental work here
+ * 		}
+ * 		throw new CancellationException("Tool aborted");
+ * 	});
+ * };
+ * }
* * @see ToolHandler * @see ToolDefinition + * @see AbortSignal * @since 1.0.0 */ @JsonInclude(JsonInclude.Include.NON_NULL) @@ -34,6 +52,7 @@ public final class ToolInvocation { private String toolCallId; private String toolName; private JsonNode argumentsNode; + private AbortSignal abortSignal = new AbortSignal(); /** * Gets the session ID where the tool was invoked. @@ -168,4 +187,54 @@ public ToolInvocation setArguments(JsonNode arguments) { this.argumentsNode = arguments; return this; } + + /** + * Returns the abort signal for this tool invocation. + *

+ * The signal is triggered when + * {@link com.github.copilot.CopilotSession#abort()} is called while this tool + * is executing. Use it to implement cooperative cancellation in your tool + * handler. + * + *

{@code
+     * ToolHandler handler = invocation -> {
+     * 	AbortSignal signal = invocation.getAbortSignal();
+     * 	return CompletableFuture.supplyAsync(() -> {
+     * 		while (!signal.isAborted()) {
+     * 			// do incremental work here
+     * 		}
+     * 		throw new CancellationException("Tool aborted");
+     * 	});
+     * };
+     * }
+ * + * @return the abort signal; never {@code null} + * @see AbortSignal + * @since 1.6.0 + */ + @JsonIgnore + public AbortSignal getAbortSignal() { + return abortSignal; + } + + /** + * Sets the abort signal for this tool invocation. + *

+ * Note: This method is intended for internal SDK use only. + * Users do not need to call this method directly. Passing {@code null} is + * accepted for backwards compatibility and leaves the existing signal + * unchanged. + * + * @param abortSignal + * the abort signal to associate with this invocation, or + * {@code null} to leave the existing signal unchanged + * @return this invocation for method chaining + */ + @JsonIgnore + public ToolInvocation setAbortSignal(AbortSignal abortSignal) { + if (abortSignal != null) { + this.abortSignal = abortSignal; + } + return this; + } } diff --git a/java/src/main/java/com/github/copilot/rpc/package-info.java b/java/src/main/java/com/github/copilot/rpc/package-info.java index edc7dedcf..eda97abe4 100644 --- a/java/src/main/java/com/github/copilot/rpc/package-info.java +++ b/java/src/main/java/com/github/copilot/rpc/package-info.java @@ -39,6 +39,10 @@ * tool that can be invoked by the assistant. *

  • {@link com.github.copilot.rpc.ToolInvocation} - Represents a tool * invocation request from the assistant.
  • + *
  • {@link com.github.copilot.rpc.AbortSignal} - Cancellation signal passed + * to tool handlers via + * {@link com.github.copilot.rpc.ToolInvocation#getAbortSignal()}, triggered + * when {@link com.github.copilot.CopilotSession#abort()} is called.
  • *
  • {@link com.github.copilot.rpc.Attachment} - File attachment for * messages.
  • * diff --git a/java/src/test/java/com/github/copilot/CancelToolCallTest.java b/java/src/test/java/com/github/copilot/CancelToolCallTest.java new file mode 100644 index 000000000..9f76d7a39 --- /dev/null +++ b/java/src/test/java/com/github/copilot/CancelToolCallTest.java @@ -0,0 +1,91 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static org.junit.jupiter.api.Assertions.*; + +import java.lang.reflect.Field; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.github.copilot.rpc.AbortSignal; + +/** + * Unit tests for {@link CopilotSession#cancelToolCall(String)}. + *

    + * Uses reflection to inject {@link AbortSignal} instances directly into the + * session's active-tool-signal tracking map, allowing the cancellation logic to + * be verified in isolation without requiring the full E2E test harness. + */ +class CancelToolCallTest { + + /** + * Injects two signals into a session, calls cancelToolCall for one, and + * verifies that only the targeted signal is aborted while the other remains + * unaffected. + */ + @Test + void cancelToolCallFiresOnlyTargetedSignal() throws Exception { + var session = new CopilotSession("sess-cancel-test", null); + + AbortSignal signalA = new AbortSignal(); + AbortSignal signalB = new AbortSignal(); + + Map map = getActiveToolSignals(session); + map.put("call-A", signalA); + map.put("call-B", signalB); + + boolean result = session.cancelToolCall("call-A"); + + assertTrue(result, "cancelToolCall should return true for a known toolCallId"); + assertTrue(signalA.isAborted(), "signal A should be aborted after cancelToolCall(call-A)"); + assertFalse(signalB.isAborted(), "signal B must NOT be aborted — only the targeted signal fires"); + } + + /** + * Verifies that cancelToolCall returns false for an unknown tool call ID, + * without affecting any in-flight signals. + */ + @Test + void cancelToolCallReturnsFalseForUnknownId() throws Exception { + var session = new CopilotSession("sess-cancel-unknown", null); + + AbortSignal signal = new AbortSignal(); + Map map = getActiveToolSignals(session); + map.put("call-exists", signal); + + boolean result = session.cancelToolCall("call-does-not-exist"); + + assertFalse(result, "cancelToolCall should return false for an unknown toolCallId"); + assertFalse(signal.isAborted(), "existing signal must not be affected"); + } + + /** + * Verifies that a cancelled signal is removed from the tracking map so it + * cannot be double-fired. + */ + @Test + void cancelToolCallRemovesSignalFromMap() throws Exception { + var session = new CopilotSession("sess-cancel-cleanup", null); + + AbortSignal signal = new AbortSignal(); + Map map = getActiveToolSignals(session); + map.put("call-X", signal); + + session.cancelToolCall("call-X"); + + assertFalse(map.containsKey("call-X"), "signal should be removed from the map after cancellation"); + // second call must return false since the entry is gone + assertFalse(session.cancelToolCall("call-X"), "second cancelToolCall for same id should return false"); + } + + @SuppressWarnings("unchecked") + private static Map getActiveToolSignals(CopilotSession session) throws Exception { + Field f = CopilotSession.class.getDeclaredField("activeToolSignals"); + f.setAccessible(true); + return (Map) f.get(session); + } +} diff --git a/java/src/test/java/com/github/copilot/ToolInvocationTest.java b/java/src/test/java/com/github/copilot/ToolInvocationTest.java index 2bc9edb1b..1d3d14d9f 100644 --- a/java/src/test/java/com/github/copilot/ToolInvocationTest.java +++ b/java/src/test/java/com/github/copilot/ToolInvocationTest.java @@ -6,10 +6,13 @@ import static org.junit.jupiter.api.Assertions.*; +import java.util.concurrent.atomic.AtomicBoolean; + import org.junit.jupiter.api.Test; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.copilot.rpc.AbortSignal; import com.github.copilot.rpc.ToolInvocation; /** @@ -127,6 +130,102 @@ void testGetArgumentsAsThrowsOnInvalidType() { assertTrue(exception.getMessage().contains("StrictType")); } + /** + * Test that getAbortSignal returns a non-null signal by default. + */ + @Test + void testGetAbortSignalReturnedByDefault() { + ToolInvocation invocation = new ToolInvocation().setSessionId("s1").setToolCallId("c1").setToolName("my_tool"); + assertNotNull(invocation.getAbortSignal(), "getAbortSignal should not return null"); + assertFalse(invocation.getAbortSignal().isAborted(), "signal should not be aborted by default"); + } + + /** + * Test that isAborted returns true after the signal is aborted. + */ + @Test + void testAbortSignalIsAbortedAfterAbort() { + AbortSignal signal = new AbortSignal(); + assertFalse(signal.isAborted()); + signal.abort(); + assertTrue(signal.isAborted()); + } + + /** + * Test that onAborted callback is invoked when signal is aborted. + */ + @Test + void testAbortSignalOnAbortedCallbackInvoked() { + AbortSignal signal = new AbortSignal(); + var called = new AtomicBoolean(false); + signal.onAborted(() -> called.set(true)); + assertFalse(called.get()); + signal.abort(); + assertTrue(called.get()); + } + + /** + * Test that onAborted callback is invoked immediately if signal is already + * aborted. + */ + @Test + void testAbortSignalOnAbortedCallbackInvokedImmediatelyIfAlreadyAborted() { + AbortSignal signal = new AbortSignal(); + signal.abort(); + var called = new AtomicBoolean(false); + signal.onAborted(() -> called.set(true)); + assertTrue(called.get(), "callback should be invoked immediately when signal is already aborted"); + } + + /** + * Test that abort() is idempotent — callbacks fire only once. + */ + @Test + void testAbortSignalAbortIsIdempotent() { + AbortSignal signal = new AbortSignal(); + var count = new java.util.concurrent.atomic.AtomicInteger(0); + signal.onAborted(count::incrementAndGet); + signal.abort(); + signal.abort(); + assertEquals(1, count.get(), "callback should be invoked exactly once even if abort() called twice"); + } + + /** + * Test that onAborted throws NullPointerException for null listener. + */ + @Test + void testAbortSignalOnAbortedRejectsNullListener() { + AbortSignal signal = new AbortSignal(); + assertThrows(NullPointerException.class, () -> signal.onAborted(null)); + } + + /** + * Test that a callback registered via onAborted fires at most once even when + * abort() races with onAborted registration (at-most-once delivery guarantee). + */ + @Test + void testAbortSignalCallbackFiresAtMostOnce() { + AbortSignal signal = new AbortSignal(); + // Pre-abort the signal so onAborted() will fire immediately on registration + signal.abort(); + var count = new java.util.concurrent.atomic.AtomicInteger(0); + // Registering after abort fires immediately — but only once + signal.onAborted(count::incrementAndGet); + assertEquals(1, count.get(), "callback should fire exactly once when registered after abort"); + } + + /** + * Test that setAbortSignal(null) is accepted for backwards compatibility and + * leaves the existing signal unchanged. + */ + @Test + void testSetAbortSignalNullIsIgnored() { + ToolInvocation invocation = new ToolInvocation(); + AbortSignal original = invocation.getAbortSignal(); + invocation.setAbortSignal(null); // must not throw + assertSame(original, invocation.getAbortSignal(), "existing signal should be preserved when null is passed"); + } + /** * Record for testing type-safe argument deserialization. */