Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/amd-prediction-event.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@livekit/agents': patch
---

Port AMD result → `AMDPredictionEvent` rename and event emission from Python (livekit/agents#5621). The `AMD` detector now extends an `EventEmitter` and emits `amd_prediction` with an `AMDPredictionEvent` payload (`type: 'amd_prediction'`, plus the existing `category` / `reason` / `transcript` / `rawResponse` / `isMachine` fields and a new optional `speechDurationMs`). `AMDResult` is kept as a deprecated type alias for `AMDPredictionEvent` for backward compatibility. The remote-session wire serialization for AMD predictions is intentionally deferred until `@livekit/protocol` ships the corresponding `AgentSessionEvent.AmdPrediction` / `AmdCategory` message types; a TODO marker has been left in `voice/remote_session.ts` where it will be wired up.
11 changes: 10 additions & 1 deletion agents/src/voice/amd.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { LLM, type LLMStream } from '../llm/llm.js';
import type { ToolChoice, ToolContext } from '../llm/tool_context.js';
import type { APIConnectOptions } from '../types.js';
import type { AgentSession } from './agent_session.js';
import { AMD, AMDCategory } from './amd.js';
import { AMD, AMDCategory, type AMDPredictionEvent } from './amd.js';
import { AgentSessionEventTypes } from './events.js';

class StaticLLM extends LLM {
Expand Down Expand Up @@ -70,6 +70,9 @@ describe('AMD', () => {
llm.on('error', () => {});
const amd = new AMD(asAgentSession(session), { llm, detectionTimeoutMs: 50 });

const events: AMDPredictionEvent[] = [];
amd.on('amd_prediction', (ev) => events.push(ev));

const promise = amd.execute();
session.emit(AgentSessionEventTypes.UserInputTranscribed, {
type: 'user_input_transcribed',
Expand All @@ -81,12 +84,18 @@ describe('AMD', () => {
});

await expect(promise).resolves.toMatchObject({
type: 'amd_prediction',
category: AMDCategory.MACHINE_VM,
isMachine: true,
});
expect(session.pauseReplyAuthorization).toHaveBeenCalledTimes(1);
expect(session.resumeReplyAuthorization).toHaveBeenCalled();
expect(session.interrupt).toHaveBeenCalledWith({ force: true });
expect(events).toHaveLength(1);
expect(events[0]).toMatchObject({
type: 'amd_prediction',
category: AMDCategory.MACHINE_VM,
});
});

it('should classify unavailable mailbox as machine', async () => {
Expand Down
56 changes: 45 additions & 11 deletions agents/src/voice/amd.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// SPDX-FileCopyrightText: 2026 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import type { TypedEventEmitter } from '@livekit/typed-emitter';
import type { Span } from '@opentelemetry/api';
import EventEmitter from 'node:events';
import { ChatContext } from '../llm/chat_context.js';
import { LLM } from '../llm/llm.js';
import { traceTypes, tracer } from '../telemetry/index.js';
Expand All @@ -21,14 +23,26 @@ export enum AMDCategory {
UNCERTAIN = 'uncertain',
}

export interface AMDResult {
export interface AMDPredictionEvent {
type: 'amd_prediction';
category: AMDCategory;
transcript: string;
reason: string;
rawResponse: string;
isMachine: boolean;
/** Duration of detected user speech in milliseconds, when known. */
speechDurationMs?: number;
}

/**
* @deprecated Use {@link AMDPredictionEvent}.
*/
export type AMDResult = AMDPredictionEvent;

export type AMDCallbacks = {
amd_prediction: (ev: AMDPredictionEvent) => void;
};

export interface AMDOptions {
llm?: LLM;
interruptOnMachine?: boolean;
Expand Down Expand Up @@ -80,8 +94,11 @@ Do not include markdown fences or extra text.`;
* Mirrors Python's `_AMDClassifier` two-gate architecture:
* a result is only emitted when both a **verdict** (from LLM or heuristic) and
* a **silence gate** (from VAD or timeout) are satisfied.
*
* Emits `amd_prediction` with an {@link AMDPredictionEvent} payload when a
* prediction settles. Callers can subscribe via `amd.on('amd_prediction', ...)`.
*/
export class AMD {
export class AMD extends (EventEmitter as new () => TypedEventEmitter<AMDCallbacks>) {
private readonly llm: LLM;
private readonly interruptOnMachine: boolean;
private readonly noSpeechTimeoutMs: number;
Expand All @@ -92,23 +109,25 @@ export class AMD {
private active = false;
private settled = false;
private transcriptParts: string[] = [];
private verdictResult: AMDResult | undefined;
private verdictResult: AMDPredictionEvent | undefined;
private machineSilenceReached = false;
private speechStartedAt: number | undefined;
private speechEndedAt: number | undefined;
private detectGeneration = 0;

private noSpeechTimer: ReturnType<typeof setTimeout> | undefined;
private detectionTimer: ReturnType<typeof setTimeout> | undefined;
private silenceTimer: ReturnType<typeof setTimeout> | undefined;

private resolveRun: ((value: AMDResult) => void) | undefined;
private resolveRun: ((value: AMDPredictionEvent) => void) | undefined;
private rejectRun: ((reason?: unknown) => void) | undefined;
private span: Span | undefined;

constructor(
private readonly session: AgentSession,
options: AMDOptions = {},
) {
super();
const llm = options.llm ?? this.resolveSessionLLM();
if (!llm) {
throw new Error(
Expand All @@ -125,7 +144,7 @@ export class AMD {

// ─── public API ──────────────────────────────────────────────────────────────

async execute(): Promise<AMDResult> {
async execute(): Promise<AMDPredictionEvent> {
return tracer.startActiveSpan(
async (span) => {
if (this.active) {
Expand All @@ -146,7 +165,7 @@ export class AMD {
}

try {
const result = await new Promise<AMDResult>((resolve, reject) => {
const result = await new Promise<AMDPredictionEvent>((resolve, reject) => {
this.resolveRun = resolve;
this.rejectRun = reject;
this.subscribe();
Expand Down Expand Up @@ -181,6 +200,7 @@ export class AMD {
this.verdictResult = undefined;
this.machineSilenceReached = false;
this.speechStartedAt = undefined;
this.speechEndedAt = undefined;
this.detectGeneration = 0;
this.resolveRun = undefined;
this.rejectRun = undefined;
Expand Down Expand Up @@ -220,7 +240,7 @@ export class AMD {
* Ref: python classifier.py `_set_verdict` — stores the LLM/heuristic verdict.
* Emission is deferred until the silence gate also opens.
*/
private setVerdict(result: AMDResult): void {
private setVerdict(result: AMDPredictionEvent): void {
this.verdictResult = result;
this.tryEmitResult();
}
Expand All @@ -237,7 +257,7 @@ export class AMD {
this.finish(this.verdictResult);
}

private finish(result: AMDResult): void {
private finish(result: AMDPredictionEvent): void {
if (this.settled) {
return;
}
Expand All @@ -247,6 +267,7 @@ export class AMD {
if (result.isMachine && this.interruptOnMachine) {
this.session.interrupt({ force: true }).await.catch(() => {});
}
this.emit('amd_prediction', result);
this.resolveRun?.(result);
}

Expand All @@ -260,11 +281,13 @@ export class AMD {
private onSilenceTimerFired(category?: AMDCategory, reason?: string): void {
if (category && reason && !this.verdictResult) {
this.setVerdict({
type: 'amd_prediction',
category,
reason,
transcript: this.joinTranscript(),
rawResponse: '',
isMachine: isMachineCategory(category),
speechDurationMs: this.computeSpeechDurationMs(),
});
}
this.machineSilenceReached = true;
Expand Down Expand Up @@ -304,6 +327,7 @@ export class AMD {
return;
}

this.speechEndedAt = ev.createdAt;
const speechDurationMs = ev.createdAt - (this.speechStartedAt ?? ev.createdAt);

this.clearTimer('silence');
Expand Down Expand Up @@ -390,14 +414,22 @@ export class AMD {
return this.transcriptParts.join('\n');
}

private setSpanAttributes(result: AMDResult): void {
private computeSpeechDurationMs(): number | undefined {
if (this.speechStartedAt === undefined) {
return undefined;
}
const end = this.speechEndedAt ?? Date.now();
return Math.max(0, end - this.speechStartedAt);
}

private setSpanAttributes(result: AMDPredictionEvent): void {
this.span?.setAttribute(traceTypes.ATTR_AMD_CATEGORY, result.category);
this.span?.setAttribute(traceTypes.ATTR_AMD_REASON, result.reason);
this.span?.setAttribute(traceTypes.ATTR_AMD_IS_MACHINE, result.isMachine);
this.span?.setAttribute(traceTypes.ATTR_USER_TRANSCRIPT, result.transcript);
}

private async detect(transcript: string): Promise<AMDResult> {
private async detect(transcript: string): Promise<AMDPredictionEvent> {
const chatCtx = new ChatContext();
chatCtx.addMessage({ role: 'system', content: AMD_PROMPT });
chatCtx.addMessage({
Expand All @@ -417,14 +449,16 @@ export class AMD {

const parsed = this.parseDetection(rawResponse);
return {
type: 'amd_prediction',
...parsed,
transcript,
rawResponse,
isMachine: isMachineCategory(parsed.category),
speechDurationMs: this.computeSpeechDurationMs(),
};
}

private parseDetection(rawResponse: string): Pick<AMDResult, 'category' | 'reason'> {
private parseDetection(rawResponse: string): Pick<AMDPredictionEvent, 'category' | 'reason'> {
const normalized = rawResponse.trim();
const jsonStart = normalized.indexOf('{');
const jsonEnd = normalized.lastIndexOf('}');
Expand Down
6 changes: 6 additions & 0 deletions agents/src/voice/remote_session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,12 @@ export class SessionHost {
this.emitEvent({ case: 'overlappingSpeech', value });
};

// TODO(amd_prediction): mirror Python `_on_amd_prediction` once
// `@livekit/protocol` ships `AgentSessionEvent.AmdPrediction` / `AmdCategory`.
// The AMD detector now emits an `amd_prediction` event (see voice/amd.ts);
// wire this handler up via `amd.on('amd_prediction', ...)` and forward the
// payload through `emitEvent({ case: 'amdPrediction', value: ... })`.

private onMetricsCollected = (event: MetricsCollectedEvent): void => {
if (!this.session) return;
this.emitEvent(
Expand Down
Loading