diff --git a/examples/server/v1/components.py b/examples/server/v1/components.py index ad9035806..8263e6c65 100644 --- a/examples/server/v1/components.py +++ b/examples/server/v1/components.py @@ -13,7 +13,7 @@ from io import BytesIO from multiprocessing import Lock from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any, AsyncIterable, Dict, List, Mapping, Optional, Sequence, Tuple from google.protobuf.timestamp_pb2 import Timestamp from PIL import Image @@ -144,6 +144,28 @@ async def play( self.is_playing = False + async def play_stream( + self, + info: AudioInfo, + chunks: AsyncIterable[bytes], + *, + extra: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> None: + """Play streamed audio chunks.""" + + self.is_playing = True + print( + f"Streaming audio: codec={info.codec}, sample_rate={info.sample_rate_hz}, channels={info.num_channels}" + ) + total = 0 + async for chunk in chunks: + total += len(chunk) + await asyncio.sleep(0) + print(f"Stream complete: {total} bytes") + self.is_playing = False + async def get_properties( self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs ) -> AudioOut.Properties: diff --git a/src/viam/components/audio_out/audio_out.py b/src/viam/components/audio_out/audio_out.py index e21f5b57e..2d21291a7 100644 --- a/src/viam/components/audio_out/audio_out.py +++ b/src/viam/components/audio_out/audio_out.py @@ -1,5 +1,5 @@ import abc -from typing import Any, AsyncIterator, Dict, Final, Optional, TypeAlias +from typing import Any, AsyncIterable, Dict, Final, Optional, TypeAlias from viam.proto.common import GetPropertiesResponse from viam.resource.types import API, RESOURCE_NAMESPACE_RDK, RESOURCE_TYPE_COMPONENT @@ -54,30 +54,32 @@ async def play( @abc.abstractmethod async def play_stream( self, - chunks: AsyncIterator[bytes], - info: Optional[AudioInfo] = None, + info: AudioInfo, + chunks: AsyncIterable[bytes], *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs, ) -> None: """ - Play audio data from a stream of chunks. + Stream audio chunks to the audio output device for playback. - :: + The caller provides an async iterable of raw audio bytes. Each chunk + must match the codec and format described by ``info``. Implementations + consume the iterable until it is exhausted, playing chunks as they arrive. - my_audio_out = AudioOut.from_robot(robot=machine, name="my_audio_out") + :: - async def audio_generator(): - for chunk in audio_chunks: + async def chunk_source(): + for chunk in pcm_chunks: yield chunk - audio_info = AudioInfo(codec=AudioCodec.PCM16, sample_rate_hz=44100, num_channels=2) - await my_audio_out.play_stream(audio_generator(), audio_info) + audio_info = AudioInfo(codec="pcm16", sample_rate_hz=22050, num_channels=1) + await my_audio_out.play_stream(audio_info, chunk_source()) Args: - chunks: async iterator of audio data chunks to play - info: (optional) information about the audio data such as codec, sample rate, and channel count + info: information about the audio stream such as codec, sample rate, and channel count + chunks: async iterable of audio bytes to play in order """ @abc.abstractmethod diff --git a/src/viam/components/audio_out/client.py b/src/viam/components/audio_out/client.py index 1d453312d..cf84ab2a1 100644 --- a/src/viam/components/audio_out/client.py +++ b/src/viam/components/audio_out/client.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterator, Dict, List, Mapping, Optional +from typing import Any, AsyncIterable, Dict, List, Mapping, Optional from grpclib.client import Channel @@ -55,8 +55,8 @@ async def play( async def play_stream( self, - chunks: AsyncIterator[bytes], - info: Optional[AudioInfo] = None, + info: AudioInfo, + chunks: AsyncIterable[bytes], *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, @@ -66,17 +66,9 @@ async def play_stream( extra = {} md = kwargs.get("metadata", self.Metadata()).proto - + init = PlayStreamRequest(init=PlayStreamInit(name=self.name, audio_info=info, extra=dict_to_struct(extra))) async with self.client.PlayStream.open(timeout=timeout, metadata=md) as stream: - await stream.send_message( - PlayStreamRequest( - init=PlayStreamInit( - name=self.name, - audio_info=info, - extra=dict_to_struct(extra), - ) - ) - ) + await stream.send_message(init) async for chunk in chunks: await stream.send_message(PlayStreamRequest(audio_chunk=PlayStreamChunk(audio_data=chunk))) await stream.end() diff --git a/src/viam/components/audio_out/service.py b/src/viam/components/audio_out/service.py index 64d257133..7fa4abdc1 100644 --- a/src/viam/components/audio_out/service.py +++ b/src/viam/components/audio_out/service.py @@ -1,5 +1,6 @@ from typing import AsyncIterator +from grpclib import GRPCError, Status from grpclib.server import Stream from viam.proto.common import ( @@ -42,29 +43,29 @@ async def Play(self, stream: Stream[PlayRequest, PlayResponse]) -> None: await stream.send_message(PlayResponse()) async def PlayStream(self, stream: Stream[PlayStreamRequest, PlayStreamResponse]) -> None: - # Receive the first message which should be the init - first_request = await stream.recv_message() - assert first_request is not None - assert first_request.HasField("init"), "First message must contain init" + first = await stream.recv_message() + if first is None: + raise GRPCError(Status.INVALID_ARGUMENT, "PlayStream: stream closed before init message") + if not first.HasField("init"): + raise GRPCError(Status.INVALID_ARGUMENT, "PlayStream: first message must be PlayStreamInit") + init = first.init + if not init.HasField("audio_info"): + raise GRPCError(Status.INVALID_ARGUMENT, "PlayStream: audio_info is required on PlayStreamInit") + audio_out = self.get_resource(init.name) - init = first_request.init - name = init.name - audio_out = self.get_resource(name) - audio_info = init.audio_info if init.HasField("audio_info") else None - extra = struct_to_dict(init.extra) - timeout = stream.deadline.time_remaining() if stream.deadline else None + async def chunks() -> AsyncIterator[bytes]: + async for msg in stream: + if msg.HasField("audio_chunk"): + yield msg.audio_chunk.audio_data - # Create an async generator for the audio chunks - async def chunk_iterator() -> AsyncIterator[bytes]: - # Read remaining messages from the stream - while True: - request = await stream.recv_message() - if request is None: - break - if request.HasField("audio_chunk"): - yield request.audio_chunk.audio_data - - await audio_out.play_stream(chunk_iterator(), audio_info, extra=extra, timeout=timeout, metadata=stream.metadata) + timeout = stream.deadline.time_remaining() if stream.deadline else None + await audio_out.play_stream( + init.audio_info, + chunks(), + extra=struct_to_dict(init.extra), + timeout=timeout, + metadata=stream.metadata, + ) await stream.send_message(PlayStreamResponse()) async def GetProperties(self, stream: Stream[GetPropertiesRequest, GetPropertiesResponse]) -> None: diff --git a/tests/mocks/components.py b/tests/mocks/components.py index 192c0de05..bc949fcc9 100644 --- a/tests/mocks/components.py +++ b/tests/mocks/components.py @@ -190,6 +190,9 @@ def __init__(self, name: str, properties: AudioOut.Properties): self.geometries = GEOMETRIES self.timeout: Optional[float] = None self.extra: Optional[Dict[str, Any]] = None + self.play_stream_called = False + self.last_streamed_info: Optional[AudioInfo] = None + self.streamed_chunks: list[bytes] = [] async def play( self, @@ -206,17 +209,15 @@ async def play( async def play_stream( self, - chunks: AsyncIterator[bytes], - info: Optional[AudioInfo] = None, + info: AudioInfo, + chunks, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs, ) -> None: self.play_stream_called = True - self.last_audio_info = info - self.extra = extra - self.timeout = timeout + self.last_streamed_info = info self.streamed_chunks = [] async for chunk in chunks: self.streamed_chunks.append(chunk) diff --git a/tests/test_audio_out.py b/tests/test_audio_out.py index 8b0e7a476..b6604f5b4 100644 --- a/tests/test_audio_out.py +++ b/tests/test_audio_out.py @@ -61,6 +61,33 @@ async def test_play_without_audio_info(self, audio_out: MockAudioOut): assert audio_out.last_audio_data == audio_data assert audio_out.last_audio_info is None + @pytest.mark.asyncio + async def test_play_stream(self, audio_out: MockAudioOut): + audio_info = AudioInfo(codec="pcm16", sample_rate_hz=22050, num_channels=1) + chunks = [b"chunk_one", b"chunk_two", b"chunk_three"] + + async def source(): + for c in chunks: + yield c + + await audio_out.play_stream(audio_info, source()) + assert audio_out.play_stream_called + assert audio_out.last_streamed_info == audio_info + assert audio_out.streamed_chunks == chunks + + @pytest.mark.asyncio + async def test_play_stream_no_chunks(self, audio_out: MockAudioOut): + audio_info = AudioInfo(codec="pcm16", sample_rate_hz=48000, num_channels=1) + + async def empty(): + if False: + yield b"" + + await audio_out.play_stream(audio_info, empty()) + assert audio_out.play_stream_called + assert audio_out.last_streamed_info == audio_info + assert audio_out.streamed_chunks == [] + @pytest.mark.asyncio async def test_get_properties(self, audio_out: MockAudioOut): properties = await audio_out.get_properties() @@ -139,6 +166,26 @@ async def test_play_without_audio_info(self, audio_out: MockAudioOut, service: A assert audio_out.last_audio_data == audio_data assert audio_out.last_audio_info is None + @pytest.mark.asyncio + async def test_play_stream(self, audio_out: MockAudioOut, service: AudioOutRPCService): + audio_info = AudioInfo(codec="pcm16", sample_rate_hz=22050, num_channels=1) + chunks = [b"a" * 4, b"b" * 4, b"c" * 4] + + async with ChannelFor([service]) as channel: + client = AudioOutServiceStub(channel) + async with client.PlayStream.open() as stream: + await stream.send_message( + PlayStreamRequest(init=PlayStreamInit(name=audio_out.name, audio_info=audio_info, extra=dict_to_struct({}))) + ) + for c in chunks: + await stream.send_message(PlayStreamRequest(audio_chunk=PlayStreamChunk(audio_data=c))) + await stream.end() + await stream.recv_message() + + assert audio_out.play_stream_called + assert audio_out.last_streamed_info == audio_info + assert audio_out.streamed_chunks == chunks + @pytest.mark.asyncio async def test_get_properties(self, audio_out: MockAudioOut, service: AudioOutRPCService): async with ChannelFor([service]) as channel: @@ -234,6 +281,39 @@ async def test_play_without_audio_info(self, audio_out: MockAudioOut, service: A assert audio_out.last_audio_data == audio_data assert audio_out.last_audio_info is None + @pytest.mark.asyncio + async def test_play_stream(self, audio_out: MockAudioOut, service: AudioOutRPCService): + async with ChannelFor([service]) as channel: + client = AudioOutClient(audio_out.name, channel) + audio_info = AudioInfo(codec="pcm16", sample_rate_hz=22050, num_channels=1) + chunks = [b"hello ", b"world", b"!"] + + async def source(): + for c in chunks: + yield c + + await client.play_stream(audio_info, source()) + + assert audio_out.play_stream_called + assert audio_out.last_streamed_info == audio_info + assert audio_out.streamed_chunks == chunks + + @pytest.mark.asyncio + async def test_play_stream_no_chunks(self, audio_out: MockAudioOut, service: AudioOutRPCService): + async with ChannelFor([service]) as channel: + client = AudioOutClient(audio_out.name, channel) + audio_info = AudioInfo(codec="pcm16", sample_rate_hz=44100, num_channels=2) + + async def empty(): + if False: + yield b"" + + await client.play_stream(audio_info, empty()) + + assert audio_out.play_stream_called + assert audio_out.last_streamed_info == audio_info + assert audio_out.streamed_chunks == [] + @pytest.mark.asyncio async def test_get_properties(self, audio_out: MockAudioOut, service: AudioOutRPCService): async with ChannelFor([service]) as channel: