diff --git a/packages/sdk/src/realtime/client.ts b/packages/sdk/src/realtime/client.ts index f7220ce..6755cba 100644 --- a/packages/sdk/src/realtime/client.ts +++ b/packages/sdk/src/realtime/client.ts @@ -42,6 +42,9 @@ const avatarOptionsSchema = z.object({ }); export type AvatarOptions = z.infer; +type OnStatusFn = (status: string) => void; +type OnQueuePositionFn = (data: { position: number; queueSize: number }) => void; + const realTimeClientConnectOptionsSchema = z.object({ model: modelDefinitionSchema, onRemoteStream: z.custom((val) => typeof val === "function", { @@ -50,12 +53,24 @@ const realTimeClientConnectOptionsSchema = z.object({ initialState: realTimeClientInitialStateSchema.optional(), customizeOffer: createAsyncFunctionSchema(z.function()).optional(), avatar: avatarOptionsSchema.optional(), + onStatus: z + .custom((val) => typeof val === "function", { + message: "onStatus must be a function", + }) + .optional(), + onQueuePosition: z + .custom((val) => typeof val === "function", { + message: "onQueuePosition must be a function", + }) + .optional(), }); export type RealTimeClientConnectOptions = z.infer; export type Events = { connectionChange: "connected" | "connecting" | "disconnected"; error: DecartSDKError; + status: string; + queuePosition: { position: number; queueSize: number }; }; export type RealTimeClient = { @@ -66,7 +81,10 @@ export type RealTimeClient = { on: (event: K, listener: (data: Events[K]) => void) => void; off: (event: K, listener: (data: Events[K]) => void) => void; sessionId: string; - setImage: (image: Blob | File | string | null, options?: { prompt?: string; enhance?: boolean }) => Promise; + setImage: ( + image: Blob | File | string | null, + options?: { prompt?: string; enhance?: boolean; timeout?: number }, + ) => Promise; // live_avatar audio method (only available when model is live_avatar and no stream is provided) playAudio?: (audio: Blob | File | ArrayBuffer) => Promise; }; @@ -146,6 +164,18 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { initialPrompt, }); + // Wire up queue status events (called before connect so we don't miss early messages) + const wsEmitter = webrtcManager.getWebsocketMessageEmitter(); + wsEmitter.on("status", (msg) => { + eventEmitter.emit("status", msg.status); + options.onStatus?.(msg.status); + }); + wsEmitter.on("queuePosition", (msg) => { + const data = { position: msg.position, queueSize: msg.queue_size }; + eventEmitter.emit("queuePosition", data); + options.onQueuePosition?.(data); + }); + await webrtcManager.connect(inputStream); const methods = realtimeMethods(webrtcManager); @@ -167,7 +197,10 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { on: eventEmitter.on, off: eventEmitter.off, sessionId, - setImage: async (image: Blob | File | string | null, options?: { prompt?: string; enhance?: boolean }) => { + setImage: async ( + image: Blob | File | string | null, + options?: { prompt?: string; enhance?: boolean; timeout?: number }, + ) => { if (image === null) { return webrtcManager.setImage(null, options); } diff --git a/packages/sdk/src/realtime/types.ts b/packages/sdk/src/realtime/types.ts index bd07c2c..2df6a18 100644 --- a/packages/sdk/src/realtime/types.ts +++ b/packages/sdk/src/realtime/types.ts @@ -60,6 +60,18 @@ export type SetImageAckMessage = { error: null | string; }; +// Queue message types +export type StatusMessage = { + type: "status"; + status: string; +}; + +export type QueuePositionMessage = { + type: "queue_position"; + position: number; + queue_size: number; +}; + // Incoming message types (from server) export type IncomingWebRTCMessage = | ReadyMessage @@ -69,7 +81,9 @@ export type IncomingWebRTCMessage = | IceRestartMessage | PromptAckMessage | ErrorMessage - | SetImageAckMessage; + | SetImageAckMessage + | StatusMessage + | QueuePositionMessage; // Outgoing message types (to server) export type OutgoingWebRTCMessage = diff --git a/packages/sdk/src/realtime/webrtc-connection.ts b/packages/sdk/src/realtime/webrtc-connection.ts index 107923a..0f22379 100644 --- a/packages/sdk/src/realtime/webrtc-connection.ts +++ b/packages/sdk/src/realtime/webrtc-connection.ts @@ -4,7 +4,9 @@ import type { IncomingWebRTCMessage, OutgoingWebRTCMessage, PromptAckMessage, + QueuePositionMessage, SetImageAckMessage, + StatusMessage, TurnConfig, } from "./types"; @@ -28,6 +30,8 @@ export type ConnectionState = "connecting" | "connected" | "disconnected"; type WsMessageEvents = { promptAck: PromptAckMessage; setImageAck: SetImageAckMessage; + status: StatusMessage; + queuePosition: QueuePositionMessage; }; export class WebRTCConnection { @@ -122,6 +126,16 @@ export class WebRTCConnection { return; } + if (msg.type === "status") { + this.websocketMessagesEmitter.emit("status", msg); + return; + } + + if (msg.type === "queue_position") { + this.websocketMessagesEmitter.emit("queuePosition", msg); + return; + } + // All other messages require peer connection if (!this.pc) return; @@ -181,12 +195,15 @@ export class WebRTCConnection { * Pass null to clear the reference image or use a placeholder. * Optionally include a prompt to send with the image. */ - async setImageBase64(imageBase64: string | null, options?: { prompt?: string; enhance?: boolean }): Promise { + async setImageBase64( + imageBase64: string | null, + options?: { prompt?: string; enhance?: boolean; timeout?: number }, + ): Promise { return new Promise((resolve, reject) => { const timeoutId = setTimeout(() => { this.websocketMessagesEmitter.off("setImageAck", listener); reject(new Error("Image send timed out")); - }, AVATAR_SETUP_TIMEOUT_MS); + }, options?.timeout ?? AVATAR_SETUP_TIMEOUT_MS); const listener = (msg: SetImageAckMessage) => { clearTimeout(timeoutId); diff --git a/packages/sdk/src/realtime/webrtc-manager.ts b/packages/sdk/src/realtime/webrtc-manager.ts index 026269d..894a4b9 100644 --- a/packages/sdk/src/realtime/webrtc-manager.ts +++ b/packages/sdk/src/realtime/webrtc-manager.ts @@ -52,7 +52,7 @@ export class WebRTCManager { async connect(localStream: MediaStream): Promise { return pRetry( async () => { - await this.connection.connect(this.config.webrtcUrl, localStream, 60000, this.config.integration); + await this.connection.connect(this.config.webrtcUrl, localStream, 300000, this.config.integration); return true; }, { @@ -92,7 +92,10 @@ export class WebRTCManager { return this.connection.websocketMessagesEmitter; } - setImage(imageBase64: string | null, options?: { prompt?: string; enhance?: boolean }): Promise { + setImage( + imageBase64: string | null, + options?: { prompt?: string; enhance?: boolean; timeout?: number }, + ): Promise { return this.connection.setImageBase64(imageBase64, options); } } diff --git a/packages/sdk/tests/unit.test.ts b/packages/sdk/tests/unit.test.ts index c79fd3b..a35e431 100644 --- a/packages/sdk/tests/unit.test.ts +++ b/packages/sdk/tests/unit.test.ts @@ -1,6 +1,6 @@ import { HttpResponse, http } from "msw"; import { setupServer } from "msw/node"; -import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it } from "vitest"; +import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; import { createDecartClient, models } from "../src/index.js"; const MOCK_RESPONSE_DATA = new Uint8Array([0x00, 0x01, 0x02]).buffer; @@ -943,6 +943,67 @@ describe("Lucy 14b realtime", () => { }); }); +describe("WebRTCConnection", () => { + describe("setImageBase64 timeout", () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it("uses custom timeout when provided", async () => { + const { WebRTCConnection } = await import("../src/realtime/webrtc-connection.js"); + const connection = new WebRTCConnection(); + + const customTimeout = 5000; + let rejected = false; + let rejectionError: Error | null = null; + + const promise = connection.setImageBase64("base64data", { timeout: customTimeout }).catch((err) => { + rejected = true; + rejectionError = err; + }); + + // Advance time to just before the custom timeout - should not have rejected yet + await vi.advanceTimersByTimeAsync(customTimeout - 1); + expect(rejected).toBe(false); + + // Advance past the custom timeout - now it should reject + await vi.advanceTimersByTimeAsync(2); + await promise; + + expect(rejected).toBe(true); + expect(rejectionError?.message).toBe("Image send timed out"); + }); + + it("uses default timeout (15000ms) when not provided", async () => { + const { WebRTCConnection } = await import("../src/realtime/webrtc-connection.js"); + const connection = new WebRTCConnection(); + + let rejected = false; + let rejectionError: Error | null = null; + + const promise = connection.setImageBase64("base64data").catch((err) => { + rejected = true; + rejectionError = err; + }); + + // Advance to just before the default timeout (15000ms) - should not reject yet + await vi.advanceTimersByTimeAsync(14999); + expect(rejected).toBe(false); + + // Now advance past the default timeout + await vi.advanceTimersByTimeAsync(2); + await promise; + + expect(rejected).toBe(true); + expect(rejectionError?.message).toBe("Image send timed out"); + }); + }); +}); + describe("live_avatar Model", () => { describe("Model Definition", () => { it("has correct model name", () => {