diff --git a/jest.setup.ts b/jest.setup.ts index ee9039a1..d50b4883 100644 --- a/jest.setup.ts +++ b/jest.setup.ts @@ -1,3 +1,8 @@ +import { TextDecoder, TextEncoder } from 'util'; + +global.TextDecoder = TextDecoder as any; +global.TextEncoder = TextEncoder as any; + const mockDataChannel = { onopen: null, onmessage: null, send: jest.fn(), readyState: 'open' }; const mockPeerConnection = { diff --git a/package.json b/package.json index 17847dd7..c3e10a1e 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@d-id/client-sdk", "private": false, - "version": "1.1.29", + "version": "1.1.30", "type": "module", "description": "d-id client sdk", "repository": { diff --git a/src/services/socket-manager/message-queue.test.ts b/src/services/socket-manager/message-queue.test.ts new file mode 100644 index 00000000..52e4277b --- /dev/null +++ b/src/services/socket-manager/message-queue.test.ts @@ -0,0 +1,235 @@ +import { ChatMode, ChatProgress } from '@sdk/types'; +import { AgentManagerItems } from '../agent-manager'; +import { createMessageEventQueue } from './message-queue'; + +jest.mock('@sdk/utils/analytics', () => ({ + getStreamAnalyticsProps: jest.fn(() => ({})), +})); + +describe('createMessageEventQueue', () => { + let mockAnalytics: any; + let mockItems: AgentManagerItems; + let mockOptions: any; + let mockAgent: any; + let mockOnStreamDone: jest.Mock; + let mockOnNewMessage: jest.Mock; + + beforeEach(() => { + mockAnalytics = { + track: jest.fn(), + linkTrack: jest.fn(), + }; + + mockItems = { + messages: [], + chatMode: ChatMode.Functional, + } as AgentManagerItems; + + mockOnNewMessage = jest.fn(); + mockOptions = { + callbacks: { + onNewMessage: mockOnNewMessage, + onError: jest.fn(), + }, + }; + + mockAgent = { id: 'agent-1' }; + mockOnStreamDone = jest.fn(); + }); + + describe('queue clearing behavior', () => { + it('should clear queue when user event is received', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + mockItems.messages.push({ + id: 'user-1', + role: 'user', + content: 'first question', + created_at: new Date().toISOString(), + transcribed: true, + }); + + onMessage(ChatProgress.Partial, { content: 'Old', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: ' response', sequence: 1 }); + + onMessage(ChatProgress.Transcribe, { + content: 'new user message', + role: 'user', + id: 'user-2', + }); + + onMessage(ChatProgress.Partial, { content: 'New', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: ' response', sequence: 1 }); + + const lastCall = mockOnNewMessage.mock.calls[mockOnNewMessage.mock.calls.length - 1]; + const lastMessage = lastCall[0][lastCall[0].length - 1]; + expect(lastMessage.content).toBe('New response'); + expect(lastMessage.content).not.toContain('Old'); + }); + + it('should NOT clear queue when partial event is received', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + mockItems.messages.push({ + id: 'user-1', + role: 'user', + content: 'test', + created_at: new Date().toISOString(), + transcribed: true, + }); + + onMessage(ChatProgress.Partial, { content: 'Hello', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: ' World', sequence: 1 }); + + const lastCall = mockOnNewMessage.mock.calls[mockOnNewMessage.mock.calls.length - 1]; + const lastMessage = lastCall[0][lastCall[0].length - 1]; + expect(lastMessage.content).toBe('Hello World'); + }); + + it('should NOT clear queue when answer event is received', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + mockItems.messages.push({ + id: 'user-1', + role: 'user', + content: 'test', + created_at: new Date().toISOString(), + transcribed: true, + }); + + onMessage(ChatProgress.Partial, { content: 'Hello', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: ' World', sequence: 1 }); + onMessage(ChatProgress.Answer, { content: 'Hello World!' }); + + const lastCall = mockOnNewMessage.mock.calls[mockOnNewMessage.mock.calls.length - 1]; + const lastMessage = lastCall[0][lastCall[0].length - 1]; + expect(lastMessage.content).toBe('Hello World!'); + }); + + it('should accumulate partials correctly without clearing', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + mockItems.messages.push({ + id: 'user-1', + role: 'user', + content: 'test', + created_at: new Date().toISOString(), + transcribed: true, + }); + + onMessage(ChatProgress.Partial, { content: 'A', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: 'B', sequence: 1 }); + onMessage(ChatProgress.Partial, { content: 'C', sequence: 2 }); + onMessage(ChatProgress.Partial, { content: 'D', sequence: 3 }); + + const lastCall = mockOnNewMessage.mock.calls[mockOnNewMessage.mock.calls.length - 1]; + const lastMessage = lastCall[0][lastCall[0].length - 1]; + expect(lastMessage.content).toBe('ABCD'); + }); + + it('should clear stale partials when new transcription arrives', () => { + const { onMessage } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + mockItems.messages.push({ + id: 'user-1', + role: 'user', + content: 'first message', + created_at: new Date().toISOString(), + transcribed: true, + }); + + onMessage(ChatProgress.Partial, { content: 'Old', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: ' response', sequence: 1 }); + + const assistantMessageBeforeInterrupt = mockItems.messages.find(m => m.role === 'assistant'); + expect(assistantMessageBeforeInterrupt?.content).toBe('Old response'); + + onMessage(ChatProgress.Transcribe, { + content: 'interrupt message', + role: 'user', + id: 'user-2', + }); + + onMessage(ChatProgress.Partial, { content: 'Fresh', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: ' start', sequence: 1 }); + + const newAssistantMessage = mockItems.messages[mockItems.messages.length - 1]; + expect(newAssistantMessage.role).toBe('assistant'); + expect(newAssistantMessage.content).toBe('Fresh start'); + expect(newAssistantMessage.content).not.toContain('Old'); + }); + }); + + describe('clearQueue function', () => { + it('should expose clearQueue for external use', () => { + const { clearQueue } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + expect(typeof clearQueue).toBe('function'); + }); + + it('should clear queue when called directly', () => { + const { onMessage, clearQueue } = createMessageEventQueue( + mockAnalytics, + mockItems, + mockOptions, + mockAgent, + mockOnStreamDone + ); + + mockItems.messages.push({ + id: 'user-1', + role: 'user', + content: 'test', + created_at: new Date().toISOString(), + transcribed: true, + }); + + onMessage(ChatProgress.Partial, { content: 'Old', sequence: 0 }); + onMessage(ChatProgress.Partial, { content: ' data', sequence: 1 }); + + clearQueue(); + + onMessage(ChatProgress.Partial, { content: 'Fresh', sequence: 0 }); + + const lastCall = mockOnNewMessage.mock.calls[mockOnNewMessage.mock.calls.length - 1]; + const lastMessage = lastCall[0][lastCall[0].length - 1]; + expect(lastMessage.content).toBe('Fresh'); + }); + }); +}); diff --git a/src/services/socket-manager/message-queue.ts b/src/services/socket-manager/message-queue.ts index 8d21bab5..d2d26571 100644 --- a/src/services/socket-manager/message-queue.ts +++ b/src/services/socket-manager/message-queue.ts @@ -109,9 +109,16 @@ export function createMessageEventQueue( onStreamDone: () => void ) { let chatEventQueue: ChatEventQueue = {}; + const clearQueue = () => (chatEventQueue = {}); + const onNewMessage: AgentManagerOptions['callbacks']['onNewMessage'] = (messages, event) => { + if (event === 'user') { + clearQueue(); + } + options.callbacks.onNewMessage?.(messages, event); + }; return { - clearQueue: () => (chatEventQueue = {}), + clearQueue, onMessage: (event: ChatProgress | StreamEvents, data: any) => { if ('content' in data) { const chatEvent = @@ -120,7 +127,7 @@ export function createMessageEventQueue( : event === StreamEvents.ChatAudioTranscribed ? ChatProgress.Transcribe : (event as ChatProgress); - processChatEvent(chatEvent, data, chatEventQueue, items, options.callbacks.onNewMessage); + processChatEvent(chatEvent, data, chatEventQueue, items, onNewMessage); if (chatEvent === ChatProgress.Answer) { analytics.track('agent-message-received', { diff --git a/src/services/streaming-manager/livekit-manager.test.ts b/src/services/streaming-manager/livekit-manager.test.ts index 82009e85..7c085541 100644 --- a/src/services/streaming-manager/livekit-manager.test.ts +++ b/src/services/streaming-manager/livekit-manager.test.ts @@ -1,5 +1,5 @@ import { StreamingManagerOptionsFactory } from '../../test-utils/factories'; -import { CreateSessionV2Options, StreamingManagerOptions } from '../../types/index'; +import { AgentActivityState, CreateSessionV2Options, StreamEvents, StreamingManagerOptions } from '../../types/index'; import { createLiveKitStreamingManager } from './livekit-manager'; // Mock livekit-client @@ -119,6 +119,11 @@ function getConnectionStateHandler(index?: number) { return calls.length > 0 ? calls[calls.length - 1][1] : undefined; } +function getDataReceivedHandler() { + const calls = mockRoom.on.mock.calls.filter((call: any[]) => call[0] === 'DataReceived'); + return calls.length > 0 ? calls[calls.length - 1][1] : undefined; +} + async function simulateConnection(handlerIndex?: number) { const handler = getConnectionStateHandler(handlerIndex); if (handler) { @@ -307,4 +312,51 @@ describe('LiveKit Streaming Manager - Microphone Stream', () => { await expect(manager.publishMicrophoneStream?.(mockStream)).rejects.toThrow('Room is not connected'); }); }); + + describe('Agent Activity State Changes', () => { + let mockOnAgentActivityStateChange: jest.Mock; + let sendDataEvent: (event: StreamEvents, extraData?: object) => void; + + beforeEach(async () => { + mockOnAgentActivityStateChange = jest.fn(); + options.callbacks.onAgentActivityStateChange = mockOnAgentActivityStateChange; + + await createLiveKitStreamingManager(agentId, sessionOptions, options); + await simulateConnection(); + + const dataHandler = getDataReceivedHandler(); + sendDataEvent = (event: StreamEvents, extraData = {}) => { + const payload = Buffer.from(JSON.stringify({ subject: event, ...extraData })); + dataHandler(payload, undefined, undefined, event); + }; + }); + + it.each([ + [StreamEvents.StreamVideoCreated, AgentActivityState.Talking], + [StreamEvents.StreamVideoDone, AgentActivityState.Idle], + ])('should set activity state on %s event', (event, expectedState) => { + sendDataEvent(event); + + expect(mockOnAgentActivityStateChange).toHaveBeenCalledTimes(1); + expect(mockOnAgentActivityStateChange).toHaveBeenCalledWith(expectedState); + }); + + it('should set activity state to Loading on ChatAudioTranscribed event', async () => { + sendDataEvent(StreamEvents.ChatAudioTranscribed, { content: 'test', role: 'user' }); + + await new Promise(resolve => setTimeout(resolve, 0)); + + expect(mockOnAgentActivityStateChange).toHaveBeenCalledTimes(1); + expect(mockOnAgentActivityStateChange).toHaveBeenCalledWith(AgentActivityState.Loading); + }); + + it('should transition from Talking to Idle when video ends', () => { + sendDataEvent(StreamEvents.StreamVideoCreated); + sendDataEvent(StreamEvents.StreamVideoDone); + + expect(mockOnAgentActivityStateChange).toHaveBeenCalledTimes(2); + expect(mockOnAgentActivityStateChange).toHaveBeenNthCalledWith(1, AgentActivityState.Talking); + expect(mockOnAgentActivityStateChange).toHaveBeenNthCalledWith(2, AgentActivityState.Idle); + }); + }); }); diff --git a/src/services/streaming-manager/livekit-manager.ts b/src/services/streaming-manager/livekit-manager.ts index bd167c55..b744dc85 100644 --- a/src/services/streaming-manager/livekit-manager.ts +++ b/src/services/streaming-manager/livekit-manager.ts @@ -225,15 +225,6 @@ export async function createLiveKitStreamingManager !speaker.isLocal); - - if (isRemoteParticipantSpeaking) { - currentActivityState = AgentActivityState.Talking; - callbacks.onAgentActivityStateChange?.(AgentActivityState.Talking); - } else { - callbacks.onAgentActivityStateChange?.(AgentActivityState.Idle); - currentActivityState = AgentActivityState.Idle; - } } function handleParticipantConnected(participant: RemoteParticipant): void { @@ -312,6 +303,10 @@ export async function createLiveKitStreamingManager