From 922325a9b843cdac43b3bfdc248086a372341667 Mon Sep 17 00:00:00 2001 From: "rosetta-livekit-bot[bot]" <282703043+rosetta-livekit-bot[bot]@users.noreply.github.com> Date: Thu, 11 Jun 2026 17:05:12 +0000 Subject: [PATCH] feat(assemblyai): add u3 context options --- .changeset/assemblyai-u3-context.md | 5 ++++ plugins/assemblyai/src/models.ts | 1 + plugins/assemblyai/src/stt.test.ts | 33 ++++++++++++++++++++++++- plugins/assemblyai/src/stt.ts | 37 ++++++++++++++++++++++++----- 4 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 .changeset/assemblyai-u3-context.md diff --git a/.changeset/assemblyai-u3-context.md b/.changeset/assemblyai-u3-context.md new file mode 100644 index 000000000..6cf6c479c --- /dev/null +++ b/.changeset/assemblyai-u3-context.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents-plugin-assemblyai': patch +--- + +Add AssemblyAI `agentContext`, `previousContextNTurns`, and `u3-rt-pro-beta-1` streaming options. diff --git a/plugins/assemblyai/src/models.ts b/plugins/assemblyai/src/models.ts index 17e0acfd0..2413b7b8f 100644 --- a/plugins/assemblyai/src/models.ts +++ b/plugins/assemblyai/src/models.ts @@ -6,6 +6,7 @@ export type STTModels = | 'universal-streaming-english' | 'universal-streaming-multilingual' | 'u3-rt-pro' + | 'u3-rt-pro-beta-1' // Deprecated alias — AssemblyAI maps this to `u3-rt-pro` server-side, but the // Python plugin emits a warning and rewrites it. Kept here so TS users don't // break if they already pass it. diff --git a/plugins/assemblyai/src/stt.test.ts b/plugins/assemblyai/src/stt.test.ts index c63ec0101..0f41ba45b 100644 --- a/plugins/assemblyai/src/stt.test.ts +++ b/plugins/assemblyai/src/stt.test.ts @@ -3,9 +3,40 @@ // SPDX-License-Identifier: Apache-2.0 import { VAD } from '@livekit/agents-plugin-silero'; import { stt } from '@livekit/agents-plugins-test'; -import { describe, it } from 'vitest'; +import { describe, expect, it } from 'vitest'; import { STT } from './stt.js'; +describe('AssemblyAI options', () => { + it('accepts u3-rt-pro-beta-1', () => { + const stt = new STT({ apiKey: 'test-key', speechModel: 'u3-rt-pro-beta-1' }); + + expect(stt.model).toBe('u3-rt-pro-beta-1'); + }); + + it('accepts u3-pro parameters for u3-rt-pro-beta-1', () => { + expect( + () => + new STT({ + apiKey: 'test-key', + speechModel: 'u3-rt-pro-beta-1', + prompt: 'medical dictation', + agentContext: "The agent asked for the patient's name.", + previousContextNTurns: 10, + }), + ).not.toThrow(); + }); + + it('requires a u3-rt-pro model for agentContext', () => { + expect(() => new STT({ apiKey: 'test-key', agentContext: 'hello' })).toThrow(/agentContext/); + }); + + it('requires a u3-rt-pro model for previousContextNTurns', () => { + expect(() => new STT({ apiKey: 'test-key', previousContextNTurns: 5 })).toThrow( + /previousContextNTurns/, + ); + }); +}); + const hasAssemblyAIApiKey = Boolean(process.env.ASSEMBLYAI_API_KEY); if (hasAssemblyAIApiKey) { diff --git a/plugins/assemblyai/src/stt.ts b/plugins/assemblyai/src/stt.ts index 871502814..82d5edac1 100644 --- a/plugins/assemblyai/src/stt.ts +++ b/plugins/assemblyai/src/stt.ts @@ -66,8 +66,12 @@ export interface STTOptions { maxTurnSilence?: number; formatTurns?: boolean; keytermsPrompt?: string[]; - /** Only supported with the `u3-rt-pro` model. */ + /** Only supported with the `u3-rt-pro` model family. */ prompt?: string; + /** Only supported with the `u3-rt-pro` model family. */ + agentContext?: string; + /** Only supported with the `u3-rt-pro` model family. Set at connection time only. */ + previousContextNTurns?: number; vadThreshold?: number; /** * Enable speaker diarization. Note: AssemblyAI will return per-word speaker @@ -92,6 +96,12 @@ const defaultSTTOptions: STTOptions = { baseUrl: 'wss://streaming.assemblyai.com', }; +const u3ProModels: STTModels[] = ['u3-rt-pro', 'u3-rt-pro-beta-1']; + +function isU3ProModel(model: STTModels | undefined) { + return model !== undefined && u3ProModels.includes(model); +} + export class STT extends stt.STT { #opts: STTOptions; #streams = new Set>(); @@ -117,8 +127,20 @@ export class STT extends stt.STT { opts.speechModel = 'u3-rt-pro'; } - if (opts.prompt !== undefined && opts.speechModel !== 'u3-rt-pro') { - throw new Error("The 'prompt' parameter is only supported with the 'u3-rt-pro' model."); + if (opts.prompt !== undefined && !isU3ProModel(opts.speechModel)) { + throw new Error("The 'prompt' parameter is only supported with the 'u3-rt-pro' models."); + } + + if (opts.agentContext !== undefined && !isU3ProModel(opts.speechModel)) { + throw new Error( + "The 'agentContext' parameter is only supported with the 'u3-rt-pro' models.", + ); + } + + if (opts.previousContextNTurns !== undefined && !isU3ProModel(opts.speechModel)) { + throw new Error( + "The 'previousContextNTurns' parameter is only supported with the 'u3-rt-pro' models.", + ); } const apiKey = opts.apiKey ?? defaultSTTOptions.apiKey; @@ -204,6 +226,7 @@ export class SpeechStream extends stt.SpeechStream { const configMsg: Record = { type: 'UpdateConfiguration' }; if (opts.prompt !== undefined) configMsg.prompt = opts.prompt; + if (opts.agentContext !== undefined) configMsg.agent_context = opts.agentContext; if (opts.keytermsPrompt !== undefined) configMsg.keyterms_prompt = opts.keytermsPrompt; if (opts.maxTurnSilence !== undefined) configMsg.max_turn_silence = opts.maxTurnSilence; if (opts.minTurnSilence !== undefined) configMsg.min_turn_silence = opts.minTurnSilence; @@ -262,17 +285,17 @@ export class SpeechStream extends stt.SpeechStream { } async #connectWS(): Promise { - // u3-rt-pro has different silence defaults — if unset, both min and max default to 100ms. + // u3-rt-pro models have different silence defaults: if unset, both min and max are 100ms. let minSilence = this.#opts.minTurnSilence; let maxSilence = this.#opts.maxTurnSilence; - if (this.#opts.speechModel === 'u3-rt-pro') { + if (isU3ProModel(this.#opts.speechModel)) { if (minSilence === undefined) minSilence = 100; if (maxSilence === undefined) maxSilence = minSilence; } // Default language_detection to true for multilingual / u3-rt-pro models, false otherwise. const defaultLanguageDetection = - this.#opts.speechModel.includes('multilingual') || this.#opts.speechModel === 'u3-rt-pro'; + this.#opts.speechModel.includes('multilingual') || isU3ProModel(this.#opts.speechModel); const languageDetection = this.#opts.languageDetection ?? defaultLanguageDetection; const liveConfig: Record = { @@ -289,6 +312,8 @@ export class SpeechStream extends stt.SpeechStream { : undefined, language_detection: languageDetection, prompt: this.#opts.prompt, + agent_context: this.#opts.agentContext, + previous_context_n_turns: this.#opts.previousContextNTurns, vad_threshold: this.#opts.vadThreshold, speaker_labels: this.#opts.speakerLabels, max_speakers: this.#opts.maxSpeakers,