Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/assemblyai-u3-context.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@livekit/agents-plugin-assemblyai': patch
---

Add AssemblyAI `agentContext`, `previousContextNTurns`, and `u3-rt-pro-beta-1` streaming options.
1 change: 1 addition & 0 deletions plugins/assemblyai/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export type STTModels =
| 'universal-streaming-english'
| 'universal-streaming-multilingual'
| 'u3-rt-pro'
| 'u3-rt-pro-beta-1'
// Deprecated alias — AssemblyAI maps this to `u3-rt-pro` server-side, but the
// Python plugin emits a warning and rewrites it. Kept here so TS users don't
// break if they already pass it.
Expand Down
33 changes: 32 additions & 1 deletion plugins/assemblyai/src/stt.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,40 @@
// SPDX-License-Identifier: Apache-2.0
import { VAD } from '@livekit/agents-plugin-silero';
import { stt } from '@livekit/agents-plugins-test';
import { describe, it } from 'vitest';
import { describe, expect, it } from 'vitest';
import { STT } from './stt.js';

describe('AssemblyAI options', () => {
it('accepts u3-rt-pro-beta-1', () => {
const stt = new STT({ apiKey: 'test-key', speechModel: 'u3-rt-pro-beta-1' });

expect(stt.model).toBe('u3-rt-pro-beta-1');
});

it('accepts u3-pro parameters for u3-rt-pro-beta-1', () => {
expect(
() =>
new STT({
apiKey: 'test-key',
speechModel: 'u3-rt-pro-beta-1',
prompt: 'medical dictation',
agentContext: "The agent asked for the patient's name.",
previousContextNTurns: 10,
}),
).not.toThrow();
});

it('requires a u3-rt-pro model for agentContext', () => {
expect(() => new STT({ apiKey: 'test-key', agentContext: 'hello' })).toThrow(/agentContext/);
});

it('requires a u3-rt-pro model for previousContextNTurns', () => {
expect(() => new STT({ apiKey: 'test-key', previousContextNTurns: 5 })).toThrow(
/previousContextNTurns/,
);
});
});

const hasAssemblyAIApiKey = Boolean(process.env.ASSEMBLYAI_API_KEY);

if (hasAssemblyAIApiKey) {
Expand Down
37 changes: 31 additions & 6 deletions plugins/assemblyai/src/stt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ export interface STTOptions {
maxTurnSilence?: number;
formatTurns?: boolean;
keytermsPrompt?: string[];
/** Only supported with the `u3-rt-pro` model. */
/** Only supported with the `u3-rt-pro` model family. */
prompt?: string;
/** Only supported with the `u3-rt-pro` model family. */
agentContext?: string;
/** Only supported with the `u3-rt-pro` model family. Set at connection time only. */
previousContextNTurns?: number;
vadThreshold?: number;
/**
* Enable speaker diarization. Note: AssemblyAI will return per-word speaker
Expand All @@ -92,6 +96,12 @@ const defaultSTTOptions: STTOptions = {
baseUrl: 'wss://streaming.assemblyai.com',
};

const u3ProModels: STTModels[] = ['u3-rt-pro', 'u3-rt-pro-beta-1'];

function isU3ProModel(model: STTModels | undefined) {
return model !== undefined && u3ProModels.includes(model);
}

export class STT extends stt.STT {
#opts: STTOptions;
#streams = new Set<WeakRef<SpeechStream>>();
Expand All @@ -117,8 +127,20 @@ export class STT extends stt.STT {
opts.speechModel = 'u3-rt-pro';
}

if (opts.prompt !== undefined && opts.speechModel !== 'u3-rt-pro') {
throw new Error("The 'prompt' parameter is only supported with the 'u3-rt-pro' model.");
if (opts.prompt !== undefined && !isU3ProModel(opts.speechModel)) {
throw new Error("The 'prompt' parameter is only supported with the 'u3-rt-pro' models.");
}

if (opts.agentContext !== undefined && !isU3ProModel(opts.speechModel)) {
throw new Error(
"The 'agentContext' parameter is only supported with the 'u3-rt-pro' models.",
);
}

if (opts.previousContextNTurns !== undefined && !isU3ProModel(opts.speechModel)) {
throw new Error(
"The 'previousContextNTurns' parameter is only supported with the 'u3-rt-pro' models.",
);
}

const apiKey = opts.apiKey ?? defaultSTTOptions.apiKey;
Expand Down Expand Up @@ -204,6 +226,7 @@ export class SpeechStream extends stt.SpeechStream {

const configMsg: Record<string, unknown> = { type: 'UpdateConfiguration' };
if (opts.prompt !== undefined) configMsg.prompt = opts.prompt;
if (opts.agentContext !== undefined) configMsg.agent_context = opts.agentContext;
if (opts.keytermsPrompt !== undefined) configMsg.keyterms_prompt = opts.keytermsPrompt;
if (opts.maxTurnSilence !== undefined) configMsg.max_turn_silence = opts.maxTurnSilence;
if (opts.minTurnSilence !== undefined) configMsg.min_turn_silence = opts.minTurnSilence;
Expand Down Expand Up @@ -262,17 +285,17 @@ export class SpeechStream extends stt.SpeechStream {
}

async #connectWS(): Promise<WebSocket> {
// u3-rt-pro has different silence defaultsif unset, both min and max default to 100ms.
// u3-rt-pro models have different silence defaults: if unset, both min and max are 100ms.
let minSilence = this.#opts.minTurnSilence;
let maxSilence = this.#opts.maxTurnSilence;
if (this.#opts.speechModel === 'u3-rt-pro') {
if (isU3ProModel(this.#opts.speechModel)) {
if (minSilence === undefined) minSilence = 100;
if (maxSilence === undefined) maxSilence = minSilence;
}

// Default language_detection to true for multilingual / u3-rt-pro models, false otherwise.
const defaultLanguageDetection =
this.#opts.speechModel.includes('multilingual') || this.#opts.speechModel === 'u3-rt-pro';
this.#opts.speechModel.includes('multilingual') || isU3ProModel(this.#opts.speechModel);
const languageDetection = this.#opts.languageDetection ?? defaultLanguageDetection;

const liveConfig: Record<string, unknown> = {
Expand All @@ -289,6 +312,8 @@ export class SpeechStream extends stt.SpeechStream {
: undefined,
language_detection: languageDetection,
prompt: this.#opts.prompt,
agent_context: this.#opts.agentContext,
previous_context_n_turns: this.#opts.previousContextNTurns,
vad_threshold: this.#opts.vadThreshold,
speaker_labels: this.#opts.speakerLabels,
max_speakers: this.#opts.maxSpeakers,
Expand Down