diff --git a/examples/sdk-core/realtime/connection-events.ts b/examples/sdk-core/realtime/connection-events.ts index efef67e..fd49469 100644 --- a/examples/sdk-core/realtime/connection-events.ts +++ b/examples/sdk-core/realtime/connection-events.ts @@ -39,6 +39,12 @@ async function main() { case "connected": console.log("Connected! Streaming active."); break; + case "generating": + console.log("Generation started! Frames incoming."); + break; + case "reconnecting": + console.log("Connection lost, reconnecting..."); + break; case "disconnected": console.log("Disconnected from server."); break; diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index d549aee..990b8cd 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -25,7 +25,7 @@ export type { RealTimeClientInitialState, } from "./realtime/client"; export type { SetInput } from "./realtime/methods"; -export type { ConnectionState } from "./realtime/webrtc-connection"; +export type { ConnectionState } from "./realtime/types"; export { type ImageModelDefinition, type ImageModels, diff --git a/packages/sdk/src/realtime/client.ts b/packages/sdk/src/realtime/client.ts index e7f02d0..586a3ed 100644 --- a/packages/sdk/src/realtime/client.ts +++ b/packages/sdk/src/realtime/client.ts @@ -6,6 +6,7 @@ import { modelStateSchema } from "../shared/types"; import { createWebrtcError, type DecartSDKError } from "../utils/errors"; import { AudioStreamManager } from "./audio-stream-manager"; import { realtimeMethods, type SetInput } from "./methods"; +import type { ConnectionState } from "./types"; import { WebRTCManager } from "./webrtc-manager"; async function blobToBase64(blob: Blob): Promise { @@ -90,7 +91,7 @@ const realTimeClientConnectOptionsSchema = z.object({ export type RealTimeClientConnectOptions = z.infer; export type Events = { - connectionChange: "connected" | "connecting" | "disconnected" | "reconnecting"; + connectionChange: ConnectionState; error: DecartSDKError; }; @@ -98,7 +99,7 @@ export type RealTimeClient = { set: (input: SetInput) => Promise; setPrompt: (prompt: string, { enhance }?: { enhance?: boolean }) => Promise; isConnected: () => boolean; - getConnectionState: () => "connected" | "connecting" | "disconnected" | "reconnecting"; + getConnectionState: () => ConnectionState; disconnect: () => void; on: (event: K, listener: (data: Events[K]) => void) => void; off: (event: K, listener: (data: Events[K]) => void) => void; @@ -167,16 +168,38 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { : undefined; const url = `${baseUrl}${options.model.urlPath}`; + + const eventBuffer: Array<{ event: keyof Events; data: Events[keyof Events] }> = []; + let buffering = true; + + const emitOrBuffer = (event: K, data: Events[K]) => { + if (buffering) { + eventBuffer.push({ event, data: data as Events[keyof Events] }); + } else { + eventEmitter.emit(event, data); + } + }; + + const flushBufferedEvents = () => { + setTimeout(() => { + buffering = false; + for (const { event, data } of eventBuffer) { + (eventEmitter.emit as (type: keyof Events, data: Events[keyof Events]) => void)(event, data); + } + eventBuffer.length = 0; + }, 0); + }; + webrtcManager = new WebRTCManager({ webrtcUrl: `${url}?api_key=${encodeURIComponent(apiKey)}&model=${encodeURIComponent(options.model.name)}`, integration, onRemoteStream, onConnectionStateChange: (state) => { - eventEmitter.emit("connectionChange", state); + emitOrBuffer("connectionChange", state); }, onError: (error) => { console.error("WebRTC error:", error); - eventEmitter.emit("error", createWebrtcError(error)); + emitOrBuffer("error", createWebrtcError(error)); }, customizeOffer: options.customizeOffer as ((offer: RTCSessionDescriptionInit) => Promise) | undefined, vp8MinBitrate: 300, @@ -203,6 +226,8 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { isConnected: () => manager.isConnected(), getConnectionState: () => manager.getConnectionState(), disconnect: () => { + buffering = false; + eventBuffer.length = 0; manager.cleanup(); audioStreamManager?.cleanup(); }, @@ -227,6 +252,7 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { client.playAudio = (audio: Blob | File | ArrayBuffer) => manager.playAudio(audio); } + flushBufferedEvents(); return client; } catch (error) { webrtcManager?.cleanup(); diff --git a/packages/sdk/src/realtime/methods.ts b/packages/sdk/src/realtime/methods.ts index e542917..6755d41 100644 --- a/packages/sdk/src/realtime/methods.ts +++ b/packages/sdk/src/realtime/methods.ts @@ -28,7 +28,7 @@ export const realtimeMethods = ( ) => { const assertConnected = () => { const state = webrtcManager.getConnectionState(); - if (state !== "connected") { + if (state !== "connected" && state !== "generating") { throw new Error(`Cannot send message: connection is ${state}`); } }; diff --git a/packages/sdk/src/realtime/types.ts b/packages/sdk/src/realtime/types.ts index 5b69d09..339608f 100644 --- a/packages/sdk/src/realtime/types.ts +++ b/packages/sdk/src/realtime/types.ts @@ -60,6 +60,12 @@ export type SetImageAckMessage = { error: null | string; }; +export type GenerationStartedMessage = { + type: "generation_started"; +}; + +export type ConnectionState = "connecting" | "connected" | "generating" | "disconnected" | "reconnecting"; + // Incoming message types (from server) export type IncomingWebRTCMessage = | ReadyMessage @@ -69,7 +75,8 @@ export type IncomingWebRTCMessage = | IceRestartMessage | PromptAckMessage | ErrorMessage - | SetImageAckMessage; + | SetImageAckMessage + | GenerationStartedMessage; // Outgoing message types (to server) export type OutgoingWebRTCMessage = diff --git a/packages/sdk/src/realtime/webrtc-connection.ts b/packages/sdk/src/realtime/webrtc-connection.ts index 35401d2..de127b2 100644 --- a/packages/sdk/src/realtime/webrtc-connection.ts +++ b/packages/sdk/src/realtime/webrtc-connection.ts @@ -1,6 +1,7 @@ import mitt from "mitt"; import { buildUserAgent } from "../utils/user-agent"; import type { + ConnectionState, IncomingWebRTCMessage, OutgoingWebRTCMessage, PromptAckMessage, @@ -23,8 +24,6 @@ interface ConnectionCallbacks { initialPrompt?: { text: string; enhance?: boolean }; } -export type ConnectionState = "connecting" | "connected" | "disconnected" | "reconnecting"; - type WsMessageEvents = { promptAck: PromptAckMessage; setImageAck: SetImageAckMessage; @@ -107,7 +106,7 @@ export class WebRTCConnection { await Promise.race([ new Promise((resolve, reject) => { const checkConnection = setInterval(() => { - if (this.state === "connected") { + if (this.state === "connected" || this.state === "generating") { clearInterval(checkConnection); resolve(); } else if (this.state === "disconnected") { @@ -151,6 +150,11 @@ export class WebRTCConnection { return; } + if (msg.type === "generation_started") { + this.setState("generating"); + return; + } + // All other messages require peer connection if (!this.pc) return; @@ -340,9 +344,11 @@ export class WebRTCConnection { this.pc.onconnectionstatechange = () => { if (!this.pc) return; const s = this.pc.connectionState; - this.setState( - s === "connected" ? "connected" : ["connecting", "new"].includes(s) ? "connecting" : "disconnected", - ); + const nextState = + s === "connected" ? "connected" : ["connecting", "new"].includes(s) ? "connecting" : "disconnected"; + // Keep "generating" sticky unless the connection is actually lost. + if (this.state === "generating" && nextState !== "disconnected") return; + this.setState(nextState); }; this.pc.oniceconnectionstatechange = () => { diff --git a/packages/sdk/src/realtime/webrtc-manager.ts b/packages/sdk/src/realtime/webrtc-manager.ts index ed53aec..024f0dc 100644 --- a/packages/sdk/src/realtime/webrtc-manager.ts +++ b/packages/sdk/src/realtime/webrtc-manager.ts @@ -1,6 +1,6 @@ import pRetry, { AbortError } from "p-retry"; -import type { OutgoingMessage } from "./types"; -import { type ConnectionState, WebRTCConnection } from "./webrtc-connection"; +import type { ConnectionState, OutgoingMessage } from "./types"; +import { WebRTCConnection } from "./webrtc-connection"; export interface WebRTCConfig { webrtcUrl: string; @@ -62,7 +62,7 @@ export class WebRTCManager { private emitState(state: ConnectionState): void { if (this.managerState !== state) { this.managerState = state; - if (state === "connected") this.hasConnected = true; + if (state === "connected" || state === "generating") this.hasConnected = true; this.config.onConnectionStateChange?.(state); } } @@ -75,12 +75,10 @@ export class WebRTCManager { // During reconnection, intercept state changes from the connection layer if (this.isReconnecting) { - if (state === "connected") { - // Reconnection succeeded + if (state === "connected" || state === "generating") { this.isReconnecting = false; - this.emitState("connected"); + this.emitState(state); } - // Swallow other states during reconnection (connecting, disconnected) return; } @@ -197,7 +195,7 @@ export class WebRTCManager { } isConnected(): boolean { - return this.managerState === "connected"; + return this.managerState === "connected" || this.managerState === "generating"; } getConnectionState(): ConnectionState { diff --git a/packages/sdk/tests/unit.test.ts b/packages/sdk/tests/unit.test.ts index 38d7e36..ea81311 100644 --- a/packages/sdk/tests/unit.test.ts +++ b/packages/sdk/tests/unit.test.ts @@ -1350,3 +1350,210 @@ describe("set()", () => { }); }); }); + +describe("WebSockets Connection", () => { + it("connect resolves when state becomes generating before poll observes connected", async () => { + const { WebRTCConnection } = await import("../src/realtime/webrtc-connection.js"); + + class FakeWebSocket { + static OPEN = 1; + static CLOSED = 3; + readyState = FakeWebSocket.OPEN; + onopen: (() => void) | null = null; + onmessage: ((event: { data: string }) => void) | null = null; + onerror: (() => void) | null = null; + onclose: (() => void) | null = null; + + constructor(_url: string) { + setTimeout(() => this.onopen?.(), 0); + } + + send(): void {} + + close(): void { + this.readyState = FakeWebSocket.CLOSED; + this.onclose?.(); + } + } + + vi.stubGlobal("WebSocket", FakeWebSocket as unknown as typeof WebSocket); + + try { + const connection = new WebRTCConnection(); + const internal = connection as unknown as { + setState: (state: import("../src/realtime/types").ConnectionState) => void; + setupNewPeerConnection: () => Promise; + }; + + vi.spyOn(internal, "setupNewPeerConnection").mockImplementation(async () => { + internal.setState("connected"); + setTimeout(() => internal.setState("generating"), 0); + }); + + await expect( + connection.connect("wss://example.com", { getTracks: () => [] } as MediaStream, 750), + ).resolves.toBeUndefined(); + } finally { + vi.unstubAllGlobals(); + } + }); + + it("transitions from generating to disconnected when peer connection disconnects", async () => { + const { WebRTCConnection } = await import("../src/realtime/webrtc-connection.js"); + + class FakePeerConnection { + connectionState: RTCPeerConnectionState = "new"; + iceConnectionState: RTCIceConnectionState = "new"; + ontrack: ((event: RTCTrackEvent) => void) | null = null; + onicecandidate: ((event: RTCPeerConnectionIceEvent) => void) | null = null; + onconnectionstatechange: (() => void) | null = null; + oniceconnectionstatechange: (() => void) | null = null; + + getSenders(): RTCRtpSender[] { + return []; + } + + removeTrack(): void {} + + close(): void {} + + addTrack(): RTCRtpSender { + return {} as RTCRtpSender; + } + + addTransceiver(): RTCRtpTransceiver { + return {} as RTCRtpTransceiver; + } + } + + vi.stubGlobal("RTCPeerConnection", FakePeerConnection as unknown as typeof RTCPeerConnection); + + try { + const connection = new WebRTCConnection(); + const internal = connection as unknown as { + handleSignalingMessage: (msg: unknown) => Promise; + localStream: { getTracks: () => MediaStreamTrack[] }; + setupNewPeerConnection: () => Promise; + pc: { connectionState: RTCPeerConnectionState; onconnectionstatechange: (() => void) | null } | null; + }; + + vi.spyOn(internal, "handleSignalingMessage").mockResolvedValue(undefined); + internal.localStream = { getTracks: () => [] }; + await internal.setupNewPeerConnection(); + + connection.state = "generating"; + if (!internal.pc?.onconnectionstatechange) { + throw new Error("Peer connection state callback was not set"); + } + + internal.pc.connectionState = "disconnected"; + internal.pc.onconnectionstatechange(); + + expect(connection.state).toBe("disconnected"); + } finally { + vi.unstubAllGlobals(); + } + }); + + it("treats generating as an established connection for reconnect decisions", async () => { + const { WebRTCManager } = await import("../src/realtime/webrtc-manager.js"); + const manager = new WebRTCManager({ + webrtcUrl: "wss://example.com", + onRemoteStream: vi.fn(), + onError: vi.fn(), + }); + + const internal = manager as unknown as { + handleConnectionStateChange: (state: import("../src/realtime/types").ConnectionState) => void; + reconnect: () => Promise; + }; + + const reconnectSpy = vi.spyOn(internal, "reconnect").mockResolvedValue(undefined); + try { + internal.handleConnectionStateChange("generating"); + internal.handleConnectionStateChange("disconnected"); + expect(reconnectSpy).toHaveBeenCalledTimes(1); + } finally { + reconnectSpy.mockRestore(); + } + }); + + it("replays connection events emitted during connect before returning client", async () => { + const { createRealTimeClient } = await import("../src/realtime/client.js"); + const { WebRTCManager } = await import("../src/realtime/webrtc-manager.js"); + + const promptAckListeners = new Set<(msg: import("../src/realtime/types").PromptAckMessage) => void>(); + const websocketEmitter = { + on: (event: string, listener: (msg: import("../src/realtime/types").PromptAckMessage) => void) => { + if (event === "promptAck") promptAckListeners.add(listener); + }, + off: (event: string, listener: (msg: import("../src/realtime/types").PromptAckMessage) => void) => { + if (event === "promptAck") promptAckListeners.delete(listener); + }, + }; + + const connectSpy = vi.spyOn(WebRTCManager.prototype, "connect").mockImplementation(async function () { + const manager = this as unknown as { + config: { onConnectionStateChange?: (state: import("../src/realtime/types").ConnectionState) => void }; + managerState: import("../src/realtime/types").ConnectionState; + }; + manager.managerState = "connected"; + manager.config.onConnectionStateChange?.("connected"); + return true; + }); + const stateSpy = vi.spyOn(WebRTCManager.prototype, "getConnectionState").mockImplementation(function () { + const manager = this as unknown as { managerState: import("../src/realtime/types").ConnectionState }; + return manager.managerState ?? "connected"; + }); + const emitterSpy = vi + .spyOn(WebRTCManager.prototype, "getWebsocketMessageEmitter") + .mockReturnValue(websocketEmitter as never); + const sendSpy = vi.spyOn(WebRTCManager.prototype, "sendMessage").mockImplementation(function (message) { + if (message.type === "prompt") { + setTimeout(() => { + const manager = this as unknown as { + config: { onConnectionStateChange?: (state: import("../src/realtime/types").ConnectionState) => void }; + managerState: import("../src/realtime/types").ConnectionState; + }; + manager.managerState = "generating"; + manager.config.onConnectionStateChange?.("generating"); + for (const listener of promptAckListeners) { + listener({ + type: "prompt_ack", + prompt: message.prompt, + success: true, + error: null, + }); + } + }, 0); + } + return true; + }); + const cleanupSpy = vi.spyOn(WebRTCManager.prototype, "cleanup").mockImplementation(() => {}); + + try { + const realtime = createRealTimeClient({ baseUrl: "wss://example.com", apiKey: "test-key" }); + const client = await realtime.connect({} as MediaStream, { + model: models.realtime("mirage_v2"), + onRemoteStream: vi.fn(), + initialState: { + prompt: { + text: "test", + }, + }, + }); + + const states: import("../src/realtime/types").ConnectionState[] = []; + client.on("connectionChange", (state) => states.push(state)); + + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(states).toEqual(["connected", "generating"]); + } finally { + connectSpy.mockRestore(); + stateSpy.mockRestore(); + emitterSpy.mockRestore(); + sendSpy.mockRestore(); + cleanupSpy.mockRestore(); + } + }); +});