diff --git a/src/services/socket-manager/message-queue.test.ts b/src/services/socket-manager/message-queue.test.ts index f2d2d924..2b1deb2a 100644 --- a/src/services/socket-manager/message-queue.test.ts +++ b/src/services/socket-manager/message-queue.test.ts @@ -281,6 +281,153 @@ describe('createMessageEventQueue', () => { }); }); + describe('multi-message assistant turns', () => { + beforeEach(() => { + mockItems.messages.push({ + id: 'user-1', + role: 'user', + content: 'please book a meeting', + parts: [], + created_at: new Date().toISOString(), + transcribed: true, + }); + }); + + it('should append a new assistant message when answer arrives with a different id', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + onMessage(ChatProgress.Answer, { id: 'assistant-1', content: "ok, i'll book a meeting" }); + onMessage(ChatProgress.Answer, { id: 'assistant-2', content: 'i booked a meeting for you' }); + + const assistantMessages = mockItems.messages.filter(m => m.role === 'assistant'); + expect(assistantMessages).toHaveLength(2); + expect(assistantMessages[0].id).toBe('assistant-1'); + expect(assistantMessages[0].content).toBe("ok, i'll book a meeting"); + expect(assistantMessages[1].id).toBe('assistant-2'); + expect(assistantMessages[1].content).toBe('i booked a meeting for you'); + }); + + it('should preserve the first assistant message when a second arrives after a tool call', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + onMessage(ChatProgress.Answer, { id: 'assistant-1', content: "ok, i'll book a meeting" }); + // Tool-call events are dispatched via a separate path and do not touch messages; + // simulating that gap here means the next answer event arrives with a fresh id. + onMessage(ChatProgress.Answer, { id: 'assistant-2', content: 'i booked a meeting for you' }); + + expect(mockItems.messages.map(m => m.content)).toEqual([ + 'please book a meeting', + "ok, i'll book a meeting", + 'i booked a meeting for you', + ]); + }); + + it('should overwrite the last assistant message when answer has the same id', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + onMessage(ChatProgress.Answer, { id: 'assistant-1', content: 'first draft' }); + onMessage(ChatProgress.Answer, { id: 'assistant-1', content: 'final answer' }); + + const assistantMessages = mockItems.messages.filter(m => m.role === 'assistant'); + expect(assistantMessages).toHaveLength(1); + expect(assistantMessages[0].content).toBe('final answer'); + }); + + it('should overwrite the last assistant message when answer has no id (legacy backends)', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + onMessage(ChatProgress.Answer, { content: 'first' }); + onMessage(ChatProgress.Answer, { content: 'second' }); + + const assistantMessages = mockItems.messages.filter(m => m.role === 'assistant'); + expect(assistantMessages).toHaveLength(1); + expect(assistantMessages[0].content).toBe('second'); + }); + + it('should not leak content from the previous assistant message into the new one', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + onMessage(ChatProgress.Answer, { id: 'assistant-1', content: "ok, i'll book a meeting" }); + onMessage(ChatProgress.Answer, { id: 'assistant-2', content: 'done' }); + + const assistantMessages = mockItems.messages.filter(m => m.role === 'assistant'); + expect(assistantMessages).toHaveLength(2); + expect(assistantMessages[1].content).toBe('done'); + expect(assistantMessages[1].content).not.toContain("ok, i'll book"); + }); + + it('should keep streaming partials into the same assistant message', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + onMessage(ChatProgress.Partial, { content: 'Hello', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: ' World', sequence: 1 }); + + const assistantMessages = mockItems.messages.filter(m => m.role === 'assistant'); + expect(assistantMessages).toHaveLength(1); + expect(assistantMessages[0].content).toBe('Hello World'); + }); + + it('should append a new assistant message when partials arrive with a different id', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + onMessage(ChatProgress.Partial, { id: 'assistant-1', content: 'first ', sequence: 0 }); + onMessage(ChatProgress.Partial, { id: 'assistant-1', content: 'message', sequence: 1 }); + onMessage(ChatProgress.Answer, { id: 'assistant-1', content: 'first message' }); + + onMessage(ChatProgress.Partial, { id: 'assistant-2', content: 'second ', sequence: 0 }); + onMessage(ChatProgress.Partial, { id: 'assistant-2', content: 'message', sequence: 1 }); + onMessage(ChatProgress.Answer, { id: 'assistant-2', content: 'second message' }); + + const assistantMessages = mockItems.messages.filter(m => m.role === 'assistant'); + expect(assistantMessages).toHaveLength(2); + expect(assistantMessages[0].content).toBe('first message'); + expect(assistantMessages[1].content).toBe('second message'); + expect(assistantMessages[1].content).not.toContain('first'); + }); + }); + describe('message parts population', () => { it('should populate parts on partial messages', () => { const { onMessage } = createMessageEventQueue( diff --git a/src/services/socket-manager/message-queue.ts b/src/services/socket-manager/message-queue.ts index 930d2dd5..91ba7649 100644 --- a/src/services/socket-manager/message-queue.ts +++ b/src/services/socket-manager/message-queue.ts @@ -57,7 +57,8 @@ function processChatEvent( data: any, chatEventQueue: ChatEventQueue, items: AgentManagerItems, - onNewMessage: AgentManagerOptions['callbacks']['onNewMessage'] + onNewMessage: AgentManagerOptions['callbacks']['onNewMessage'], + clearQueue: () => void ) { if (event === ChatProgress.Transcribe && data.content) { handleAudioTranscribedMessage(data, items, onNewMessage); @@ -70,10 +71,21 @@ function processChatEvent( const lastMessage = items.messages[items.messages.length - 1]; + // A new assistant message within the same turn (e.g. after a client tool call, or several + // assistant messages in a row) is signalled by a chat event whose `id` differs from the + // last assistant message. The new message typically starts with `Partial` events and ends + // with `Answer`, so both branches must detect the id change — otherwise the SDK overwrites + // the previous message on the first partial of the new one. + const isNewAssistantMessage = data.id && lastMessage?.role === 'assistant' && lastMessage.id !== data.id; + let currentMessage: Message; - if (lastMessage?.role === 'assistant') { + if (lastMessage?.role === 'assistant' && !isNewAssistantMessage) { currentMessage = lastMessage; - } else if (!lastMessage || (lastMessage.transcribed && lastMessage.role === 'user')) { + } else if (!lastMessage || (lastMessage.transcribed && lastMessage.role === 'user') || isNewAssistantMessage) { + if (isNewAssistantMessage) { + // Reset the streaming buffer so the next message does not inherit the previous one's content. + clearQueue(); + } currentMessage = { id: data.id || `assistant-${Date.now()}`, role: data.role || 'assistant', @@ -132,7 +144,7 @@ export function createMessageEventQueue( : event === StreamEvents.ChatAudioTranscribed ? ChatProgress.Transcribe : (event as ChatProgress); - processChatEvent(chatEvent, data, chatEventQueue, items, onNewMessage); + processChatEvent(chatEvent, data, chatEventQueue, items, onNewMessage, clearQueue); if (chatEvent === ChatProgress.Answer) { analytics.track('agent-message-received', { diff --git a/src/services/streaming-manager/livekit-manager.test.ts b/src/services/streaming-manager/livekit-manager.test.ts index 118fa076..933fa9ca 100644 --- a/src/services/streaming-manager/livekit-manager.test.ts +++ b/src/services/streaming-manager/livekit-manager.test.ts @@ -1141,20 +1141,14 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { }); describe('Enum values', () => { - it('should have correct StreamEvents enum values for tool events', () => { - // ASSERT: - expect(StreamEvents.ToolCalling).toBe('tool/calling'); - expect(StreamEvents.ToolResult).toBe('tool/result'); - }); - it('should have correct AgentActivityState enum value for ToolActive', () => { // ASSERT: expect(AgentActivityState.ToolActive).toBe('TOOL_ACTIVE'); }); }); - describe('handleDataReceived - tool/calling', () => { - it('should transition to ToolActive and call onToolEvent on tool/calling', async () => { + describe('handleDataReceived - tool-call/started', () => { + it('should transition to ToolActive and forward payload via onToolEvent on tool-call/started', async () => { // ARRANGE: const onAgentActivityStateChange = jest.fn(); const onToolEvent = jest.fn(); @@ -1166,11 +1160,12 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { const dataHandler = getDataReceivedHandler(); const payload = createDataChannelPayload({ - subject: StreamEvents.ToolCalling, - execution_id: 'exec-123', - tool_name: 'get_weather', - arguments: { location: 'Tel Aviv' }, - created_at: new Date().toISOString(), + subject: StreamEvents.ToolCallStarted, + call_id: 'call-123', + name: 'get_weather', + input: { location: 'Tel Aviv' }, + output: {}, + timestamp: new Date().toISOString(), }); // ACT: @@ -1179,17 +1174,18 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { // ASSERT: expect(onAgentActivityStateChange).toHaveBeenCalledWith(AgentActivityState.ToolActive); expect(onToolEvent).toHaveBeenCalledWith( - StreamEvents.ToolCalling, + StreamEvents.ToolCallStarted, expect.objectContaining({ - execution_id: 'exec-123', - tool_name: 'get_weather', + call_id: 'call-123', + name: 'get_weather', + input: { location: 'Tel Aviv' }, }) ); }); }); - describe('handleDataReceived - tool/result', () => { - it('should call onToolEvent but not change state on tool/result', async () => { + describe('handleDataReceived - tool-call/done', () => { + it('should forward payload via onToolEvent without changing activity state', async () => { // ARRANGE: const onAgentActivityStateChange = jest.fn(); const onToolEvent = jest.fn(); @@ -1201,38 +1197,92 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { const dataHandler = getDataReceivedHandler(); - // First trigger tool/calling to set ToolActive + // Set ToolActive first so we can verify done doesn't touch it dataHandler( createDataChannelPayload({ - subject: StreamEvents.ToolCalling, - execution_id: 'exec-123', - tool_name: 'get_weather', - arguments: {}, - created_at: new Date().toISOString(), + subject: StreamEvents.ToolCallStarted, + call_id: 'call-123', + name: 'get_weather', + input: {}, + output: {}, + timestamp: new Date().toISOString(), }) ); onAgentActivityStateChange.mockClear(); - const toolResultPayload = createDataChannelPayload({ - subject: StreamEvents.ToolResult, - execution_id: 'exec-123', - tool_name: 'get_weather', - success: true, + const donePayload = createDataChannelPayload({ + subject: StreamEvents.ToolCallDone, + call_id: 'call-123', + name: 'get_weather', + input: {}, + output: { temp: 22 }, duration_ms: 500, - error_message: null, - created_at: new Date().toISOString(), + extra: {}, + timestamp: new Date().toISOString(), }); // ACT: - dataHandler(toolResultPayload); + dataHandler(donePayload); // ASSERT: expect(onAgentActivityStateChange).not.toHaveBeenCalled(); expect(onToolEvent).toHaveBeenCalledWith( - StreamEvents.ToolResult, + StreamEvents.ToolCallDone, expect.objectContaining({ - execution_id: 'exec-123', - success: true, + call_id: 'call-123', + output: { temp: 22 }, + duration_ms: 500, + }) + ); + }); + }); + + describe('handleDataReceived - tool-call/error', () => { + it('should forward payload via onToolEvent without changing activity state', async () => { + // ARRANGE: + const onAgentActivityStateChange = jest.fn(); + const onToolEvent = jest.fn(); + options.callbacks.onAgentActivityStateChange = onAgentActivityStateChange; + options.callbacks.onToolEvent = onToolEvent; + + await createLiveKitStreamingManager(agentId, sessionOptions, options); + await simulateConnection(); + + const dataHandler = getDataReceivedHandler(); + + dataHandler( + createDataChannelPayload({ + subject: StreamEvents.ToolCallStarted, + call_id: 'call-123', + name: 'get_weather', + input: {}, + output: {}, + timestamp: new Date().toISOString(), + }) + ); + onAgentActivityStateChange.mockClear(); + + const errorPayload = createDataChannelPayload({ + subject: StreamEvents.ToolCallError, + call_id: 'call-123', + name: 'get_weather', + input: {}, + output: {}, + duration_ms: 120, + extra: { message: 'upstream timeout' }, + timestamp: new Date().toISOString(), + }); + + // ACT: + dataHandler(errorPayload); + + // ASSERT: + expect(onAgentActivityStateChange).not.toHaveBeenCalled(); + expect(onToolEvent).toHaveBeenCalledWith( + StreamEvents.ToolCallError, + expect.objectContaining({ + call_id: 'call-123', + extra: { message: 'upstream timeout' }, }) ); }); @@ -1252,11 +1302,12 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { // Set ToolActive state first dataHandler( createDataChannelPayload({ - subject: StreamEvents.ToolCalling, - execution_id: 'exec-123', - tool_name: 'test', - arguments: {}, - created_at: new Date().toISOString(), + subject: StreamEvents.ToolCallStarted, + call_id: 'call-123', + name: 'test', + input: {}, + output: {}, + timestamp: new Date().toISOString(), }) ); onAgentActivityStateChange.mockClear(); @@ -1286,11 +1337,12 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { // Set ToolActive state first dataHandler( createDataChannelPayload({ - subject: StreamEvents.ToolCalling, - execution_id: 'exec-123', - tool_name: 'test', - arguments: {}, - created_at: new Date().toISOString(), + subject: StreamEvents.ToolCallStarted, + call_id: 'call-123', + name: 'test', + input: {}, + output: {}, + timestamp: new Date().toISOString(), }) ); onAgentActivityStateChange.mockClear(); @@ -1319,11 +1371,12 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { // Set ToolActive state first dataHandler( createDataChannelPayload({ - subject: StreamEvents.ToolCalling, - execution_id: 'exec-123', - tool_name: 'test', - arguments: {}, - created_at: new Date().toISOString(), + subject: StreamEvents.ToolCallStarted, + call_id: 'call-123', + name: 'test', + input: {}, + output: {}, + timestamp: new Date().toISOString(), }) ); onAgentActivityStateChange.mockClear(); @@ -1345,9 +1398,7 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { it('should stay ToolActive across multiple tool calls until final stream-video/done', async () => { // ARRANGE: const onAgentActivityStateChange = jest.fn(); - const onToolEvent = jest.fn(); options.callbacks.onAgentActivityStateChange = onAgentActivityStateChange; - options.callbacks.onToolEvent = onToolEvent; await createLiveKitStreamingManager(agentId, sessionOptions, options); await simulateConnection(); @@ -1357,22 +1408,12 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { // First tool cycle dataHandler( createDataChannelPayload({ - subject: StreamEvents.ToolCalling, - execution_id: 'exec-1', - tool_name: 'tool1', - arguments: {}, - created_at: new Date().toISOString(), - }) - ); - - dataHandler( - createDataChannelPayload({ - subject: StreamEvents.ToolResult, - execution_id: 'exec-1', - tool_name: 'tool1', - success: true, - duration_ms: 100, - created_at: new Date().toISOString(), + subject: StreamEvents.ToolCallStarted, + call_id: 'call-1', + name: 'tool1', + input: {}, + output: {}, + timestamp: new Date().toISOString(), }) ); @@ -1387,22 +1428,12 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { // Second tool cycle dataHandler( createDataChannelPayload({ - subject: StreamEvents.ToolCalling, - execution_id: 'exec-2', - tool_name: 'tool2', - arguments: {}, - created_at: new Date().toISOString(), - }) - ); - - dataHandler( - createDataChannelPayload({ - subject: StreamEvents.ToolResult, - execution_id: 'exec-2', - tool_name: 'tool2', - success: true, - duration_ms: 200, - created_at: new Date().toISOString(), + subject: StreamEvents.ToolCallStarted, + call_id: 'call-2', + name: 'tool2', + input: {}, + output: {}, + timestamp: new Date().toISOString(), }) ); @@ -1417,7 +1448,6 @@ describe('LiveKit Streaming Manager - Tool Events and Activity State', () => { // ASSERT: expect(onAgentActivityStateChange).toHaveBeenCalledWith(AgentActivityState.ToolActive); expect(onAgentActivityStateChange).toHaveBeenLastCalledWith(AgentActivityState.Idle); - expect(onToolEvent).toHaveBeenCalledTimes(4); }); }); diff --git a/src/services/streaming-manager/livekit-manager.ts b/src/services/streaming-manager/livekit-manager.ts index e714d376..3ecd20a6 100644 --- a/src/services/streaming-manager/livekit-manager.ts +++ b/src/services/streaming-manager/livekit-manager.ts @@ -10,8 +10,9 @@ import { StreamingManagerOptions, StreamingState, StreamType, - ToolCallingPayload, - ToolResultPayload, + ToolCallDonePayload, + ToolCallErrorPayload, + ToolCallStartedPayload, } from '@sdk/types'; import { ChatProgress } from '@sdk/types/entities/agents/manager'; import { noop } from '@sdk/utils'; @@ -357,20 +358,25 @@ export async function createLiveKitStreamingManager sets ToolActive + * - tool-call/started -> sets ToolActive * - stream-video/done with interruptible: true -> sets Idle * - stream-video/done with interruptible: false -> stays ToolActive (more tools coming) */ function handleToolEvents(subject: string, data: any): void { - if (subject === StreamEvents.ToolCalling) { + if (subject === StreamEvents.ToolCallStarted) { currentActivityState = AgentActivityState.ToolActive; callbacks.onAgentActivityStateChange?.(AgentActivityState.ToolActive); - callbacks.onToolEvent?.(StreamEvents.ToolCalling, data as ToolCallingPayload); + callbacks.onToolEvent?.(StreamEvents.ToolCallStarted, data as ToolCallStartedPayload); return; } - if (subject === StreamEvents.ToolResult) { - callbacks.onToolEvent?.(StreamEvents.ToolResult, data as ToolResultPayload); + if (subject === StreamEvents.ToolCallDone) { + callbacks.onToolEvent?.(StreamEvents.ToolCallDone, data as ToolCallDonePayload); + return; + } + + if (subject === StreamEvents.ToolCallError) { + callbacks.onToolEvent?.(StreamEvents.ToolCallError, data as ToolCallErrorPayload); } } @@ -421,8 +427,9 @@ export async function createLiveKitStreamingManager = { [StreamEvents.ChatAnswer]: handleChatEvents, [StreamEvents.ChatPartial]: handleChatEvents, - [StreamEvents.ToolCalling]: handleToolEvents, - [StreamEvents.ToolResult]: handleToolEvents, + [StreamEvents.ToolCallStarted]: handleToolEvents, + [StreamEvents.ToolCallDone]: handleToolEvents, + [StreamEvents.ToolCallError]: handleToolEvents, [StreamEvents.StreamVideoCreated]: handleVideoEvents, [StreamEvents.StreamVideoDone]: handleVideoEvents, [StreamEvents.StreamVideoError]: handleVideoEvents, diff --git a/src/types/entities/agents/manager.ts b/src/types/entities/agents/manager.ts index 7ef2b6cd..044c15e3 100644 --- a/src/types/entities/agents/manager.ts +++ b/src/types/entities/agents/manager.ts @@ -10,8 +10,6 @@ import { StreamEvents, StreamType, StreamingState, - ToolCallingPayload, - ToolResultPayload, } from '@sdk/types/stream'; import { SupportedStreamScript } from '@sdk/types/stream-script'; import type { ManagerCallbacks as StreamManagerCallbacks } from '../../stream/stream'; @@ -107,14 +105,11 @@ interface ManagerCallbacks { */ onStreamCreated?: StreamManagerCallbacks['onStreamCreated']; /** - * Optional callback function that will be triggered when tool events occur during the call - * @param event - The tool event type (tool/calling or tool/result) - * @param data - The tool event payload + * Optional callback function that will be triggered when tool-call events occur during the call + * (tool-call/started, tool-call/done, tool-call/error). + * The payload shape is discriminated by the event argument. */ - onToolEvent?: ( - event: StreamEvents.ToolCalling | StreamEvents.ToolResult, - data: ToolCallingPayload | ToolResultPayload - ) => void; + onToolEvent?: StreamManagerCallbacks['onToolEvent']; /** * Optional callback function that will be triggered when the interruptible state changes * @param interruptible - Whether the agent can be interrupted by the user diff --git a/src/types/stream/stream.ts b/src/types/stream/stream.ts index 55f3c5df..fdadbce7 100644 --- a/src/types/stream/stream.ts +++ b/src/types/stream/stream.ts @@ -40,8 +40,9 @@ export enum StreamEvents { StreamVideoDone = 'stream-video/done', StreamVideoError = 'stream-video/error', StreamVideoRejected = 'stream-video/rejected', - ToolCalling = 'tool/calling', - ToolResult = 'tool/result', + ToolCallStarted = 'tool-call/started', + ToolCallDone = 'tool-call/done', + ToolCallError = 'tool-call/error', } export enum ConnectionState { @@ -72,10 +73,7 @@ export interface ManagerCallbacks { onStreamCreated?: (stream: { stream_id: string; session_id: string; agent_id: string }) => void; onStreamReady?: () => void; onInterruptDetected?: (interrupt: Interrupt) => void; - onToolEvent?: ( - event: StreamEvents.ToolCalling | StreamEvents.ToolResult, - data: ToolCallingPayload | ToolResultPayload - ) => void; + onToolEvent?: ToolEventCallback; onInterruptibleChange?: (interruptible: boolean) => void; onFirstAudioDetected?: (metrics: AudioDetectionMetrics) => void; } @@ -196,18 +194,38 @@ export interface StreamInterruptPayload { export type ClientToolHandler = (args: Record) => Promise; -export interface ToolCallingPayload { - execution_id: string; - tool_name: string; - arguments: Record; - created_at: string; +export interface ToolCallStartedPayload { + call_id: string; + name: string; + input: Record; + output: Record; + timestamp: string; } -export interface ToolResultPayload { - execution_id: string; - tool_name: string; +export interface ToolCallDonePayload { + call_id: string; + name: string; + input: Record; + output: Record; duration_ms: number; - result?: unknown; - error_message?: string | null; - created_at: string; + extra: Record; + timestamp: string; } + +export interface ToolCallErrorPayload { + call_id: string; + name: string; + input: Record; + output: Record; + duration_ms: number; + extra: Record; + timestamp: string; +} + +export type ToolEventPayload = ToolCallStartedPayload | ToolCallDonePayload | ToolCallErrorPayload; + +export type ToolEventCallback = { + (event: StreamEvents.ToolCallStarted, data: ToolCallStartedPayload): void; + (event: StreamEvents.ToolCallDone, data: ToolCallDonePayload): void; + (event: StreamEvents.ToolCallError, data: ToolCallErrorPayload): void; +};