From 016685f0c33b59f58c5fe098813f8c5581696acc Mon Sep 17 00:00:00 2001 From: Chenghao Mou Date: Thu, 11 Jun 2026 17:21:09 +0100 Subject: [PATCH] feat(barge-in): add default threshold support and drop http transport (#1698) Co-authored-by: Claude Opus 4.8 (1M context) --- .../bargein-default-threshold-drop-http.md | 5 + agents/src/inference/interruption/_mock_ws.ts | 58 +++ agents/src/inference/interruption/defaults.ts | 8 +- .../inference/interruption/http_transport.ts | 208 --------- .../interruption/interruption_detector.ts | 64 +-- .../interruption_failover.test.ts | 408 ++++++++++++++++++ .../interruption_session_create.test.ts | 127 ++++++ .../interruption/interruption_stream.ts | 53 ++- agents/src/inference/interruption/types.ts | 3 +- .../inference/interruption/ws_transport.ts | 231 +++++++--- agents/src/inference/utils.ts | 16 +- agents/src/voice/audio_recognition.ts | 3 + agents/src/worker.ts | 6 + turbo.json | 1 + 14 files changed, 835 insertions(+), 356 deletions(-) create mode 100644 .changeset/bargein-default-threshold-drop-http.md create mode 100644 agents/src/inference/interruption/_mock_ws.ts delete mode 100644 agents/src/inference/interruption/http_transport.ts create mode 100644 agents/src/inference/interruption/interruption_failover.test.ts create mode 100644 agents/src/inference/interruption/interruption_session_create.test.ts diff --git a/.changeset/bargein-default-threshold-drop-http.md b/.changeset/bargein-default-threshold-drop-http.md new file mode 100644 index 000000000..2d937eb5c --- /dev/null +++ b/.changeset/bargein-default-threshold-drop-http.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Adaptive interruption detection now omits the threshold from `session.create` unless the user explicitly overrides it, letting the gateway apply its fetched default (surfaced via `default_threshold` on `session.created`). The HTTP transport has been dropped — detection always connects over WebSocket and always requires LiveKit credentials, and its base URL now defaults from `LIVEKIT_INFERENCE_URL` instead of `LIVEKIT_REMOTE_EOT_URL`. Inference requests also send an `X-LiveKit-Worker-Token` header when `LIVEKIT_WORKER_TOKEN` is set (hosted agents); a token supplied via the `--worker-token` CLI flag is now re-exported into the environment so forked job subprocesses inherit it and include the header. The `X-LiveKit-Agent-Id` header is now only attached once the room is connected to avoid leaking an unset local-participant SID. The interruption WebSocket is now closed deterministically on stream teardown (including error and cancel paths) instead of only on graceful completion — previously an orphaned socket leaked per session/activity and accumulated for the worker's lifetime. Mid-session threshold/duration changes via `updateOptions` now reconnect the WebSocket in place rather than closing it and letting the next send error the stream — so option changes no longer consume a failover retry (previously enough updates in a session could exhaust the retry budget and stop interruption detection). diff --git a/agents/src/inference/interruption/_mock_ws.ts b/agents/src/inference/interruption/_mock_ws.ts new file mode 100644 index 000000000..978224a0b --- /dev/null +++ b/agents/src/inference/interruption/_mock_ws.ts @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { EventEmitter } from 'node:events'; + +/** + * Minimal stand-in for the `ws` WebSocket, used to drive the interruption transport in tests. + * + * Lives in its own module (rather than inline in each test) so the `vi.mock('ws')` factory can + * `await import()` it without a top-level await — the tsup build transpiles test files to CJS, + * which does not support top-level await. + */ +export class MockWebSocket extends EventEmitter { + static OPEN = 1; + static instances: MockWebSocket[] = []; + + readyState = 0; // CONNECTING + readonly sent: unknown[] = []; + terminated = false; + + constructor( + public url: string, + public opts: unknown, + ) { + super(); + MockWebSocket.instances.push(this); + } + + send(data: unknown): void { + this.sent.push(data); + } + + close(): void { + this.readyState = 3; // CLOSED + this.emit('close', 1000, Buffer.from('')); + } + + terminate(): void { + this.terminated = true; + this.readyState = 3; + } + + /** Simulate a successful upgrade. */ + simulateOpen(): void { + this.readyState = MockWebSocket.OPEN; + this.emit('open'); + } + + /** Simulate the server rejecting the upgrade with an HTTP status. */ + simulateUnexpectedResponse(statusCode: number): void { + this.emit('unexpected-response', {}, { statusCode }); + } + + /** Simulate a server message frame carrying a JSON payload. */ + simulateMessage(payload: unknown): void { + this.emit('message', Buffer.from(JSON.stringify(payload))); + } +} diff --git a/agents/src/inference/interruption/defaults.ts b/agents/src/inference/interruption/defaults.ts index df6134bc4..c01c127a1 100644 --- a/agents/src/inference/interruption/defaults.ts +++ b/agents/src/inference/interruption/defaults.ts @@ -5,9 +5,8 @@ import type { ApiConnectOptions } from './interruption_stream.js'; import type { InterruptionOptions } from './types.js'; export const MIN_INTERRUPTION_DURATION_IN_S = 0.025 * 2; // 25ms per frame, 2 consecutive frames -export const THRESHOLD = 0.5; export const MAX_AUDIO_DURATION_IN_S = 3.0; -export const AUDIO_PREFIX_DURATION_IN_S = 0.5; +export const AUDIO_PREFIX_DURATION_IN_S = 1.0; export const DETECTION_INTERVAL_IN_S = 0.1; export const REMOTE_INFERENCE_TIMEOUT_IN_S = 0.7; export const SAMPLE_RATE = 16000; @@ -36,12 +35,13 @@ export function intervalForRetry( } // env-derived fields are resolved in the constructor, not at module load. +// `threshold` is intentionally omitted: when the user does not override it, it stays undefined +// so the server applies its fetched default. export const interruptionOptionDefaults: Omit< InterruptionOptions, - 'baseUrl' | 'useProxy' | 'apiKey' | 'apiSecret' + 'baseUrl' | 'apiKey' | 'apiSecret' | 'threshold' > = { sampleRate: SAMPLE_RATE, - threshold: THRESHOLD, minFrames: Math.ceil(MIN_INTERRUPTION_DURATION_IN_S * FRAMES_PER_SECOND), maxAudioDurationInS: MAX_AUDIO_DURATION_IN_S, audioPrefixDurationInS: AUDIO_PREFIX_DURATION_IN_S, diff --git a/agents/src/inference/interruption/http_transport.ts b/agents/src/inference/interruption/http_transport.ts deleted file mode 100644 index bd7953688..000000000 --- a/agents/src/inference/interruption/http_transport.ts +++ /dev/null @@ -1,208 +0,0 @@ -// SPDX-FileCopyrightText: 2026 LiveKit, Inc. -// -// SPDX-License-Identifier: Apache-2.0 -import type { Throws } from '@livekit/throws-transformer/throws'; -import { FetchError, ofetch } from 'ofetch'; -import { TransformStream } from 'stream/web'; -import { z } from 'zod'; -import { APIConnectionError, APIError, APIStatusError, isAPIError } from '../../_exceptions.js'; -import { log } from '../../log.js'; -import { buildMetadataHeaders, createAccessToken } from '../utils.js'; -import { InterruptionCacheEntry } from './interruption_cache_entry.js'; -import type { OverlappingSpeechEvent } from './types.js'; -import type { BoundedCache } from './utils.js'; - -export interface PostOptions { - baseUrl: string; - token: string; - signal?: AbortSignal; - timeout?: number; - maxRetries?: number; -} - -export interface PredictOptions { - threshold: number; - minFrames: number; -} - -export const predictEndpointResponseSchema = z.object({ - created_at: z.number(), - is_bargein: z.boolean(), - probabilities: z.array(z.number()), -}); - -export type PredictEndpointResponse = z.infer; - -export interface PredictResponse { - createdAt: number; - isBargein: boolean; - probabilities: number[]; - predictionDurationInS: number; -} - -export async function predictHTTP( - data: Int16Array, - predictOptions: PredictOptions, - options: PostOptions, -): Promise> { - const createdAt = performance.now(); - const url = new URL(`/bargein`, options.baseUrl); - url.searchParams.append('threshold', predictOptions.threshold.toString()); - url.searchParams.append('min_frames', predictOptions.minFrames.toFixed()); - url.searchParams.append('created_at', createdAt.toFixed()); - - try { - const response = await ofetch(url.toString(), { - retry: 0, - headers: { - ...buildMetadataHeaders(), - 'Content-Type': 'application/octet-stream', - Authorization: `Bearer ${options.token}`, - }, - signal: options.signal, - timeout: options.timeout, - method: 'POST', - body: data, - }); - const { created_at, is_bargein, probabilities } = predictEndpointResponseSchema.parse(response); - - return { - createdAt: created_at, - isBargein: is_bargein, - probabilities, - predictionDurationInS: (performance.now() - createdAt) / 1000, - }; - } catch (err) { - if (isAPIError(err)) throw err; - if (err instanceof FetchError) { - if (err.statusCode) { - throw new APIStatusError({ - message: `error during interruption prediction: ${err.message}`, - options: { statusCode: err.statusCode, body: err.data }, - }); - } - if ( - err.cause instanceof Error && - (err.cause.name === 'TimeoutError' || err.cause.name === 'AbortError') - ) { - throw new APIStatusError({ - message: `interruption inference timeout: ${err.message}`, - options: { statusCode: 408, retryable: false }, - }); - } - throw new APIConnectionError({ - message: `interruption inference connection error: ${err.message}`, - }); - } - throw new APIError(`error during interruption prediction: ${err}`); - } -} - -export interface HttpTransportOptions { - baseUrl: string; - apiKey: string; - apiSecret: string; - threshold: number; - minFrames: number; - timeout: number; - maxRetries?: number; -} - -export interface HttpTransportState { - overlapSpeechStarted: boolean; - overlapSpeechStartedAt: number | undefined; - cache: BoundedCache; -} - -/** - * Creates an HTTP transport TransformStream for interruption detection. - * - * This transport receives Int16Array audio slices and outputs InterruptionEvents. - * Each audio slice triggers an HTTP POST request. - * - * @param options - Transport options object. This is read on each request, so mutations - * to threshold/minFrames will be picked up dynamically. - */ -export function createHttpTransport( - options: HttpTransportOptions, - getState: () => HttpTransportState, - setState: (partial: Partial) => void, - updateUserSpeakingSpan?: (entry: InterruptionCacheEntry) => void, - getAndResetNumRequests?: () => number, -): TransformStream { - const logger = log(); - - return new TransformStream( - { - async transform(chunk, controller) { - if (!(chunk instanceof Int16Array)) { - controller.enqueue(chunk); - return; - } - - const state = getState(); - const overlapSpeechStartedAt = state.overlapSpeechStartedAt; - if (overlapSpeechStartedAt === undefined || !state.overlapSpeechStarted) return; - - try { - const resp = await predictHTTP( - chunk, - { threshold: options.threshold, minFrames: options.minFrames }, - { - baseUrl: options.baseUrl, - timeout: options.timeout, - maxRetries: options.maxRetries, - token: await createAccessToken(options.apiKey, options.apiSecret), - }, - ); - - const { createdAt, isBargein, probabilities, predictionDurationInS } = resp; - const entry = state.cache.setOrUpdate( - createdAt, - () => new InterruptionCacheEntry({ createdAt }), - { - probabilities, - isInterruption: isBargein, - speechInput: chunk, - totalDurationInS: (performance.now() - createdAt) / 1000, - detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000, - predictionDurationInS, - }, - ); - - if (state.overlapSpeechStarted && entry.isInterruption) { - if (updateUserSpeakingSpan) { - updateUserSpeakingSpan(entry); - } - const event: OverlappingSpeechEvent = { - type: 'overlapping_speech', - detectedAt: Date.now(), - overlapStartedAt: overlapSpeechStartedAt, - isInterruption: entry.isInterruption, - speechInput: entry.speechInput, - probabilities: entry.probabilities, - totalDurationInS: entry.totalDurationInS, - predictionDurationInS: entry.predictionDurationInS, - detectionDelayInS: entry.detectionDelayInS, - probability: entry.probability, - numRequests: getAndResetNumRequests?.() ?? 0, - }; - logger.debug( - { - detectionDelayInS: entry.detectionDelayInS, - totalDurationInS: entry.totalDurationInS, - }, - 'interruption detected', - ); - setState({ overlapSpeechStarted: false }); - controller.enqueue(event); - } - } catch (err) { - controller.error(err); - } - }, - }, - { highWaterMark: 2 }, - { highWaterMark: 2 }, - ); -} diff --git a/agents/src/inference/interruption/interruption_detector.ts b/agents/src/inference/interruption/interruption_detector.ts index 40c36e17d..0f4c341a9 100644 --- a/agents/src/inference/interruption/interruption_detector.ts +++ b/agents/src/inference/interruption/interruption_detector.ts @@ -7,7 +7,7 @@ import EventEmitter from 'events'; import { log } from '../../log.js'; import type { InterruptionMetrics } from '../../metrics/base.js'; import { asError } from '../../utils.js'; -import { DEFAULT_INFERENCE_URL, STAGING_INFERENCE_URL, getDefaultInferenceUrl } from '../utils.js'; +import { getDefaultInferenceUrl } from '../utils.js'; import { FRAMES_PER_SECOND, SAMPLE_RATE, interruptionOptionDefaults } from './defaults.js'; import { InterruptionDetectionError } from './errors.js'; import { InterruptionStreamBase } from './interruption_stream.js'; @@ -19,7 +19,7 @@ type InterruptionCallbacks = { error: (error: InterruptionDetectionError) => void; }; -export type AdaptiveInterruptionDetectorOptions = Omit, 'useProxy'>; +export type AdaptiveInterruptionDetectorOptions = Partial; export class AdaptiveInterruptionDetector extends (EventEmitter as new () => TypedEventEmitter) { options: InterruptionOptions; @@ -46,46 +46,23 @@ export class AdaptiveInterruptionDetector extends (EventEmitter as new () => Typ throw new RangeError('maxAudioDurationInS must be less than or equal to 3.0 seconds'); } - const lkBaseUrl = baseUrl ?? process.env.LIVEKIT_REMOTE_EOT_URL ?? getDefaultInferenceUrl(); - let lkApiKey = apiKey || ''; - let lkApiSecret = apiSecret || ''; - let useProxy: boolean; - - // Use LiveKit credentials if using the inference service (production or staging) - const isInferenceUrl = - lkBaseUrl === DEFAULT_INFERENCE_URL || lkBaseUrl === STAGING_INFERENCE_URL; - if (isInferenceUrl) { - lkApiKey = - apiKey || process.env.LIVEKIT_INFERENCE_API_KEY || process.env.LIVEKIT_API_KEY || ''; - if (!lkApiKey) { - throw new TypeError( - 'apiKey is required, either as argument or set LIVEKIT_API_KEY environmental variable', - ); - } - - lkApiSecret = - apiSecret || - process.env.LIVEKIT_INFERENCE_API_SECRET || - process.env.LIVEKIT_API_SECRET || - ''; - if (!lkApiSecret) { - throw new TypeError( - 'apiSecret is required, either as argument or set LIVEKIT_API_SECRET environmental variable', - ); - } - useProxy = true; - } else { - useProxy = false; + const lkBaseUrl = baseUrl ?? getDefaultInferenceUrl(); + + const lkApiKey = + apiKey || process.env.LIVEKIT_INFERENCE_API_KEY || process.env.LIVEKIT_API_KEY || ''; + if (!lkApiKey) { + throw new TypeError( + 'apiKey is required, either as argument or set LIVEKIT_API_KEY environmental variable', + ); + } + + const lkApiSecret = + apiSecret || process.env.LIVEKIT_INFERENCE_API_SECRET || process.env.LIVEKIT_API_SECRET || ''; + if (!lkApiSecret) { + throw new TypeError( + 'apiSecret is required, either as argument or set LIVEKIT_API_SECRET environmental variable', + ); } - const transport = useProxy ? 'websocket' : 'http'; - this.logger.debug( - { - baseUrl: lkBaseUrl, - useProxy, - transport, - }, - '=== Resolved interruption detector transport configuration', - ); this.options = { sampleRate: SAMPLE_RATE, @@ -98,7 +75,6 @@ export class AdaptiveInterruptionDetector extends (EventEmitter as new () => Typ baseUrl: lkBaseUrl, apiKey: lkApiKey, apiSecret: lkApiSecret, - useProxy, minInterruptionDurationInS, }; @@ -111,10 +87,8 @@ export class AdaptiveInterruptionDetector extends (EventEmitter as new () => Typ audioPrefixDurationInS: this.options.audioPrefixDurationInS, maxAudioDurationInS: this.options.maxAudioDurationInS, minFrames: this.options.minFrames, - threshold: this.options.threshold, + threshold: this.options.threshold ?? null, inferenceTimeout: this.options.inferenceTimeout, - useProxy: this.options.useProxy, - transport, }, '=== Adaptive interruption detector initialized', ); diff --git a/agents/src/inference/interruption/interruption_failover.test.ts b/agents/src/inference/interruption/interruption_failover.test.ts new file mode 100644 index 000000000..b1cd77bed --- /dev/null +++ b/agents/src/inference/interruption/interruption_failover.test.ts @@ -0,0 +1,408 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Tests for interruption detection failover (transport error production + error-emission) +// behavior. Ported from the Python `test_interruption_failover.py` suite. +// +// Covers, for the WebSocket-only transport (HTTP transport was dropped): +// - connection timeout -> non-retryable APITimeoutError +// - connection 429 -> non-retryable APIStatusError +// - cache-based inference timeout -> non-retryable APIStatusError (408) +// and that a non-retryable transport error surfaces as exactly one unrecoverable +// InterruptionDetectionError (zero recoverable) through AudioRecognition's retry loop. +// +// Also covers in-place reconnect on updateOptions: it opens a fresh socket with the updated +// settings without erroring the stream (so it never consumes a failover retry), while a genuine +// reconnect failure still surfaces as a stream error. +import { AudioFrame } from '@livekit/rtc-node'; +import { ReadableStream } from 'node:stream/web'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import type { APIError } from '../../_exceptions.js'; +import { APIStatusError, APITimeoutError } from '../../_exceptions.js'; +import { ChatContext } from '../../llm/chat_context.js'; +import { initializeLogger } from '../../log.js'; +import { AudioRecognition, type RecognitionHooks } from '../../voice/audio_recognition.js'; +import { MockWebSocket } from './_mock_ws.js'; +import { apiConnectDefaults } from './defaults.js'; +import { AdaptiveInterruptionDetector } from './interruption_detector.js'; +import { InterruptionStreamBase, InterruptionStreamSentinel } from './interruption_stream.js'; + +// --------------------------------------------------------------------------- +// Mock `ws` so the WebSocket transport can be driven deterministically. +// --------------------------------------------------------------------------- + +vi.mock('ws', async () => { + const { MockWebSocket } = await import('./_mock_ws.js'); + return { default: MockWebSocket, WebSocket: MockWebSocket }; +}); + +type MockSocket = MockWebSocket; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +initializeLogger({ pretty: false, level: 'silent' }); + +const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); + +async function waitForInstance(timeoutMs = 2000): Promise { + const start = performance.now(); + while (MockWebSocket.instances.length === 0) { + if (performance.now() - start > timeoutMs) { + throw new Error('WebSocket instance was never constructed'); + } + await sleep(5); + } + return MockWebSocket.instances[MockWebSocket.instances.length - 1]!; +} + +/** Wait until at least `n` sockets have been constructed and return the nth (1-indexed). */ +async function waitForInstanceCount(n: number, timeoutMs = 2000): Promise { + const start = performance.now(); + while (MockWebSocket.instances.length < n) { + if (performance.now() - start > timeoutMs) { + throw new Error(`expected ${n} WebSocket instances, saw ${MockWebSocket.instances.length}`); + } + await sleep(5); + } + return MockWebSocket.instances[n - 1]!; +} + +async function waitFor(predicate: () => boolean, timeoutMs = 2000): Promise { + const start = performance.now(); + while (!predicate()) { + if (performance.now() - start > timeoutMs) { + throw new Error('condition not met within timeout'); + } + await sleep(5); + } +} + +/** Resolves to 'alive' if the stream has not errored/ended within `ms`. */ +function stillAlive(errPromise: Promise, ms = 30): Promise<'alive' | unknown> { + return Promise.race([errPromise, sleep(ms).then(() => 'alive' as const)]); +} + +function sessionCreateSettings(ws: MockSocket): Record { + return JSON.parse(String(ws.sent[0])).settings; +} + +function makeAudioFrame(numSamples = 1600, sampleRate = 16000): AudioFrame { + const data = new Int16Array(numSamples); + return new AudioFrame(data, sampleRate, 1, numSamples); +} + +function createDetector(opts: { inferenceTimeout?: number } = {}): AdaptiveInterruptionDetector { + return new AdaptiveInterruptionDetector({ + baseUrl: 'http://localhost:9999', + apiKey: 'test-key', + apiSecret: 'test-secret', + ...opts, + }); +} + +/** Drain a stream's event side and return the rejection (or undefined on clean end). */ +async function readError(stream: InterruptionStreamBase): Promise { + const reader = stream.stream().getReader(); + try { + for (;;) { + const { done } = await reader.read(); + if (done) return undefined; + } + } catch (e) { + return e; + } finally { + try { + reader.releaseLock(); + } catch { + // already released + } + } +} + +beforeEach(() => { + MockWebSocket.instances.length = 0; +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +// --------------------------------------------------------------------------- +// WebSocket transport error production +// --------------------------------------------------------------------------- + +describe('interruption WebSocket transport failover', () => { + it('surfaces a non-retryable APITimeoutError on connection timeout', async () => { + const detector = createDetector(); + // Short connect timeout so the test does not wait on the default. + const stream = new InterruptionStreamBase(detector, { timeout: 50 }); + + const err = await readError(stream); + + expect(err).toBeInstanceOf(APITimeoutError); + expect((err as APIError).retryable).toBe(false); + + await stream.close(); + }); + + it('surfaces a non-retryable APIStatusError on connection 429', async () => { + const detector = createDetector(); + const stream = new InterruptionStreamBase(detector, {}); + + const errPromise = readError(stream); + const ws = await waitForInstance(); + ws.simulateUnexpectedResponse(429); + + const err = await errPromise; + + expect(err).toBeInstanceOf(APIStatusError); + expect((err as APIStatusError).statusCode).toBe(429); + expect((err as APIError).retryable).toBe(false); + + await stream.close(); + }); + + it('surfaces a non-retryable 408 APIStatusError when inference responses time out', async () => { + const inferenceTimeout = 50; + const detector = createDetector({ inferenceTimeout }); + const stream = new InterruptionStreamBase(detector, {}); + + const errPromise = readError(stream); + const ws = await waitForInstance(); + ws.simulateOpen(); + // let ensureConnection() resolve and send session.create + await sleep(5); + + // Drive overlap audio so the transport sends a request and caches it, then never + // answers — the next slice must trip the cache-timeout guard. + await stream.pushFrame(InterruptionStreamSentinel.agentSpeechStarted()); + await stream.pushFrame(InterruptionStreamSentinel.overlapSpeechStarted(500, Date.now())); + await stream.pushFrame(makeAudioFrame()); + await sleep(inferenceTimeout + 40); + await stream.pushFrame(makeAudioFrame()); + + const err = await errPromise; + + expect(err).toBeInstanceOf(APIStatusError); + expect((err as APIStatusError).statusCode).toBe(408); + expect((err as APIError).retryable).toBe(false); + + await stream.close(); + }); + + it('errors when session.created reports no threshold and the user did not override', async () => { + const detector = createDetector(); + const stream = new InterruptionStreamBase(detector, {}); + + const errPromise = readError(stream); + const ws = await waitForInstance(); + ws.simulateOpen(); + await sleep(5); + ws.simulateMessage({ type: 'session.created' }); + + const err = await errPromise; + + expect(err).toBeInstanceOf(APIStatusError); + expect((err as APIStatusError).statusCode).toBe(500); + expect((err as APIError).retryable).toBe(false); + + await stream.close(); + }); + + it('does not error when session.created carries a default_threshold', async () => { + const detector = createDetector(); + const stream = new InterruptionStreamBase(detector, {}); + + const errPromise = readError(stream); + const ws = await waitForInstance(); + ws.simulateOpen(); + await sleep(5); + ws.simulateMessage({ type: 'session.created', default_threshold: 0.42 }); + await sleep(5); + + await stream.close(); + expect(await errPromise).toBeUndefined(); + }); + + it('closes the underlying WebSocket on teardown even after the transport errored', async () => { + // The transport's flush() only runs on graceful stream completion. When the stream is torn + // down via an error (here a 408 inference timeout) the socket is still open, so close() must + // tear it down directly — otherwise the WebSocket leaks. + const inferenceTimeout = 50; + const detector = createDetector({ inferenceTimeout }); + const stream = new InterruptionStreamBase(detector, {}); + + const errPromise = readError(stream); + const ws = await waitForInstance(); + ws.simulateOpen(); + await sleep(5); + + await stream.pushFrame(InterruptionStreamSentinel.agentSpeechStarted()); + await stream.pushFrame(InterruptionStreamSentinel.overlapSpeechStarted(500, Date.now())); + await stream.pushFrame(makeAudioFrame()); + await sleep(inferenceTimeout + 40); + await stream.pushFrame(makeAudioFrame()); + + await errPromise; + expect(ws.readyState).toBe(MockWebSocket.OPEN); // socket still open after the stream errored + + await stream.close(); + + expect(ws.readyState).toBe(3); // CLOSED — close() tore the socket down despite the error + }); +}); + +// --------------------------------------------------------------------------- +// In-place reconnect on updateOptions +// --------------------------------------------------------------------------- + +describe('interruption updateOptions reconnect', () => { + it('reconnects in place with the updated threshold without erroring the stream', async () => { + const detector = createDetector(); + const stream = new InterruptionStreamBase(detector, {}); + const errPromise = readError(stream); + + const ws1 = await waitForInstance(); + ws1.simulateOpen(); + await waitFor(() => ws1.sent.length > 0); // session.create #1 + // No user override → the first session.create omits threshold (server applies its default). + expect('threshold' in sessionCreateSettings(ws1)).toBe(false); + + await stream.updateOptions({ threshold: 0.7 }); + + // A fresh socket is opened in place and the old one is closed — no error on the stream. + const ws2 = await waitForInstanceCount(2); + expect(ws2).not.toBe(ws1); + expect(ws1.readyState).toBe(3); // old socket closed + ws2.simulateOpen(); + await waitFor(() => ws2.sent.length > 0); // session.create #2 + + expect(sessionCreateSettings(ws2).threshold).toBe(0.7); + await expect(stillAlive(errPromise)).resolves.toBe('alive'); + + await stream.close(); + }); + + it('does not error the stream across more reconnects than the failover budget', async () => { + // The original bug: each updateOptions errored the stream and burned one (never-reset) retry, + // so exceeding maxRetries killed detection. In-place reconnect must keep the stream live. + const detector = createDetector(); + const stream = new InterruptionStreamBase(detector, {}); + const errPromise = readError(stream); + + const ws1 = await waitForInstance(); + ws1.simulateOpen(); + await waitFor(() => ws1.sent.length > 0); + + const rounds = apiConnectDefaults.maxRetries + 2; // deliberately exceed the retry budget + for (let i = 0; i < rounds; i++) { + await stream.updateOptions({ threshold: 0.5 + (i + 1) * 0.01 }); + const ws = await waitForInstanceCount(i + 2); + ws.simulateOpen(); + await waitFor(() => ws.sent.length > 0); + } + + expect(MockWebSocket.instances.length).toBe(rounds + 1); + await expect(stillAlive(errPromise)).resolves.toBe('alive'); + + await stream.close(); + }); + + it('errors the stream when the post-updateOptions reconnect genuinely fails', async () => { + // A real reconnect failure (here connection 429) must still surface as a stream error so the + // failover path runs. + const detector = createDetector(); + const stream = new InterruptionStreamBase(detector, {}); + const errPromise = readError(stream); + + const ws1 = await waitForInstance(); + ws1.simulateOpen(); + await waitFor(() => ws1.sent.length > 0); + + await stream.updateOptions({ threshold: 0.7 }); + const ws2 = await waitForInstanceCount(2); + ws2.simulateUnexpectedResponse(429); // the reconnect's connect attempt is rejected + + const err = await errPromise; + expect(err).toBeInstanceOf(APIStatusError); + expect((err as APIStatusError).statusCode).toBe(429); + + await stream.close(); + }); +}); + +// --------------------------------------------------------------------------- +// AudioRecognition error-emission classification +// --------------------------------------------------------------------------- + +function createHooks(): RecognitionHooks { + return { + onInterruption: vi.fn(), + onStartOfSpeech: vi.fn(), + onVADInferenceDone: vi.fn(), + onEndOfSpeech: vi.fn(), + onInterimTranscript: vi.fn(), + onFinalTranscript: vi.fn(), + onPreemptiveGeneration: vi.fn(), + retrieveChatCtx: () => ChatContext.empty(), + onEndOfTurn: vi.fn(async () => true), + }; +} + +describe('interruption failover error emission', () => { + it('emits exactly one unrecoverable error for a non-retryable transport failure', async () => { + // Mirrors how ws_transport constructs the connection-rejected error (retryable forced off). + const transportError = new APIStatusError({ + message: 'WebSocket connection rejected with status 429', + options: { statusCode: 429, retryable: false }, + }); + expect(transportError.retryable).toBe(false); + + const errors: Array<{ recoverable: boolean }> = []; + const erroringStream = { + stream: () => + new ReadableStream({ + start(controller) { + controller.error(transportError); + }, + }), + pushFrame: async () => {}, + close: async () => {}, + }; + const mockDetector = { + label: 'mock-detector', + createStream: () => erroringStream, + emitError: (e: { recoverable: boolean }) => errors.push(e), + }; + + const recognition = new AudioRecognition({ + recognitionHooks: createHooks(), + interruptionDetection: mockDetector as unknown as AdaptiveInterruptionDetector, + }); + + const ac = new AbortController(); + const task = ( + recognition as unknown as { + createInterruptionTask: ( + d: AdaptiveInterruptionDetector, + signal: AbortSignal, + ) => Promise; + } + ).createInterruptionTask(mockDetector as unknown as AdaptiveInterruptionDetector, ac.signal); + + // The non-retryable path emits the unrecoverable error then blocks in `finally` + // awaiting the (idle) input-forwarding task; abort to let it wind down. + const start = performance.now(); + while (errors.length === 0 && performance.now() - start < 2000) { + await sleep(5); + } + ac.abort(); + await task; + + expect(errors.filter((e) => e.recoverable)).toHaveLength(0); + expect(errors.filter((e) => !e.recoverable)).toHaveLength(1); + }); +}); diff --git a/agents/src/inference/interruption/interruption_session_create.test.ts b/agents/src/inference/interruption/interruption_session_create.test.ts new file mode 100644 index 000000000..7db426df9 --- /dev/null +++ b/agents/src/inference/interruption/interruption_session_create.test.ts @@ -0,0 +1,127 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Tests for the adaptive-interruption threshold negotiation contract. Ported from the Python +// `test_interruption_session_create.py` suite. +// +// The feature is server-driven: the SDK only sends `threshold` in `session.create` when the user +// explicitly overrode it, and otherwise omits the field so the server applies its fetched default. +// These tests lock that serialization contract plus the parsing of the server's `default_threshold` +// off `session.created` and the observability-only effective-threshold resolution. +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { initializeLogger } from '../../log.js'; +import { MockWebSocket } from './_mock_ws.js'; +import { AdaptiveInterruptionDetector } from './interruption_detector.js'; +import { InterruptionStreamBase } from './interruption_stream.js'; +import { resolveEffectiveThreshold, wsMessageSchema } from './ws_transport.js'; + +// --------------------------------------------------------------------------- +// Mock `ws` so the WebSocket transport can be driven deterministically. +// --------------------------------------------------------------------------- + +vi.mock('ws', async () => { + const { MockWebSocket } = await import('./_mock_ws.js'); + return { default: MockWebSocket, WebSocket: MockWebSocket }; +}); + +type MockSocket = MockWebSocket; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +initializeLogger({ pretty: false, level: 'silent' }); + +const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); + +async function waitFor(predicate: () => boolean, timeoutMs = 2000): Promise { + const start = performance.now(); + while (!predicate()) { + if (performance.now() - start > timeoutMs) { + throw new Error('condition not met within timeout'); + } + await sleep(5); + } +} + +function createDetector(opts: { threshold?: number } = {}): AdaptiveInterruptionDetector { + return new AdaptiveInterruptionDetector({ + baseUrl: 'http://localhost:9999', + apiKey: 'test-key', + apiSecret: 'test-secret', + ...opts, + }); +} + +/** Run the real connect path and return the parsed session.create payload the transport sent. */ +async function captureSessionCreate( + detector: AdaptiveInterruptionDetector, +): Promise<{ settings: Record }> { + const stream = new InterruptionStreamBase(detector, {}); + try { + await waitFor(() => MockWebSocket.instances.length > 0); + const ws = MockWebSocket.instances[MockWebSocket.instances.length - 1] as MockSocket; + ws.simulateOpen(); + await waitFor(() => ws.sent.length > 0); + return JSON.parse(String(ws.sent[0])); + } finally { + await stream.close(); + } +} + +beforeEach(() => { + MockWebSocket.instances.length = 0; +}); + +// --------------------------------------------------------------------------- +// session.create threshold serialization contract +// --------------------------------------------------------------------------- + +describe('session.create threshold', () => { + it('omits threshold when not given', async () => { + const payload = await captureSessionCreate(createDetector()); + expect('threshold' in payload.settings).toBe(false); + }); + + it('includes threshold when overridden', async () => { + const payload = await captureSessionCreate(createDetector({ threshold: 0.7 })); + expect(payload.settings.threshold).toBe(0.7); + }); +}); + +// --------------------------------------------------------------------------- +// session.created default_threshold parsing +// --------------------------------------------------------------------------- + +describe('session.created default_threshold parsing', () => { + it('parses default_threshold', () => { + const msg = wsMessageSchema.parse({ type: 'session.created', default_threshold: 0.42 }); + expect(msg).toMatchObject({ type: 'session.created', default_threshold: 0.42 }); + }); + + it('treats default_threshold as optional', () => { + const msg = wsMessageSchema.parse({ type: 'session.created' }) as { + default_threshold?: number | null; + }; + expect(msg.default_threshold ?? null).toBeNull(); + }); +}); + +// --------------------------------------------------------------------------- +// effective-threshold resolution (observability only) +// --------------------------------------------------------------------------- + +describe('resolveEffectiveThreshold', () => { + it('prefers the user override', () => { + expect(resolveEffectiveThreshold(0.7, 0.3)).toBe(0.7); + }); + + it('falls back to the server default', () => { + expect(resolveEffectiveThreshold(undefined, 0.3)).toBe(0.3); + }); + + it('returns null when neither the user nor the server provides a value', () => { + expect(resolveEffectiveThreshold(undefined, null)).toBeNull(); + }); +}); diff --git a/agents/src/inference/interruption/interruption_stream.ts b/agents/src/inference/interruption/interruption_stream.ts index 37bd2f201..486a95eb3 100644 --- a/agents/src/inference/interruption/interruption_stream.ts +++ b/agents/src/inference/interruption/interruption_stream.ts @@ -10,7 +10,6 @@ import { type StreamChannel, createStreamChannel } from '../../stream/stream_cha import { traceTypes } from '../../telemetry/index.js'; import { FRAMES_PER_SECOND, apiConnectDefaults } from './defaults.js'; import type { InterruptionDetectionError } from './errors.js'; -import { createHttpTransport } from './http_transport.js'; import { InterruptionCacheEntry } from './interruption_cache_entry.js'; import type { AdaptiveInterruptionDetector } from './interruption_detector.js'; import { @@ -99,13 +98,15 @@ export class InterruptionStreamBase { // Store reconnect function for WebSocket transport private wsReconnect?: () => Promise; + private wsClose?: () => void; + // Mutable transport options that can be updated via updateOptions() private transportOptions: { baseUrl: string; apiKey: string; apiSecret: string; sampleRate: number; - threshold: number; + threshold?: number; minFrames: number; timeout: number; connectTimeout: number; @@ -154,8 +155,8 @@ export class InterruptionStreamBase { this.options.minFrames = Math.ceil(options.minInterruptionDurationInS * FRAMES_PER_SECOND); this.transportOptions.minFrames = this.options.minFrames; } - // Trigger WebSocket reconnection if using proxy (WebSocket transport) - if (this.options.useProxy && this.wsReconnect) { + // Trigger WebSocket reconnection to apply updated settings. + if (this.wsReconnect) { await this.wsReconnect(); } } @@ -309,30 +310,20 @@ export class InterruptionStreamBase { { highWaterMark: 32 }, ); - // Second transform: transport layer (HTTP or WebSocket based on useProxy) + // Second transform: WebSocket transport layer. const transportOptions = this.transportOptions; - let transport: TransformStream; - if (this.options.useProxy) { - const wsResult = createWsTransport( - transportOptions, - getState, - setState, - handleSpanUpdate, - onRequestSent, - getAndResetNumRequests, - ); - transport = wsResult.transport; - this.wsReconnect = wsResult.reconnect; - } else { - transport = createHttpTransport( - transportOptions, - getState, - setState, - handleSpanUpdate, - getAndResetNumRequests, - ); - } + const wsResult = createWsTransport( + transportOptions, + getState, + setState, + handleSpanUpdate, + onRequestSent, + getAndResetNumRequests, + ); + const transport = wsResult.transport; + this.wsReconnect = wsResult.reconnect; + this.wsClose = wsResult.close; const eventEmitter = new TransformStream({ transform: (chunk, controller) => { @@ -415,9 +406,13 @@ export class InterruptionStreamBase { } async close(): Promise { - if (!this.inputStream.closed) await this.inputStream.close(); - this.resampler?.close(); - this.model.removeStream(this); + try { + if (!this.inputStream.closed) await this.inputStream.close(); + } finally { + this.wsClose?.(); + this.resampler?.close(); + this.model.removeStream(this); + } } } diff --git a/agents/src/inference/interruption/types.ts b/agents/src/inference/interruption/types.ts index d3aae7b95..8306684a7 100644 --- a/agents/src/inference/interruption/types.ts +++ b/agents/src/inference/interruption/types.ts @@ -22,7 +22,7 @@ export interface OverlappingSpeechEvent { */ export interface InterruptionOptions { sampleRate: number; - threshold: number; + threshold?: number; minFrames: number; maxAudioDurationInS: number; audioPrefixDurationInS: number; @@ -32,7 +32,6 @@ export interface InterruptionOptions { baseUrl: string; apiKey: string; apiSecret: string; - useProxy: boolean; } /** diff --git a/agents/src/inference/interruption/ws_transport.ts b/agents/src/inference/interruption/ws_transport.ts index 89f97e5fc..f0bd91b9f 100644 --- a/agents/src/inference/interruption/ws_transport.ts +++ b/agents/src/inference/interruption/ws_transport.ts @@ -7,6 +7,7 @@ import WebSocket from 'ws'; import { z } from 'zod'; import { APIConnectionError, APIStatusError, APITimeoutError } from '../../_exceptions.js'; import { log } from '../../log.js'; +import { Event } from '../../utils.js'; import { buildMetadataHeaders, createAccessToken } from '../utils.js'; import { InterruptionCacheEntry } from './interruption_cache_entry.js'; import type { OverlappingSpeechEvent } from './types.js'; @@ -26,7 +27,7 @@ export interface WsTransportOptions { apiKey: string; apiSecret: string; sampleRate: number; - threshold: number; + threshold?: number; minFrames: number; timeout: number; connectTimeout: number; @@ -39,9 +40,10 @@ export interface WsTransportState { cache: BoundedCache; } -const wsMessageSchema = z.discriminatedUnion('type', [ +export const wsMessageSchema = z.discriminatedUnion('type', [ z.object({ type: z.literal(MSG_SESSION_CREATED), + default_threshold: z.number().nullish(), }), z.object({ type: z.literal(MSG_SESSION_CLOSED), @@ -69,6 +71,23 @@ const wsMessageSchema = z.discriminatedUnion('type', [ type WsMessage = z.infer; +/** + * Resolve the effective interruption threshold for observability only — the server makes the + * actual decision. Precedence: user override, then server default; null when neither is known. + */ +export function resolveEffectiveThreshold( + threshold: number | undefined, + defaultThreshold: number | null | undefined, +): number | null { + if (threshold !== undefined) { + return threshold; + } + if (defaultThreshold != null) { + return defaultThreshold; + } + return null; +} + /** * Creates a WebSocket connection and waits for it to open. */ @@ -83,39 +102,46 @@ async function connectWebSocket( headers: { ...buildMetadataHeaders(), Authorization: `Bearer ${token}` }, }); - await new ThrowsPromise( - (resolve, reject) => { - const timeout = setTimeout(() => { - ws.terminate(); - reject( - new APITimeoutError({ - message: 'WebSocket connection timeout', - options: { retryable: false }, - }), - ); - }, options.connectTimeout); - ws.once('open', () => { - clearTimeout(timeout); - resolve(); - }); - ws.once('unexpected-response', (_req, res) => { - clearTimeout(timeout); - ws.terminate(); - const statusCode = res.statusCode ?? -1; - reject( - new APIStatusError({ - message: `WebSocket connection rejected with status ${statusCode}`, - options: { statusCode, retryable: false }, - }), - ); - }); - ws.once('error', (err: Error) => { - clearTimeout(timeout); - ws.terminate(); - reject(new APIConnectionError({ message: `WebSocket connection error: ${err.message}` })); - }); - }, - ); + try { + await new ThrowsPromise( + (resolve, reject) => { + const timeout = setTimeout(() => { + ws.terminate(); + reject( + new APITimeoutError({ + message: 'WebSocket connection timeout', + options: { retryable: false }, + }), + ); + }, options.connectTimeout); + ws.once('open', () => { + clearTimeout(timeout); + resolve(); + }); + ws.once('unexpected-response', (_req, res) => { + clearTimeout(timeout); + ws.terminate(); + const statusCode = res.statusCode ?? -1; + reject( + new APIStatusError({ + message: `WebSocket connection rejected with status ${statusCode}`, + options: { statusCode, retryable: false }, + }), + ); + }); + ws.once('error', (err: Error) => { + clearTimeout(timeout); + ws.terminate(); + reject(new APIConnectionError({ message: `WebSocket connection error: ${err.message}` })); + }); + }, + ); + } finally { + // Drop the connection-phase once() listeners so a later socket error can't fire the stale + // once('error') alongside the operational on('error'). Safe to remove all: the message handler + // is attached after this returns. + ws.removeAllListeners(); + } return ws; } @@ -123,6 +149,7 @@ async function connectWebSocket( export interface WsTransportResult { transport: TransformStream; reconnect: () => Promise; + close: () => void; } /** @@ -141,9 +168,15 @@ export function createWsTransport( getAndResetNumRequests?: () => number, ): WsTransportResult { const logger = log(); - let ws: WebSocket | null = null; + let activeWs: WebSocket | null = null; let outputController: TransformStreamDefaultController | null = null; + // `reconnecting` is the in-flight reconnect; transform() awaits it so it never sends on a socket + // being torn down. `closed` lets the background watcher exit its loop. + const reconnectEvent = new Event(); + let reconnecting: Promise | null = null; + let closed = false; + function setupMessageHandler(socket: WebSocket): void { socket.on('message', (data: WebSocket.Data) => { let message: WsMessage; @@ -182,31 +215,56 @@ export function createWsTransport( async function ensureConnection(): Promise< Throws > { - if (ws && ws.readyState === WebSocket.OPEN) return; - - ws = await connectWebSocket(options); - setupMessageHandler(ws); - + if (activeWs && activeWs.readyState === WebSocket.OPEN) return; + + activeWs = await connectWebSocket(options); + setupMessageHandler(activeWs); + + const settings: Record = { + sample_rate: options.sampleRate, + num_channels: 1, + min_frames: options.minFrames, + encoding: 's16le', + }; + if (options.threshold !== undefined) { + settings.threshold = options.threshold; + } const sessionCreateMsg = JSON.stringify({ type: MSG_SESSION_CREATE, - settings: { - sample_rate: options.sampleRate, - num_channels: 1, - threshold: options.threshold, - min_frames: options.minFrames, - encoding: 's16le', - }, + settings, }); - ws.send(sessionCreateMsg); + activeWs.send(sessionCreateMsg); } function handleMessage(message: WsMessage): void { const state = getState(); switch (message.type) { - case MSG_SESSION_CREATED: - logger.debug('WebSocket session created'); + case MSG_SESSION_CREATED: { + if (options.threshold === undefined && message.default_threshold == null) { + outputController?.error( + new APIStatusError({ + message: + 'adaptive interruption session created without a threshold: no user override and the server did not report a default_threshold', + options: { statusCode: 500, retryable: false }, + }), + ); + break; + } + // Observability only — the server makes the actual decision. + logger.debug( + { + defaultThreshold: message.default_threshold, + effectiveThreshold: resolveEffectiveThreshold( + options.threshold, + message.default_threshold, + ), + userOverride: options.threshold !== undefined, + }, + 'adaptive interruption session created', + ); break; + } case MSG_INTERRUPTION_DETECTED: { const createdAt = message.created_at; @@ -325,7 +383,9 @@ export function createWsTransport( } function sendAudioData(audioSlice: Int16Array): void { - if (!ws || ws.readyState !== WebSocket.OPEN) { + // Backstop for a genuine unexpected drop: throws a retryable error so the stream fails over. An + // intentional reconnect is awaited in transform() before we get here, so it won't fire then. + if (!activeWs || activeWs.readyState !== WebSocket.OPEN) { throw new APIConnectionError({ message: 'WebSocket not connected' }); } @@ -355,29 +415,69 @@ export function createWsTransport( combined.set(new Uint8Array(header), 0); combined.set(audioBytes, 8); - ws.send(combined); + activeWs.send(combined); onRequestSent?.(); } - function close(): void { - if (ws?.readyState === WebSocket.OPEN) { + // Close the current socket without ending the transport (used by both close() and reconnect). + function closeSocket(): void { + if (activeWs?.readyState === WebSocket.OPEN) { const closeMsg = JSON.stringify({ type: MSG_SESSION_CLOSE }); try { - ws.send(closeMsg); + activeWs.send(closeMsg); } catch (e: unknown) { logger.error(e, 'failed to send close message'); } } - ws?.close(1000); // signal normal websocket closure - ws = null; + // The abandoned socket can still emit 'error' during its close handshake; that handler closes + // over the shared outputController, so a late error would tear down the replacement stream. + activeWs?.removeAllListeners(); + activeWs?.close(1000); // signal normal websocket closure + activeWs = null; + } + + function close(): void { + closed = true; + reconnectEvent.set(); // wake the watcher so it exits its loop + closeSocket(); } /** - * Reconnect the WebSocket with updated options. - * This is called when options are updated via updateOptions(). + * Request an in-place reconnect to apply updated options (threshold / min frames): it does not + * error the stream and does not consume a failover retry. The work happens in reconnectWatcher(). */ async function reconnect(): Promise { - close(); + if (closed) return; + reconnectEvent.set(); + } + + // Background loop that reconnects in place when reconnect() fires, so applying new options keeps + // the stream alive and off the failover retry path. + async function reconnectWatcher(): Promise { + while (!closed) { + await reconnectEvent.wait(); + if (closed) break; + reconnectEvent.clear(); + + // `.catch` keeps `reconnecting` non-rejecting (transform awaits it); a genuine reconnect + // failure is routed to the stream as an error — a legitimate retry. + const done = (async () => { + closeSocket(); + getState().cache.clear(); // abandon the old socket's unanswered in-flight requests + await ensureConnection(); + // close() may have raced in during the await; its closeSocket() saw activeWs === null and + // was a no-op, so tear down the socket we just opened — else it leaks with live handlers. + if (closed) closeSocket(); + })().catch((e: unknown) => { + outputController?.error(e); + }); + reconnecting = done; + try { + await done; + } finally { + if (reconnecting === done) reconnecting = null; + } + } } const transport = new TransformStream< @@ -390,14 +490,19 @@ export function createWsTransport( await ensureConnection().catch((e) => { controller.error(e); }); + void reconnectWatcher(); }, - transform(chunk, controller) { + async transform(chunk, controller) { if (!(chunk instanceof Int16Array)) { controller.enqueue(chunk); return; } + // Wait out any in-flight reconnect so we don't send on a socket being torn down. It never + // rejects — a failed reconnect has already errored the stream via outputController. + if (reconnecting) await reconnecting; + // Only forwards buffered audio while overlap speech is actively on. const state = getState(); if (!state.overlapSpeechStartedAt || !state.overlapSpeechStarted) return; @@ -434,5 +539,5 @@ export function createWsTransport( { highWaterMark: 2 }, ); - return { transport, reconnect }; + return { transport, reconnect, close }; } diff --git a/agents/src/inference/utils.ts b/agents/src/inference/utils.ts index 8a9b8319e..2849ed9ab 100644 --- a/agents/src/inference/utils.ts +++ b/agents/src/inference/utils.ts @@ -56,6 +56,7 @@ export async function createAccessToken( /** * Build metadata headers for inference requests. * Includes SDK version/platform, and optionally room/job/agent IDs from the current job context. + * Includes X-LiveKit-Worker-Token when LIVEKIT_WORKER_TOKEN is set (hosted agents). */ export function buildMetadataHeaders(): Record { const headers: Record = { @@ -71,11 +72,16 @@ export function buildMetadataHeaders(): Record { if (ctx.job.id) { headers['X-LiveKit-Job-Id'] = ctx.job.id; } - if (ctx.room.isConnected) { - const agentSid = ctx.agent?.sid; - if (agentSid) { - headers['X-LiveKit-Agent-Id'] = agentSid; - } + // for hosted agents where job context is always present + const workerToken = process.env.LIVEKIT_WORKER_TOKEN; + if (workerToken) { + headers['X-LiveKit-Worker-Token'] = workerToken; + } + // Only emit the agent SID once the room is connected: before connection the + // local participant SID is unset/placeholder and would leak into requests. + const agentSid = ctx.agent?.sid; + if (ctx.room.isConnected && agentSid) { + headers['X-LiveKit-Agent-Id'] = agentSid; } } diff --git a/agents/src/voice/audio_recognition.ts b/agents/src/voice/audio_recognition.ts index c85d926bd..a46841dd3 100644 --- a/agents/src/voice/audio_recognition.ts +++ b/agents/src/voice/audio_recognition.ts @@ -1506,6 +1506,9 @@ export class AudioRecognition { if (!res) break; const { done, value: ev } = res; if (done) break; + // A healthy stream delivering events recovers the failover budget, so a later transient + // failure isn't charged against earlier ones. + numRetries = 0; this.onOverlapSpeechEvent(ev); } break; diff --git a/agents/src/worker.ts b/agents/src/worker.ts index 73d7b92d3..8238a185e 100644 --- a/agents/src/worker.ts +++ b/agents/src/worker.ts @@ -307,6 +307,12 @@ export class AgentServer { ); if (opts.workerToken) { + // Re-export into the environment so forked subprocesses inherit it (fork() + // copies process.env by default). The inference-header code in the child reads + // process.env.LIVEKIT_WORKER_TOKEN — see inference/utils.ts buildMetadataHeaders(). + // Mirrors Python worker.py, which sets os.environ before spawning job procs. + process.env.LIVEKIT_WORKER_TOKEN = opts.workerToken; + if (opts.loadFunc !== defaultCpuLoad) { this.#logger.warn( 'custom loadFunc is not supported when deploying to Cloud, using defaults', diff --git a/turbo.json b/turbo.json index 18a9cd8b1..b0bc90527 100644 --- a/turbo.json +++ b/turbo.json @@ -45,6 +45,7 @@ "LLAMA_API_KEY", "LIVEKIT_AGENT_ID", "LIVEKIT_AGENT_NAME", + "LIVEKIT_WORKER_TOKEN", "LOG_LEVEL", "OCTOAI_TOKEN", "OPENAI_API_KEY",