Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 96 additions & 29 deletions examples/avatar_agents/audio_wave/agent_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import os
from dataclasses import asdict, dataclass
Expand All @@ -13,9 +14,10 @@
ConversationItemAddedEvent,
JobContext,
cli,
get_job_context,
inference,
)
from livekit.agents.voice.avatar import DataStreamAudioOutput
from livekit.agents.voice.avatar import AvatarSession, DataStreamAudioOutput
from livekit.agents.voice.io import PlaybackFinishedEvent, PlaybackStartedEvent
from livekit.agents.voice.room_io import ATTRIBUTE_PUBLISH_ON_BEHALF
from livekit.plugins import silero
Expand All @@ -37,30 +39,67 @@ class AvatarConnectionInfo:
"""Token for avatar worker to join"""


async def launch_avatar(ctx: JobContext, avatar_dispatcher_url: str, avatar_identity: str) -> None:
"""
Send a request to the avatar service for it to join the room
class CustomAvatarSession(AvatarSession):
"""Minimal avatar plugin backed by the example avatar dispatcher.

This function should be wrapped in a avatar plugin.
Subclasses the base :class:`AvatarSession` so we get the join-wait, metrics
and teardown for free — :meth:`AvatarSession.aclose` already removes the
avatar participant from the room. We only add the dispatcher handshake and
route the agent audio with :meth:`swap_audio_endpoint`.
"""

# create a token for the avatar to join the room
token = (
api.AccessToken()
.with_identity(avatar_identity)
.with_name("Avatar Runner")
.with_grants(api.VideoGrants(room_join=True, room=ctx.room.name))
.with_kind("agent")
.with_attributes({ATTRIBUTE_PUBLISH_ON_BEHALF: ctx.local_participant_identity})
.to_jwt()
)
def __init__(self, *, avatar_dispatcher_url: str, avatar_identity: str) -> None:
super().__init__()
self._avatar_dispatcher_url = avatar_dispatcher_url
self._avatar_identity = avatar_identity

@property
def avatar_identity(self) -> str:
return self._avatar_identity

@property
def provider(self) -> str:
return "example-datastream-avatar"

async def start(self, agent_session: AgentSession, room: rtc.Room) -> None:
await super().start(agent_session, room)

# create a token for the avatar to join the room under our identity
token = (
api.AccessToken()
.with_identity(self._avatar_identity)
.with_name("Avatar Runner")
.with_grants(api.VideoGrants(room_join=True, room=room.name))
.with_kind("agent")
.with_attributes({ATTRIBUTE_PUBLISH_ON_BEHALF: room.local_participant.identity})
.to_jwt()
)

logger.info(f"Sending connection info to avatar dispatcher {avatar_dispatcher_url}")
connection_info = AvatarConnectionInfo(room_name=ctx.room.name, url=ctx._info.url, token=token)
async with httpx.AsyncClient() as client:
response = await client.post(avatar_dispatcher_url, json=asdict(connection_info))
response.raise_for_status()
logger.info("Avatar handshake completed")
logger.info(
f"sending connection info to avatar dispatcher {self._avatar_dispatcher_url}",
extra={"identity": self._avatar_identity},
)
connection_info = AvatarConnectionInfo(
room_name=room.name, url=get_job_context()._info.url, token=token
)
async with httpx.AsyncClient() as client:
response = await client.post(self._avatar_dispatcher_url, json=asdict(connection_info))
response.raise_for_status()
logger.info("avatar handshake completed", extra={"identity": self._avatar_identity})

# route the agent audio to this avatar. swap_audio_endpoint swaps only the
# bottom sink, keeping the TranscriptSynchronizer / recorder wrappers (and any
# event listeners attached to session.output.audio) intact across hot swaps
agent_session.output.swap_audio_endpoint(
DataStreamAudioOutput(
room,
destination_identity=self._avatar_identity,
# (optional) wait for the avatar to publish video track before generating a reply
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
# the example avatar_runner uses AvatarRunner which sends lk.playback_started
wait_playback_start=True,
)
)


server = AgentServer()
Expand All @@ -79,19 +118,47 @@ async def entrypoint(ctx: JobContext):
resume_false_interruption=False,
)

await launch_avatar(ctx, AVATAR_DISPATCHER_URL, AVATAR_IDENTITY)
session.output.audio = DataStreamAudioOutput(
ctx.room,
destination_identity=AVATAR_IDENTITY,
# (optional) wait for the avatar to publish video track before generating a reply
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
# the example avatar_runner uses AvatarRunner which sends lk.playback_started
wait_playback_start=True,
avatar = CustomAvatarSession(
avatar_dispatcher_url=AVATAR_DISPATCHER_URL,
avatar_identity=AVATAR_IDENTITY,
)
await avatar.start(session, ctx.room)

# start agent with room input and room text output
await session.start(agent=agent, room=ctx.room)

swap_lock = asyncio.Lock()

@ctx.room.local_participant.register_rpc_method("swap_avatar")
async def swap_avatar(data: rtc.RpcInvocationData) -> str:
"""RPC handler: tear down the current avatar and launch a fresh one.

