diff --git a/src/client/streaming.integration.test.ts b/src/client/streaming.integration.test.ts index 45fbc6fd..68bfbca2 100644 --- a/src/client/streaming.integration.test.ts +++ b/src/client/streaming.integration.test.ts @@ -12,7 +12,6 @@ import { import { getParts, deriveUIMessagesFromDeltas, - deriveUIMessagesFromTextStreamParts, } from "../deltas.js"; import type { TestConvex } from "convex-test"; import type { StreamDelta, StreamMessage } from "../validators.js"; @@ -677,77 +676,6 @@ describe("Delta Stream Consumption", () => { expect((parts[0] as { type: string }).type).toBe("new"); expect(cursor).toBe(6); }); - - test("TextStreamPart format delta reconstruction with tool calls", () => { - const streamId = "s1"; - const streamMessage: StreamMessage = { - streamId, - order: 1, - stepOrder: 0, - status: "streaming", - }; - const deltas: StreamDelta[] = [ - { - streamId, - start: 0, - end: 1, - parts: [{ type: "text-delta", id: "txt-0", text: "Let me call a tool. " }], - }, - { - streamId, - start: 1, - end: 2, - parts: [ - { - type: "tool-call", - toolCallId: "tc1", - toolName: "search", - input: { query: "hello" }, - }, - ], - }, - { - streamId, - start: 2, - end: 3, - parts: [ - { - type: "tool-result", - toolCallId: "tc1", - toolName: "search", - output: "Found 3 results", - }, - ], - }, - { - streamId, - start: 3, - end: 4, - parts: [ - { type: "text-delta", id: "txt-1", text: "Here are the results." }, - ], - }, - ]; - - const [messages, , changed] = deriveUIMessagesFromTextStreamParts( - "thread1", - [streamMessage], - [], - deltas, - ); - - expect(messages).toHaveLength(1); - expect(changed).toBe(true); - - const msg = messages[0]; - expect(msg.text).toContain("Let me call a tool."); - expect(msg.text).toContain("Here are the results."); - - const toolParts = msg.parts.filter((p: any) => - p.type.startsWith("tool-"), - ); - expect(toolParts.length).toBeGreaterThan(0); - }); }); // ============================================================================ @@ -879,18 +807,21 @@ describe("Fallback Behavior", () => { order: 0, stepOrder: 0, status: "streaming", + format: "UIMessageChunk", }; const finishedMsg: StreamMessage = { streamId: "s2", order: 1, stepOrder: 0, status: "finished", + format: "UIMessageChunk", }; const abortedMsg: StreamMessage = { streamId: "s3", order: 2, stepOrder: 0, status: "aborted", + format: "UIMessageChunk", }; const msgs = await deriveUIMessagesFromDeltas( diff --git a/src/deltas.test.ts b/src/deltas.test.ts index 6be75868..98c09e48 100644 --- a/src/deltas.test.ts +++ b/src/deltas.test.ts @@ -1,13 +1,13 @@ import { describe, it, expect } from "vitest"; import { + applyUIMessageChunksIncremental, blankUIMessage, - deriveUIMessagesFromTextStreamParts, - updateFromTextStreamParts, + emptyIncrementalStreamState, + getParts, updateFromUIMessageChunks, } from "./deltas.js"; import type { StreamMessage, StreamDelta } from "./validators.js"; -import { omit } from "convex-helpers"; -import type { Tool, ToolUIPart, TypedToolResult } from "ai"; +import type { ToolUIPart, UIMessageChunk } from "ai"; describe("UIMessageChunks", () => { it("updates a UIMessage with a tool call and follow up", async () => { @@ -200,427 +200,402 @@ describe("UIMessageChunks - continuation stream", () => { }); describe("mergeDeltas", () => { - it("merges a single text-delta into a message", () => { - const streamId = "s1"; - const deltas = [ - { - streamId, - start: 0, - end: 5, - parts: [{ type: "text-delta", id: "1", text: "Hello" }], - } satisfies StreamDelta, - ]; - const [messages, newStreams, changed] = deriveUIMessagesFromTextStreamParts( - "thread1", - [{ streamId, order: 1, stepOrder: 0, status: "streaming" }], - [], - deltas, - ); - expect(messages).toHaveLength(1); - expect(messages[0].text).toBe("Hello"); - expect(messages[0].role).toBe("assistant"); - expect(changed).toBe(true); - expect(newStreams[0].cursor).toBe(5); - }); - - it("merges multiple deltas for the same stream", () => { - const streamId = "s1"; - const deltas = [ - { - streamId, - start: 0, - end: 5, - parts: [{ type: "text-delta", id: "1", text: "Hello" }], - }, - { - streamId, - start: 5, - end: 11, - parts: [{ type: "text-delta", id: "2", text: " World!" }], - }, - ]; - const [messages, newStreams, changed] = deriveUIMessagesFromTextStreamParts( - "thread1", - [{ streamId, order: 1, stepOrder: 0, status: "streaming" }], - [], - deltas, - ); - expect(messages).toHaveLength(1); - expect(messages[0].text).toBe("Hello World!"); - expect(changed).toBe(true); - expect(newStreams[0].cursor).toBe(11); - }); + it("incremental apply only consumes parts past the cursor (no re-processing)", () => { + const N = 500; + const streamId = "s-perf"; + const toolCallId = "tool-0"; + const streamMessage = { + streamId, + status: "streaming" as const, + order: 0, + stepOrder: 0, + format: "UIMessageChunk" as const, + agentName: "agent1", + }; - it("handles tool-call and tool-result parts", () => { - const streamId = "s2"; - const deltas = [ + // One StreamDelta with preamble, then N deltas each with one tool-input-delta + const allDeltas: StreamDelta[] = [ { streamId, start: 0, end: 1, parts: [ - { - type: "tool-call", - toolCallId: "call1", - toolName: "myTool", - input: "What's the meaning of life?", - }, - ], - } satisfies StreamDelta, - { + { type: "start" }, + { type: "start-step" }, + { type: "tool-input-start", toolCallId, toolName: "myTool" }, + ] as UIMessageChunk[], + }, + ...Array.from({ length: N }, (_, i) => ({ streamId, - start: 1, - end: 2, + start: i + 1, + end: i + 2, parts: [ { - type: "tool-result", - toolCallId: "call1", - toolName: "myTool", - input: undefined, - output: "42", - } satisfies TypedToolResult<{ myTool: Tool }>, + type: "tool-input-delta", + toolCallId, + inputTextDelta: "x", + } as UIMessageChunk, ], - } satisfies StreamDelta, + })), ]; - const [[message], _, changed] = deriveUIMessagesFromTextStreamParts( - "thread1", - [{ streamId, order: 2, stepOrder: 0, status: "streaming" }], - [], - deltas, - ); - expect(message).toBeDefined(); - expect(message.role).toBe("assistant"); - const content = message.parts; - expect(content).toEqual([ - { - type: "tool-myTool", - toolCallId: "call1", - input: "What's the meaning of life?", - output: "42", - state: "output-available", - } satisfies ToolUIPart, - ]); - expect(changed).toBe(true); - }); - it("returns changed=false if no new deltas", () => { - const streamId = "s3"; - const deltas: StreamDelta[] = []; - const [, newStreams, changed] = deriveUIMessagesFromTextStreamParts( - "thread1", - [{ streamId, order: 3, stepOrder: 0, status: "streaming" }], - [], - deltas, + // Simulate the hook: process one delta at a time, tracking cursor + prior message + let cursor = 0; + let uiMessage = blankUIMessage(streamMessage, "thread-perf"); + let streamState = emptyIncrementalStreamState(); + let totalPartsProcessed = 0; + + for (let i = 0; i <= N; i++) { + const available = allDeltas.slice(0, i + 1); + const { parts: newParts, cursor: newCursor } = getParts( + available, + cursor, + ); + if (newParts.length > 0) { + totalPartsProcessed += newParts.length; + ({ message: uiMessage, streamState } = + applyUIMessageChunksIncremental( + structuredClone(uiMessage), + newParts, + streamState, + )); + cursor = newCursor; + } + } + + // Each delta part is handed to applyUIMessageChunksIncremental exactly + // once across all batches (cursor slicing — no re-processing of prior + // parts). N tool-input-deltas + 3 preamble parts. The end-to-end O(N) + // claim is proven by the PR's 21,000 ms → 73 ms benchmark, not by this + // unit test. + expect(totalPartsProcessed).toBe(N + 3); + + // Correctness: the raw accumulator holds "x" repeated N times across batches + expect(streamState.toolInputText[toolCallId]).toBe("x".repeat(N)); + const toolPart = uiMessage.parts.find( + (p): p is ToolUIPart => "toolCallId" in p && p.toolCallId === toolCallId, ); - expect(changed).toBe(false); - expect(newStreams[0].cursor).toBe(0); + expect(toolPart).toBeDefined(); }); - it("handles multiple streams and sorts by order/stepOrder", () => { - const deltas = [ - { - streamId: "s2", - start: 0, - end: 3, - parts: [{ type: "text-delta", id: "1", text: "B" }], - } satisfies StreamDelta, - { - streamId: "s1", - start: 0, - end: 3, - parts: [{ type: "text-delta", id: "2", text: "A" }], - } satisfies StreamDelta, - ]; - const [messages, _, changed] = deriveUIMessagesFromTextStreamParts( - "thread1", + it("applyUIMessageChunksIncremental: text-delta accumulation across calls", () => { + const streamMessage = { + streamId: "s-text", + status: "streaming" as const, + order: 0, + stepOrder: 0, + format: "UIMessageChunk" as const, + agentName: "a", + }; + let msg = blankUIMessage(streamMessage, "thread-text"); + let state = emptyIncrementalStreamState(); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, [ - { streamId: "s1", order: 1, stepOrder: 0, status: "streaming" }, - { streamId: "s2", order: 2, stepOrder: 0, status: "streaming" }, - ], - [], - deltas, - ); - expect(messages).toHaveLength(2); - expect(messages[0].text).toBe("A"); - expect(messages[1].text).toBe("B"); - expect(changed).toBe(true); - // Sorted by order - expect(messages[0].order).toBe(1); - expect(messages[1].order).toBe(2); - }); + { type: "start" }, + { type: "start-step" }, + { type: "text-start", id: "t0" }, + { type: "text-delta", id: "t0", delta: "Hello " }, + ] as UIMessageChunk[], + state, + )); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [{ type: "text-delta", id: "t0", delta: "world" }] as UIMessageChunk[], + state, + )); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { type: "text-delta", id: "t0", delta: "!" }, + { type: "text-end", id: "t0" }, + ] as UIMessageChunk[], + state, + )); - it("does not duplicate text content when merging sequential text-deltas", () => { - const streamId = "s4"; - const deltas = [ - { - streamId, - start: 0, - end: 5, - parts: [{ type: "text-delta", id: "1", text: "Hello" }], - }, - { - streamId, - start: 5, - end: 11, - parts: [{ type: "text-delta", id: "2", text: " World!" }], - }, - { - streamId, - start: 11, - end: 12, - parts: [{ type: "text-delta", id: "3", text: "!" }], - }, - ] satisfies StreamDelta[]; - const [messages] = deriveUIMessagesFromTextStreamParts( - "thread1", - [{ streamId, order: 4, stepOrder: 0, status: "streaming" }], - [], - deltas, - ); - expect(messages).toHaveLength(1); - expect(messages[0].text).toBe("Hello World!!"); - // There should only be one text part per message - const content = messages[0].parts; - if (Array.isArray(content)) { - const textParts = content.filter((p) => p.type === "text"); - expect(textParts).toHaveLength(1); - expect(textParts[0].text).toBe("Hello World!!"); - } + const textPart = msg.parts.find((p) => p.type === "text") as + | { text: string; state: string } + | undefined; + expect(textPart?.text).toBe("Hello world!"); + expect(textPart?.state).toBe("done"); + expect(msg.text).toBe("Hello world!"); }); - it("does not duplicate reasoning parts", () => { - const streamId = "s6"; - const deltas = [ - { - streamId, - start: 0, - end: 1, - parts: [ - { type: "reasoning-start", id: "1" }, - { type: "reasoning-delta", id: "1", text: "I'm thinking..." }, - ], - }, - { - streamId, - start: 1, - end: 2, - parts: [ - { type: "reasoning-delta", id: "1", text: " Still thinking..." }, - ], - }, - { - streamId, - start: 2, - end: 3, - parts: [{ type: "reasoning-end", id: "1" }], - }, - ]; - const [messages] = deriveUIMessagesFromTextStreamParts( - "thread1", - [{ streamId, order: 6, stepOrder: 0, status: "streaming" }], - [], - deltas, + it("applyUIMessageChunksIncremental: tool-output-available preserves input and sets fields", async () => { + const streamMessage = { + streamId: "s-tool-out", + status: "streaming" as const, + order: 0, + stepOrder: 0, + format: "UIMessageChunk" as const, + agentName: "a", + }; + let msg = blankUIMessage(streamMessage, "thread-tool-out"); + let state = emptyIncrementalStreamState(); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { type: "start" }, + { type: "start-step" }, + { type: "tool-input-start", toolCallId: "c1", toolName: "myTool" }, + { + type: "tool-input-available", + toolCallId: "c1", + toolName: "myTool", + input: { q: "hi" }, + }, + ] as UIMessageChunk[], + state, + )); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { + type: "tool-output-available", + toolCallId: "c1", + output: { result: "ok" }, + preliminary: true, + providerExecuted: true, + }, + ] as UIMessageChunk[], + state, + )); + + const toolPart = msg.parts.find( + (p): p is ToolUIPart => "toolCallId" in p && p.toolCallId === "c1", ); - expect(messages).toHaveLength(1); - if (Array.isArray(messages[0].parts)) { - const reasoningParts = messages[0].parts.filter( - (p) => p.type === "reasoning", - ); - expect(reasoningParts).toHaveLength(1); - expect(reasoningParts[0].text).toBe("I'm thinking... Still thinking..."); - expect(reasoningParts[0].state).toBe("done"); - } + expect(toolPart?.state).toBe("output-available"); + expect(toolPart?.input).toEqual({ q: "hi" }); + expect((toolPart as { output?: unknown }).output).toEqual({ result: "ok" }); + expect((toolPart as { preliminary?: boolean }).preliminary).toBe(true); + expect((toolPart as { providerExecuted?: boolean }).providerExecuted).toBe(true); }); - it("applyDeltasToStreamMessage is idempotent and does not duplicate content", () => { - const streamId = "s7"; + it("applyUIMessageChunksIncremental: tool-input-error sets rawInput and clears input for static tools", async () => { const streamMessage = { - streamId, - order: 7, + streamId: "s-tool-err", + status: "streaming" as const, + order: 0, stepOrder: 0, - status: "streaming", - } satisfies StreamMessage; - const deltas = [ - { - streamId, - start: 0, - end: 5, - parts: [{ type: "text-delta", id: "1", text: "Hello" }], - }, - { - streamId, - start: 5, - end: 11, - parts: [{ type: "text-delta", id: "2", text: " World!" }], - }, - ]; - // First call: apply both deltas - let [result, changed] = updateFromTextStreamParts( - "thread1", - streamMessage, - undefined, - deltas, - ); - expect(result.message.text).toBe("Hello World!"); - // Second call: re-apply the same deltas (should not duplicate) - [result, changed] = updateFromTextStreamParts( - "thread1", - streamMessage, - result, - deltas, - ); - expect(result.message.text).toBe("Hello World!"); - // Third call: add a new delta - const moreDeltas = [ - ...deltas, - { - streamId, - start: 11, - end: 12, - parts: [{ type: "text-delta", id: "3", text: "!" }], - }, - ]; - [result, changed] = updateFromTextStreamParts( - "thread1", - streamMessage, - result, - moreDeltas, + format: "UIMessageChunk" as const, + agentName: "a", + }; + let msg = blankUIMessage(streamMessage, "thread-tool-err"); + let state = emptyIncrementalStreamState(); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { type: "start" }, + { type: "start-step" }, + { type: "tool-input-start", toolCallId: "c2", toolName: "myTool" }, + ] as UIMessageChunk[], + state, + )); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { + type: "tool-input-error", + toolCallId: "c2", + toolName: "myTool", + input: { bad: "args" }, + errorText: "validation failed", + }, + ] as UIMessageChunk[], + state, + )); + + const toolPart = msg.parts.find( + (p): p is ToolUIPart => "toolCallId" in p && p.toolCallId === "c2", ); - expect(changed).toBe(true); - expect(result.message.text).toBe("Hello World!!"); - // Re-apply all deltas again (should still not duplicate) - [result, changed] = updateFromTextStreamParts( - "thread1", - streamMessage, - result, - moreDeltas, + expect(toolPart?.state).toBe("output-error"); + expect((toolPart as { errorText?: string }).errorText).toBe( + "validation failed", ); - expect(changed).toBe(false); - expect(result.message.text).toBe("Hello World!!"); + expect(toolPart?.input).toBeUndefined(); + expect((toolPart as { rawInput?: unknown }).rawInput).toEqual({ + bad: "args", + }); }); - it("mergeDeltas is pure and does not mutate inputs", () => { - const streamId = "s8"; - const streamMessages = [ - { streamId, order: 8, stepOrder: 0, status: "streaming" }, - ] satisfies StreamMessage[]; - const deltas = [ - { - streamId, - start: 0, - end: 5, - parts: [{ type: "text-delta", id: "1", text: "Hello" }], - }, - { - streamId, - start: 5, - end: 11, - parts: [{ type: "text-delta", id: "2", text: " World!" }], - }, - ]; - // Deep freeze inputs to catch mutation - function deepFreeze(obj: unknown): unknown { - if (obj && typeof obj === "object" && !Object.isFrozen(obj)) { - Object.freeze(obj); - for (const key of Object.keys(obj)) { - deepFreeze((obj as Record)[key]); - } - } - return obj; - } - deepFreeze(streamMessages); - deepFreeze(deltas); - const [messages1, streams1, changed1] = deriveUIMessagesFromTextStreamParts( - "thread1", - streamMessages, - [], - deltas, - ); - const [messages2, streams2, changed2] = deriveUIMessagesFromTextStreamParts( - "thread1", - streamMessages, - [], - deltas, - ); - expect(messages1.map((m) => omit(m, ["_creationTime"]))).toEqual( - messages2.map((m) => omit(m, ["_creationTime"])), + it("accumulates tool input across a batch boundary", async () => { + const streamMessage = { + streamId: "s-tool-split", + status: "streaming" as const, + order: 0, + stepOrder: 0, + format: "UIMessageChunk" as const, + agentName: "a", + }; + let msg = blankUIMessage(streamMessage, "thread-tool-split"); + let state = emptyIncrementalStreamState(); + + // Batch A: preamble + the first half of the JSON input. + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { type: "start" }, + { type: "start-step" }, + { type: "tool-input-start", toolCallId: "c1", toolName: "myTool" }, + { type: "tool-input-delta", toolCallId: "c1", inputTextDelta: '{"a":1' }, + ] as UIMessageChunk[], + state, + )); + const afterA = msg.parts.find( + (p): p is ToolUIPart => "toolCallId" in p && p.toolCallId === "c1", ); - expect( - streams1.map((s) => ({ - ...s, - message: omit(s.message, ["_creationTime"]), - })), - ).toEqual( - streams2.map((s) => ({ - ...s, - message: omit(s.message, ["_creationTime"]), - })), + // Mid-stream: JSON is incomplete, input stays unset. + expect(afterA?.input).toBeUndefined(); + + // Batch B: the remainder of the JSON input. + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { type: "tool-input-delta", toolCallId: "c1", inputTextDelta: ',"b":2}' }, + ] as UIMessageChunk[], + state, + )); + const afterB = msg.parts.find( + (p): p is ToolUIPart => "toolCallId" in p && p.toolCallId === "c1", ); - expect(changed1).toBe(changed2); - // Inputs should remain unchanged - expect(streamMessages).toMatchObject([ - { streamId, order: 8, stepOrder: 0, status: "streaming" }, - ]); - expect(deltas).toEqual([ - { - streamId, - start: 0, - end: 5, - parts: [{ type: "text-delta", id: "1", text: "Hello" }], - }, - { - streamId, - start: 5, - end: 11, - parts: [{ type: "text-delta", id: "2", text: " World!" }], - }, - ]); + // Complete JSON is parsed once the accumulator is valid. + expect(afterB?.input).toEqual({ a: 1, b: 2 }); + expect(state.toolInputText["c1"]).toBe('{"a":1,"b":2}'); }); - it("handles streaming tool-approval-request and updates tool state", () => { - const streamId = "s10"; - const deltas = [ - { - streamId, - start: 0, - end: 1, - parts: [ - { - type: "tool-call", - toolCallId: "call1", - toolName: "dangerousTool", - input: { action: "delete" }, - }, - ], - } satisfies StreamDelta, - { - streamId, - start: 1, - end: 2, - parts: [ - { - type: "tool-approval-request", - toolCallId: "call1", - approvalId: "approval1", - }, - ], - } satisfies StreamDelta, + it("pushes file parts and merges message metadata in later batches", async () => { + const streamMessage = { + streamId: "s-file-meta", + status: "streaming" as const, + order: 0, + stepOrder: 0, + format: "UIMessageChunk" as const, + agentName: "a", + }; + let msg = blankUIMessage(streamMessage, "thread-file-meta"); + let state = emptyIncrementalStreamState(); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { type: "start" }, + { type: "start-step" }, + ] as UIMessageChunk[], + state, + )); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { + type: "file", + mediaType: "image/png", + url: "https://example.com/a.png", + }, + { type: "message-metadata", messageMetadata: { foo: "bar" } }, + ] as UIMessageChunk[], + state, + )); + + const filePart = msg.parts.find((p) => p.type === "file") as + | { mediaType: string; url: string } + | undefined; + expect(filePart?.mediaType).toBe("image/png"); + expect(filePart?.url).toBe("https://example.com/a.png"); + expect(msg.metadata).toEqual({ foo: "bar" }); + }); + + it("tracks concurrent text parts by id across batches", async () => { + const streamMessage = { + streamId: "s-multi-text", + status: "streaming" as const, + order: 0, + stepOrder: 0, + format: "UIMessageChunk" as const, + agentName: "a", + }; + let msg = blankUIMessage(streamMessage, "thread-multi-text"); + let state = emptyIncrementalStreamState(); + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { type: "start" }, + { type: "start-step" }, + { type: "text-start", id: "t0" }, + { type: "text-start", id: "t1" }, + { type: "text-delta", id: "t0", delta: "A" }, + ] as UIMessageChunk[], + state, + )); + // Deltas in a later batch must land on the part matching their id. + ({ message: msg, streamState: state } = applyUIMessageChunksIncremental( + msg, + [ + { type: "text-delta", id: "t1", delta: "B" }, + { type: "text-delta", id: "t0", delta: "C" }, + ] as UIMessageChunk[], + state, + )); + + const textParts = msg.parts.filter((p) => p.type === "text") as Array<{ + text: string; + }>; + expect(textParts.map((p) => p.text)).toEqual(["AC", "B"]); + }); + + it("incremental batches match the SDK processing the full stream", async () => { + const streamMessage = { + streamId: "s-equiv", + status: "streaming" as const, + order: 0, + stepOrder: 0, + format: "UIMessageChunk" as const, + agentName: "a", + }; + const batches: UIMessageChunk[][] = [ + [ + { type: "start" }, + { type: "start-step" }, + { type: "text-start", id: "t0" }, + { type: "text-delta", id: "t0", delta: "Hello " }, + ] as UIMessageChunk[], + [ + { type: "text-delta", id: "t0", delta: "world" }, + { type: "text-end", id: "t0" }, + { type: "tool-input-start", toolCallId: "c1", toolName: "myTool" }, + { type: "tool-input-delta", toolCallId: "c1", inputTextDelta: '{"q":' }, + ] as UIMessageChunk[], + [ + { type: "tool-input-delta", toolCallId: "c1", inputTextDelta: '"hi"}' }, + { + type: "tool-input-available", + toolCallId: "c1", + toolName: "myTool", + input: { q: "hi" }, + }, + { + type: "tool-output-available", + toolCallId: "c1", + output: { ok: true }, + }, + { type: "finish-step" }, + { type: "finish" }, + ] as UIMessageChunk[], ]; - const [[message], _, changed] = deriveUIMessagesFromTextStreamParts( - "thread1", - [{ streamId, order: 10, stepOrder: 0, status: "streaming" }], - [], - deltas, + + // SDK: process the entire stream at once. + const sdkMsg = await updateFromUIMessageChunks( + blankUIMessage(streamMessage, "thread-equiv"), + batches.flat(), ); - expect(message).toBeDefined(); - expect(message.role).toBe("assistant"); - expect(changed).toBe(true); - const toolPart = message.parts.find( - (p) => p.type === "tool-dangerousTool", - ) as any; - expect(toolPart).toBeDefined(); - expect(toolPart.state).toBe("approval-requested"); - expect(toolPart.approval).toEqual({ id: "approval1" }); + // Incremental: process batch by batch, threading state. + let incMsg = blankUIMessage(streamMessage, "thread-equiv"); + let state = emptyIncrementalStreamState(); + for (const batch of batches) { + ({ message: incMsg, streamState: state } = + applyUIMessageChunksIncremental(incMsg, batch, state)); + } + + expect(incMsg.parts).toEqual(sdkMsg.parts); + expect(incMsg.text).toBe(sdkMsg.text); }); }); diff --git a/src/deltas.ts b/src/deltas.ts index 0e816d04..ff5a21ec 100644 --- a/src/deltas.ts +++ b/src/deltas.ts @@ -3,9 +3,7 @@ import { type DynamicToolUIPart, type ProviderMetadata, type ReasoningUIPart, - type TextStreamPart, type TextUIPart, - type ToolSet, type ToolUIPart, type UIMessageChunk, } from "ai"; @@ -57,6 +55,9 @@ export async function updateFromUIMessageChunks( uiMessage: UIMessage, parts: UIMessageChunk[], ) { + if (parts.length === 0) { + return uiMessage; + } const partsStream = new ReadableStream({ start(controller) { for (const part of parts) { @@ -72,11 +73,7 @@ export async function updateFromUIMessageChunks( stream: partsStream, onError: (e) => { const errorMessage = e instanceof Error ? e.message : String(e); - // Tool invocation errors can be safely ignored when streaming continuation - // after tool approval - the stored messages have the complete tool context if (errorMessage.toLowerCase().includes("no tool invocation found")) { - // Silently suppress - this is expected after tool approval when the - // continuation stream has tool-result without the original tool-call suppressError = true; return; } @@ -95,8 +92,6 @@ export async function updateFromUIMessageChunks( message = messagePart; } } catch (e) { - // If we've already handled this error in onError and marked it as suppressed, - // don't rethrow - the stored messages provide the fallback if (!suppressError) { throw e; } @@ -108,188 +103,86 @@ export async function updateFromUIMessageChunks( return message; } -export async function deriveUIMessagesFromDeltas( - threadId: string, - streamMessages: StreamMessage[], - allDeltas: StreamDelta[], -): Promise { - const messages: UIMessage[] = []; - for (const streamMessage of streamMessages) { - if (streamMessage.format === "UIMessageChunk") { - const { parts } = getParts( - allDeltas.filter((d) => d.streamId === streamMessage.streamId), - 0, - ); - const uiMessage = await updateFromUIMessageChunks( - blankUIMessage(streamMessage, threadId), - parts, - ); - messages.push(uiMessage); - } else { - const [uiMessages] = deriveUIMessagesFromTextStreamParts( - threadId, - [streamMessage], - [], - allDeltas, - ); - messages.push(...uiMessages); - } - } - return sorted(messages); -} - -/** - * - */ +type ToolPart = ToolUIPart | DynamicToolUIPart; -export function deriveUIMessagesFromTextStreamParts( - threadId: string, - streamMessages: StreamMessage[], - existingStreams: Array<{ - streamId: string; - cursor: number; - message: UIMessage; - }>, - allDeltas: StreamDelta[], -): [ - UIMessage[], - Array<{ streamId: string; cursor: number; message: UIMessage }>, - boolean, -] { - const newStreams: Array<{ - streamId: string; - cursor: number; - message: UIMessage; - }> = []; - // Seed the existing chunks - let changed = false; - for (const streamMessage of streamMessages) { - const deltas = allDeltas.filter( - (d) => d.streamId === streamMessage.streamId, - ); - const existing = existingStreams.find( - (s) => s.streamId === streamMessage.streamId, - ); - const [newStream, messageChanged] = updateFromTextStreamParts( - threadId, - streamMessage, - existing, - deltas, - ); - newStreams.push(newStream); - if (messageChanged) changed = true; - } - for (const { streamId } of existingStreams) { - if (!newStreams.find((s) => s.streamId === streamId)) { - // There's a stream that's no longer active. - changed = true; - } - } - const messages = sorted(newStreams.map((s) => s.message)); - return [messages, newStreams, changed]; +function transitionToolPart( + part: ToolPart, + updates: { state: S } & Partial>, +): void { + Object.assign(part, updates); } -export function getParts( - deltas: StreamDelta[], - fromCursor?: number, -): { parts: T[]; cursor: number } { - const parts: T[] = []; - let cursor = fromCursor ?? 0; - for (const delta of deltas.sort((a, b) => a.start - b.start)) { - if (delta.parts.length === 0) { - console.debug(`Got delta with no parts: ${JSON.stringify(delta)}`); - continue; - } - if (cursor !== delta.start) { - if (cursor >= delta.end) { - continue; - } else if (cursor < delta.start) { - console.warn( - `Got delta for stream ${delta.streamId} that has a gap ${cursor} -> ${delta.start}`, - ); - break; - } else { - throw new Error( - `Got unexpected delta for stream ${delta.streamId}: delta: ${delta.start} -> ${delta.end} existing cursor: ${cursor}`, - ); - } - } - parts.push(...delta.parts); - cursor = delta.end; - } - return { parts, cursor }; +export type IncrementalStreamState = { + // chunk id -> index of the streaming text part in message.parts + activeText: Record; + // chunk id -> index of the streaming reasoning part in message.parts + activeReasoning: Record; + // toolCallId -> raw accumulated input JSON text (kept separate from the + // parsed `input` so partial JSON can be repair-parsed each batch) + toolInputText: Record; +}; + +export function emptyIncrementalStreamState(): IncrementalStreamState { + return { activeText: {}, activeReasoning: {}, toolInputText: {} }; } /** - * This is historically from when we would use the onChunk callback instead of - * consuming the full UIMessageStream. + * Apply a batch of new UIMessageChunks to an existing UIMessage without + * replaying prior chunks. `prev` carries the ephemeral stream state that the + * UIMessage itself can't hold (which text/reasoning parts are still streaming, + * and the raw accumulated tool input text). Parts are append-only, so part + * indices stay stable across the structuredClone between batches. Behavior + * mirrors the AI SDK's processUIMessageStream. */ +export function applyUIMessageChunksIncremental( + uiMessage: UIMessage, + newParts: UIMessageChunk[], + prev: IncrementalStreamState, +): { message: UIMessage; streamState: IncrementalStreamState } { + const message: UIMessage = structuredClone(uiMessage); + const activeText: Record = { ...prev.activeText }; + const activeReasoning: Record = { ...prev.activeReasoning }; + const toolInputText: Record = { ...prev.toolInputText }; + const touchedTools = new Set(); -// exported for testing -export function updateFromTextStreamParts( - threadId: string, - streamMessage: StreamMessage, - existing: - | { streamId: string; cursor: number; message: UIMessage } - | undefined, - deltas: StreamDelta[], -): [{ streamId: string; cursor: number; message: UIMessage }, boolean] { - const { cursor, parts } = getParts>( - deltas, - existing?.cursor, - ); - const changed = - parts.length > 0 || - (existing && - statusFromStreamStatus(streamMessage.status) !== existing.message.status); - const existingMessage = - existing?.message ?? blankUIMessage(streamMessage, threadId); - if (!changed) { - return [ - existing ?? { - streamId: streamMessage.streamId, - cursor, - message: existingMessage, - }, - false, - ]; - } - - const message: UIMessage = structuredClone(existingMessage); - message.status = statusFromStreamStatus(streamMessage.status); - - const textPartsById = new Map(); - const toolPartsById = new Map( - message.parts - .filter( - (p): p is ToolUIPart | DynamicToolUIPart => - p.type.startsWith("tool-") || p.type === "dynamic-tool", - ) - .map((p) => [p.toolCallId, p]), - ); - const reasoningPartsById = new Map(); + const toolIndexById = new Map(); + message.parts.forEach((p, i) => { + if ("toolCallId" in p && (p.type.startsWith("tool-") || p.type === "dynamic-tool")) { + toolIndexById.set((p as ToolPart).toolCallId, i); + } + }); + const toolPartAt = (toolCallId: string): ToolPart | undefined => { + const idx = toolIndexById.get(toolCallId); + return idx === undefined ? undefined : (message.parts[idx] as ToolPart); + }; + const mergeMetadata = (metadata: unknown) => { + if (metadata == null) { + return; + } + message.metadata = { + ...(message.metadata as Record | undefined), + ...(metadata as Record), + } as typeof message.metadata; + }; - for (const part of parts) { + for (const part of newParts) { switch (part.type) { - case "text-start": + case "text-start": { + const newPart: TextUIPart = { + type: "text", + text: "", + state: "streaming", + providerMetadata: part.providerMetadata, + }; + message.parts.push(newPart); + activeText[part.id] = message.parts.length - 1; + break; + } case "text-delta": { - if (!textPartsById.has(part.id)) { - const lastPart = message.parts.at(-1); - if (lastPart?.type === "text") { - textPartsById.set(part.id, lastPart); - } else { - const newPart = { - type: "text", - text: "", - providerMetadata: part.providerMetadata, - } satisfies TextUIPart; - textPartsById.set(part.id, newPart); - message.parts.push(newPart); - } - } - if (part.type === "text-delta") { - const textPart = textPartsById.get(part.id)!; - textPart.text += part.text; + const idx = activeText[part.id]; + if (idx !== undefined) { + const textPart = message.parts[idx] as TextUIPart; + textPart.text += part.delta; textPart.providerMetadata = mergeProviderMetadata( textPart.providerMetadata, part.providerMetadata, @@ -297,139 +190,35 @@ export function updateFromTextStreamParts( } break; } - case "tool-input-start": { - let newPart: ToolUIPart | DynamicToolUIPart; - if (part.dynamic) { - newPart = { - type: "dynamic-tool", - toolCallId: part.id, - toolName: part.toolName, - state: "input-streaming", - input: "", - } satisfies DynamicToolUIPart; - } else { - newPart = { - type: `tool-${part.toolName}`, - toolCallId: part.id, - state: "input-streaming", - input: "", - providerExecuted: part.providerExecuted, - } satisfies ToolUIPart; - } - toolPartsById.set(part.id, newPart); - message.parts.push(newPart); - break; - } - case "tool-input-delta": - { - const toUpdate = toolPartsById.get(part.id); - assert( - toUpdate, - `Expected to find tool call part ${part.id} to update`, - ); - toUpdate.input = (toUpdate.input ?? "") + part.delta; - } - break; - case "tool-input-end": - { - const toUpdate = toolPartsById.get(part.id); - assert( - toUpdate, - `Expected to find tool call part ${part.id} to update`, + case "text-end": { + const idx = activeText[part.id]; + if (idx !== undefined) { + const textPart = message.parts[idx] as TextUIPart; + textPart.state = "done"; + textPart.providerMetadata = mergeProviderMetadata( + textPart.providerMetadata, + part.providerMetadata, ); - toUpdate.state = "input-available"; - if (part.providerMetadata) { - const updatable = toUpdate as Extract< - ToolUIPart | DynamicToolUIPart, - { state: "input-available" } - >; - updatable.callProviderMetadata = mergeProviderMetadata( - updatable.callProviderMetadata, - part.providerMetadata, - ); - } - } - break; - case "tool-call": { - let newPart: ToolUIPart | DynamicToolUIPart; - if (part.dynamic) { - newPart = { - type: "dynamic-tool", - toolCallId: part.toolCallId, - toolName: part.toolName, - input: part.input, - state: "input-available", - }; - } else { - newPart = { - type: `tool-${part.toolName}`, - toolCallId: part.toolCallId, - input: part.input, - state: "input-available", - }; - if (part.providerExecuted) { - newPart.providerExecuted = part.providerExecuted; - } - } - if (part.providerMetadata) { - newPart.callProviderMetadata = part.providerMetadata; - } - if (toolPartsById.has(part.toolCallId)) { - const toUpdate = toolPartsById.get(part.toolCallId)!; - Object.assign(toUpdate, newPart); - } else { - toolPartsById.set(part.toolCallId, newPart); - message.parts.push(newPart); + delete activeText[part.id]; } break; } - case "tool-result": { - const toolCall = toolPartsById.get(part.toolCallId); - assert( - toolCall, - `Expected to find tool call part ${part.toolCallId} to update with result`, - ); - let newPart: ToolUIPart | DynamicToolUIPart; - if (toolCall.type === "dynamic-tool") { - newPart = { - ...toolCall, - state: "output-available", - input: part.input ?? toolCall.input, - output: part.output ?? toolCall.output, - ...pick(part, ["preliminary"]), - } as DynamicToolUIPart; - } else { - newPart = { - ...toolCall, - state: "output-available", - input: part.input ?? toolCall.input, - output: part.output ?? toolCall.output, - preliminary: part.preliminary, - } as ToolUIPart; - } - Object.assign(toolCall, newPart); + case "reasoning-start": { + const newPart: ReasoningUIPart = { + type: "reasoning", + text: "", + state: "streaming", + providerMetadata: part.providerMetadata, + }; + message.parts.push(newPart); + activeReasoning[part.id] = message.parts.length - 1; break; } - case "reasoning-start": case "reasoning-delta": { - if (!reasoningPartsById.has(part.id)) { - const lastPart = message.parts.at(-1); - if (lastPart?.type === "reasoning") { - reasoningPartsById.set(part.id, lastPart); - } else { - const newPart = { - type: "reasoning", - state: "streaming", - text: "", - providerMetadata: part.providerMetadata, - } satisfies ReasoningUIPart; - reasoningPartsById.set(part.id, newPart); - message.parts.push(newPart); - } - } - const reasoningPart = reasoningPartsById.get(part.id)!; - if (part.type === "reasoning-delta") { - reasoningPart.text += part.text; + const idx = activeReasoning[part.id]; + if (idx !== undefined) { + const reasoningPart = message.parts[idx] as ReasoningUIPart; + reasoningPart.text += part.delta; reasoningPart.providerMetadata = mergeProviderMetadata( reasoningPart.providerMetadata, part.providerMetadata, @@ -438,111 +227,277 @@ export function updateFromTextStreamParts( break; } case "reasoning-end": { - const reasoningPart = - reasoningPartsById.get(part.id) ?? - message.parts.find( - (p): p is ReasoningUIPart => - p.type === "reasoning" && p.state === "streaming", - )!; - if (reasoningPart) { + const idx = activeReasoning[part.id]; + if (idx !== undefined) { + const reasoningPart = message.parts[idx] as ReasoningUIPart; reasoningPart.state = "done"; + reasoningPart.providerMetadata = mergeProviderMetadata( + reasoningPart.providerMetadata, + part.providerMetadata, + ); + delete activeReasoning[part.id]; + } + break; + } + case "tool-input-start": { + const newToolPart: ToolUIPart | DynamicToolUIPart = part.dynamic + ? ({ + type: "dynamic-tool", + toolCallId: part.toolCallId, + toolName: part.toolName, + state: "input-streaming", + input: undefined, + } satisfies DynamicToolUIPart) + : ({ + type: `tool-${part.toolName}`, + toolCallId: part.toolCallId, + state: "input-streaming", + input: undefined, + providerExecuted: part.providerExecuted, + } satisfies ToolUIPart); + message.parts.push(newToolPart); + toolIndexById.set(part.toolCallId, message.parts.length - 1); + toolInputText[part.toolCallId] = ""; + break; + } + case "tool-input-delta": { + if (toolIndexById.has(part.toolCallId)) { + toolInputText[part.toolCallId] = + (toolInputText[part.toolCallId] ?? "") + part.inputTextDelta; + touchedTools.add(part.toolCallId); } else { console.warn( - `Expected to find reasoning part ${part.id} to finish, but found none`, + `tool-input-delta for unknown toolCallId ${part.toolCallId}`, ); } break; } - case "source": - if (part.sourceType === "url") { - message.parts.push({ - type: "source-url", - url: part.url, - sourceId: part.id, - providerMetadata: part.providerMetadata, - title: part.title, + case "tool-input-available": { + const toolPart = toolPartAt(part.toolCallId); + if (toolPart) { + transitionToolPart(toolPart, { + state: "input-available", + input: part.input, + callProviderMetadata: mergeProviderMetadata( + (toolPart as { callProviderMetadata?: ProviderMetadata }) + .callProviderMetadata, + part.providerMetadata, + ), }); - } else if (part.sourceType === "document") { - message.parts.push({ - type: "source-document", - mediaType: part.mediaType, - sourceId: part.id, - title: part.title, - filename: part.filename, - providerMetadata: part.providerMetadata, + } + touchedTools.delete(part.toolCallId); + // The raw JSON buffer is no longer needed; drop it so it doesn't get + // carried through every later batch on the hot path. + delete toolInputText[part.toolCallId]; + break; + } + case "tool-input-error": { + const toolPart = toolPartAt(part.toolCallId); + if (toolPart) { + transitionToolPart(toolPart, { + state: "output-error", + errorText: part.errorText, + providerExecuted: part.providerExecuted, + ...(toolPart.type === "dynamic-tool" + ? { input: part.input } + : { input: undefined, rawInput: part.input }), + callProviderMetadata: mergeProviderMetadata( + (toolPart as { callProviderMetadata?: ProviderMetadata }) + .callProviderMetadata, + part.providerMetadata, + ), }); - } else { - console.warn("Got source part with unknown source type", part); } + touchedTools.delete(part.toolCallId); + delete toolInputText[part.toolCallId]; break; - case "abort": - message.status = "failed"; + } + case "tool-output-available": { + const toolPart = toolPartAt(part.toolCallId); + if (toolPart) { + transitionToolPart(toolPart, { + state: "output-available", + output: part.output, + preliminary: part.preliminary, + providerExecuted: part.providerExecuted, + }); + } break; - case "error": - message.status = "failed"; - console.warn("Generation failed with error", part.error); + } + case "tool-output-error": { + const toolPart = toolPartAt(part.toolCallId); + if (toolPart) { + transitionToolPart(toolPart, { + state: "output-error", + errorText: part.errorText, + providerExecuted: part.providerExecuted, + }); + } break; - case "tool-error": { - const toolPart = toolPartsById.get(part.toolCallId); + } + case "tool-output-denied": { + const toolPart = toolPartAt(part.toolCallId); if (toolPart) { - toolPart.errorText = getErrorMessage(part.error); + transitionToolPart(toolPart, { state: "output-denied" }); } break; } case "tool-approval-request": { - const typedPart = part as unknown as { - type: "tool-approval-request"; - toolCallId: string; - approvalId: string; - }; - const toolPart = toolPartsById.get(typedPart.toolCallId); + const toolPart = toolPartAt(part.toolCallId); if (toolPart) { - toolPart.state = "approval-requested"; - (toolPart as ToolUIPart & { approval?: object }).approval = { - id: typedPart.approvalId, - }; - } else { - console.warn( - `Expected tool call part ${typedPart.toolCallId} for approval request`, - ); + transitionToolPart(toolPart, { + state: "approval-requested", + approval: { id: part.approvalId }, + }); } break; } + case "source-url": + message.parts.push({ + type: "source-url", + url: part.url, + sourceId: part.sourceId, + title: part.title, + providerMetadata: part.providerMetadata, + }); + break; + case "source-document": + message.parts.push({ + type: "source-document", + mediaType: part.mediaType, + sourceId: part.sourceId, + title: part.title, + filename: part.filename, + providerMetadata: part.providerMetadata, + }); + break; case "file": - case "text-end": - case "finish-step": - case "finish": - case "raw": + message.parts.push({ + type: "file", + mediaType: part.mediaType, + url: part.url, + }); + break; case "start-step": + message.parts.push({ type: "step-start" }); + break; + case "finish-step": + // Match the SDK: a new step starts fresh streaming parts; the prior + // parts keep their state rather than being forced to "done". + for (const id of Object.keys(activeText)) delete activeText[id]; + for (const id of Object.keys(activeReasoning)) delete activeReasoning[id]; + break; case "start": - // ignore + case "finish": + case "message-metadata": + mergeMetadata(part.messageMetadata); + break; + case "abort": + case "error": + // The stream-level status (statusFromStreamStatus) is authoritative and + // is applied by the caller; nothing to mutate on the message here. break; default: { - // Exhaustiveness check disabled intentionally for forwards compatibility. - // New TextStreamPart types from future AI SDK versions will trigger a - // runtime warning rather than a compile error, allowing graceful degradation. - // const _: never = part; - console.warn(`Received unexpected part: ${JSON.stringify(part)}`); + if (typeof part.type === "string" && part.type.startsWith("data-")) { + const dataPart = part as Extract< + UIMessageChunk, + { type: `data-${string}` } + >; + const existingIdx = + dataPart.id != null + ? message.parts.findIndex( + (p) => + p.type === dataPart.type && + (p as { id?: string }).id === dataPart.id, + ) + : -1; + if (existingIdx >= 0) { + (message.parts[existingIdx] as { data?: unknown }).data = + dataPart.data; + } else { + message.parts.push( + dataPart as unknown as UIMessage["parts"][number], + ); + } + } else { + console.warn( + `applyUIMessageChunksIncremental: unhandled chunk type ${String(part.type)}`, + ); + } break; } } } - // Consider reasoning done once something else happens - for (let i = 0; i < message.parts.length - 1; i++) { - const part = message.parts[i]; - if (part.type === "reasoning") { - part.state = "done"; + + for (const toolCallId of touchedTools) { + const toolPart = toolPartAt(toolCallId); + if (toolPart && toolPart.state === "input-streaming") { + try { + toolPart.input = JSON.parse(toolInputText[toolCallId] ?? ""); + } catch { + // partial JSON — leave input unset until complete + } } } + message.text = joinText(message.parts); - return [ - { - streamId: streamMessage.streamId, - cursor, - message, - }, - true, - ]; + return { message, streamState: { activeText, activeReasoning, toolInputText } }; +} + +export async function deriveUIMessagesFromDeltas( + threadId: string, + streamMessages: StreamMessage[], + allDeltas: StreamDelta[], +): Promise { + const messages: UIMessage[] = []; + for (const streamMessage of streamMessages) { + if (streamMessage.format !== "UIMessageChunk") { + throw new Error( + `deriveUIMessagesFromDeltas: unsupported stream format "${streamMessage.format ?? "text"}" for stream ${streamMessage.streamId}`, + ); + } + const { parts } = getParts( + allDeltas.filter((d) => d.streamId === streamMessage.streamId), + 0, + ); + const uiMessage = await updateFromUIMessageChunks( + blankUIMessage(streamMessage, threadId), + parts, + ); + messages.push(uiMessage); + } + return sorted(messages); +} + +export function getParts( + deltas: StreamDelta[], + fromCursor?: number, +): { parts: T[]; cursor: number } { + const parts: T[] = []; + let cursor = fromCursor ?? 0; + for (const delta of deltas.sort((a, b) => a.start - b.start)) { + if (delta.parts.length === 0) { + console.debug(`Got delta with no parts: ${JSON.stringify(delta)}`); + continue; + } + if (cursor !== delta.start) { + if (cursor >= delta.end) { + continue; + } else if (cursor < delta.start) { + console.warn( + `Got delta for stream ${delta.streamId} that has a gap ${cursor} -> ${delta.start}`, + ); + break; + } else { + throw new Error( + `Got unexpected delta for stream ${delta.streamId}: delta: ${delta.start} -> ${delta.end} existing cursor: ${cursor}`, + ); + } + } + parts.push(...delta.parts); + cursor = delta.end; + } + return { parts, cursor }; } function mergeProviderMetadata( diff --git a/src/react/useStreamingUIMessages.ts b/src/react/useStreamingUIMessages.ts index e8a4b003..540b4f53 100644 --- a/src/react/useStreamingUIMessages.ts +++ b/src/react/useStreamingUIMessages.ts @@ -4,10 +4,12 @@ import { type UIDataTypes, type UIMessageChunk, type UITools } from "ai"; import type { StreamQuery, StreamQueryArgs } from "./types.js"; import { type UIMessage } from "../UIMessages.js"; import { + applyUIMessageChunksIncremental, blankUIMessage, + emptyIncrementalStreamState, getParts, - updateFromUIMessageChunks, - deriveUIMessagesFromTextStreamParts, + statusFromStreamStatus, + type IncrementalStreamState, } from "../deltas.js"; import { useDeltaStreams } from "./useDeltaStreams.js"; @@ -53,6 +55,7 @@ export function useStreamingUIMessages< { uiMessage: UIMessage; cursor: number; + streamState: IncrementalStreamState; } > >({}); @@ -63,16 +66,23 @@ export function useStreamingUIMessages< useEffect(() => { if (!streams) return; - // return if there are no new deltas beyond the cursors let noNewDeltas = true; for (const stream of streams) { - const lastDelta = stream.deltas.at(-1); - const cursor = messageState[stream.streamMessage.streamId]?.cursor; - if (!cursor) { + const existingStreamState = messageState[stream.streamMessage.streamId]; + const cursor = existingStreamState?.cursor; + if (existingStreamState === undefined || cursor === undefined) { noNewDeltas = false; break; } - if (lastDelta && lastDelta.start >= cursor) { + if (stream.deltas.some((d) => d.parts.length > 0 && d.end > cursor)) { + noNewDeltas = false; + break; + } + if ( + existingStreamState && + existingStreamState.uiMessage.status !== + statusFromStreamStatus(stream.streamMessage.status) + ) { noNewDeltas = false; break; } @@ -87,40 +97,58 @@ export function useStreamingUIMessages< { uiMessage: UIMessage; cursor: number; + streamState: IncrementalStreamState; } > = Object.fromEntries( await Promise.all( streams.map(async ({ deltas, streamMessage }) => { - const { parts, cursor } = getParts(deltas, 0); - if (streamMessage.format === "UIMessageChunk") { - // Unfortunately this can't handle resuming from a UIMessage and - // adding more chunks, so we re-create it from scratch each time. - const uiMessage = await updateFromUIMessageChunks( - blankUIMessage(streamMessage, threadId), - parts, - ); - return [ - streamMessage.streamId, - { - uiMessage, - cursor, - }, - ]; - } else { - const [uiMessages] = deriveUIMessagesFromTextStreamParts( - threadId, - [streamMessage], - [], - deltas, - ); + const streamId = streamMessage.streamId; + const existing = messageState[streamId]; + const fromCursor = existing?.cursor ?? 0; + const status = statusFromStreamStatus(streamMessage.status); + const prevState = + existing?.streamState ?? emptyIncrementalStreamState(); + + const { parts: newParts, cursor } = getParts( + deltas, + fromCursor, + ); + + const base = + existing?.uiMessage ?? + blankUIMessage(streamMessage, threadId as string); + + if (newParts.length === 0) { + if (existing && existing.uiMessage.status !== status) { + return [ + streamId, + { + uiMessage: { ...existing.uiMessage, status }, + cursor: existing.cursor, + streamState: prevState, + }, + ]; + } return [ - streamMessage.streamId, - { - uiMessage: uiMessages[0], - cursor, - }, + streamId, + existing ?? { uiMessage: base, cursor: 0, streamState: prevState }, ]; } + + const { message, streamState } = applyUIMessageChunksIncremental( + base as UIMessage, + newParts, + prevState, + ); + message.status = status; + return [ + streamId, + { + uiMessage: message as UIMessage, + cursor, + streamState, + }, + ]; }), ), );