Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion examples/server/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from io import BytesIO
from multiprocessing import Lock
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
from typing import Any, AsyncIterable, Dict, List, Mapping, Optional, Sequence, Tuple

from google.protobuf.timestamp_pb2 import Timestamp
from PIL import Image
Expand Down Expand Up @@ -144,6 +144,28 @@ async def play(

self.is_playing = False

async def play_stream(
self,
info: AudioInfo,
chunks: AsyncIterable[bytes],
*,
extra: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
**kwargs,
) -> None:
"""Play streamed audio chunks."""

self.is_playing = True
print(
f"Streaming audio: codec={info.codec}, sample_rate={info.sample_rate_hz}, channels={info.num_channels}"
)
total = 0
async for chunk in chunks:
total += len(chunk)
await asyncio.sleep(0)
print(f"Stream complete: {total} bytes")
self.is_playing = False

async def get_properties(
self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs
) -> AudioOut.Properties:
Expand Down
26 changes: 14 additions & 12 deletions src/viam/components/audio_out/audio_out.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Any, AsyncIterator, Dict, Final, Optional, TypeAlias
from typing import Any, AsyncIterable, Dict, Final, Optional, TypeAlias

from viam.proto.common import GetPropertiesResponse
from viam.resource.types import API, RESOURCE_NAMESPACE_RDK, RESOURCE_TYPE_COMPONENT
Expand Down Expand Up @@ -54,30 +54,32 @@ async def play(
@abc.abstractmethod
async def play_stream(
self,
chunks: AsyncIterator[bytes],
info: Optional[AudioInfo] = None,
info: AudioInfo,
chunks: AsyncIterable[bytes],
*,
extra: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
**kwargs,
) -> None:
"""
Play audio data from a stream of chunks.
Stream audio chunks to the audio output device for playback.

::
The caller provides an async iterable of raw audio bytes. Each chunk
must match the codec and format described by ``info``. Implementations
consume the iterable until it is exhausted, playing chunks as they arrive.

my_audio_out = AudioOut.from_robot(robot=machine, name="my_audio_out")
::

async def audio_generator():
for chunk in audio_chunks:
async def chunk_source():
for chunk in pcm_chunks:
yield chunk

audio_info = AudioInfo(codec=AudioCodec.PCM16, sample_rate_hz=44100, num_channels=2)
await my_audio_out.play_stream(audio_generator(), audio_info)
audio_info = AudioInfo(codec="pcm16", sample_rate_hz=22050, num_channels=1)
await my_audio_out.play_stream(audio_info, chunk_source())

Args:
chunks: async iterator of audio data chunks to play
info: (optional) information about the audio data such as codec, sample rate, and channel count
info: information about the audio stream such as codec, sample rate, and channel count
chunks: async iterable of audio bytes to play in order
"""

@abc.abstractmethod
Expand Down
18 changes: 5 additions & 13 deletions src/viam/components/audio_out/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncIterator, Dict, List, Mapping, Optional
from typing import Any, AsyncIterable, Dict, List, Mapping, Optional

from grpclib.client import Channel

Expand Down Expand Up @@ -55,8 +55,8 @@ async def play(

async def play_stream(
self,
chunks: AsyncIterator[bytes],
info: Optional[AudioInfo] = None,
info: AudioInfo,
chunks: AsyncIterable[bytes],
*,
extra: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
Expand All @@ -66,17 +66,9 @@ async def play_stream(
extra = {}

md = kwargs.get("metadata", self.Metadata()).proto

init = PlayStreamRequest(init=PlayStreamInit(name=self.name, audio_info=info, extra=dict_to_struct(extra)))
async with self.client.PlayStream.open(timeout=timeout, metadata=md) as stream:
await stream.send_message(
PlayStreamRequest(
init=PlayStreamInit(
name=self.name,
audio_info=info,
extra=dict_to_struct(extra),
)
)
)
await stream.send_message(init)
async for chunk in chunks:
await stream.send_message(PlayStreamRequest(audio_chunk=PlayStreamChunk(audio_data=chunk)))
await stream.end()
Expand Down
43 changes: 22 additions & 21 deletions src/viam/components/audio_out/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import AsyncIterator

from grpclib import GRPCError, Status
from grpclib.server import Stream

from viam.proto.common import (
Expand Down Expand Up @@ -42,29 +43,29 @@ async def Play(self, stream: Stream[PlayRequest, PlayResponse]) -> None:
await stream.send_message(PlayResponse())

async def PlayStream(self, stream: Stream[PlayStreamRequest, PlayStreamResponse]) -> None:
# Receive the first message which should be the init
first_request = await stream.recv_message()
assert first_request is not None
assert first_request.HasField("init"), "First message must contain init"
first = await stream.recv_message()
if first is None:
raise GRPCError(Status.INVALID_ARGUMENT, "PlayStream: stream closed before init message")
if not first.HasField("init"):
raise GRPCError(Status.INVALID_ARGUMENT, "PlayStream: first message must be PlayStreamInit")
init = first.init
if not init.HasField("audio_info"):
raise GRPCError(Status.INVALID_ARGUMENT, "PlayStream: audio_info is required on PlayStreamInit")
audio_out = self.get_resource(init.name)

init = first_request.init
name = init.name
audio_out = self.get_resource(name)
audio_info = init.audio_info if init.HasField("audio_info") else None
extra = struct_to_dict(init.extra)
timeout = stream.deadline.time_remaining() if stream.deadline else None
async def chunks() -> AsyncIterator[bytes]:
async for msg in stream:
if msg.HasField("audio_chunk"):
yield msg.audio_chunk.audio_data

# Create an async generator for the audio chunks
async def chunk_iterator() -> AsyncIterator[bytes]:
# Read remaining messages from the stream
while True:
request = await stream.recv_message()
if request is None:
break
if request.HasField("audio_chunk"):
yield request.audio_chunk.audio_data

await audio_out.play_stream(chunk_iterator(), audio_info, extra=extra, timeout=timeout, metadata=stream.metadata)
timeout = stream.deadline.time_remaining() if stream.deadline else None
await audio_out.play_stream(
init.audio_info,
chunks(),
extra=struct_to_dict(init.extra),
timeout=timeout,
metadata=stream.metadata,
)
await stream.send_message(PlayStreamResponse())

async def GetProperties(self, stream: Stream[GetPropertiesRequest, GetPropertiesResponse]) -> None:
Expand Down
11 changes: 6 additions & 5 deletions tests/mocks/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def __init__(self, name: str, properties: AudioOut.Properties):
self.geometries = GEOMETRIES
self.timeout: Optional[float] = None
self.extra: Optional[Dict[str, Any]] = None
self.play_stream_called = False
self.last_streamed_info: Optional[AudioInfo] = None
self.streamed_chunks: list[bytes] = []

async def play(
self,
Expand All @@ -206,17 +209,15 @@ async def play(

async def play_stream(
self,
chunks: AsyncIterator[bytes],
info: Optional[AudioInfo] = None,
info: AudioInfo,
chunks,
*,
extra: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
**kwargs,
) -> None:
self.play_stream_called = True
self.last_audio_info = info
self.extra = extra
self.timeout = timeout
self.last_streamed_info = info
self.streamed_chunks = []
async for chunk in chunks:
self.streamed_chunks.append(chunk)
Expand Down
80 changes: 80 additions & 0 deletions tests/test_audio_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,33 @@ async def test_play_without_audio_info(self, audio_out: MockAudioOut):
assert audio_out.last_audio_data == audio_data
assert audio_out.last_audio_info is None

@pytest.mark.asyncio
async def test_play_stream(self, audio_out: MockAudioOut):
audio_info = AudioInfo(codec="pcm16", sample_rate_hz=22050, num_channels=1)
chunks = [b"chunk_one", b"chunk_two", b"chunk_three"]

async def source():
for c in chunks:
yield c

await audio_out.play_stream(audio_info, source())
assert audio_out.play_stream_called
assert audio_out.last_streamed_info == audio_info
assert audio_out.streamed_chunks == chunks

@pytest.mark.asyncio
async def test_play_stream_no_chunks(self, audio_out: MockAudioOut):
audio_info = AudioInfo(codec="pcm16", sample_rate_hz=48000, num_channels=1)

async def empty():
if False:
yield b""

await audio_out.play_stream(audio_info, empty())
assert audio_out.play_stream_called
assert audio_out.last_streamed_info == audio_info
assert audio_out.streamed_chunks == []

@pytest.mark.asyncio
async def test_get_properties(self, audio_out: MockAudioOut):
properties = await audio_out.get_properties()
Expand Down Expand Up @@ -139,6 +166,26 @@ async def test_play_without_audio_info(self, audio_out: MockAudioOut, service: A
assert audio_out.last_audio_data == audio_data
assert audio_out.last_audio_info is None

@pytest.mark.asyncio
async def test_play_stream(self, audio_out: MockAudioOut, service: AudioOutRPCService):
audio_info = AudioInfo(codec="pcm16", sample_rate_hz=22050, num_channels=1)
chunks = [b"a" * 4, b"b" * 4, b"c" * 4]

async with ChannelFor([service]) as channel:
client = AudioOutServiceStub(channel)
async with client.PlayStream.open() as stream:
await stream.send_message(
PlayStreamRequest(init=PlayStreamInit(name=audio_out.name, audio_info=audio_info, extra=dict_to_struct({})))
)
for c in chunks:
await stream.send_message(PlayStreamRequest(audio_chunk=PlayStreamChunk(audio_data=c)))
await stream.end()
await stream.recv_message()

assert audio_out.play_stream_called
assert audio_out.last_streamed_info == audio_info
assert audio_out.streamed_chunks == chunks

@pytest.mark.asyncio
async def test_get_properties(self, audio_out: MockAudioOut, service: AudioOutRPCService):
async with ChannelFor([service]) as channel:
Expand Down Expand Up @@ -234,6 +281,39 @@ async def test_play_without_audio_info(self, audio_out: MockAudioOut, service: A
assert audio_out.last_audio_data == audio_data
assert audio_out.last_audio_info is None

@pytest.mark.asyncio
async def test_play_stream(self, audio_out: MockAudioOut, service: AudioOutRPCService):
async with ChannelFor([service]) as channel:
client = AudioOutClient(audio_out.name, channel)
audio_info = AudioInfo(codec="pcm16", sample_rate_hz=22050, num_channels=1)
chunks = [b"hello ", b"world", b"!"]

async def source():
for c in chunks:
yield c

await client.play_stream(audio_info, source())

assert audio_out.play_stream_called
assert audio_out.last_streamed_info == audio_info
assert audio_out.streamed_chunks == chunks

@pytest.mark.asyncio
async def test_play_stream_no_chunks(self, audio_out: MockAudioOut, service: AudioOutRPCService):
async with ChannelFor([service]) as channel:
client = AudioOutClient(audio_out.name, channel)
audio_info = AudioInfo(codec="pcm16", sample_rate_hz=44100, num_channels=2)

async def empty():
if False:
yield b""

await client.play_stream(audio_info, empty())

assert audio_out.play_stream_called
assert audio_out.last_streamed_info == audio_info
assert audio_out.streamed_chunks == []

@pytest.mark.asyncio
async def test_get_properties(self, audio_out: MockAudioOut, service: AudioOutRPCService):
async with ChannelFor([service]) as channel:
Expand Down
Loading