1+ import asyncio
12import os
23import uuid
4+ from asyncio import Future , Queue , Task , create_task
5+ from collections .abc import AsyncIterator
6+ from copy import deepcopy
37
48import cv2 as cv
59import numpy as np
610import pytest
711import pytest_asyncio
812from dotenv import load_dotenv
13+ from grpc .aio import Channel
914
15+ from resolver_athena_client .client .athena_client import AthenaClient
1016from resolver_athena_client .client .athena_options import AthenaOptions
11- from resolver_athena_client .client .channel import CredentialHelper
17+ from resolver_athena_client .client .channel import (
18+ CredentialHelper ,
19+ create_channel_with_credentials ,
20+ )
1221from resolver_athena_client .client .consts import (
1322 EXPECTED_HEIGHT ,
1423 EXPECTED_WIDTH ,
1524 MAX_DEPLOYMENT_ID_LENGTH ,
1625)
26+ from resolver_athena_client .client .models .input_model import ImageData
27+ from resolver_athena_client .generated .athena .models_pb2 import (
28+ ClassificationOutput ,
29+ )
1730
1831
1932def _create_base_test_image_opencv (width : int , height : int ) -> np .ndarray :
@@ -79,8 +92,7 @@ async def credential_helper() -> CredentialHelper:
7992 )
8093
8194
82- @pytest .fixture
83- def athena_options () -> AthenaOptions :
95+ def _load_options () -> AthenaOptions :
8496 _ = load_dotenv ()
8597 host = os .getenv ("ATHENA_HOST" , "localhost" )
8698
@@ -99,9 +111,15 @@ def athena_options() -> AthenaOptions:
99111 timeout = 120.0 , # Maximum duration, not forced timeout
100112 keepalive_interval = 30.0 , # Longer intervals for persistent streams
101113 affiliate = affiliate ,
114+ compression_quality = 2 ,
102115 )
103116
104117
118+ @pytest .fixture
119+ def athena_options () -> AthenaOptions :
120+ return _load_options ()
121+
122+
105123@pytest .fixture (scope = "session" , params = SUPPORTED_TEST_FORMATS )
106124def valid_formatted_image (
107125 request : pytest .FixtureRequest ,
@@ -144,3 +162,75 @@ def valid_formatted_image(
144162 _ = f .write (image_bytes )
145163
146164 return image_bytes
165+
166+
167+ class StreamingSender :
168+ """Helper class to provide a single-send-like interface with speed
169+
170+ The class provides a 'send' method that can be passed an imagedata and will
171+ send it along a stream, and collect all results into an internal buffer.
172+
173+ The 'send' method will asynchronously wait for the result and return it,
174+ providing an interface that mimics a single request-response call, while
175+ under the hood it is using a streaming connection for speed.
176+ """
177+
178+ def __init__ (self , grpc_channel : Channel , options : AthenaOptions ) -> None :
179+ self ._request_queue : Queue [ImageData ] = Queue ()
180+ self ._pending_results : dict [str , Future [ClassificationOutput ]] = {}
181+
182+ # tests are run in series, so we gain nothing here from waiting for a
183+ # batch that will never fill, so just send it immediately for better
184+ # latency
185+ streaming_options = deepcopy (options )
186+ streaming_options .max_batch_size = 1
187+
188+ self ._run_task : Task [None ] = create_task (
189+ self ._run (grpc_channel , streaming_options )
190+ )
191+
192+ async def _run (self , grpc_channel : Channel , options : AthenaOptions ) -> None :
193+ async with AthenaClient (grpc_channel , options ) as client :
194+ generator = self ._send_from_queue ()
195+ responses = client .classify_images (generator )
196+ async for response in responses :
197+ for output in response .outputs :
198+ if output .correlation_id in self ._pending_results :
199+ future = self ._pending_results .pop (
200+ output .correlation_id
201+ )
202+ future .set_result (output )
203+
204+ async def _send_from_queue (self ) -> AsyncIterator [ImageData ]:
205+ """Async generator to yield requests from the queue."""
206+ while True :
207+ if image_data := await self ._request_queue .get ():
208+ yield image_data
209+ self ._request_queue .task_done ()
210+
211+ async def send (self , image_data : ImageData ) -> ClassificationOutput :
212+ """Send an image and wait for the corresponding result."""
213+ if self ._run_task .done ():
214+ self ._run_task .result ()
215+
216+ if image_data .correlation_id is None :
217+ image_data .correlation_id = str (uuid .uuid4 ())
218+ future = asyncio .get_event_loop ().create_future ()
219+ self ._pending_results [image_data .correlation_id ] = future
220+
221+ await self ._request_queue .put (image_data )
222+
223+ return await future
224+
225+
226+ @pytest_asyncio .fixture (scope = "session" , loop_scope = "session" )
227+ async def streaming_sender (
228+ credential_helper : CredentialHelper ,
229+ ) -> StreamingSender :
230+ """Fixture to provide a helper for sending over a streaming connection."""
231+ # Create gRPC channel with credentials
232+ opts = _load_options ()
233+ channel = await create_channel_with_credentials (
234+ opts .host , credential_helper
235+ )
236+ return StreamingSender (channel , opts )
0 commit comments