diff --git a/src/providers/codex.test.ts b/src/providers/codex.test.ts index a879ebdd..1b01c195 100644 --- a/src/providers/codex.test.ts +++ b/src/providers/codex.test.ts @@ -311,6 +311,103 @@ Deno.test("codex provider streams assistant text deltas from agent_message event assertEquals(streamedText, ["hello ", "world"]); }); +Deno.test("codex provider preserves multiple assistant message items in order", async () => { + const streamEvents: Array< + { + type?: string; + item_id?: string; + output_index?: number; + item?: { id?: string; type?: string }; + } + > = []; + const provider = createCodexProvider({ + runCommand: ({ onStdoutLine }) => { + const lines = [ + JSON.stringify({ + type: "item.delta", + item: { id: "msg_1", type: "agent_message", text: "first " }, + }), + JSON.stringify({ + type: "item.completed", + item: { id: "msg_1", type: "agent_message", text: "first reply" }, + }), + JSON.stringify({ + type: "item.delta", + item: { id: "msg_2", type: "agent_message", text: "second " }, + }), + JSON.stringify({ + type: "item.completed", + item: { id: "msg_2", type: "agent_message", text: "second reply" }, + }), + ]; + lines.forEach((line) => onStdoutLine?.(line)); + return Promise.resolve({ + success: true, + code: 0, + stdout: enc.encode(lines.join("\n")), + stderr: new Uint8Array(), + }); + }, + }); + + const result = await provider.responses?.({ + request: { + model: "codex-cli/default", + stream: true, + input: [{ + type: "message", + role: "user", + content: [{ type: "input_text", text: "hi" }], + }], + }, + onStreamEvent: (event) => { + streamEvents.push( + event as { + type?: string; + item_id?: string; + output_index?: number; + item?: { id?: string; type?: string }; + }, + ); + }, + }); + + assertEquals( + result?.output + .filter((item) => item.type === "message") + .map((item) => + item.type === "message" + ? { + id: item.id, + text: item.content.map((part) => part.text).join(""), + } + : null + ), + [ + { id: "msg_1", text: "first reply" }, + { id: "msg_2", text: "second reply" }, + ], + ); + assertEquals( + streamEvents + .filter((event) => event.type === "response.output_text.delta") + .map((event) => ({ + item_id: event.item_id, + output_index: event.output_index, + })), + [ + { item_id: "msg_1", output_index: 0 }, + { item_id: "msg_2", output_index: 1 }, + ], + ); + assertEquals( + streamEvents + .filter((event) => event.type === "response.output_item.done") + .map((event) => event.item?.id), + ["msg_1", "msg_2"], + ); +}); + Deno.test("codex provider streams completed-only assistant text once", async () => { const streamedText: Array = []; const provider = createCodexProvider({ diff --git a/src/providers/codex.ts b/src/providers/codex.ts index 5608c22a..ade66f3f 100644 --- a/src/providers/codex.ts +++ b/src/providers/codex.ts @@ -60,6 +60,11 @@ type CommandOutput = { stderr: Uint8Array; }; +type CodexAssistantMessage = { + itemId: string | null; + text: string; +}; + type CommandRunner = (input: { args: Array; cwd: string; @@ -439,6 +444,8 @@ function extractCodexItemText(record: Record): string { type CodexAssistantStreamState = { streamedText: string; sawAssistantTextStream: boolean; + assistantOutputIndexByItemId: Map; + emittedTerminalAssistantItemIds: Set; }; function emitCodexAssistantTextEvents(input: { @@ -446,6 +453,7 @@ function emitCodexAssistantTextEvents(input: { emit: (event: Record) => void; emitText?: (text: string) => void; assistantState: CodexAssistantStreamState; + nextOutputIndexRef: { value: number }; }): void { const payloadType = typeof input.event.type === "string" ? input.event.type @@ -456,7 +464,19 @@ function emitCodexAssistantTextEvents(input: { const record = item as Record; if (record.type !== "agent_message") return; - const outputIndex = 0; + const itemId = typeof record.id === "string" && record.id.trim().length > 0 + ? record.id.trim() + : `assistant_${input.nextOutputIndexRef.value}`; + const outputIndex = (() => { + const existing = input.assistantState.assistantOutputIndexByItemId.get( + itemId, + ); + if (typeof existing === "number") return existing; + const next = input.nextOutputIndexRef.value; + input.nextOutputIndexRef.value += 1; + input.assistantState.assistantOutputIndexByItemId.set(itemId, next); + return next; + })(); const text = extractCodexItemText(record); if (!text) return; @@ -467,6 +487,7 @@ function emitCodexAssistantTextEvents(input: { type: "response.output_text.delta", output_index: outputIndex, delta: text, + item_id: itemId, }); input.emitText?.(text); return; @@ -480,7 +501,21 @@ function emitCodexAssistantTextEvents(input: { type: "response.output_text.done", output_index: outputIndex, text, + item_id: itemId, }); + if (!input.assistantState.emittedTerminalAssistantItemIds.has(itemId)) { + input.assistantState.emittedTerminalAssistantItemIds.add(itemId); + input.emit({ + type: "response.output_item.done", + output_index: outputIndex, + item: { + type: "message", + role: "assistant", + id: itemId, + content: [{ type: "output_text", text }], + }, + }); + } if (!hadPriorAssistantDelta) { input.emitText?.(text); } @@ -640,6 +675,21 @@ function responseItemsFromAssistantMessage( return output; } +function responseItemsFromAssistantMessages( + messages: Array, +): Array { + return messages + .filter((message) => message.text.length > 0) + .map((message) => + ({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: message.text }], + ...(message.itemId ? { id: message.itemId } : {}), + }) satisfies ResponseMessageItem + ); +} + function stringContent(content: ModelMessage["content"]): string { if (typeof content === "string") return content.trim(); return ""; @@ -684,6 +734,7 @@ function parseNumber(input: unknown): number { function parseCodexStdout(stdout: string): { threadId?: string; assistantText: string; + assistantMessages: Array; usage?: { promptTokens: number; completionTokens: number; @@ -691,7 +742,7 @@ function parseCodexStdout(stdout: string): { }; } { let threadId: string | undefined; - let assistantText = ""; + const assistantMessages: Array = []; let usage: { promptTokens: number; completionTokens: number; @@ -716,7 +767,7 @@ function parseCodexStdout(stdout: string): { continue; } - if (parsed.type === "item.completed") { + if (parsed.type === "item.completed" || parsed.type === "item.done") { const item = parsed.item as Record | undefined; if (!item || typeof item !== "object") continue; if (item.type !== "agent_message") continue; @@ -732,7 +783,14 @@ function parseCodexStdout(stdout: string): { .join("") .trim() : ""; - if (content) assistantText = content; + if (content) { + assistantMessages.push({ + itemId: typeof item.id === "string" && item.id.trim().length > 0 + ? item.id.trim() + : null, + text: content, + }); + } continue; } @@ -747,18 +805,33 @@ function parseCodexStdout(stdout: string): { } } - return { threadId, assistantText, usage }; + return { + threadId, + assistantText: assistantMessages.map((message) => message.text).join(""), + assistantMessages, + usage, + }; } function buildUpdatedState(input: { priorState?: SavedState; messages: Array; assistantText: string; + assistantMessages?: Array; threadId?: string; }): SavedState { const priorState = input.priorState; const baseMessages = input.messages.map((message) => ({ ...message })); - baseMessages.push({ role: "assistant", content: input.assistantText }); + if (input.assistantMessages && input.assistantMessages.length > 0) { + baseMessages.push( + ...input.assistantMessages.map((message) => ({ + role: "assistant" as const, + content: message.text, + })), + ); + } else { + baseMessages.push({ role: "assistant", content: input.assistantText }); + } const meta = { ...(priorState?.meta ?? {}) }; if (input.threadId) { meta[CODEX_THREAD_META_KEY] = input.threadId; @@ -867,6 +940,7 @@ function buildCodexStreamHandler(input: { emit: input.emitTool, emitText: input.emitText, assistantState: input.assistantState, + nextOutputIndexRef, }); emitCodexReasoningEvents({ event, @@ -890,13 +964,21 @@ export function createCodexProvider(opts?: { runCommand?: CommandRunner; }): ModelProvider { const runCommand = opts?.runCommand ?? defaultCommandRunner; - const runChat: ModelProvider["chat"] = async (input) => { + const runCodexTurn = async ( + input: Parameters>[0], + ): Promise< + Awaited>> & { + assistantMessages: Array; + } + > => { if (input.signal?.aborted) { throw new DOMException("Run canceled", "AbortError"); } const assistantState: CodexAssistantStreamState = { streamedText: "", sawAssistantTextStream: false, + assistantOutputIndexByItemId: new Map(), + emittedTerminalAssistantItemIds: new Set(), }; const streamHandler = (input.onStreamEvent || input.onTraceEvent || (input.stream && input.onStreamText)) @@ -993,6 +1075,7 @@ export function createCodexProvider(opts?: { priorState: input.state, messages: input.messages, assistantText: parsed.assistantText, + assistantMessages: parsed.assistantMessages, threadId, }); @@ -1001,6 +1084,16 @@ export function createCodexProvider(opts?: { finishReason: "stop" as const, updatedState, usage: parsed.usage, + assistantMessages: parsed.assistantMessages, + }; + }; + const runChat: ModelProvider["chat"] = async (input) => { + const result = await runCodexTurn(input); + return { + message: result.message, + finishReason: result.finishReason, + updatedState: result.updatedState, + usage: result.usage, }; }; @@ -1017,6 +1110,8 @@ export function createCodexProvider(opts?: { const assistantState: CodexAssistantStreamState = { streamedText: "", sawAssistantTextStream: false, + assistantOutputIndexByItemId: new Map(), + emittedTerminalAssistantItemIds: new Set(), }; return { assistantState, @@ -1037,7 +1132,7 @@ export function createCodexProvider(opts?: { }; })() : undefined; - const result = await runChat({ + const result = await runCodexTurn({ model: input.request.model, messages: responseItemsToChatMessages( input.request.input, @@ -1051,7 +1146,9 @@ export function createCodexProvider(opts?: { onStreamEvent: streamHandler?.handle, }); - const output = responseItemsFromAssistantMessage(result.message); + const output = result.assistantMessages.length > 0 + ? responseItemsFromAssistantMessages(result.assistantMessages) + : responseItemsFromAssistantMessage(result.message); const responseId = `codex-${crypto.randomUUID()}`; const createdAt = Math.floor(Date.now() / 1000); if (input.request.stream) { @@ -1070,24 +1167,43 @@ export function createCodexProvider(opts?: { }, }); if ( - typeof result.message.content === "string" && - result.message.content && !streamHandler?.assistantState.sawAssistantTextStream ) { - input.onStreamEvent?.({ - type: "response.output_text.delta", - sequence_number: 1, - output_index: 0, - delta: result.message.content, - }); - input.onStreamEvent?.({ - type: "response.output_text.done", - sequence_number: 2, - output_index: 0, - text: result.message.content, + const fallbackMessages = result.assistantMessages.length > 0 + ? result.assistantMessages + : typeof result.message.content === "string" && + result.message.content + ? [{ itemId: null, text: result.message.content }] + : []; + fallbackMessages.forEach((message, index) => { + if (!message.text) return; + input.onStreamEvent?.({ + type: "response.output_text.delta", + sequence_number: 1 + (index * 2), + output_index: index, + delta: message.text, + ...(message.itemId ? { item_id: message.itemId } : {}), + }); + input.onStreamEvent?.({ + type: "response.output_text.done", + sequence_number: 2 + (index * 2), + output_index: index, + text: message.text, + ...(message.itemId ? { item_id: message.itemId } : {}), + }); }); } output.forEach((item, index) => { + if ( + item.type === "message" && + item.role === "assistant" && + typeof item.id === "string" && + streamHandler?.assistantState.emittedTerminalAssistantItemIds.has( + item.id, + ) + ) { + return; + } input.onStreamEvent?.({ type: "response.output_item.added", sequence_number: 3 + (index * 2),