Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions packages/agent-runtime/src/__tests__/tool-validation-error.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
import { getInitialSessionState } from '@codebuff/common/types/session-state'
import { promptSuccess } from '@codebuff/common/util/error'
import { jsonToolResult } from '@codebuff/common/util/messages'
import { beforeEach, describe, expect, it } from 'bun:test'

import { mockFileContext } from './test-utils'
Expand All @@ -12,6 +13,10 @@ import type {
AgentRuntimeScopedDeps,
} from '@codebuff/common/types/contracts/agent-runtime'
import type { StreamChunk } from '@codebuff/common/types/contracts/llm'
import type {
AssistantMessage,
ToolMessage,
} from '@codebuff/common/types/messages/codebuff-message'
import type { PrintModeEvent } from '@codebuff/common/types/print-mode'

describe('tool validation error handling', () => {
Expand Down Expand Up @@ -225,4 +230,127 @@ describe('tool validation error handling', () => {
)
expect(errorEvents.length).toBe(0)
})

it('should preserve tool_call/tool_result ordering when custom tool setup is async', async () => {
const toolName = 'delayed_custom_tool'
const agentWithCustomTool: AgentTemplate = {
...testAgentTemplate,
toolNames: [toolName, 'end_turn'],
}

const delayedToolCallChunk: StreamChunk = {
type: 'tool-call',
toolName,
toolCallId: 'delayed-custom-tool-call-id',
input: {
query: 'test',
},
}

async function* mockStream() {
yield delayedToolCallChunk
return promptSuccess('mock-message-id')
}

const fileContextWithCustomTool = {
...mockFileContext,
customToolDefinitions: {
[toolName]: {
inputSchema: {
type: 'object',
properties: {
query: { type: 'string' },
},
required: ['query'],
additionalProperties: false,
},
endsAgentStep: false,
description: 'A delayed custom tool for ordering tests',
},
},
}

const sessionState = getInitialSessionState(fileContextWithCustomTool)
const agentState = sessionState.mainAgentState

agentRuntimeImpl.requestMcpToolData = async () => {
// Force an async gap so tool_call emission happens after stream completion.
await new Promise((resolve) => setTimeout(resolve, 20))
return []
}
agentRuntimeImpl.requestToolCall = async () => ({
output: jsonToolResult({ ok: true }),
})

await processStream({
...agentRuntimeImpl,
agentContext: {},
agentState,
agentStepId: 'test-step-id',
agentTemplate: agentWithCustomTool,
ancestorRunIds: [],
clientSessionId: 'test-session',
fileContext: fileContextWithCustomTool,
fingerprintId: 'test-fingerprint',
fullResponse: '',
localAgentTemplates: { 'test-agent': agentWithCustomTool },
messages: [],
prompt: 'test prompt',
repoId: undefined,
repoUrl: undefined,
runId: 'test-run-id',
signal: new AbortController().signal,
stream: mockStream(),
system: 'test system',
tools: {},
userId: 'test-user',
userInputId: 'test-input-id',
onCostCalculated: async () => {},
onResponseChunk: () => {},
})

const assistantToolCallMessages = agentState.messageHistory.filter(
(m): m is AssistantMessage =>
m.role === 'assistant' &&
m.content.some((c) => c.type === 'tool-call' && c.toolName === toolName),
)
const toolMessages = agentState.messageHistory.filter(
(m): m is ToolMessage => m.role === 'tool' && m.toolName === toolName,
)

expect(assistantToolCallMessages.length).toBe(1)
expect(toolMessages.length).toBe(1)

const assistantToolCallPart = assistantToolCallMessages[0].content.find(
(
c,
): c is Extract<AssistantMessage['content'][number], { type: 'tool-call' }> =>
c.type === 'tool-call' && c.toolName === toolName,
)
expect(assistantToolCallPart).toBeDefined()
expect(toolMessages[0].toolCallId).toBe(assistantToolCallPart!.toolCallId)

const assistantIndex = agentState.messageHistory.indexOf(
assistantToolCallMessages[0],
)
const toolResultIndex = agentState.messageHistory.indexOf(toolMessages[0])
expect(assistantIndex).toBeGreaterThanOrEqual(0)
expect(toolResultIndex).toBeGreaterThan(assistantIndex)

const assistantToolCallIds = new Set(
agentState.messageHistory.flatMap((message) => {
if (message.role !== 'assistant') {
return []
}
return message.content.flatMap((part) =>
part.type === 'tool-call' ? [part.toolCallId] : [],
)
}),
)
const orphanToolResults = agentState.messageHistory.filter(
(message): message is ToolMessage =>
message.role === 'tool' && !assistantToolCallIds.has(message.toolCallId),
)
expect(orphanToolResults.length).toBe(0)
})
})
30 changes: 21 additions & 9 deletions packages/agent-runtime/src/tool-stream-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,25 @@ export async function* processStreamWithTools(params: {
processors: Record<
string,
{
onTagStart: (tagName: string, attributes: Record<string, string>) => void
onTagEnd: (tagName: string, params: Record<string, any>) => void
onTagStart: (
tagName: string,
attributes: Record<string, string>,
) => void | Promise<void>
onTagEnd: (
tagName: string,
params: Record<string, any>,
) => void | Promise<void>
}
>
defaultProcessor: (toolName: string) => {
onTagStart: (tagName: string, attributes: Record<string, string>) => void
onTagEnd: (tagName: string, params: Record<string, any>) => void
onTagStart: (
tagName: string,
attributes: Record<string, string>,
) => void | Promise<void>
onTagEnd: (
tagName: string,
params: Record<string, any>,
) => void | Promise<void>
}
onError: (tagName: string, errorMessage: string) => void
onResponseChunk: (chunk: PrintModeText | PrintModeError) => void
Expand Down Expand Up @@ -62,11 +74,11 @@ export async function* processStreamWithTools(params: {
// State for parsing XML tool calls from text stream
const xmlParserState: StreamParserState = createStreamParserState()

function processToolCallObject(params: {
async function processToolCallObject(params: {
toolName: string
input: any
contents?: string
}): void {
}): Promise<void> {
const { toolName, input, contents } = params

const processor = processors[toolName] ?? defaultProcessor(toolName)
Expand All @@ -85,8 +97,8 @@ export async function* processStreamWithTools(params: {
logger,
})

processor.onTagStart(toolName, {})
processor.onTagEnd(toolName, input)
await processor.onTagStart(toolName, {})
await processor.onTagEnd(toolName, input)
}

function flush() {
Expand Down Expand Up @@ -146,7 +158,7 @@ export async function* processStreamWithTools(params: {
}

if (chunk.type === 'tool-call') {
processToolCallObject(chunk)
await processToolCallObject(chunk)
}

yield chunk
Expand Down
34 changes: 20 additions & 14 deletions packages/agent-runtime/src/tools/stream-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ export async function processStream(
userId,
} = params
const fullResponseChunks: string[] = [fullResponse]
const messageHistoryBeforeStream = expireMessages(
agentState.messageHistory,
'agentStep',
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary, this is run right before each step already (in runAgentStep)


// === MUTABLE STATE ===
const toolResults: ToolMessage[] = []
Expand Down Expand Up @@ -111,9 +115,11 @@ export async function processStream(
return (chunk: string | PrintModeEvent) => {
if (typeof chunk !== 'string') {
if (chunk.type === 'tool_call') {
assistantMessages.push(
assistantMessage({ ...chunk, type: 'tool-call' }),
)
if (chunk.includeToolCall !== false) {
assistantMessages.push(
assistantMessage({ ...chunk, type: 'tool-call' }),
)
}
} else if (isXmlMode && chunk.type === 'tool_result') {
const toolResultMessage: ToolMessage = {
role: 'tool',
Expand Down Expand Up @@ -182,7 +188,7 @@ export async function processStream(
: (toolName as ToolName),
input: transformed ? transformed.input : input,
fromHandleSteps: false,
skipDirectResultPush: isXmlMode,
skipDirectResultPush: true,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we should refactor to delete this param entirely if always true

fileProcessingState,
fullResponse: fullResponseChunks.join(''),
previousToolCallFinished: previousPromise,
Expand All @@ -199,7 +205,7 @@ export async function processStream(
...params,
toolName,
input,
skipDirectResultPush: isXmlMode,
skipDirectResultPush: true,
fileProcessingState,
fullResponse: fullResponseChunks.join(''),
previousToolCallFinished: previousPromise,
Expand Down Expand Up @@ -327,20 +333,20 @@ export async function processStream(
}
}

// === FINALIZATION ===
agentState.messageHistory = buildArray<Message>([
...expireMessages(agentState.messageHistory, 'agentStep'),
...assistantMessages,
...toolResultsToAddAfterStream,
])

if (!signal.aborted) {
resolveStreamDonePromise()
await previousToolCallFinished
}

// Error messages must come AFTER tool results for proper API ordering
agentState.messageHistory.push(...errorMessages)
// === FINALIZATION ===
// Build message history from the pre-stream snapshot so tool_calls and
// tool_results are always appended in deterministic order.
agentState.messageHistory = buildArray<Message>([
...messageHistoryBeforeStream,
...assistantMessages,
...toolResultsToAddAfterStream,
...errorMessages,
])
Comment on lines +314 to +322
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, I think it does make more sense to do all the mutating of message history here, in one place.


return {
fullResponse: fullResponseChunks.join(''),
Expand Down
12 changes: 10 additions & 2 deletions packages/agent-runtime/src/tools/tool-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ export async function executeToolCall<T extends ToolName>(
previousToolCallFinished,
toolCalls,
toolResults,
toolResultsToAddAfterStream: _toolResultsToAddAfterStream,
toolResultsToAddAfterStream,
userInputId,

onCostCalculated,
Expand Down Expand Up @@ -350,6 +350,10 @@ export async function executeToolCall<T extends ToolName>(

toolResults.push(toolResult)

if (!excludeToolFromMessageHistory) {
toolResultsToAddAfterStream.push(toolResult)
}

if (!excludeToolFromMessageHistory && !params.skipDirectResultPush) {
agentState.messageHistory.push(toolResult)
}
Expand Down Expand Up @@ -450,7 +454,7 @@ export async function executeCustomToolCall(
toolCallId,
toolCalls,
toolResults,
toolResultsToAddAfterStream: _toolResultsToAddAfterStream,
toolResultsToAddAfterStream,
userInputId,
} = params
const toolCall: CustomToolCall | ToolCallError = parseRawCustomToolCall({
Expand Down Expand Up @@ -560,6 +564,10 @@ export async function executeCustomToolCall(

toolResults.push(toolResult)

if (!excludeToolFromMessageHistory) {
toolResultsToAddAfterStream.push(toolResult)
}

if (!excludeToolFromMessageHistory && !params.skipDirectResultPush) {
agentState.messageHistory.push(toolResult)
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new paradigm, we can delete this if and
agentState.messageHistory.push(toolResult).

We might want to rename toolResultsToAddAfterStream to toolResultsToAddToMessageHistory, since all tool results are added after stream and these are the ones included in the message history in particular

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,25 @@ export function convertToOpenAICompatibleChatMessages(
}
}

// Debug: dump OpenAI-format message summary to catch tool_use_id mismatches
console.error('[SDK DEBUG] OpenAI-format messages (' + messages.length + '):')
for (let i = 0; i < messages.length; i++) {
const m = messages[i] as Record<string, unknown>
const role = m.role as string
if (role === 'tool') {
console.error(` [${i}] tool tool_call_id=${(m as { tool_call_id?: string }).tool_call_id}`)
} else if (role === 'assistant') {
const toolCalls = (m as { tool_calls?: Array<{ id: string; function?: { name: string } }> }).tool_calls
if (toolCalls?.length) {
const ids = toolCalls.map(tc => `${tc.function?.name}:${tc.id}`)
console.error(` [${i}] assistant tool_calls=[${ids.join(', ')}]`)
} else {
console.error(` [${i}] assistant (text)`)
}
} else {
console.error(` [${i}] ${role}`)
}
}

return messages
}
25 changes: 24 additions & 1 deletion sdk/src/impl/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,34 @@ export async function* promptAiSdkStream(
}
}

const convertedMessages = convertCbToModelMessages(params)

// Debug: dump message summary to catch tool_use_id mismatches before they hit the API
console.error('[SDK DEBUG] promptAiSdkStream messages (' + convertedMessages.length + '):')
for (let i = 0; i < convertedMessages.length; i++) {
const m = convertedMessages[i] as Record<string, unknown>
const role = m.role as string
const content = m.content
if (role === 'tool' && Array.isArray(content)) {
const toolIds = (content as Array<{ toolCallId?: string; type?: string }>)
.filter(c => c.type === 'tool-result')
.map(c => c.toolCallId)
console.error(` [${i}] ${role} toolCallIds=${JSON.stringify(toolIds)}`)
} else if (role === 'assistant' && Array.isArray(content)) {
const parts = (content as Array<{ type?: string; toolCallId?: string; toolName?: string }>)
.map(c => c.type === 'tool-call' ? `tool-call(${c.toolName}:${c.toolCallId})` : c.type)
console.error(` [${i}] ${role} parts=[${parts.join(', ')}]`)
} else {
const tags = (m as { tags?: string[] }).tags
console.error(` [${i}] ${role}${tags ? ' tags=' + JSON.stringify(tags) : ''}`)
}
}

const response = streamText({
...params,
prompt: undefined,
model: aiSDKModel,
messages: convertCbToModelMessages(params),
messages: convertedMessages,
// When using Claude OAuth, disable retries so we can immediately fall back to Codebuff
// backend on rate limit errors instead of retrying 4 times first
...(isClaudeOAuth && { maxRetries: 0 }),
Expand Down
Loading