diff --git a/core/src/agents/functions.ts b/core/src/agents/functions.ts index e46ce9a7..50aa32e0 100644 --- a/core/src/agents/functions.ts +++ b/core/src/agents/functions.ts @@ -5,11 +5,14 @@ */ import {Content, createUserContent, FunctionCall, Part} from '@google/genai'; -import {isEmpty} from 'lodash-es'; +import {cloneDeep, isEmpty, isPlainObject} from 'lodash-es'; import {InvocationContext} from '../agents/invocation_context.js'; import {createEvent, Event, getFunctionCalls} from '../events/event.js'; -import {mergeEventActions} from '../events/event_actions.js'; +import { + createEventActions, + mergeEventActions, +} from '../events/event_actions.js'; import {BaseTool} from '../tools/base_tool.js'; import {ToolConfirmation} from '../tools/tool_confirmation.js'; import {randomUUID} from '../utils/env_aware_utils.js'; @@ -34,6 +37,8 @@ export const REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = // Export these items for testing purposes only export const functionsExportedForTestingOnly = { handleFunctionCallList, + handleFunctionCallsAsync, + executeSingleFunctionCall, generateAuthEvent, generateRequestConfirmationEvent, }; @@ -206,7 +211,13 @@ export function generateRequestConfirmationEvent({ parts: parts, role: functionResponseEvent.content!.role, }, - actions: functionResponseEvent.actions, + // Carry only requestedToolConfirmations — not the full actions — to avoid + // double-applying stateDelta/artifactDelta/transferToAgent that were + // already applied by the streamed individual tool response events. + actions: createEventActions({ + requestedToolConfirmations: + functionResponseEvent.actions.requestedToolConfirmations, + }), longRunningToolIds: Array.from(longRunningToolIds), }); } @@ -284,7 +295,7 @@ function buildResponseEvent( * - If the tool is long-running and the response is null, continue. !!state * - Merge all function response events into a single event. */ -export async function handleFunctionCallsAsync({ +export async function* handleFunctionCallsAsync({ invocationContext, functionCallEvent, toolsDict, @@ -300,200 +311,441 @@ export async function handleFunctionCallsAsync({ afterToolCallbacks: SingleAfterToolCallback[]; filters?: Set; toolConfirmationDict?: Record; -}): Promise { +}): AsyncGenerator { const functionCalls = getFunctionCalls(functionCallEvent); - return await handleFunctionCallList({ - invocationContext: invocationContext, - functionCalls: functionCalls, - toolsDict: toolsDict, - beforeToolCallbacks: beforeToolCallbacks, - afterToolCallbacks: afterToolCallbacks, - filters: filters, - toolConfirmationDict: toolConfirmationDict, + // Note: only function ids INCLUDED in the filters will be executed. + const filteredFunctionCalls = functionCalls.filter((functionCall) => { + return !filters || (functionCall.id && filters.has(functionCall.id)); }); + + if (!filteredFunctionCalls.length) { + return; + } + + const parallel = invocationContext.runConfig?.parallelToolExecution ?? false; + + const executeSingle = (fc: FunctionCall) => + executeSingleFunctionCall({ + invocationContext, + functionCall: fc, + toolsDict, + beforeToolCallbacks, + afterToolCallbacks, + toolConfirmation: resolveToolConfirmation(fc, toolConfirmationDict), + }); + + const functionResponseEvents: Event[] = []; + if (parallel) { + for await (const event of dispatchParallelStreaming( + filteredFunctionCalls, + executeSingle, + invocationContext, + )) { + functionResponseEvents.push(event); + yield event; + } + } else { + for await (const event of dispatchSequentialStreaming( + filteredFunctionCalls, + executeSingle, + invocationContext, + )) { + functionResponseEvents.push(event); + yield event; + } + } + + if (functionResponseEvents.length > 1) { + const mergedEvent = mergeParallelFunctionResponseEvents( + functionResponseEvents, + ); + tracer.startActiveSpan('execute_tool (merged)', (span) => { + try { + logger.debug('execute_tool (merged)'); + // TODO - b/436079721: implement [traceMergedToolCalls] + logger.debug('traceMergedToolCalls', { + responseEventId: mergedEvent.id, + functionResponseEvent: mergedEvent.id, + }); + traceMergedToolCalls({ + responseEventId: mergedEvent.id, + functionResponseEvent: mergedEvent, + }); + } finally { + span.end(); + } + }); + } } /** - * The underlying implementation of handleFunctionCalls, but takes a list of - * function calls instead of an event. - * This is also used by llm_agent execution flow in preprocessing. + * Executes a single function call through the full callback pipeline. + * + * Extracted from the former sequential loop to enable parallel execution + * via Promise.allSettled. Mirrors adk-python's _execute_single_function_call_async. + * + * Pipeline: plugin before → canonical before → tool exec → plugin after → canonical after → build event. */ -export async function handleFunctionCallList({ +async function executeSingleFunctionCall({ invocationContext, - functionCalls, + functionCall, toolsDict, beforeToolCallbacks, afterToolCallbacks, - filters, - toolConfirmationDict, + toolConfirmation, }: { invocationContext: InvocationContext; - functionCalls: FunctionCall[]; + functionCall: FunctionCall; toolsDict: Record; beforeToolCallbacks: SingleBeforeToolCallback[]; afterToolCallbacks: SingleAfterToolCallback[]; - filters?: Set; - toolConfirmationDict?: Record; + toolConfirmation?: ToolConfirmation; }): Promise { - const functionResponseEvents: Event[] = []; - - // Note: only function ids INCLUDED in the filters will be executed. - const filteredFunctionCalls = functionCalls.filter((functionCall) => { - return !filters || (functionCall.id && filters.has(functionCall.id)); - }); - - for (const functionCall of filteredFunctionCalls) { - let toolConfirmation = undefined; - if (toolConfirmationDict && functionCall.id) { - toolConfirmation = toolConfirmationDict[functionCall.id]; - } + const functionArgs = functionCall.args ? cloneDeep(functionCall.args) : {}; - const {tool, toolContext} = getToolAndContext({ + let tool: BaseTool; + let toolContext: ToolContext; + try { + ({tool, toolContext} = getToolAndContext({ invocationContext, functionCall, toolsDict, toolConfirmation, + })); + } catch (e) { + toolContext = new ToolContext({ + invocationContext, + functionCallId: functionCall.id, + toolConfirmation, }); - // TODO - b/436079721: implement [tracer.start_as_current_span] - logger.debug(`execute_tool ${tool.name}`); - const functionArgs = functionCall.args ?? {}; - - // Step 1: Check if plugin before_tool_callback overrides the function - // response. - let functionResponse = null; - let functionResponseError: string | unknown | undefined; - functionResponse = - await invocationContext.pluginManager.runBeforeToolCallback({ - tool: tool, + const toolError = e instanceof Error ? e : new Error(String(e)); + const errorResponse = + await invocationContext.pluginManager.runOnToolErrorCallback({ + tool: { + name: functionCall.name || 'unknown', + description: 'Tool not found', + isLongRunning: false, + } as BaseTool, toolArgs: functionArgs, - toolContext: toolContext, + toolContext, + error: toolError, }); - // Step 2: If no overrides are provided from the plugins, further run the - // canonical callback. - // TODO - b/425992518: validate the callback response type matches. - if (functionResponse == null) { - // Cover both null and undefined - for (const callback of beforeToolCallbacks) { - functionResponse = await callback({ - tool: tool, - args: functionArgs, - context: toolContext, - }); - if (functionResponse) { - break; - } + if (errorResponse) { + const response = + typeof errorResponse !== 'object' || errorResponse == null + ? {result: errorResponse} + : errorResponse; + + return createEvent({ + invocationId: invocationContext.invocationId, + author: invocationContext.agent.name, + content: createUserContent({ + functionResponse: { + id: functionCall.id, + name: functionCall.name || 'unknown', + response, + }, + }), + actions: toolContext.actions, + branch: invocationContext.branch, + }); + } + throw e; + } + + logger.debug(`execute_tool ${tool.name}`); + + let functionResponse = null; + let functionResponseError: string | unknown | undefined; + + // Step 1: plugin before_tool_callback + functionResponse = + await invocationContext.pluginManager.runBeforeToolCallback({ + tool: tool, + toolArgs: functionArgs, + toolContext: toolContext, + }); + + // Step 2: canonical beforeToolCallbacks + // TODO - b/425992518: validate the callback response type matches. + if (functionResponse == null) { + for (const callback of beforeToolCallbacks) { + functionResponse = await callback({ + tool: tool, + args: functionArgs, + context: toolContext, + }); + if (functionResponse) { + break; } } + } - // Step 3: Otherwise, proceed calling the tool normally. - if (functionResponse == null) { - // Cover both null and undefined - try { - functionResponse = await callToolAsync(tool, functionArgs, toolContext); - } catch (e: unknown) { - if (e instanceof Error) { - const onToolErrorResponse = - await invocationContext.pluginManager.runOnToolErrorCallback({ - tool: tool, - toolArgs: functionArgs, - toolContext: toolContext, - error: e, - }); - - // Set function response to the result of the error callback and - // continue execution, do not shortcut - if (onToolErrorResponse) { - functionResponse = onToolErrorResponse; - } else { - // If the error callback returns undefined, use the error message - // as the function response error. - functionResponseError = e.message; - } + // Step 3: call the tool + if (functionResponse == null) { + try { + functionResponse = await callToolAsync(tool, functionArgs, toolContext); + } catch (e: unknown) { + if (e instanceof Error) { + const onToolErrorResponse = + await invocationContext.pluginManager.runOnToolErrorCallback({ + tool: tool, + toolArgs: functionArgs, + toolContext: toolContext, + error: e, + }); + + if (onToolErrorResponse) { + functionResponse = onToolErrorResponse; } else { - // If the error is not an Error, use the error object as the function - // response error. - functionResponseError = e; + functionResponseError = e.message; } + } else { + functionResponseError = e; } } + } - // Step 4: Check if plugin after_tool_callback overrides the function - // response. - let alteredFunctionResponse = - await invocationContext.pluginManager.runAfterToolCallback({ + // Step 4: plugin after_tool_callback + let alteredFunctionResponse = + await invocationContext.pluginManager.runAfterToolCallback({ + tool: tool, + toolArgs: functionArgs, + toolContext: toolContext, + result: functionResponse, + }); + + // Step 5: canonical afterToolCallbacks + if (alteredFunctionResponse == null) { + for (const callback of afterToolCallbacks) { + alteredFunctionResponse = await callback({ tool: tool, - toolArgs: functionArgs, - toolContext: toolContext, - result: functionResponse, + args: functionArgs, + context: toolContext, + response: functionResponse, }); - - // Step 5: If no overrides are provided from the plugins, further run the - // canonical after_tool_callbacks. - if (alteredFunctionResponse == null) { - // Cover both null and undefined - for (const callback of afterToolCallbacks) { - alteredFunctionResponse = await callback({ - tool: tool, - args: functionArgs, - context: toolContext, - response: functionResponse, - }); - if (alteredFunctionResponse) { - break; - } + if (alteredFunctionResponse) { + break; } } + } - // Step 6: If alternative response exists from after_tool_callback, use it - // instead of the original function response. - if (alteredFunctionResponse != null) { - functionResponse = alteredFunctionResponse; - } + // Step 6: apply altered response + if (alteredFunctionResponse != null) { + functionResponse = alteredFunctionResponse; + } - // TODO - b/425992518: state event polluting runtime, consider fix. - // Allow long running function to return None as response. - if (tool.isLongRunning && !functionResponse) { - continue; - } + // TODO - b/425992518: state event polluting runtime, consider fix. + if (tool.isLongRunning && !functionResponse) { + return null; + } - if (functionResponseError) { - functionResponse = {error: functionResponseError}; - } else if ( - typeof functionResponse !== 'object' || - functionResponse == null - ) { - functionResponse = {result: functionResponse}; + if (functionResponseError) { + functionResponse = {error: functionResponseError}; + } else if (typeof functionResponse !== 'object' || functionResponse == null) { + functionResponse = {result: functionResponse}; + } + + const functionResponseEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: invocationContext.agent.name, + content: createUserContent({ + functionResponse: { + id: toolContext.functionCallId, + name: tool.name, + response: functionResponse, + }, + }), + actions: toolContext.actions, + branch: invocationContext.branch, + }); + + // TODO - b/436079721: implement [traceToolCall] + logger.debug('traceToolCall', { + tool: tool.name, + args: functionArgs, + functionResponseEvent: functionResponseEvent.id, + }); + + return functionResponseEvent; +} + +function resolveToolConfirmation( + functionCall: FunctionCall, + toolConfirmationDict?: Record, +): ToolConfirmation | undefined { + return toolConfirmationDict && functionCall.id + ? toolConfirmationDict[functionCall.id] + : undefined; +} + +function createErrorResponseEvent( + invocationContext: InvocationContext, + functionCall: FunctionCall, + error: unknown, +): Event { + const errorMessage = error instanceof Error ? error.message : String(error); + return createEvent({ + invocationId: invocationContext.invocationId, + author: invocationContext.agent.name, + content: createUserContent({ + functionResponse: { + id: functionCall.id, + name: functionCall.name!, + response: {error: errorMessage}, + }, + }), + branch: invocationContext.branch, + }); +} + +async function executeInBatches( + tasks: Array<() => Promise>, + batchSize: number, +): Promise[]> { + const results: PromiseSettledResult[] = []; + for (let i = 0; i < tasks.length; i += batchSize) { + const batch = tasks.slice(i, i + batchSize); + results.push(...(await Promise.allSettled(batch.map((t) => t())))); + } + return results; +} + +function detectStateDeltaConflicts(events: Event[]): void { + if (events.length <= 1) return; + + const seenKeys = new Map(); + const conflicts: Record = {}; + for (const event of events) { + if (!event.actions?.stateDelta) continue; + for (const [key, value] of Object.entries(event.actions.stateDelta)) { + if (seenKeys.has(key)) { + (conflicts[key] ??= [seenKeys.get(key)!]).push(value); + } + seenKeys.set(key, value); } + } - // Builds the function response event. - const functionResponseEvent = createEvent({ - invocationId: invocationContext.invocationId, - author: invocationContext.agent.name, - content: createUserContent({ - functionResponse: { - id: toolContext.functionCallId, - name: tool.name, - response: functionResponse, - }, - }), - actions: toolContext.actions, - branch: invocationContext.branch, - }); + const conflictKeys = Object.keys(conflicts); + if (!conflictKeys.length) return; + + const deepMergedKeys: string[] = []; + const lastWriteWinsKeys: string[] = []; + const details = conflictKeys + .map((k) => { + const values = conflicts[k]; + const allPlainObjects = values.every((v) => isPlainObject(v)); + const resolution = allPlainObjects ? 'deep-merged' : 'last-write-wins'; + if (allPlainObjects) { + deepMergedKeys.push(k); + } else { + lastWriteWinsKeys.push(k); + } + const serialized = values + .map((v) => { + try { + return JSON.stringify(v); + } catch { + return '[unserializable]'; + } + }) + .join(' → '); + return `${k} (${resolution}): [${serialized}]`; + }) + .join('; '); + + if (lastWriteWinsKeys.length) { + logger.warn( + `Parallel tool calls wrote to the same stateDelta key(s) with last-write-wins: [${lastWriteWinsKeys.join(', ')}]. ` + + `Values: ${details}. ` + + `Consider sequential mode if ordering matters.`, + ); + } else { + logger.debug( + `Parallel tool calls wrote to the same stateDelta key(s) (safely deep-merged): [${deepMergedKeys.join(', ')}]. ` + + `Values: ${details}.`, + ); + } +} - // TODO - b/436079721: implement [traceToolCall] - logger.debug('traceToolCall', { - tool: tool.name, - args: functionArgs, - functionResponseEvent: functionResponseEvent.id, - }); - functionResponseEvents.push(functionResponseEvent); +/** + * The underlying implementation of handleFunctionCalls, but takes a list of + * function calls instead of an event. + * This is also used by llm_agent execution flow in preprocessing. + * + * Execution mode is controlled by `RunConfig.parallelToolExecution`: + * - true: tool calls run concurrently via Promise.allSettled, + * matching adk-python's asyncio.gather pattern. Individual failures + * do not affect other calls — failed tools produce error response events. + * - false (default): tool calls execute sequentially in order, preserving original + * behavior for tools with interdependencies or ordering requirements. + * + * When parallel, `RunConfig.maxConcurrentToolCalls` controls back-pressure: + * - undefined/0: all tool calls dispatch at once (no limit). + * - positive int: tool calls dispatch in batches of this size. + * + * In parallel mode, overlapping `stateDelta` keys across tools are detected + * and logged as a warning (last-write-wins applies). + * + * NOTE: In parallel mode, beforeToolCallback / afterToolCallback may fire + * concurrently. Callbacks must not depend on execution order across calls. + */ +export async function handleFunctionCallList({ + invocationContext, + functionCalls, + toolsDict, + beforeToolCallbacks, + afterToolCallbacks, + filters, + toolConfirmationDict, +}: { + invocationContext: InvocationContext; + functionCalls: FunctionCall[]; + toolsDict: Record; + beforeToolCallbacks: SingleBeforeToolCallback[]; + afterToolCallbacks: SingleAfterToolCallback[]; + filters?: Set; + toolConfirmationDict?: Record; +}): Promise { + // Note: only function ids INCLUDED in the filters will be executed. + const filteredFunctionCalls = functionCalls.filter((functionCall) => { + return !filters || (functionCall.id && filters.has(functionCall.id)); + }); + + if (!filteredFunctionCalls.length) { + return null; } + const parallel = invocationContext.runConfig?.parallelToolExecution ?? false; + + const executeSingle = (fc: FunctionCall) => + executeSingleFunctionCall({ + invocationContext, + functionCall: fc, + toolsDict, + beforeToolCallbacks, + afterToolCallbacks, + toolConfirmation: resolveToolConfirmation(fc, toolConfirmationDict), + }); + + const functionResponseEvents: Event[] = parallel + ? await dispatchParallel( + filteredFunctionCalls, + executeSingle, + invocationContext, + ) + : await dispatchSequential( + filteredFunctionCalls, + executeSingle, + invocationContext, + ); + if (!functionResponseEvents.length) { return null; } + const mergedEvent = mergeParallelFunctionResponseEvents( functionResponseEvents, ); @@ -519,6 +771,157 @@ export async function handleFunctionCallList({ return mergedEvent; } +async function dispatchParallel( + functionCalls: FunctionCall[], + executeSingle: (fc: FunctionCall) => Promise, + invocationContext: InvocationContext, +): Promise { + if (functionCalls.length > 1) { + logger.info( + `parallel_tool_execution: ${functionCalls.length} tools ` + + `[${functionCalls.map((fc) => fc.name).join(', ')}]`, + ); + } + + const tasks = functionCalls.map((fc) => () => executeSingle(fc)); + const maxConcurrency = Math.floor( + invocationContext.runConfig?.maxConcurrentToolCalls ?? 0, + ); + const results = + maxConcurrency > 0 && functionCalls.length > maxConcurrency + ? await executeInBatches(tasks, maxConcurrency) + : await Promise.allSettled(tasks.map((t) => t())); + + const events: Event[] = []; + for (const [i, result] of results.entries()) { + if (result.status === 'fulfilled') { + if (result.value) { + events.push(result.value); + } + } else { + const fc = functionCalls[i]; + logger.warn(`Parallel tool call failed: ${fc.name}`, { + error: result.reason, + }); + events.push( + createErrorResponseEvent(invocationContext, fc, result.reason), + ); + } + } + + detectStateDeltaConflicts(events); + return events; +} + +async function dispatchSequential( + functionCalls: FunctionCall[], + executeSingle: (fc: FunctionCall) => Promise, + invocationContext: InvocationContext, +): Promise { + const events: Event[] = []; + for (const fc of functionCalls) { + try { + const event = await executeSingle(fc); + if (event) { + events.push(event); + } + } catch (e) { + logger.warn(`Sequential tool call failed: ${fc.name}`, {error: e}); + events.push(createErrorResponseEvent(invocationContext, fc, e)); + } + } + return events; +} + +async function* dispatchParallelStreaming( + functionCalls: FunctionCall[], + executeSingle: (fc: FunctionCall) => Promise, + invocationContext: InvocationContext, +): AsyncGenerator { + if (functionCalls.length > 1) { + logger.info( + `parallel_tool_execution: ${functionCalls.length} tools ` + + `[${functionCalls.map((fc) => fc.name).join(', ')}]`, + ); + } + + const tasks = functionCalls.map((fc) => () => executeSingle(fc)); + const maxConcurrency = Math.floor( + invocationContext.runConfig?.maxConcurrentToolCalls ?? 0, + ); + + const allEvents: Event[] = []; + if (maxConcurrency > 0 && functionCalls.length > maxConcurrency) { + for (let i = 0; i < tasks.length; i += maxConcurrency) { + const batch = tasks.slice(i, i + maxConcurrency); + const batchCalls = functionCalls.slice(i, i + maxConcurrency); + const batchResults = await Promise.allSettled(batch.map((t) => t())); + for (const [batchIndex, result] of batchResults.entries()) { + if (result.status === 'fulfilled') { + if (result.value) { + allEvents.push(result.value); + yield result.value; + } + } else { + const fc = batchCalls[batchIndex]; + logger.warn(`Parallel tool call failed: ${fc.name}`, { + error: result.reason, + }); + const errorEvent = createErrorResponseEvent( + invocationContext, + fc, + result.reason, + ); + allEvents.push(errorEvent); + yield errorEvent; + } + } + } + } else { + const results = await Promise.allSettled(tasks.map((t) => t())); + for (const [i, result] of results.entries()) { + if (result.status === 'fulfilled') { + if (result.value) { + allEvents.push(result.value); + yield result.value; + } + } else { + const fc = functionCalls[i]; + logger.warn(`Parallel tool call failed: ${fc.name}`, { + error: result.reason, + }); + const errorEvent = createErrorResponseEvent( + invocationContext, + fc, + result.reason, + ); + allEvents.push(errorEvent); + yield errorEvent; + } + } + } + + detectStateDeltaConflicts(allEvents); +} + +async function* dispatchSequentialStreaming( + functionCalls: FunctionCall[], + executeSingle: (fc: FunctionCall) => Promise, + invocationContext: InvocationContext, +): AsyncGenerator { + for (const fc of functionCalls) { + try { + const event = await executeSingle(fc); + if (event) { + yield event; + } + } catch (e) { + logger.warn(`Sequential tool call failed: ${fc.name}`, {error: e}); + yield createErrorResponseEvent(invocationContext, fc, e); + } + } +} + // TODO - b/425992518: consider inline, which is much cleaner. function getToolAndContext({ invocationContext, @@ -539,7 +942,7 @@ function getToolAndContext({ const toolContext = new Context({ invocationContext: invocationContext, - functionCallId: functionCall.id || undefined, + functionCallId: functionCall.id, toolConfirmation, }); @@ -578,6 +981,7 @@ export function mergeParallelFunctionResponseEvents( return createEvent({ author: baseEvent.author, + invocationId: baseEvent.invocationId, branch: baseEvent.branch, content: {role: 'user', parts: mergedParts}, actions: mergedActions, diff --git a/core/src/agents/llm_agent.ts b/core/src/agents/llm_agent.ts index 9e72f02a..472615f5 100644 --- a/core/src/agents/llm_agent.ts +++ b/core/src/agents/llm_agent.ts @@ -50,6 +50,7 @@ import { generateRequestConfirmationEvent, getLongRunningFunctionCalls, handleFunctionCallsAsync, + mergeParallelFunctionResponseEvents, populateClientFunctionCallId, } from './functions.js'; @@ -831,23 +832,64 @@ export class LlmAgent extends BaseAgent { // Call functions // TODO - b/425992518: bloated funciton input, fix. // Tool callback passed to get rid of cyclic dependency. - const functionResponseEvent = await handleFunctionCallsAsync({ + const functionResponseEvents: Event[] = []; + for await (const functionResponseEvent of handleFunctionCallsAsync({ invocationContext: invocationContext, functionCallEvent: mergedEvent, toolsDict: llmRequest.toolsDict, beforeToolCallbacks: this.canonicalBeforeToolCallbacks, afterToolCallbacks: this.canonicalAfterToolCallbacks, - }); + })) { + functionResponseEvents.push(functionResponseEvent); + yield functionResponseEvent; + } - if (!functionResponseEvent) { + if (!functionResponseEvents.length) { return; } + const mergedFunctionResponseEvent = mergeParallelFunctionResponseEvents( + functionResponseEvents, + ); + + // Persist an internal completion marker for streamed parallel tool batches. + // This allows resumption logic to distinguish complete vs partial batches. + if ((getFunctionCalls(mergedEvent)?.length ?? 0) > 1) { + const functionCallIds = new Set( + getFunctionCalls(mergedEvent) + .map((fc) => fc.id) + .filter(Boolean), + ); + const longRunningIds = new Set(mergedEvent.longRunningToolIds ?? []); + const expectedResponseCount = Array.from(functionCallIds).filter( + (id): id is string => !!id && !longRunningIds.has(id), + ).length; + const completionEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: invocationContext.agent.name, + branch: invocationContext.branch, + actions: createEventActions({ + customMetadata: { + parallelToolBatchCompletion: { + functionCallEventId: mergedEvent.id, + expectedResponseCount, + }, + }, + }), + }); + if (invocationContext.sessionService) { + await invocationContext.sessionService.appendEvent({ + session: invocationContext.session, + event: completionEvent, + }); + } + } + // Yiels an authentication event if any. // TODO - b/425992518: transaction log session, simplify. const authEvent = generateAuthEvent( invocationContext, - functionResponseEvent, + mergedFunctionResponseEvent, ); if (authEvent) { yield authEvent; @@ -857,7 +899,7 @@ export class LlmAgent extends BaseAgent { const toolConfirmationEvent = generateRequestConfirmationEvent({ invocationContext: invocationContext, functionCallEvent: mergedEvent, - functionResponseEvent: functionResponseEvent, + functionResponseEvent: mergedFunctionResponseEvent, }); if (toolConfirmationEvent) { yield toolConfirmationEvent; @@ -865,11 +907,8 @@ export class LlmAgent extends BaseAgent { return; } - // Yields the function response event. - yield functionResponseEvent; - // If model instruct to transfer to an agent, run the transferred agent. - const nextAgentName = functionResponseEvent.actions.transferToAgent; + const nextAgentName = mergedFunctionResponseEvent.actions.transferToAgent; if (nextAgentName) { const nextAgent = this.getAgentByName(invocationContext, nextAgentName); for await (const event of nextAgent.runAsync(invocationContext)) { diff --git a/core/src/agents/run_config.ts b/core/src/agents/run_config.ts index b46b27e2..784122e4 100644 --- a/core/src/agents/run_config.ts +++ b/core/src/agents/run_config.ts @@ -95,10 +95,44 @@ export interface RunConfig { maxLlmCalls?: number; /** - * If true, the agent loop will suspend on ANY tool call, allowing the client - * to intercept and execute tools (Client-Side Tool Execution). + * Controls whether multiple tool calls from a single LLM response are + * executed concurrently (true) or sequentially (false). + * + * When true: tool calls run via Promise.allSettled, matching + * adk-python's asyncio.gather pattern. Individual failures don't affect + * other calls. + * + * When false (default): tool calls execute one-by-one in order, preserving + * backward compatibility for tools with interdependencies, shared state + * mutations, or deterministic ordering requirements. + * + * @default false + */ + parallelToolExecution?: boolean; + + /** + * When true, execution pauses after receiving tool calls from the model, + * allowing client-side tool execution patterns. + * + * @default false */ pauseOnToolCalls?: boolean; + + /** + * Maximum number of tool calls to execute concurrently when + * `parallelToolExecution` is true. + * + * When set to a positive integer, tool calls are dispatched in batches of + * this size — each batch runs via Promise.allSettled, and the next batch + * starts only after the current one settles. This provides back-pressure + * for rate-limited APIs or resource-constrained environments. + * + * When undefined or <= 0, all tool calls run concurrently (no limit). + * Ignored when `parallelToolExecution` is false. + * + * @default undefined + */ + maxConcurrentToolCalls?: number; } export function createRunConfig(params: Partial = {}) { @@ -108,6 +142,7 @@ export function createRunConfig(params: Partial = {}) { enableAffectiveDialog: false, streamingMode: StreamingMode.NONE, maxLlmCalls: validateMaxLlmCalls(params.maxLlmCalls || 500), + parallelToolExecution: false, pauseOnToolCalls: false, ...params, }; diff --git a/core/src/events/event_actions.ts b/core/src/events/event_actions.ts index a4f966e8..79327e2a 100644 --- a/core/src/events/event_actions.ts +++ b/core/src/events/event_actions.ts @@ -4,6 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +import {isPlainObject} from 'lodash-es'; + import {ToolConfirmation} from '../tools/tool_confirmation.js'; // TODO: b/425992518 - Replace 'any' with a proper AuthConfig. @@ -59,6 +61,11 @@ export interface EventActions { * call id. */ requestedToolConfirmations: {[key: string]: ToolConfirmation}; + + /** + * Optional metadata for framework-level runtime coordination. + */ + customMetadata?: {[key: string]: unknown}; } /** @@ -72,6 +79,7 @@ export function createEventActions( artifactDelta: {}, requestedAuthConfigs: {}, requestedToolConfirmations: {}, + customMetadata: {}, ...state, }; } @@ -85,6 +93,25 @@ export function createEventActions( * 2. For other properties (skipSummarization,transferToAgent, escalate), the * last one wins. */ +function deepMergeStateDelta( + target: Record, + source: Record, +): void { + for (const [key, srcValue] of Object.entries(source)) { + const targetValue = target[key]; + + if (isPlainObject(targetValue) && isPlainObject(srcValue)) { + const nestedTarget = {...(targetValue as Record)}; + deepMergeStateDelta(nestedTarget, srcValue as Record); + target[key] = nestedTarget; + continue; + } + + // Preserve explicit undefined writes as last-write-wins for clear semantics. + target[key] = srcValue; + } +} + export function mergeEventActions( sources: Array>, target?: EventActions, @@ -99,7 +126,7 @@ export function mergeEventActions( if (!source) continue; if (source.stateDelta) { - Object.assign(result.stateDelta, source.stateDelta); + deepMergeStateDelta(result.stateDelta, source.stateDelta); } if (source.artifactDelta) { Object.assign(result.artifactDelta, source.artifactDelta); @@ -113,6 +140,12 @@ export function mergeEventActions( source.requestedToolConfirmations, ); } + if (source.customMetadata) { + result.customMetadata = Object.assign( + result.customMetadata ?? {}, + source.customMetadata, + ); + } if (source.skipSummarization !== undefined) { result.skipSummarization = source.skipSummarization; diff --git a/core/src/runner/runner.ts b/core/src/runner/runner.ts index 75849dcf..e910c31b 100644 --- a/core/src/runner/runner.ts +++ b/core/src/runner/runner.ts @@ -20,7 +20,12 @@ import { BuiltInCodeExecutor, isBuiltInCodeExecutor, } from '../code_executors/built_in_code_executor.js'; -import {createEvent, Event, getFunctionCalls} from '../events/event.js'; +import { + createEvent, + Event, + getFunctionCalls, + getFunctionResponses, +} from '../events/event.js'; import {createEventActions} from '../events/event_actions.js'; import {BaseMemoryService} from '../memory/base_memory_service.js'; import {BasePlugin} from '../plugins/base_plugin.js'; @@ -347,6 +352,10 @@ export class Runner { session: Session, rootAgent: BaseAgent, ): BaseAgent { + if (hasIncompleteParallelToolBatch(session.events)) { + return rootAgent; + } + // ========================================================================= // Case 1: If the last event is a function response, this returns the // agent that made the original function call. @@ -428,14 +437,46 @@ function findEventByLastFunctionResponseId(events: Event[]): Event | null { return null; } - const lastEvent = events[events.length - 1]; - const functionCallId = lastEvent.content?.parts?.find( + let latestCompletionMarkerEvent: Event | null = null; + let latestFunctionResponseEvent: Event | null = null; + for (let i = events.length - 1; i >= 0; i--) { + const event = events[i]; + if ( + !latestCompletionMarkerEvent && + getParallelBatchCompletionMetadata(event) + ) { + latestCompletionMarkerEvent = event; + } + if ( + !latestFunctionResponseEvent && + getFunctionResponses(event).length > 0 + ) { + latestFunctionResponseEvent = event; + } + if (latestCompletionMarkerEvent && latestFunctionResponseEvent) { + break; + } + } + + const completionMarker = latestCompletionMarkerEvent + ? getParallelBatchCompletionMetadata(latestCompletionMarkerEvent) + : null; + if (completionMarker) { + const functionCallEvent = events.find( + (event) => event.id === completionMarker.functionCallEventId, + ); + return functionCallEvent ?? null; + } + + const functionCallId = latestFunctionResponseEvent?.content?.parts?.find( (part) => part.functionResponse, )?.functionResponse?.id; if (!functionCallId) { return null; } + let matchingFunctionCallEventIndex = -1; + let matchingFunctionCallEvent: Event | null = null; // TODO - b/425992518: inefficient search, fix. for (let i = events.length - 2; i >= 0; i--) { const event = events[i]; @@ -447,9 +488,134 @@ function findEventByLastFunctionResponseId(events: Event[]): Event | null { for (const functionCall of functionCalls) { if (functionCall.id === functionCallId) { - return event; + matchingFunctionCallEventIndex = i; + matchingFunctionCallEvent = event; + break; } } + if (matchingFunctionCallEvent) { + break; + } + } + if (!matchingFunctionCallEvent) { + return null; + } + + const expectedResponseCount = getExpectedResponseCount( + matchingFunctionCallEvent, + ); + if (expectedResponseCount > 1) { + const hasCompletionSentinel = events + .slice(matchingFunctionCallEventIndex + 1) + .some((event) => { + const marker = getParallelBatchCompletionMetadata(event); + return ( + marker?.functionCallEventId === matchingFunctionCallEvent?.id && + marker.expectedResponseCount === expectedResponseCount + ); + }); + if (!hasCompletionSentinel) { + const observedResponseCount = getObservedResponseCount( + events, + matchingFunctionCallEventIndex, + matchingFunctionCallEvent, + ); + if (observedResponseCount < expectedResponseCount) { + return null; + } + } + } + + return matchingFunctionCallEvent; +} + +type ParallelBatchCompletionMetadata = { + functionCallEventId: string; + expectedResponseCount: number; +}; + +function getParallelBatchCompletionMetadata( + event: Event, +): ParallelBatchCompletionMetadata | null { + const maybeMarker = event.actions?.customMetadata?.[ + 'parallelToolBatchCompletion' + ] as Partial | undefined; + if ( + maybeMarker && + typeof maybeMarker.functionCallEventId === 'string' && + typeof maybeMarker.expectedResponseCount === 'number' + ) { + return { + functionCallEventId: maybeMarker.functionCallEventId, + expectedResponseCount: maybeMarker.expectedResponseCount, + }; } return null; } + +function getExpectedResponseCount(functionCallEvent: Event): number { + const functionCalls = getFunctionCalls(functionCallEvent); + if (!functionCalls.length) { + return 0; + } + const longRunningToolIds = new Set( + functionCallEvent.longRunningToolIds ?? [], + ); + return functionCalls.filter((functionCall) => { + return !!functionCall.id && !longRunningToolIds.has(functionCall.id); + }).length; +} + +function getObservedResponseCount( + events: Event[], + functionCallEventIndex: number, + functionCallEvent: Event, +): number { + const functionCallIds = new Set( + getFunctionCalls(functionCallEvent).map((functionCall) => functionCall.id), + ); + const observedIds = new Set(); + for (const event of events.slice(functionCallEventIndex + 1)) { + for (const response of getFunctionResponses(event)) { + if (response.id && functionCallIds.has(response.id)) { + observedIds.add(response.id); + } + } + } + return observedIds.size; +} + +function hasIncompleteParallelToolBatch(events: Event[]): boolean { + for (let i = events.length - 1; i >= 0; i--) { + const functionCallEvent = events[i]; + const expectedResponseCount = getExpectedResponseCount(functionCallEvent); + if (expectedResponseCount <= 1) { + continue; + } + + const hasCompletionSentinel = events.slice(i + 1).some((event) => { + const marker = getParallelBatchCompletionMetadata(event); + return ( + marker?.functionCallEventId === functionCallEvent.id && + marker.expectedResponseCount === expectedResponseCount + ); + }); + if (hasCompletionSentinel) { + return false; + } + + const observedResponseCount = getObservedResponseCount( + events, + i, + functionCallEvent, + ); + if ( + observedResponseCount > 0 && + observedResponseCount < expectedResponseCount + ) { + return true; + } + return false; + } + return false; +} diff --git a/core/test/agents/functions_test.ts b/core/test/agents/functions_test.ts index 0c9c2635..9dfaa7c4 100644 --- a/core/test/agents/functions_test.ts +++ b/core/test/agents/functions_test.ts @@ -6,6 +6,7 @@ import { BasePlugin, BaseTool, + createEvent, Event, functionsExportedForTestingOnly, FunctionTool, @@ -25,6 +26,18 @@ const { generateAuthEvent, generateRequestConfirmationEvent, } = functionsExportedForTestingOnly; +const handleFunctionCallsAsync = ( + functionsExportedForTestingOnly as unknown as { + handleFunctionCallsAsync: (args: { + invocationContext: InvocationContext; + functionCallEvent: Event; + toolsDict: Record; + beforeToolCallbacks: SingleBeforeToolCallback[]; + afterToolCallbacks: SingleAfterToolCallback[]; + filters?: Set; + }) => AsyncGenerator; + } +).handleFunctionCallsAsync; // Tool for testing const testTool = new FunctionTool({ @@ -283,78 +296,93 @@ describe('handleFunctionCallList', () => { error: "Error in tool 'errorTool': tool error message content", }); }); -}); -describe('generateAuthEvent', () => { - let invocationContext: InvocationContext; - let pluginManager: PluginManager; + it('should deep-copy args so callback mutations do not affect original FunctionCall', async () => { + const originalArgs = {key: 'original'}; + const mutatingCallback: SingleBeforeToolCallback = async ({args}) => { + (args as Record).key = 'mutated-by-callback'; + return undefined; + }; - beforeEach(() => { - pluginManager = new PluginManager(); - const agent = new LlmAgent({name: 'test_agent', model: 'test_model'}); - invocationContext = new InvocationContext({ - invocationId: 'inv_123', - session: {} as Session, - agent, - pluginManager, - }); - }); + const calls: FunctionCall[] = [ + {id: randomIdForTestingOnly(), name: 'testTool', args: originalArgs}, + ]; - it('should return undefined if no requestedAuthConfigs', () => { - const functionResponseEvent = { - actions: {}, - content: {role: 'model'}, - } as unknown as Event; + await handleFunctionCallList({ + invocationContext, + functionCalls: calls, + toolsDict: {testTool}, + beforeToolCallbacks: [mutatingCallback], + afterToolCallbacks: [], + }); - const event = generateAuthEvent(invocationContext, functionResponseEvent); - expect(event).toBeUndefined(); + expect(originalArgs.key).toBe('original'); + expect(calls[0].args!.key).toBe('original'); }); - it('should return undefined if requestedAuthConfigs is empty', () => { - const functionResponseEvent = { - actions: {requestedAuthConfigs: {}}, - content: {role: 'model'}, - } as unknown as Event; + it('should invoke onToolErrorCallback when tool is not found', async () => { + const plugin = new TestPlugin('testPlugin'); + plugin.onToolErrorCallbackResponse = { + result: 'error handled gracefully', + }; + pluginManager.registerPlugin(plugin); - const event = generateAuthEvent(invocationContext, functionResponseEvent); - expect(event).toBeUndefined(); + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [{id: 'id-1', name: 'nonexistentTool', args: {}}], + toolsDict: {}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + expect(event!.content!.parts![0].functionResponse!.response).toEqual({ + result: 'error handled gracefully', + }); }); - it('should return auth event if requestedAuthConfigs is present', () => { - const functionResponseEvent = { - actions: { - requestedAuthConfigs: { - 'call_1': 'auth_config_1', - 'call_2': 'auth_config_2', - }, - }, - content: {role: 'model'}, - } as unknown as Event; + // PR #167 Fix 1 — structuredClone → cloneDeep + // structuredClone throws DataCloneError on function values; + // cloneDeep from lodash-es copies them by reference without throwing. + it('args containing non-serializable values do not throw (cloneDeep, not structuredClone)', async () => { + const capturedArgs: Record[] = []; + const captureCallback: SingleBeforeToolCallback = async ({args}) => { + capturedArgs.push(args as Record); + return undefined; + }; - const event = generateAuthEvent(invocationContext, functionResponseEvent); - expect(event).toBeDefined(); - expect(event!.invocationId).toBe('inv_123'); - expect(event!.author).toBe('test_agent'); - expect(event!.content!.parts!.length).toBe(2); + const fnArg = () => 'sentinel'; + await expect( + handleFunctionCallList({ + invocationContext, + functionCalls: [{id: 'id-1', name: 'testTool', args: {fn: fnArg}}], + toolsDict: {testTool}, + beforeToolCallbacks: [captureCallback], + afterToolCallbacks: [], + }), + ).resolves.toBeDefined(); - const parts = event!.content!.parts!; - const call1 = parts.find( - (p) => p.functionCall?.args?.['function_call_id'] === 'call_1', - ); - expect(call1).toBeDefined(); - expect(call1!.functionCall!.name).toBe('adk_request_credential'); - expect(call1!.functionCall!.args!['auth_config']).toBe('auth_config_1'); + expect(typeof capturedArgs[0]?.['fn']).toBe('function'); + }); - const call2 = parts.find( - (p) => p.functionCall?.args?.['function_call_id'] === 'call_2', - ); - expect(call2).toBeDefined(); - expect(call2!.functionCall!.name).toBe('adk_request_credential'); - expect(call2!.functionCall!.args!['auth_config']).toBe('auth_config_2'); + // PR #167 Fix 2 — functionCall.id || undefined → functionCall.id + // The redundant `|| undefined` guard was removed; the id must round-trip verbatim. + it('functionCall.id is preserved verbatim in functionResponse.id', async () => { + const specificId = 'fc-exact-id-42'; + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [{id: specificId, name: 'testTool', args: {}}], + toolsDict: {testTool}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + expect(event!.content!.parts![0].functionResponse!.id).toBe(specificId); }); }); -describe('generateRequestConfirmationEvent', () => { +describe('parallel tool execution', () => { let invocationContext: InvocationContext; let pluginManager: PluginManager; @@ -369,142 +397,1567 @@ describe('generateRequestConfirmationEvent', () => { }); }); - it('should return undefined if no requestedToolConfirmations', () => { - const functionCallEvent = {content: {parts: []}} as unknown as Event; - const functionResponseEvent = { - actions: {}, - content: {role: 'model'}, - } as unknown as Event; + function makeDelayedTool(name: string, delayMs: number, result: string) { + return new FunctionTool({ + name, + description: name, + parameters: z.object({}), + execute: async () => { + await new Promise((resolve) => setTimeout(resolve, delayMs)); + return {result}; + }, + }); + } - const event = generateRequestConfirmationEvent({ + function makeFailingTool(name: string, delayMs: number) { + return new FunctionTool({ + name, + description: name, + parameters: z.object({}), + execute: async () => { + await new Promise((resolve) => setTimeout(resolve, delayMs)); + throw new Error(`${name} failed`); + }, + }); + } + + function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); + } + + it('should execute multiple tools concurrently (faster than sequential)', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const DELAY = 100; + const toolA = makeDelayedTool('toolA', DELAY, 'A done'); + const toolB = makeDelayedTool('toolB', DELAY, 'B done'); + const toolC = makeDelayedTool('toolC', DELAY, 'C done'); + + const toolsDict = {toolA, toolB, toolC}; + const calls: FunctionCall[] = [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + {id: 'id-c', name: 'toolC', args: {}}, + ]; + + const start = Date.now(); + const event = await handleFunctionCallList({ invocationContext, - functionCallEvent, - functionResponseEvent, + functionCalls: calls, + toolsDict, + beforeToolCallbacks: [], + afterToolCallbacks: [], }); - expect(event).toBeUndefined(); + const elapsed = Date.now() - start; + + expect(event).not.toBeNull(); + const parts = event!.content!.parts!; + expect(parts).toHaveLength(3); + + const responses = parts.map( + (p) => (p.functionResponse!.response as Record).result, + ); + expect(responses).toContain('A done'); + expect(responses).toContain('B done'); + expect(responses).toContain('C done'); + + // Parallel: should take ~DELAY, not ~3*DELAY. + // Use 2*DELAY as threshold to account for test runner overhead. + expect(elapsed).toBeLessThan(DELAY * 2); }); - it('should return undefined if requestedToolConfirmations is empty', () => { - const functionCallEvent = {content: {parts: []}} as unknown as Event; - const functionResponseEvent = { - actions: {requestedToolConfirmations: {}}, - content: {role: 'model'}, - } as unknown as Event; + it('should isolate errors — failed tool does not prevent other tools from returning', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const toolA = makeDelayedTool('toolA', 50, 'A done'); + const toolB = makeFailingTool('toolB', 50); + const toolC = makeDelayedTool('toolC', 50, 'C done'); - const event = generateRequestConfirmationEvent({ + const toolsDict = {toolA, toolB, toolC}; + const calls: FunctionCall[] = [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + {id: 'id-c', name: 'toolC', args: {}}, + ]; + + const event = await handleFunctionCallList({ invocationContext, - functionCallEvent, - functionResponseEvent, + functionCalls: calls, + toolsDict, + beforeToolCallbacks: [], + afterToolCallbacks: [], }); - expect(event).toBeUndefined(); + + expect(event).not.toBeNull(); + const parts = event!.content!.parts!; + expect(parts).toHaveLength(3); + + const responseA = parts.find((p) => p.functionResponse!.name === 'toolA'); + expect( + (responseA!.functionResponse!.response as Record).result, + ).toBe('A done'); + + const responseB = parts.find((p) => p.functionResponse!.name === 'toolB'); + expect( + (responseB!.functionResponse!.response as Record).error, + ).toContain('toolB failed'); + + const responseC = parts.find((p) => p.functionResponse!.name === 'toolC'); + expect( + (responseC!.functionResponse!.response as Record).result, + ).toBe('C done'); }); - it('should return confirmation event if requestedToolConfirmations is present', () => { - const functionCallEvent = { - content: { - parts: [ - { - functionCall: { - name: 'tool_1', - args: {arg: 'val1'}, - id: 'call_1', - }, - }, - { - functionCall: { - name: 'tool_2', - args: {arg: 'val2'}, - id: 'call_2', - }, - }, - ], - }, - } as unknown as Event; + it('should preserve result order matching input function call order', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + // Tool A is slow, B is fast — results should still be in [A, B] order + const toolA = makeDelayedTool('toolA', 100, 'A done'); + const toolB = makeDelayedTool('toolB', 10, 'B done'); - const functionResponseEvent = { - actions: { - requestedToolConfirmations: { - 'call_1': {message: 'confirm tool 1'}, - 'call_2': {message: 'confirm tool 2'}, - }, - }, - content: {role: 'model'}, - } as unknown as Event; + const toolsDict = {toolA, toolB}; + const calls: FunctionCall[] = [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + ]; - const event = generateRequestConfirmationEvent({ + const event = await handleFunctionCallList({ invocationContext, - functionCallEvent, - functionResponseEvent, + functionCalls: calls, + toolsDict, + beforeToolCallbacks: [], + afterToolCallbacks: [], }); - expect(event).toBeDefined(); - expect(event!.invocationId).toBe('inv_123'); - expect(event!.author).toBe('test_agent'); - expect(event!.content!.parts!.length).toBe(2); - + expect(event).not.toBeNull(); const parts = event!.content!.parts!; - const call1 = parts.find( - (p) => - (p.functionCall?.args?.['originalFunctionCall'] as FunctionCall)?.id === - 'call_1', - ); - expect(call1).toBeDefined(); - expect(call1!.functionCall!.name).toBe('adk_request_confirmation'); - expect(call1!.functionCall!.args!['toolConfirmation']).toEqual({ - message: 'confirm tool 1', - }); + expect(parts[0].functionResponse!.name).toBe('toolA'); + expect(parts[1].functionResponse!.name).toBe('toolB'); + }); - const call2 = parts.find( - (p) => - (p.functionCall?.args?.['originalFunctionCall'] as FunctionCall)?.id === - 'call_2', - ); - expect(call2).toBeDefined(); - expect(call2!.functionCall!.name).toBe('adk_request_confirmation'); - expect(call2!.functionCall!.args!['toolConfirmation']).toEqual({ - message: 'confirm tool 2', + it('parallel mode: order preserved even when fast tool finishes before slow tool', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const toolA = makeDelayedTool('toolA', 150, 'A done'); + const toolB = makeDelayedTool('toolB', 10, 'B done'); + const toolC = makeDelayedTool('toolC', 80, 'C done'); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + {id: 'id-c', name: 'toolC', args: {}}, + ], + toolsDict: {toolA, toolB, toolC}, + beforeToolCallbacks: [], + afterToolCallbacks: [], }); + + expect(event).not.toBeNull(); + const parts = event!.content!.parts!; + expect(parts).toHaveLength(3); + // B finishes first (~10ms), C second (~80ms), A last (~150ms) + // but results must follow input order: A, B, C + expect(parts[0].functionResponse!.name).toBe('toolA'); + expect(parts[1].functionResponse!.name).toBe('toolB'); + expect(parts[2].functionResponse!.name).toBe('toolC'); }); - it('should skip confirmation if original function call is not found', () => { - const functionCallEvent = { - content: { - parts: [ - { - functionCall: { - name: 'tool_1', - args: {arg: 'val1'}, - id: 'call_1', - }, - }, - ], - }, - } as unknown as Event; + it('should run callbacks concurrently for each parallel tool call', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const callbackOrder: string[] = []; - const functionResponseEvent = { - actions: { - requestedToolConfirmations: { - 'call_1': {message: 'confirm tool 1'}, - 'call_missing': {message: 'confirm tool missing'}, - }, - }, - content: {role: 'model'}, - } as unknown as Event; + const toolA = makeDelayedTool('toolA', 50, 'A done'); + const toolB = makeDelayedTool('toolB', 50, 'B done'); - const event = generateRequestConfirmationEvent({ + const beforeCallback: SingleBeforeToolCallback = async ({tool}) => { + callbackOrder.push(`before:${tool.name}`); + return undefined; + }; + const afterCallback: SingleAfterToolCallback = async ({tool}) => { + callbackOrder.push(`after:${tool.name}`); + return undefined; + }; + + const toolsDict = {toolA, toolB}; + const calls: FunctionCall[] = [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + ]; + + await handleFunctionCallList({ invocationContext, - functionCallEvent, - functionResponseEvent, + functionCalls: calls, + toolsDict, + beforeToolCallbacks: [beforeCallback], + afterToolCallbacks: [afterCallback], }); - expect(event).toBeDefined(); - expect(event!.content!.parts!.length).toBe(1); - const parts = event!.content!.parts!; - const call1 = parts.find( - (p) => - (p.functionCall?.args?.['originalFunctionCall'] as FunctionCall)?.id === - 'call_1', - ); - expect(call1).toBeDefined(); + // Both before callbacks should fire, both after callbacks should fire + expect(callbackOrder).toContain('before:toolA'); + expect(callbackOrder).toContain('before:toolB'); + expect(callbackOrder).toContain('after:toolA'); + expect(callbackOrder).toContain('after:toolB'); + expect(callbackOrder).toHaveLength(4); + }); + + it('single function call behaves identically to previous sequential implementation', async () => { + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [{id: 'id-1', name: 'testTool', args: {}}], + toolsDict: {testTool}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + expect(event!.content!.parts!).toHaveLength(1); + expect(event!.content!.parts![0].functionResponse!.response).toEqual({ + result: 'tool executed', + }); + }); + + it('should fall back to sequential when parallelToolExecution is false', async () => { + const executionOrder: string[] = []; + const toolA = new FunctionTool({ + name: 'toolA', + description: 'A', + parameters: z.object({}), + execute: async () => { + executionOrder.push('A-start'); + await new Promise((r) => setTimeout(r, 50)); + executionOrder.push('A-end'); + return {result: 'A done'}; + }, + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'B', + parameters: z.object({}), + execute: async () => { + executionOrder.push('B-start'); + await new Promise((r) => setTimeout(r, 50)); + executionOrder.push('B-end'); + return {result: 'B done'}; + }, + }); + + invocationContext.runConfig = { + parallelToolExecution: false, + maxLlmCalls: 500, + }; + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + // Sequential: A must finish before B starts + expect(executionOrder).toEqual(['A-start', 'A-end', 'B-start', 'B-end']); + }); + + it('sequential mode: error in one tool does not stop subsequent tools', async () => { + const toolA = makeFailingTool('toolA', 10); + const toolB = makeDelayedTool('toolB', 10, 'B done'); + + invocationContext.runConfig = { + parallelToolExecution: false, + maxLlmCalls: 500, + }; + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + const parts = event!.content!.parts!; + expect(parts).toHaveLength(2); + + const respA = parts.find((p) => p.functionResponse!.name === 'toolA'); + expect( + (respA!.functionResponse!.response as Record).error, + ).toContain('toolA failed'); + + const respB = parts.find((p) => p.functionResponse!.name === 'toolB'); + expect( + (respB!.functionResponse!.response as Record).result, + ).toBe('B done'); + }); + + it('defaults to sequential when runConfig is undefined', async () => { + invocationContext.runConfig = undefined; + + const executionOrder: string[] = []; + const toolA = new FunctionTool({ + name: 'toolA', + description: 'tool A', + parameters: z.object({}), + execute: async () => { + executionOrder.push('A-start'); + await new Promise((r) => setTimeout(r, 30)); + executionOrder.push('A-end'); + return {result: 'A done'}; + }, + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'tool B', + parameters: z.object({}), + execute: async () => { + executionOrder.push('B-start'); + await new Promise((r) => setTimeout(r, 30)); + executionOrder.push('B-end'); + return {result: 'B done'}; + }, + }); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + expect(event!.content!.parts!).toHaveLength(2); + // Sequential: A must fully finish before B starts + expect(executionOrder).toEqual(['A-start', 'A-end', 'B-start', 'B-end']); + }); + + it('parallel mode: tool-not-found produces error event without crashing others', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const toolA = makeDelayedTool('toolA', 10, 'A done'); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-missing', name: 'missingTool', args: {}}, + ], + toolsDict: {toolA}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + const parts = event!.content!.parts!; + expect(parts).toHaveLength(2); + + const respA = parts.find((p) => p.functionResponse!.name === 'toolA'); + expect( + (respA!.functionResponse!.response as Record).result, + ).toBe('A done'); + + const respMissing = parts.find( + (p) => p.functionResponse!.name === 'missingTool', + ); + expect( + (respMissing!.functionResponse!.response as Record).error, + ).toContain('missingTool'); + }); + + it('returns null when all function calls are filtered out', async () => { + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [{id: 'id-a', name: 'testTool', args: {}}], + toolsDict: {testTool}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + filters: new Set(['some-other-id']), + }); + + expect(event).toBeNull(); + }); + + it('sequential mode takes longer than parallel for same workload', async () => { + const DELAY = 60; + const toolA = makeDelayedTool('toolA', DELAY, 'A'); + const toolB = makeDelayedTool('toolB', DELAY, 'B'); + const toolC = makeDelayedTool('toolC', DELAY, 'C'); + const tools = {toolA, toolB, toolC}; + const calls: FunctionCall[] = [ + {id: 'a', name: 'toolA', args: {}}, + {id: 'b', name: 'toolB', args: {}}, + {id: 'c', name: 'toolC', args: {}}, + ]; + + // Parallel run + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const pStart = Date.now(); + await handleFunctionCallList({ + invocationContext, + functionCalls: calls, + toolsDict: tools, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + const pElapsed = Date.now() - pStart; + + // Sequential run + invocationContext.runConfig = { + parallelToolExecution: false, + maxLlmCalls: 500, + }; + const sStart = Date.now(); + await handleFunctionCallList({ + invocationContext, + functionCalls: calls, + toolsDict: tools, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + const sElapsed = Date.now() - sStart; + + // Sequential should take at least 2x longer than parallel + expect(sElapsed).toBeGreaterThan(pElapsed * 1.5); + }); + + it('maxConcurrentToolCalls limits batch size in parallel mode', async () => { + const concurrencyTracker: number[] = []; + let activeCalls = 0; + + function makeTrackedTool(name: string, delayMs: number) { + return new FunctionTool({ + name, + description: name, + parameters: z.object({}), + execute: async () => { + activeCalls++; + concurrencyTracker.push(activeCalls); + await new Promise((r) => setTimeout(r, delayMs)); + activeCalls--; + return {result: `${name} done`}; + }, + }); + } + + const tools = ['t1', 't2', 't3', 't4', 't5'].reduce( + (acc, name) => ({...acc, [name]: makeTrackedTool(name, 50)}), + {} as Record, + ); + + invocationContext.runConfig = { + parallelToolExecution: true, + maxConcurrentToolCalls: 2, + maxLlmCalls: 500, + }; + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: Object.keys(tools).map((name) => ({ + id: name, + name, + args: {}, + })), + toolsDict: tools, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + expect(event!.content!.parts!).toHaveLength(5); + // Peak concurrency should never exceed the batch size of 2 + expect(Math.max(...concurrencyTracker)).toBeLessThanOrEqual(2); + }); + + it('maxConcurrentToolCalls is ignored in sequential mode', async () => { + const executionOrder: string[] = []; + + function makeOrderTool(name: string) { + return new FunctionTool({ + name, + description: name, + parameters: z.object({}), + execute: async () => { + executionOrder.push(`${name}-start`); + await new Promise((r) => setTimeout(r, 20)); + executionOrder.push(`${name}-end`); + return {result: name}; + }, + }); + } + + const toolA = makeOrderTool('toolA'); + const toolB = makeOrderTool('toolB'); + const toolC = makeOrderTool('toolC'); + + invocationContext.runConfig = { + parallelToolExecution: false, + maxConcurrentToolCalls: 2, + maxLlmCalls: 500, + }; + + await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'a', name: 'toolA', args: {}}, + {id: 'b', name: 'toolB', args: {}}, + {id: 'c', name: 'toolC', args: {}}, + ], + toolsDict: {toolA, toolB, toolC}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(executionOrder).toEqual([ + 'toolA-start', + 'toolA-end', + 'toolB-start', + 'toolB-end', + 'toolC-start', + 'toolC-end', + ]); + }); + + it('warns on stateDelta key conflicts in parallel mode', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const warnSpy = vi.spyOn(console, 'warn'); + + const toolA = new FunctionTool({ + name: 'toolA', + description: 'sets counter', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['counter'] = 1; + return {result: 'A'}; + }, + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'also sets counter', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['counter'] = 2; + return {result: 'B'}; + }, + }); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'a', name: 'toolA', args: {}}, + {id: 'b', name: 'toolB', args: {}}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + + const warnCalls = warnSpy.mock.calls + .map((args) => args.join(' ')) + .filter((msg) => msg.includes('stateDelta')); + expect(warnCalls.length).toBeGreaterThan(0); + expect(warnCalls[0]).toContain('counter'); + + warnSpy.mockRestore(); + }); + + it('no stateDelta warning when parallel tools write different keys', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const warnSpy = vi.spyOn(console, 'warn'); + + const toolA = new FunctionTool({ + name: 'toolA', + description: 'sets key_a', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['key_a'] = 1; + return {result: 'A'}; + }, + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'sets key_b', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['key_b'] = 2; + return {result: 'B'}; + }, + }); + + await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'a', name: 'toolA', args: {}}, + {id: 'b', name: 'toolB', args: {}}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + const stateDeltaWarns = warnSpy.mock.calls + .map((args) => args.join(' ')) + .filter((msg) => msg.includes('stateDelta')); + expect(stateDeltaWarns).toHaveLength(0); + + warnSpy.mockRestore(); + }); + + it('parallel mode: each tool gets independent deep-copied args', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + + const sharedArgs = {counter: 0}; + const mutatingCallback: SingleBeforeToolCallback = async ({tool, args}) => { + if (tool.name === 'toolA') { + (args as Record).counter = 999; + } + return undefined; + }; + + const toolA = makeDelayedTool('toolA', 10, 'A done'); + const toolB = makeDelayedTool('toolB', 10, 'B done'); + + await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'a', name: 'toolA', args: sharedArgs}, + {id: 'b', name: 'toolB', args: sharedArgs}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [mutatingCallback], + afterToolCallbacks: [], + }); + + expect(sharedArgs.counter).toBe(0); + }); + + it('parallel mode: nested stateDelta is deep-merged across tools', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + + const toolA = new FunctionTool({ + name: 'toolA', + description: 'sets user.name', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['user'] = {name: 'Alice'}; + return {result: 'A'}; + }, + }); + + const toolB = new FunctionTool({ + name: 'toolB', + description: 'sets user.age', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['user'] = {age: 30}; + return {result: 'B'}; + }, + }); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'a', name: 'toolA', args: {}}, + {id: 'b', name: 'toolB', args: {}}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event!.actions!.stateDelta['user']).toEqual({ + name: 'Alice', + age: 30, + }); + }); + + it('parallel mode: tool-not-found invokes error callback instead of generic error', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + + const plugin = new TestPlugin('testPlugin'); + plugin.onToolErrorCallbackResponse = {result: 'missing tool handled'}; + pluginManager.registerPlugin(plugin); + + const toolA = makeDelayedTool('toolA', 10, 'A done'); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-missing', name: 'nonexistentTool', args: {}}, + ], + toolsDict: {toolA}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + const parts = event!.content!.parts!; + expect(parts).toHaveLength(2); + + const respA = parts.find((p) => p.functionResponse!.name === 'toolA'); + expect( + (respA!.functionResponse!.response as Record).result, + ).toBe('A done'); + + const respMissing = parts.find( + (p) => p.functionResponse!.name === 'nonexistentTool', + ); + expect( + (respMissing!.functionResponse!.response as Record) + .result, + ).toBe('missing tool handled'); + }); + + it('tool-not-found error callback receives a tool with description (BUG 2)', async () => { + invocationContext.runConfig = { + parallelToolExecution: false, + maxLlmCalls: 500, + }; + let receivedDescription: string | undefined; + + class CapturingPlugin extends BasePlugin { + override async onToolErrorCallback({ + tool, + }: { + tool: BaseTool; + toolArgs: Record; + toolContext: unknown; + error: Error; + }): Promise | undefined> { + receivedDescription = tool.description; + return {result: 'handled'}; + } + } + + pluginManager.registerPlugin(new CapturingPlugin('capturingPlugin')); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [{id: 'id-m', name: 'missingTool', args: {}}], + toolsDict: {}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + expect(receivedDescription).toBeDefined(); + expect(typeof receivedDescription).toBe('string'); + }); + + it('maxConcurrentToolCalls with fractional value is floored to nearest integer (BUG 3)', async () => { + let activeCalls = 0; + const peakConcurrency: number[] = []; + + function makeTracked(name: string) { + return new FunctionTool({ + name, + description: name, + parameters: z.object({}), + execute: async () => { + activeCalls++; + peakConcurrency.push(activeCalls); + await new Promise((r) => setTimeout(r, 30)); + activeCalls--; + return {result: name}; + }, + }); + } + + invocationContext.runConfig = { + parallelToolExecution: true, + maxConcurrentToolCalls: 1.5, + maxLlmCalls: 500, + }; + + const tools: Record = { + a: makeTracked('a'), + b: makeTracked('b'), + c: makeTracked('c'), + d: makeTracked('d'), + }; + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: Object.keys(tools).map((n) => ({ + id: n, + name: n, + args: {}, + })), + toolsDict: tools, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + expect(event!.content!.parts!).toHaveLength(4); + expect(Math.max(...peakConcurrency)).toBeLessThanOrEqual(1); + }); + + it('parallel mode: circular stateDelta values do not crash conflict warning (BUG 4)', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + + const circular: Record = {a: 1}; + circular.self = circular; + + const toolA = new FunctionTool({ + name: 'toolA', + description: 'A', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['data'] = circular; + return {result: 'A'}; + }, + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'B', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['data'] = {b: 2}; + return {result: 'B'}; + }, + }); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'a', name: 'toolA', args: {}}, + {id: 'b', name: 'toolB', args: {}}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + expect(event!.content!.parts!).toHaveLength(2); + }); + + it('sequential mode: tool-not-found does not prevent subsequent tools from running (BUG 5)', async () => { + invocationContext.runConfig = { + parallelToolExecution: false, + maxLlmCalls: 500, + }; + const toolB = makeDelayedTool('toolB', 10, 'B done'); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'id-missing', name: 'missingTool', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + ], + toolsDict: {toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + const parts = event!.content!.parts!; + expect(parts).toHaveLength(2); + + const respMissing = parts.find( + (p) => p.functionResponse!.name === 'missingTool', + ); + expect( + (respMissing!.functionResponse!.response as Record).error, + ).toContain('missingTool'); + + const respB = parts.find((p) => p.functionResponse!.name === 'toolB'); + expect( + (respB!.functionResponse!.response as Record).result, + ).toBe('B done'); + }); + + it('merged event preserves invocationId from all source events', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const toolA = makeDelayedTool('toolA', 10, 'A done'); + const toolB = makeDelayedTool('toolB', 10, 'B done'); + + const event = await handleFunctionCallList({ + invocationContext, + functionCalls: [ + {id: 'id-a', name: 'toolA', args: {}}, + {id: 'id-b', name: 'toolB', args: {}}, + ], + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + expect(event).not.toBeNull(); + // The merged event must carry the invocationId from the source events. + // Without the fix, createEvent defaults invocationId to '' when not passed. + expect(event!.invocationId).toBe('inv_123'); + }); + + it('handleFunctionCallsAsync streams per batch when maxConcurrentToolCalls is set', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxConcurrentToolCalls: 2, + maxLlmCalls: 500, + }; + + const toolA = makeDelayedTool('toolA', 80, 'A done'); + const toolB = makeDelayedTool('toolB', 80, 'B done'); + const toolC = makeDelayedTool('toolC', 80, 'C done'); + + const functionCallEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: 'test_agent', + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + {functionCall: {id: 'id-c', name: 'toolC', args: {}}}, + ], + }, + }); + + const iterator = handleFunctionCallsAsync({ + invocationContext, + functionCallEvent, + toolsDict: {toolA, toolB, toolC}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + const first = await iterator.next(); + const second = await iterator.next(); + expect(first.done).toBe(false); + expect(second.done).toBe(false); + const firstEvent = first.value as Event; + const secondEvent = second.value as Event; + expect(firstEvent.content!.parts![0].functionResponse!.name).toBe('toolA'); + expect(secondEvent.content!.parts![0].functionResponse!.name).toBe('toolB'); + + const thirdPending = iterator.next(); + const earlyResolution = await Promise.race([ + thirdPending.then(() => 'resolved'), + sleep(30).then(() => 'timeout'), + ]); + expect(earlyResolution).toBe('timeout'); + + const third = await thirdPending; + expect(third.done).toBe(false); + const thirdEvent = third.value as Event; + expect(thirdEvent.content!.parts![0].functionResponse!.name).toBe('toolC'); + expect((await iterator.next()).done).toBe(true); + }); + + it('handleFunctionCallsAsync with unlimited parallel has no early streaming', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + + const toolA = makeDelayedTool('toolA', 80, 'A done'); + const toolB = makeDelayedTool('toolB', 10, 'B done'); + const toolC = makeDelayedTool('toolC', 10, 'C done'); + + const functionCallEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: 'test_agent', + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + {functionCall: {id: 'id-c', name: 'toolC', args: {}}}, + ], + }, + }); + + const iterator = handleFunctionCallsAsync({ + invocationContext, + functionCallEvent, + toolsDict: {toolA, toolB, toolC}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + const firstPending = iterator.next(); + const earlyResolution = await Promise.race([ + firstPending.then(() => 'resolved'), + sleep(30).then(() => 'timeout'), + ]); + expect(earlyResolution).toBe('timeout'); + + const first = await firstPending; + const second = await iterator.next(); + const third = await iterator.next(); + + expect(first.done).toBe(false); + expect(second.done).toBe(false); + expect(third.done).toBe(false); + const firstEvent = first.value as Event; + const secondEvent = second.value as Event; + const thirdEvent = third.value as Event; + expect(firstEvent.content!.parts![0].functionResponse!.name).toBe('toolA'); + expect(secondEvent.content!.parts![0].functionResponse!.name).toBe('toolB'); + expect(thirdEvent.content!.parts![0].functionResponse!.name).toBe('toolC'); + expect((await iterator.next()).done).toBe(true); + }); + + it('handleFunctionCallsAsync filters out null results from long-running tools', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + + const longRunningTool = new FunctionTool({ + name: 'longTool', + description: 'long running tool', + parameters: z.object({}), + isLongRunning: true, + execute: async () => undefined, + }); + const normalTool = makeDelayedTool('normalTool', 10, 'normal done'); + + const functionCallEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: 'test_agent', + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-long', name: 'longTool', args: {}}}, + {functionCall: {id: 'id-normal', name: 'normalTool', args: {}}}, + ], + }, + }); + + const streamed: Event[] = []; + for await (const event of handleFunctionCallsAsync({ + invocationContext, + functionCallEvent, + toolsDict: {longTool: longRunningTool, normalTool}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + })) { + streamed.push(event); + } + + expect(streamed).toHaveLength(1); + expect(streamed[0].content!.parts![0].functionResponse!.name).toBe( + 'normalTool', + ); + }); + + it('handleFunctionCallsAsync with maxConcurrentToolCalls=1 behaves like sequential streaming', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxConcurrentToolCalls: 1, + maxLlmCalls: 500, + }; + + const toolA = makeDelayedTool('toolA', 60, 'A done'); + const toolB = makeDelayedTool('toolB', 60, 'B done'); + const toolC = makeDelayedTool('toolC', 60, 'C done'); + + const functionCallEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: 'test_agent', + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + {functionCall: {id: 'id-c', name: 'toolC', args: {}}}, + ], + }, + }); + + const iterator = handleFunctionCallsAsync({ + invocationContext, + functionCallEvent, + toolsDict: {toolA, toolB, toolC}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + const first = await iterator.next(); + expect(first.done).toBe(false); + expect( + (first.value as Event).content!.parts![0].functionResponse!.name, + ).toBe('toolA'); + + const secondPending = iterator.next(); + const earlyResolution = await Promise.race([ + secondPending.then(() => 'resolved'), + sleep(20).then(() => 'timeout'), + ]); + expect(earlyResolution).toBe('timeout'); + + const second = await secondPending; + const third = await iterator.next(); + expect( + (second.value as Event).content!.parts![0].functionResponse!.name, + ).toBe('toolB'); + expect( + (third.value as Event).content!.parts![0].functionResponse!.name, + ).toBe('toolC'); + expect((await iterator.next()).done).toBe(true); + }); + + it('handleFunctionCallsAsync: failed tool does not block streaming of other results', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const goodTool = makeDelayedTool('goodTool', 10, 'good result'); + const badTool = makeFailingTool('badTool', 5); + const functionCallEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: 'test_agent', + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-good', name: 'goodTool', args: {}}}, + {functionCall: {id: 'id-bad', name: 'badTool', args: {}}}, + ], + }, + }); + + const streamed: Event[] = []; + for await (const event of handleFunctionCallsAsync({ + invocationContext, + functionCallEvent, + toolsDict: {goodTool, badTool}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + })) { + streamed.push(event); + } + + expect(streamed).toHaveLength(2); + expect(streamed[0].content!.parts![0].functionResponse!.name).toBe( + 'goodTool', + ); + expect( + ( + streamed[1].content!.parts![0].functionResponse!.response as Record< + string, + string + > + ).error, + ).toContain('badTool failed'); + }); + + it('handleFunctionCallsAsync in parallel mode: all events yielded despite stateDelta key conflicts', async () => { + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + const toolA = new FunctionTool({ + name: 'toolA', + description: 'sets counter to 1', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['counter'] = 1; + return {result: 'A'}; + }, + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'sets counter to 2', + parameters: z.object({}), + execute: async (_args, context) => { + context!.actions.stateDelta['counter'] = 2; + return {result: 'B'}; + }, + }); + const functionCallEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: 'test_agent', + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + ], + }, + }); + + const streamed: Event[] = []; + for await (const event of handleFunctionCallsAsync({ + invocationContext, + functionCallEvent, + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + })) { + streamed.push(event); + } + + // Both events streamed despite the conflict; conflict is a warning not an error + expect(streamed).toHaveLength(2); + expect(streamed[0].content!.parts![0].functionResponse!.name).toBe('toolA'); + expect(streamed[1].content!.parts![0].functionResponse!.name).toBe('toolB'); + }); + + it('handleFunctionCallsAsync in sequential mode yields each response individually', async () => { + invocationContext.runConfig = { + parallelToolExecution: false, + maxLlmCalls: 500, + }; + const toolA = makeDelayedTool('toolA', 60, 'A done'); + const toolB = makeDelayedTool('toolB', 60, 'B done'); + const functionCallEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: 'test_agent', + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + ], + }, + }); + + const iterator = handleFunctionCallsAsync({ + invocationContext, + functionCallEvent, + toolsDict: {toolA, toolB}, + beforeToolCallbacks: [], + afterToolCallbacks: [], + }); + + // toolA resolves first; toolB has not started yet + const first = await iterator.next(); + expect(first.done).toBe(false); + expect( + (first.value as Event).content!.parts![0].functionResponse!.name, + ).toBe('toolA'); + + // toolB hasn't started yet so second should not resolve within 20ms + const secondPending = iterator.next(); + const earlyResolution = await Promise.race([ + secondPending.then(() => 'resolved'), + sleep(20).then(() => 'timeout'), + ]); + expect(earlyResolution).toBe('timeout'); + + const second = await secondPending; + expect(second.done).toBe(false); + expect( + (second.value as Event).content!.parts![0].functionResponse!.name, + ).toBe('toolB'); + expect((await iterator.next()).done).toBe(true); + }); +}); + +describe('generateAuthEvent', () => { + let invocationContext: InvocationContext; + let pluginManager: PluginManager; + + beforeEach(() => { + pluginManager = new PluginManager(); + const agent = new LlmAgent({name: 'test_agent', model: 'test_model'}); + invocationContext = new InvocationContext({ + invocationId: 'inv_123', + session: {} as Session, + agent, + pluginManager, + }); + }); + + it('should return undefined if no requestedAuthConfigs', () => { + const functionResponseEvent = { + actions: {}, + content: {role: 'model'}, + } as unknown as Event; + + const event = generateAuthEvent(invocationContext, functionResponseEvent); + expect(event).toBeUndefined(); + }); + + it('should return undefined if requestedAuthConfigs is empty', () => { + const functionResponseEvent = { + actions: {requestedAuthConfigs: {}}, + content: {role: 'model'}, + } as unknown as Event; + + const event = generateAuthEvent(invocationContext, functionResponseEvent); + expect(event).toBeUndefined(); + }); + + it('should return auth event if requestedAuthConfigs is present', () => { + const functionResponseEvent = { + actions: { + requestedAuthConfigs: { + 'call_1': 'auth_config_1', + 'call_2': 'auth_config_2', + }, + }, + content: {role: 'model'}, + } as unknown as Event; + + const event = generateAuthEvent(invocationContext, functionResponseEvent); + expect(event).toBeDefined(); + expect(event!.invocationId).toBe('inv_123'); + expect(event!.author).toBe('test_agent'); + expect(event!.content!.parts!.length).toBe(2); + + const parts = event!.content!.parts!; + const call1 = parts.find( + (p) => p.functionCall?.args?.['function_call_id'] === 'call_1', + ); + expect(call1).toBeDefined(); + expect(call1!.functionCall!.name).toBe('adk_request_credential'); + expect(call1!.functionCall!.args!['auth_config']).toBe('auth_config_1'); + + const call2 = parts.find( + (p) => p.functionCall?.args?.['function_call_id'] === 'call_2', + ); + expect(call2).toBeDefined(); + expect(call2!.functionCall!.name).toBe('adk_request_credential'); + expect(call2!.functionCall!.args!['auth_config']).toBe('auth_config_2'); + }); +}); + +describe('generateRequestConfirmationEvent', () => { + let invocationContext: InvocationContext; + let pluginManager: PluginManager; + + beforeEach(() => { + pluginManager = new PluginManager(); + const agent = new LlmAgent({name: 'test_agent', model: 'test_model'}); + invocationContext = new InvocationContext({ + invocationId: 'inv_123', + session: {} as Session, + agent, + pluginManager, + }); + }); + + it('should return undefined if no requestedToolConfirmations', () => { + const functionCallEvent = {content: {parts: []}} as unknown as Event; + const functionResponseEvent = { + actions: {}, + content: {role: 'model'}, + } as unknown as Event; + + const event = generateRequestConfirmationEvent({ + invocationContext, + functionCallEvent, + functionResponseEvent, + }); + expect(event).toBeUndefined(); + }); + + it('should return undefined if requestedToolConfirmations is empty', () => { + const functionCallEvent = {content: {parts: []}} as unknown as Event; + const functionResponseEvent = { + actions: {requestedToolConfirmations: {}}, + content: {role: 'model'}, + } as unknown as Event; + + const event = generateRequestConfirmationEvent({ + invocationContext, + functionCallEvent, + functionResponseEvent, + }); + expect(event).toBeUndefined(); + }); + + it('should return confirmation event if requestedToolConfirmations is present', () => { + const functionCallEvent = { + content: { + parts: [ + { + functionCall: { + name: 'tool_1', + args: {arg: 'val1'}, + id: 'call_1', + }, + }, + { + functionCall: { + name: 'tool_2', + args: {arg: 'val2'}, + id: 'call_2', + }, + }, + ], + }, + } as unknown as Event; + + const functionResponseEvent = { + actions: { + requestedToolConfirmations: { + 'call_1': {message: 'confirm tool 1'}, + 'call_2': {message: 'confirm tool 2'}, + }, + }, + content: {role: 'model'}, + } as unknown as Event; + + const event = generateRequestConfirmationEvent({ + invocationContext, + functionCallEvent, + functionResponseEvent, + }); + + expect(event).toBeDefined(); + expect(event!.invocationId).toBe('inv_123'); + expect(event!.author).toBe('test_agent'); + expect(event!.content!.parts!.length).toBe(2); + + const parts = event!.content!.parts!; + const call1 = parts.find( + (p) => + (p.functionCall?.args?.['originalFunctionCall'] as FunctionCall)?.id === + 'call_1', + ); + expect(call1).toBeDefined(); + expect(call1!.functionCall!.name).toBe('adk_request_confirmation'); + expect(call1!.functionCall!.args!['toolConfirmation']).toEqual({ + message: 'confirm tool 1', + }); + + const call2 = parts.find( + (p) => + (p.functionCall?.args?.['originalFunctionCall'] as FunctionCall)?.id === + 'call_2', + ); + expect(call2).toBeDefined(); + expect(call2!.functionCall!.name).toBe('adk_request_confirmation'); + expect(call2!.functionCall!.args!['toolConfirmation']).toEqual({ + message: 'confirm tool 2', + }); + }); + + it('should skip confirmation if original function call is not found', () => { + const functionCallEvent = { + content: { + parts: [ + { + functionCall: { + name: 'tool_1', + args: {arg: 'val1'}, + id: 'call_1', + }, + }, + ], + }, + } as unknown as Event; + + const functionResponseEvent = { + actions: { + requestedToolConfirmations: { + 'call_1': {message: 'confirm tool 1'}, + 'call_missing': {message: 'confirm tool missing'}, + }, + }, + content: {role: 'model'}, + } as unknown as Event; + + const event = generateRequestConfirmationEvent({ + invocationContext, + functionCallEvent, + functionResponseEvent, + }); + + expect(event).toBeDefined(); + expect(event!.content!.parts!.length).toBe(1); + const parts = event!.content!.parts!; + const call1 = parts.find( + (p) => + (p.functionCall?.args?.['originalFunctionCall'] as FunctionCall)?.id === + 'call_1', + ); + expect(call1).toBeDefined(); + }); + + // PR #167 Fix 3 — actions on confirmation event + // The confirmation event must carry ONLY requestedToolConfirmations. + // Copying the full actions object caused stateDelta / artifactDelta / + // transferToAgent to be double-applied by appendEvent in the streaming path. + it('carries only requestedToolConfirmations — stateDelta/artifactDelta/transferToAgent are not copied', () => { + const functionCallEvent = { + content: { + parts: [ + { + functionCall: {name: 'myTool', args: {}, id: 'fc-1'}, + }, + ], + }, + } as unknown as Event; + + const functionResponseEvent = { + actions: { + stateDelta: {counter: 1}, + artifactDelta: {report: 1}, + transferToAgent: 'other_agent', + requestedToolConfirmations: {'fc-1': {message: 'please confirm'}}, + }, + content: {role: 'model'}, + } as unknown as Event; + + const event = generateRequestConfirmationEvent({ + invocationContext, + functionCallEvent, + functionResponseEvent, + }); + + expect(event).toBeDefined(); + expect(event!.actions!.requestedToolConfirmations).toEqual({ + 'fc-1': {message: 'please confirm'}, + }); + // stateDelta and artifactDelta must be empty — not copied from functionResponseEvent + expect(event!.actions!.stateDelta).toEqual({}); + expect(event!.actions!.artifactDelta).toEqual({}); + // transferToAgent must not be set on the confirmation event + expect(event!.actions!.transferToAgent).toBeUndefined(); }); }); diff --git a/core/test/agents/llm_agent_test.ts b/core/test/agents/llm_agent_test.ts index f40df1a1..57bd9b60 100644 --- a/core/test/agents/llm_agent_test.ts +++ b/core/test/agents/llm_agent_test.ts @@ -8,16 +8,22 @@ import { BaseLlm, BaseLlmConnection, BasePlugin, + CallbackContext, + createEvent, Context, Event, + FunctionTool, + InMemorySessionService, InvocationContext, LlmAgent, LlmRequest, LlmResponse, PluginManager, + REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, Session, } from '@google/adk'; import {Content, Schema, Type} from '@google/genai'; +import {z} from 'zod'; import {z as z3} from 'zod/v3'; import {z as z4} from 'zod/v4'; @@ -65,6 +71,29 @@ class MockLlm extends BaseLlm { } } +class MultiStepMockLlm extends BaseLlm { + private readonly responses: LlmResponse[]; + private callCount = 0; + + constructor(responses: LlmResponse[]) { + super({model: 'mock-llm'}); + this.responses = responses; + } + + async *generateContentAsync( + _request: LlmRequest, + ): AsyncGenerator { + const response = this.responses[this.callCount++]; + if (response) { + yield response; + } + } + + async connect(_llmRequest: LlmRequest): Promise { + return new MockLlmConnection(); + } +} + class MockPlugin extends BasePlugin { beforeModelResponse?: LlmResponse; afterModelResponse?: LlmResponse; @@ -122,6 +151,25 @@ class TestLlmAgent extends LlmAgent { } } +class TransferTargetAgent extends LlmAgent { + private readonly responseText: string; + + constructor(name: string, responseText: string) { + super({name}); + this.responseText = responseText; + } + + protected override async *runAsyncImpl( + context: InvocationContext, + ): AsyncGenerator { + yield createEvent({ + invocationId: context.invocationId, + author: this.name, + content: {role: 'model', parts: [{text: this.responseText}]}, + }); + } +} + describe('LlmAgent.callLlm', () => { let agent: TestLlmAgent; let invocationContext: InvocationContext; @@ -384,3 +432,434 @@ describe('LlmAgent Output Processing', () => { expect(lastEvent.actions?.stateDelta?.['result']).toEqual(invalidJson); }); }); + +describe('LlmAgent tool streaming postprocess', () => { + function buildInvocationContext(agent: LlmAgent): InvocationContext { + return new InvocationContext({ + invocationId: 'inv_tools', + session: { + id: 'sess_tools', + state: { + hasDelta: () => false, + get: () => undefined, + set: () => {}, + }, + events: [], + } as unknown as Session, + agent, + pluginManager: new PluginManager(), + }); + } + + it('yields individual tool responses before confirmation event', async () => { + const toolA = new FunctionTool({ + name: 'toolA', + description: 'requests confirmation', + parameters: z.object({}), + execute: async (_args, context) => { + context!.requestConfirmation({hint: 'approve toolA'}); + return {result: 'A'}; + }, + }); + + const toolB = new FunctionTool({ + name: 'toolB', + description: 'normal tool', + parameters: z.object({}), + execute: async () => ({result: 'B'}), + }); + + const llmResponse: LlmResponse = { + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + ], + }, + }; + + const agent = new LlmAgent({ + name: 'test_agent', + model: new MockLlm(llmResponse), + tools: [toolA, toolB], + }); + const invocationContext = buildInvocationContext(agent); + invocationContext.runConfig = { + parallelToolExecution: true, + maxConcurrentToolCalls: 2, + maxLlmCalls: 500, + }; + + const events: Event[] = []; + for await (const event of agent.runAsync(invocationContext)) { + events.push(event); + } + + const functionResponseEvents = events.filter( + (event) => event.content?.parts?.[0]?.functionResponse, + ); + expect(functionResponseEvents).toHaveLength(2); + expect( + functionResponseEvents.every( + (event) => event.content?.parts?.length === 1, + ), + ).toBe(true); + + const confirmationEventIndex = events.findIndex((event) => + event.content?.parts?.some( + (part) => + part.functionCall?.name === REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + ), + ); + expect(confirmationEventIndex).toBeGreaterThan(-1); + + const lastFunctionResponseIndex = events.reduce((lastIndex, event, idx) => { + return event.content?.parts?.[0]?.functionResponse ? idx : lastIndex; + }, -1); + expect(confirmationEventIndex).toBeGreaterThan(lastFunctionResponseIndex); + }); + + it('uses merged transferToAgent (last input tool wins) after streamed tool responses', async () => { + const transferSlow = new FunctionTool({ + name: 'transferSlow', + description: 'slow transfer setter', + parameters: z.object({}), + execute: async (_args, context) => { + await new Promise((resolve) => setTimeout(resolve, 60)); + context!.actions.transferToAgent = 'agent_a'; + return {result: 'slow'}; + }, + }); + const transferFast = new FunctionTool({ + name: 'transferFast', + description: 'fast transfer setter', + parameters: z.object({}), + execute: async (_args, context) => { + await new Promise((resolve) => setTimeout(resolve, 10)); + context!.actions.transferToAgent = 'agent_b'; + return {result: 'fast'}; + }, + }); + + const llmResponse: LlmResponse = { + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-slow', name: 'transferSlow', args: {}}}, + {functionCall: {id: 'id-fast', name: 'transferFast', args: {}}}, + ], + }, + }; + + const childA = new TransferTargetAgent('agent_a', 'from A'); + const childB = new TransferTargetAgent('agent_b', 'from B'); + + const root = new LlmAgent({ + name: 'root_agent', + model: new MockLlm(llmResponse), + tools: [transferSlow, transferFast], + subAgents: [childA, childB], + }); + + const invocationContext = buildInvocationContext(root); + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + + const events: Event[] = []; + for await (const event of root.runAsync(invocationContext)) { + events.push(event); + } + + const transferredAgentEvent = events.find( + (event) => + event.author === 'agent_b' && + event.content?.parts?.some((part) => part.text === 'from B'), + ); + expect(transferredAgentEvent).toBeDefined(); + }); + + it('yields streamed tool responses before merged auth event', async () => { + const toolA = new FunctionTool({ + name: 'toolA', + description: 'requests auth', + parameters: z.object({}), + execute: async (_args, context) => { + context!.requestCredential({ + authScheme: {type: 'apiKey', in: 'header', name: 'x-api-key'}, + credentialKey: 'toolA-key', + }); + return {result: 'A'}; + }, + }); + + const toolB = new FunctionTool({ + name: 'toolB', + description: 'normal tool', + parameters: z.object({}), + execute: async () => ({result: 'B'}), + }); + + const llmResponse: LlmResponse = { + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + ], + }, + }; + + const agent = new LlmAgent({ + name: 'test_agent', + model: new MockLlm(llmResponse), + tools: [toolA, toolB], + }); + const invocationContext = buildInvocationContext(agent); + invocationContext.runConfig = { + parallelToolExecution: true, + maxConcurrentToolCalls: 2, + maxLlmCalls: 500, + }; + + const events: Event[] = []; + for await (const event of agent.runAsync(invocationContext)) { + events.push(event); + } + + const functionResponseEvents = events.filter( + (event) => event.content?.parts?.[0]?.functionResponse, + ); + expect(functionResponseEvents).toHaveLength(2); + + const authEventIndex = events.findIndex((event) => + event.content?.parts?.some( + (part) => part.functionCall?.name === 'adk_request_credential', + ), + ); + expect(authEventIndex).toBeGreaterThan(-1); + + const lastFunctionResponseIndex = events.reduce((lastIndex, event, idx) => { + return event.content?.parts?.[0]?.functionResponse ? idx : lastIndex; + }, -1); + expect(authEventIndex).toBeGreaterThan(lastFunctionResponseIndex); + }); + + async function buildInvocationContextWithSession(agent: LlmAgent): Promise<{ + ctx: InvocationContext; + sessionService: InMemorySessionService; + sessionId: string; + }> { + const sessionService = new InMemorySessionService(); + const session = await sessionService.createSession({ + appName: 'test_app', + userId: 'test_user', + sessionId: 'test_session', + }); + const ctx = new InvocationContext({ + invocationId: 'inv_sentinel', + session, + agent, + pluginManager: new PluginManager(), + sessionService, + }); + return {ctx, sessionService, sessionId: session.id}; + } + + it('does not dispatch tools when pauseOnToolCalls is set', async () => { + const toolA = new FunctionTool({ + name: 'toolA', + description: 'normal tool', + parameters: z.object({}), + execute: async () => ({result: 'A'}), + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'normal tool', + parameters: z.object({}), + execute: async () => ({result: 'B'}), + }); + const llmResponse: LlmResponse = { + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + ], + }, + }; + const agent = new LlmAgent({ + name: 'test_agent', + model: new MockLlm(llmResponse), + tools: [toolA, toolB], + }); + const invocationContext = buildInvocationContext(agent); + invocationContext.runConfig = { + parallelToolExecution: true, + pauseOnToolCalls: true, + maxLlmCalls: 500, + }; + + const events: Event[] = []; + for await (const event of agent.runAsync(invocationContext)) { + events.push(event); + } + + // Only the model event with function calls is yielded; no tool responses + const functionResponseEvents = events.filter((e) => + e.content?.parts?.some((p) => p.functionResponse), + ); + expect(functionResponseEvents).toHaveLength(0); + expect(events).toHaveLength(1); + expect(events[0].content?.parts?.some((p) => p.functionCall)).toBe(true); + }); + + it('completion sentinel is appended to session but not yielded in event stream', async () => { + const toolA = new FunctionTool({ + name: 'toolA', + description: 'tool A', + parameters: z.object({}), + execute: async () => ({result: 'A'}), + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'tool B', + parameters: z.object({}), + execute: async () => ({result: 'B'}), + }); + const llmResponse: LlmResponse = { + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + ], + }, + }; + const agent = new LlmAgent({ + name: 'test_agent', + model: new MockLlm(llmResponse), + tools: [toolA, toolB], + }); + const {ctx, sessionService, sessionId} = + await buildInvocationContextWithSession(agent); + ctx.runConfig = {parallelToolExecution: true, maxLlmCalls: 500}; + + const yieldedEvents: Event[] = []; + for await (const event of agent.runAsync(ctx)) { + yieldedEvents.push(event); + } + + // No yielded event should carry the internal sentinel metadata + const sentinelInYield = yieldedEvents.some( + (e) => e.actions?.customMetadata?.['parallelToolBatchCompletion'], + ); + expect(sentinelInYield).toBe(false); + + // Session must contain the sentinel + const session = await sessionService.getSession({ + appName: 'test_app', + userId: 'test_user', + sessionId, + }); + const sentinelInSession = session!.events.some( + (e) => e.actions?.customMetadata?.['parallelToolBatchCompletion'], + ); + expect(sentinelInSession).toBe(true); + }); + + it('single tool call does not append a completion sentinel to session', async () => { + const toolA = new FunctionTool({ + name: 'toolA', + description: 'tool A', + parameters: z.object({}), + execute: async () => ({result: 'A'}), + }); + const llmResponse: LlmResponse = { + content: { + role: 'model', + parts: [{functionCall: {id: 'id-a', name: 'toolA', args: {}}}], + }, + }; + const agent = new LlmAgent({ + name: 'test_agent', + model: new MockLlm(llmResponse), + tools: [toolA], + }); + const {ctx, sessionService, sessionId} = + await buildInvocationContextWithSession(agent); + ctx.runConfig = {parallelToolExecution: true, maxLlmCalls: 500}; + + for await (const _ of agent.runAsync(ctx)) { + // consume + } + + const session = await sessionService.getSession({ + appName: 'test_app', + userId: 'test_user', + sessionId, + }); + const hasSentinel = session!.events.some( + (e) => e.actions?.customMetadata?.['parallelToolBatchCompletion'], + ); + expect(hasSentinel).toBe(false); + }); + + it('outer loop calls LLM again after parallel tool batch completes', async () => { + const toolA = new FunctionTool({ + name: 'toolA', + description: 'tool A', + parameters: z.object({}), + execute: async () => ({result: 'A done'}), + }); + const toolB = new FunctionTool({ + name: 'toolB', + description: 'tool B', + parameters: z.object({}), + execute: async () => ({result: 'B done'}), + }); + // Step 1: model emits 2 parallel function calls + const step1Response: LlmResponse = { + content: { + role: 'model', + parts: [ + {functionCall: {id: 'id-a', name: 'toolA', args: {}}}, + {functionCall: {id: 'id-b', name: 'toolB', args: {}}}, + ], + }, + }; + // Step 2: model returns a final text answer + const step2Response: LlmResponse = { + content: {role: 'model', parts: [{text: 'final answer'}]}, + }; + const agent = new LlmAgent({ + name: 'test_agent', + model: new MultiStepMockLlm([step1Response, step2Response]), + tools: [toolA, toolB], + }); + const invocationContext = buildInvocationContext(agent); + invocationContext.runConfig = { + parallelToolExecution: true, + maxLlmCalls: 500, + }; + + const events: Event[] = []; + for await (const event of agent.runAsync(invocationContext)) { + events.push(event); + } + + // Expect both tool responses streamed from step 1 + const functionResponseEvents = events.filter((e) => + e.content?.parts?.some((p) => p.functionResponse), + ); + expect(functionResponseEvents).toHaveLength(2); + + // Expect final text event produced by step 2 LLM call + const finalTextEvent = events.find((e) => + e.content?.parts?.some((p) => p.text === 'final answer'), + ); + expect(finalTextEvent).toBeDefined(); + }); +}); diff --git a/core/test/events/event_actions_test.ts b/core/test/events/event_actions_test.ts new file mode 100644 index 00000000..63408b9f --- /dev/null +++ b/core/test/events/event_actions_test.ts @@ -0,0 +1,25 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import {describe, expect, it} from 'vitest'; +import {mergeEventActions} from '../../src/events/event_actions.js'; + +describe('mergeEventActions', () => { + it('overwrites existing nested value with undefined', () => { + const merged = mergeEventActions([ + {stateDelta: {user: {name: 'Alice', age: 30}}}, + {stateDelta: {user: {name: undefined}}}, + ]); + + expect(merged.stateDelta['user']).toEqual({name: undefined, age: 30}); + expect( + Object.prototype.hasOwnProperty.call( + merged.stateDelta['user'] as Record, + 'name', + ), + ).toBe(true); + }); +}); diff --git a/core/test/runner/runner_test.ts b/core/test/runner/runner_test.ts index cf3c2408..56c5430a 100644 --- a/core/test/runner/runner_test.ts +++ b/core/test/runner/runner_test.ts @@ -8,6 +8,7 @@ import { BaseAgent, BasePlugin, createEvent, + createEventActions, Event, InMemoryArtifactService, InMemorySessionService, @@ -48,6 +49,33 @@ class MockLlmAgent extends LlmAgent { } } +class StreamedFunctionResponseAgent extends LlmAgent { + constructor(name: string) { + super({ + name, + model: 'gemini-2.5-flash', + }); + } + + protected override async *runAsyncImpl( + context: InvocationContext, + ): AsyncGenerator { + const responses: FunctionResponse[] = [ + {id: 'fc-a', name: 'toolA', response: {result: 'A'}}, + {id: 'fc-b', name: 'toolB', response: {result: 'B'}}, + {id: 'fc-c', name: 'toolC', response: {result: 'C'}}, + ]; + + for (const functionResponse of responses) { + yield createEvent({ + invocationId: context.invocationId, + author: this.name, + content: {role: 'user', parts: [{functionResponse}]}, + }); + } + } +} + class MockPlugin extends BasePlugin { static ON_USER_CALLBACK_MSG = 'Modified user message ON_USER_CALLBACK_MSG from MockPlugin'; @@ -60,6 +88,8 @@ class MockPlugin extends BasePlugin { enableEventCallback = false; enableBeforeRunCallback = false; afterRunCallbackCalled = false; + onEventCallbackInvocationCount = 0; + onEventFunctionResponseCount = 0; constructor() { super('mock_plugin'); @@ -84,6 +114,10 @@ class MockPlugin extends BasePlugin { invocationContext: InvocationContext; event: Event; }): Promise { + this.onEventCallbackInvocationCount++; + if (event.content?.parts?.some((part) => part.functionResponse)) { + this.onEventFunctionResponseCount++; + } if (!this.enableEventCallback) { return undefined; } @@ -214,6 +248,207 @@ describe('Runner.determineAgentForResumption', () => { expect(events[0].author).toBe('sub_agent1'); }); + it('should resolve author from partial responses when completion sentinel exists', async () => { + const partialCallEvent = createEvent({ + invocationId: 'inv-parallel-call', + author: 'sub_agent2', + content: { + role: 'model', + parts: [ + { + functionCall: { + id: 'func_1', + name: 'tool_1', + args: {}, + }, + }, + { + functionCall: { + id: 'func_2', + name: 'tool_2', + args: {}, + }, + }, + { + functionCall: { + id: 'func_3', + name: 'tool_3', + args: {}, + }, + }, + ], + }, + }); + + const partialResponseEvent = createEvent({ + invocationId: 'inv-partial-response', + author: 'user', + content: { + role: 'user', + parts: [ + { + functionResponse: { + id: 'func_1', + name: 'tool_1', + response: {result: 'partial'}, + }, + }, + ], + }, + }); + + const completionSentinelEvent = createEvent({ + invocationId: 'inv-completion-sentinel', + author: 'sub_agent2', + actions: createEventActions({ + customMetadata: { + parallelToolBatchCompletion: { + functionCallEventId: partialCallEvent.id, + expectedResponseCount: 3, + }, + }, + }), + }); + + const events = await runTest([ + partialCallEvent, + partialResponseEvent, + completionSentinelEvent, + ]); + expect(events[0].author).toBe('sub_agent2'); + }); + + it('should not resume from partial parallel responses without a completion sentinel', async () => { + const partialCallEvent = createEvent({ + invocationId: 'inv-parallel-call', + author: 'sub_agent2', + content: { + role: 'model', + parts: [ + { + functionCall: { + id: 'func_1', + name: 'tool_1', + args: {}, + }, + }, + { + functionCall: { + id: 'func_2', + name: 'tool_2', + args: {}, + }, + }, + { + functionCall: { + id: 'func_3', + name: 'tool_3', + args: {}, + }, + }, + ], + }, + }); + + const partialResponseEvent = createEvent({ + invocationId: 'inv-partial-response', + author: 'user', + content: { + role: 'user', + parts: [ + { + functionResponse: { + id: 'func_1', + name: 'tool_1', + response: {result: 'partial'}, + }, + }, + ], + }, + }); + + // Desired behavior after sentinel support: + // without a completion sentinel for the parallel batch, resumption should + // not continue from partial responses. + const events = await runTest([partialCallEvent, partialResponseEvent]); + expect(events[0].author).toBe('root_agent'); + }); + + it('should resume to originating agent when all parallel responses are present but no sentinel', async () => { + const partialCallEvent = createEvent({ + invocationId: 'inv-parallel-call', + author: 'sub_agent2', + content: { + role: 'model', + parts: [ + {functionCall: {id: 'func_1', name: 'tool_1', args: {}}}, + {functionCall: {id: 'func_2', name: 'tool_2', args: {}}}, + {functionCall: {id: 'func_3', name: 'tool_3', args: {}}}, + ], + }, + }); + + const responseEvent1 = createEvent({ + invocationId: 'inv-response-1', + author: 'user', + content: { + role: 'user', + parts: [ + { + functionResponse: { + id: 'func_1', + name: 'tool_1', + response: {result: 'r1'}, + }, + }, + ], + }, + }); + + const responseEvent2 = createEvent({ + invocationId: 'inv-response-2', + author: 'user', + content: { + role: 'user', + parts: [ + { + functionResponse: { + id: 'func_2', + name: 'tool_2', + response: {result: 'r2'}, + }, + }, + ], + }, + }); + + const responseEvent3 = createEvent({ + invocationId: 'inv-response-3', + author: 'user', + content: { + role: 'user', + parts: [ + { + functionResponse: { + id: 'func_3', + name: 'tool_3', + response: {result: 'r3'}, + }, + }, + ], + }, + }); + + // All 3 responses present with no sentinel — batch is complete, should route to originating agent + const events = await runTest([ + partialCallEvent, + responseEvent1, + responseEvent2, + responseEvent3, + ]); + expect(events[0].author).toBe('sub_agent2'); + }); + it('should return root agent when session has no non-user events', async () => { const nonUserEvent = createEvent({ invocationId: 'inv1', @@ -417,6 +652,86 @@ describe('Runner with plugins', () => { await runTest(); expect(plugin.afterRunCallbackCalled).toBe(true); }); + + it('should invoke onEventCallback once per streamed function-response event', async () => { + sessionService = new InMemorySessionService(); + artifactService = new InMemoryArtifactService(); + plugin = new MockPlugin(); + runner = new Runner({ + appName: TEST_APP_ID, + agent: new StreamedFunctionResponseAgent('stream_agent'), + sessionService, + artifactService, + plugins: [plugin], + }); + + await sessionService.createSession({ + appName: TEST_APP_ID, + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + }); + + const events: Event[] = []; + for await (const event of runner.runAsync({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + newMessage: {role: 'user', parts: [{text: 'Hello'}]}, + })) { + events.push(event); + } + + expect(events).toHaveLength(3); + expect(plugin.onEventFunctionResponseCount).toBe(3); + expect(plugin.onEventCallbackInvocationCount).toBe(3); + }); + + it('should append each streamed function-response event independently', async () => { + class CountingSessionService extends InMemorySessionService { + appendEventCallCount = 0; + functionResponseAppendCount = 0; + + override async appendEvent(params: { + session: Parameters< + InMemorySessionService['appendEvent'] + >[0]['session']; + event: Parameters[0]['event']; + }) { + this.appendEventCallCount++; + if ( + params.event.content?.parts?.some((part) => part.functionResponse) + ) { + this.functionResponseAppendCount++; + } + return await super.appendEvent(params); + } + } + + const countingSessionService = new CountingSessionService(); + runner = new Runner({ + appName: TEST_APP_ID, + agent: new StreamedFunctionResponseAgent('stream_agent'), + sessionService: countingSessionService, + artifactService: new InMemoryArtifactService(), + }); + + await countingSessionService.createSession({ + appName: TEST_APP_ID, + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + }); + + for await (const _event of runner.runAsync({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + newMessage: {role: 'user', parts: [{text: 'Hello'}]}, + })) { + // consume + } + + // 1 append for user input + 3 appends for streamed function responses. + expect(countingSessionService.appendEventCallCount).toBe(4); + expect(countingSessionService.functionResponseAppendCount).toBe(3); + }); }); describe('Runner error handling', () => {