diff --git a/README.md b/README.md index 9a2d35a..bac9b7c 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,8 @@ ATHENA_NON_EXISTENT_AFFILIATE=non-existent-affiliate-id (default: thisaffiliatedoesnotexist123) - this is used to test error handling. ATHENA_NON_PERMITTED_AFFILIATE=non-permitted-affiliate-id (default: thisaffiliatedoesnothaveathenaenabled) - this is used to test error handling. +ATHENA_E2E_TESTCASE_DIR=test-case-directory (default: integrator_sample) - this is the test case directory to use for the e2e tests. +See E2E Tests section below for more details. ``` Then run the functional tests with: @@ -170,8 +172,18 @@ Then run the functional tests with: pytest -m functional ``` -To exclude the e2e tests, which require usage of the live classifier and -therefore are unsuitable for regular development runs, use: +#### E2E Tests + +The e2e tests assert that the API returns some expected _scores_ rather than +exercising different API paths. As such, they are dependent on the classifier +that you are calling through the API. Right now, there are 2 types of +classifier, benign and live. By default, the tests will run the +`integrator_sample` test set, which uses the live classifier. If you wish to +use the benign classifier instead, you may set the `ATHENA_E2E_TESTCASE_DIR` +environment variable to `benign_model`. + +Alternatively, you may disable these tests altogether, by excluding tests that +have the `e2e` marker, something like this: ```bash pytest -m 'functional and not e2e' diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 4d58a8c..2f7d4a6 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -1,19 +1,32 @@ +import asyncio import os import uuid +from asyncio import Future, Queue, Task, create_task +from collections.abc import AsyncIterator +from copy import deepcopy import cv2 as cv import numpy as np import pytest import pytest_asyncio from dotenv import load_dotenv +from grpc.aio import Channel +from resolver_athena_client.client.athena_client import AthenaClient from resolver_athena_client.client.athena_options import AthenaOptions -from resolver_athena_client.client.channel import CredentialHelper +from resolver_athena_client.client.channel import ( + CredentialHelper, + create_channel_with_credentials, +) from resolver_athena_client.client.consts import ( EXPECTED_HEIGHT, EXPECTED_WIDTH, MAX_DEPLOYMENT_ID_LENGTH, ) +from resolver_athena_client.client.models.input_model import ImageData +from resolver_athena_client.generated.athena.models_pb2 import ( + ClassificationOutput, +) def _create_base_test_image_opencv(width: int, height: int) -> np.ndarray: @@ -79,8 +92,7 @@ async def credential_helper() -> CredentialHelper: ) -@pytest.fixture -def athena_options() -> AthenaOptions: +def _load_options() -> AthenaOptions: _ = load_dotenv() host = os.getenv("ATHENA_HOST", "localhost") @@ -99,9 +111,15 @@ def athena_options() -> AthenaOptions: timeout=120.0, # Maximum duration, not forced timeout keepalive_interval=30.0, # Longer intervals for persistent streams affiliate=affiliate, + compression_quality=2, ) +@pytest.fixture +def athena_options() -> AthenaOptions: + return _load_options() + + @pytest.fixture(scope="session", params=SUPPORTED_TEST_FORMATS) def valid_formatted_image( request: pytest.FixtureRequest, @@ -144,3 +162,75 @@ def valid_formatted_image( _ = f.write(image_bytes) return image_bytes + + +class StreamingSender: + """Helper class to provide a single-send-like interface with speed + + The class provides a 'send' method that can be passed an imagedata and will + send it along a stream, and collect all results into an internal buffer. + + The 'send' method will asynchronously wait for the result and return it, + providing an interface that mimics a single request-response call, while + under the hood it is using a streaming connection for speed. + """ + + def __init__(self, grpc_channel: Channel, options: AthenaOptions) -> None: + self._request_queue: Queue[ImageData] = Queue() + self._pending_results: dict[str, Future[ClassificationOutput]] = {} + + # tests are run in series, so we gain nothing here from waiting for a + # batch that will never fill, so just send it immediately for better + # latency + streaming_options = deepcopy(options) + streaming_options.max_batch_size = 1 + + self._run_task: Task[None] = create_task( + self._run(grpc_channel, streaming_options) + ) + + async def _run(self, grpc_channel: Channel, options: AthenaOptions) -> None: + async with AthenaClient(grpc_channel, options) as client: + generator = self._send_from_queue() + responses = client.classify_images(generator) + async for response in responses: + for output in response.outputs: + if output.correlation_id in self._pending_results: + future = self._pending_results.pop( + output.correlation_id + ) + future.set_result(output) + + async def _send_from_queue(self) -> AsyncIterator[ImageData]: + """Async generator to yield requests from the queue.""" + while True: + if image_data := await self._request_queue.get(): + yield image_data + self._request_queue.task_done() + + async def send(self, image_data: ImageData) -> ClassificationOutput: + """Send an image and wait for the corresponding result.""" + if self._run_task.done(): + self._run_task.result() + + if image_data.correlation_id is None: + image_data.correlation_id = str(uuid.uuid4()) + future = asyncio.get_event_loop().create_future() + self._pending_results[image_data.correlation_id] = future + + await self._request_queue.put(image_data) + + return await future + + +@pytest_asyncio.fixture(scope="session", loop_scope="session") +async def streaming_sender( + credential_helper: CredentialHelper, +) -> StreamingSender: + """Fixture to provide a helper for sending over a streaming connection.""" + # Create gRPC channel with credentials + opts = _load_options() + channel = await create_channel_with_credentials( + opts.host, credential_helper + ) + return StreamingSender(channel, opts) diff --git a/tests/functional/e2e/test_classify_single.py b/tests/functional/e2e/test_classify_single.py index 06d26a4..e302747 100644 --- a/tests/functional/e2e/test_classify_single.py +++ b/tests/functional/e2e/test_classify_single.py @@ -2,30 +2,24 @@ import pytest -from resolver_athena_client.client.athena_client import AthenaClient -from resolver_athena_client.client.athena_options import AthenaOptions -from resolver_athena_client.client.channel import ( - CredentialHelper, - create_channel_with_credentials, -) from resolver_athena_client.client.models import ImageData +from tests.functional.conftest import StreamingSender from tests.functional.e2e.testcases.parser import ( AthenaTestCase, - load_test_cases, + load_test_cases_by_env, ) -TEST_CASES = load_test_cases("integrator_sample") +TEST_CASES = load_test_cases_by_env() FP_ERROR_TOLERANCE = 1e-4 -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") @pytest.mark.functional @pytest.mark.e2e @pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda tc: tc.id) -async def test_classify_single( - athena_options: AthenaOptions, - credential_helper: CredentialHelper, +async def test_e2e_case( + streaming_sender: StreamingSender, test_case: AthenaTestCase, ) -> None: """Functional test for ClassifySingle endpoint and API methods. @@ -34,38 +28,33 @@ async def test_classify_single( """ - # Create gRPC channel with credentials - channel = await create_channel_with_credentials( - athena_options.host, credential_helper - ) with Path.open(Path(test_case.filepath), "rb") as f: image_bytes = f.read() - async with AthenaClient(channel, athena_options) as client: - image_data = ImageData(image_bytes) + image_data = ImageData(image_bytes) - # Classify with auto-generated correlation ID - result = await client.classify_single(image_data) + # Classify with auto-generated correlation ID + result = await streaming_sender.send(image_data) - if result.error.code: - msg = f"Image Result Error: {result.error.message}" - pytest.fail(msg) + if result.error.code: + msg = f"Image Result Error: {result.error.message}" + pytest.fail(msg) - actual_output = {c.label: c.weight for c in result.classifications} - assert set(test_case.expected_output.keys()).issubset( - set(actual_output.keys()) - ), ( - "Expected output to contain labels: ", - f"{test_case.expected_output.keys() - actual_output.keys()}", + actual_output = {c.label: c.weight for c in result.classifications} + assert set(test_case.expected_output.keys()).issubset( + set(actual_output.keys()) + ), ( + "Expected output to contain labels: ", + f"{test_case.expected_output.keys() - actual_output.keys()}", + ) + actual_output = {k: actual_output[k] for k in test_case.expected_output} + + for label in test_case.expected_output: + expected = test_case.expected_output[label] + actual = actual_output[label] + diff = abs(expected - actual) + assert diff < FP_ERROR_TOLERANCE, ( + f"Weight for label '{label}' differs by more than " + f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, " + f"diff={diff}" ) - actual_output = {k: actual_output[k] for k in test_case.expected_output} - - for label in test_case.expected_output: - expected = test_case.expected_output[label] - actual = actual_output[label] - diff = abs(expected - actual) - assert diff < FP_ERROR_TOLERANCE, ( - f"Weight for label '{label}' differs by more than " - f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, " - f"diff={diff}" - ) diff --git a/tests/functional/e2e/testcases/parser.py b/tests/functional/e2e/testcases/parser.py index 80d7901..3d2275a 100644 --- a/tests/functional/e2e/testcases/parser.py +++ b/tests/functional/e2e/testcases/parser.py @@ -1,6 +1,9 @@ import json +import os from pathlib import Path +from dotenv import load_dotenv + # Path to the shared testcases directory in athena-protobufs _REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent TESTCASES_DIR = _REPO_ROOT / "athena-protobufs" / "testcases" @@ -23,6 +26,13 @@ def __init__( self.classification_labels: list[str] = classification_labels +def load_test_cases_by_env() -> list[AthenaTestCase]: + _ = load_dotenv() + return load_test_cases( + os.getenv("ATHENA_E2E_TESTCASE_DIR", "integrator_sample") + ) + + def load_test_cases(dirname: str = "benign_model") -> list[AthenaTestCase]: with Path.open( Path(TESTCASES_DIR / dirname / "expected_outputs.json"),