Trigger from a client with:
room.local_participant.perform_rpc(
destination_identity=<agent_identity>, method="swap_avatar", payload=""
)
"""
nonlocal avatar
async with swap_lock:
logger.info("swapping avatar")
# remove the current avatar first; we reuse the same identity, so the room
# can't hold both at once
await avatar.aclose()

avatar = CustomAvatarSession(
avatar_dispatcher_url=AVATAR_DISPATCHER_URL,
avatar_identity=AVATAR_IDENTITY,
)
# start() routes audio to the new avatar via swap_audio_endpoint; frames are
# buffered until it publishes its video track, so playback resumes seamlessly
await avatar.start(session, ctx.room)
await avatar.wait_for_join()

logger.info("avatar swapped")
return "ok"

# these listeners are attached to the top of the audio chain, which swap_audio_endpoint
# leaves untouched, so they keep firing across avatar swaps
@session.output.audio.on("playback_finished")
def on_playback_finished(ev: PlaybackFinishedEvent) -> None:
# the avatar should notify when the audio playback is finished
Expand Down
11 changes: 1 addition & 10 deletions livekit-agents/livekit/agents/voice/avatar/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,6 @@ async def start(self, agent_session: AgentSession, room: rtc.Room) -> None:
"release resources when the job shuts down"
)

if agent_session._started and (audio_output := agent_session.output.audio) is not None:
logger.warning(
(
"AvatarSession.start() was called after AgentSession.start(); "
"the existing audio output may be replaced by the avatar. "
"Please start the avatar session before AgentSession.start() to avoid this."
),
extra={"audio_output": audio_output.label},
)

self._room = room
self._agent_session = agent_session
self._agent_session.on("conversation_item_added", self._on_conversation_item_added)
Expand All @@ -120,6 +110,7 @@ async def wait_for_join(self, *, timeout: float | None = 30.0) -> None:
``timeout`` seconds. Pass ``timeout=None`` to wait indefinitely.
"""
if self._wait_avatar_join_task is None:
# TODO(long): fix when this called before the room is connected
return
if timeout is None:
await self._wait_avatar_join_task
Expand Down
152 changes: 138 additions & 14 deletions livekit-agents/livekit/agents/voice/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def __init__(
sample_rate: The sample rate required by the audio sink, if None, any sample rate is accepted
""" # noqa: E501
super().__init__()
self.__next_in_chain = next_in_chain
self._sample_rate = sample_rate
self.__label = label
self.__capturing = False
Expand All @@ -155,26 +154,37 @@ def __init__(
playback_position=0, interrupted=False
)

if self.next_in_chain:
self.next_in_chain.on(
"playback_finished",
lambda ev: self.on_playback_finished(
interrupted=ev.interrupted,
playback_position=ev.playback_position,
synchronized_transcript=ev.synchronized_transcript,
),
)
self.next_in_chain.on(
"playback_started", lambda ev: self.on_playback_started(created_at=ev.created_at)
)
# auto-wrap a bare leaf with a _AudioSinkProxy so the leaf can be
# hot-swapped later without disturbing wrappers above
if (
next_in_chain is not None
and next_in_chain.next_in_chain is None
and not isinstance(next_in_chain, _AudioSinkProxy)
):
next_in_chain = _AudioSinkProxy(next_in_chain)

self._next_in_chain: AudioOutput | None = next_in_chain
if next_in_chain is not None:
next_in_chain.on("playback_finished", self._forward_next_playback_finished)
next_in_chain.on("playback_started", self._forward_next_playback_started)

def _forward_next_playback_finished(self, ev: PlaybackFinishedEvent) -> None:
self.on_playback_finished(
interrupted=ev.interrupted,
playback_position=ev.playback_position,
synchronized_transcript=ev.synchronized_transcript,
)

def _forward_next_playback_started(self, ev: PlaybackStartedEvent) -> None:
self.on_playback_started(created_at=ev.created_at)

@property
def label(self) -> str:
return self.__label

@property
def next_in_chain(self) -> AudioOutput | None:
return self.__next_in_chain
return self._next_in_chain

def on_playback_started(self, *, created_at: float) -> None:
self.emit("playback_started", PlaybackStartedEvent(created_at=created_at))
Expand Down Expand Up @@ -228,6 +238,11 @@ def _reset_playback_count(self) -> None:
self.__playback_segments_count = 0
self.__playback_finished_count = 0

@property
def _pending_playback_count(self) -> int:
"""Number of captured segments that haven't reported playback_finished yet."""
return self.__playback_segments_count - self.__playback_finished_count

@property
def sample_rate(self) -> int | None:
"""The sample rate required by the audio sink, if None, any sample rate is accepted"""
Expand Down Expand Up @@ -275,6 +290,97 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(label={self.label!r}, next={self.next_in_chain!r})"


class _AudioSinkProxy(AudioOutput):
"""Stable swap point at the bottom of an audio wrapper chain.

Wrappers above hold a reference to the proxy; the actual sink lives in
``next_in_chain`` and can be replaced via :meth:`set_next_in_chain` without
disturbing them.
"""

def __init__(self, next_in_chain: AudioOutput) -> None:
super().__init__(
label="AudioSinkProxy",
capabilities=AudioOutputCapabilities(pause=True),
next_in_chain=None,
)
# whether the wrapper above us has attached the proxy; set_next_in_chain
# uses this to decide if a new/old downstream should be notified
self._attached = False
self.set_next_in_chain(next_in_chain)

self._capturing = False
self._pushed_duration: float = 0.0

@property
def next_in_chain(self) -> AudioOutput:
assert self._next_in_chain is not None
return self._next_in_chain

def on_attached(self) -> None:
self._attached = True
super().on_attached()

def on_detached(self) -> None:
self._attached = False
super().on_detached()

def set_next_in_chain(self, new: AudioOutput) -> None:
"""Replace the downstream sink, transferring playback listeners
and on_attached/on_detached state.
"""
if new is self._next_in_chain:
return

old = self._next_in_chain
if old is not None:
old.off("playback_finished", self._forward_next_playback_finished)
old.off("playback_started", self._forward_next_playback_started)
if self._pending_playback_count > 0:
# stop audio still playing on the old sink
old.clear_buffer()

if self._attached:
old.on_detached()

self._next_in_chain = new

new.on("playback_finished", self._forward_next_playback_finished)
new.on("playback_started", self._forward_next_playback_started)
if self._attached:
Comment thread
longcw marked this conversation as resolved.
new.on_attached()

# a segment already flushed to the old sink will never be reported by the
# new one; finish it as interrupted so wait_for_playout() doesn't hang
if old is not None and self._pending_playback_count > 0 and not self._capturing:
self.on_playback_finished(playback_position=self._pushed_duration, interrupted=True)

@property
def sample_rate(self) -> int | None:
return self.next_in_chain.sample_rate

@property
def can_pause(self) -> bool:
return self.next_in_chain.can_pause

async def capture_frame(self, frame: rtc.AudioFrame) -> None:
if not self._capturing:
self._capturing = True
self._pushed_duration = 0.0

await super().capture_frame(frame)
await self.next_in_chain.capture_frame(frame)
self._pushed_duration += frame.duration

def flush(self) -> None:
super().flush()
self.next_in_chain.flush()
self._capturing = False

def clear_buffer(self) -> None:
self.next_in_chain.clear_buffer()
Comment thread
longcw marked this conversation as resolved.


class TextOutput(ABC):
def __init__(self, *, label: str, next_in_chain: TextOutput | None) -> None:
self.__label = label
Expand Down Expand Up @@ -568,6 +674,24 @@ def audio(self, sink: AudioOutput | None) -> None:
else:
self._audio_sink.on_detached()

def swap_audio_endpoint(self, sink: AudioOutput) -> None:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def swap_audio_endpoint(self, sink: AudioOutput) -> None:
def replace_audio_leadl(self, sink: AudioOutput) -> None:

not a fan of the name

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be something like replace_audio_sink?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"lead" reads as the head of the chain, but the method replaces the tail, how about

  • replace_audio_endpoint
  • replace_audio_destination
  • redirect_audio
  • replace_audio_tail

"""Swap the endpoint sink at the bottom of the chain, keeping wrappers attached.

Walks the chain looking for a :class:`_AudioSinkProxy` and swaps its
downstream — leaving wrappers like :class:`TranscriptSynchronizer` and
:class:`RecorderAudioOutput` in place. Falls back to ``self.audio = sink``
when no proxy is present (no wrappers, or the chain hasn't been set up yet).

Use ``self.audio = sink`` instead to replace the entire chain.
"""
cur = self._audio_sink
while cur is not None:
if isinstance(cur, _AudioSinkProxy):
cur.set_next_in_chain(sink)
return
cur = cur.next_in_chain
self.audio = sink

@property
def transcription(self) -> TextOutput | None:
return self._transcription_sink
Expand Down
Loading
Loading