Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ ATHENA_NON_EXISTENT_AFFILIATE=non-existent-affiliate-id (default:
thisaffiliatedoesnotexist123) - this is used to test error handling.
ATHENA_NON_PERMITTED_AFFILIATE=non-permitted-affiliate-id (default:
thisaffiliatedoesnothaveathenaenabled) - this is used to test error handling.
ATHENA_E2E_TESTCASE_DIR=test-case-directory (default: integrator_sample) - this is the test case directory to use for the e2e tests.
See E2E Tests section below for more details.
```

Then run the functional tests with:
Expand All @@ -170,8 +172,18 @@ Then run the functional tests with:
pytest -m functional
```

To exclude the e2e tests, which require usage of the live classifier and
therefore are unsuitable for regular development runs, use:
#### E2E Tests

The e2e tests assert that the API returns some expected _scores_ rather than
exercising different API paths. As such, they are dependent on the classifier
that you are calling through the API. Right now, there are 2 types of
classifier, benign and live. By default, the tests will run the
`integrator_sample` test set, which uses the live classifier. If you wish to
use the benign classifier instead, you may set the `ATHENA_E2E_TESTCASE_DIR`
environment variable to `benign_model`.

Alternatively, you may disable these tests altogether, by excluding tests that
have the `e2e` marker, something like this:

```bash
pytest -m 'functional and not e2e'
Expand Down
96 changes: 93 additions & 3 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
import asyncio
import os
import uuid
from asyncio import Future, Queue, Task, create_task
from collections.abc import AsyncIterator
from copy import deepcopy

import cv2 as cv
import numpy as np
import pytest
import pytest_asyncio
from dotenv import load_dotenv
from grpc.aio import Channel

from resolver_athena_client.client.athena_client import AthenaClient
from resolver_athena_client.client.athena_options import AthenaOptions
from resolver_athena_client.client.channel import CredentialHelper
from resolver_athena_client.client.channel import (
CredentialHelper,
create_channel_with_credentials,
)
from resolver_athena_client.client.consts import (
EXPECTED_HEIGHT,
EXPECTED_WIDTH,
MAX_DEPLOYMENT_ID_LENGTH,
)
from resolver_athena_client.client.models.input_model import ImageData
from resolver_athena_client.generated.athena.models_pb2 import (
ClassificationOutput,
)


def _create_base_test_image_opencv(width: int, height: int) -> np.ndarray:
Expand Down Expand Up @@ -79,8 +92,7 @@ async def credential_helper() -> CredentialHelper:
)


@pytest.fixture
def athena_options() -> AthenaOptions:
def _load_options() -> AthenaOptions:
_ = load_dotenv()
host = os.getenv("ATHENA_HOST", "localhost")

Expand All @@ -99,9 +111,15 @@ def athena_options() -> AthenaOptions:
timeout=120.0, # Maximum duration, not forced timeout
keepalive_interval=30.0, # Longer intervals for persistent streams
affiliate=affiliate,
compression_quality=2,
)


@pytest.fixture
def athena_options() -> AthenaOptions:
return _load_options()


@pytest.fixture(scope="session", params=SUPPORTED_TEST_FORMATS)
def valid_formatted_image(
request: pytest.FixtureRequest,
Expand Down Expand Up @@ -144,3 +162,75 @@ def valid_formatted_image(
_ = f.write(image_bytes)

return image_bytes


class StreamingSender:
"""Helper class to provide a single-send-like interface with speed

The class provides a 'send' method that can be passed an imagedata and will
send it along a stream, and collect all results into an internal buffer.

The 'send' method will asynchronously wait for the result and return it,
providing an interface that mimics a single request-response call, while
under the hood it is using a streaming connection for speed.
"""

def __init__(self, grpc_channel: Channel, options: AthenaOptions) -> None:
self._request_queue: Queue[ImageData] = Queue()
self._pending_results: dict[str, Future[ClassificationOutput]] = {}

# tests are run in series, so we gain nothing here from waiting for a
# batch that will never fill, so just send it immediately for better
# latency
streaming_options = deepcopy(options)
streaming_options.max_batch_size = 1

self._run_task: Task[None] = create_task(
self._run(grpc_channel, streaming_options)
)

async def _run(self, grpc_channel: Channel, options: AthenaOptions) -> None:
async with AthenaClient(grpc_channel, options) as client:
generator = self._send_from_queue()
responses = client.classify_images(generator)
async for response in responses:
for output in response.outputs:
if output.correlation_id in self._pending_results:
future = self._pending_results.pop(
output.correlation_id
)
future.set_result(output)

async def _send_from_queue(self) -> AsyncIterator[ImageData]:
"""Async generator to yield requests from the queue."""
while True:
if image_data := await self._request_queue.get():
yield image_data
self._request_queue.task_done()

async def send(self, image_data: ImageData) -> ClassificationOutput:
"""Send an image and wait for the corresponding result."""
if self._run_task.done():
self._run_task.result()

if image_data.correlation_id is None:
image_data.correlation_id = str(uuid.uuid4())
future = asyncio.get_event_loop().create_future()
self._pending_results[image_data.correlation_id] = future

await self._request_queue.put(image_data)

return await future


@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def streaming_sender(
credential_helper: CredentialHelper,
) -> StreamingSender:
"""Fixture to provide a helper for sending over a streaming connection."""
# Create gRPC channel with credentials
opts = _load_options()
channel = await create_channel_with_credentials(
opts.host, credential_helper
)
return StreamingSender(channel, opts)
69 changes: 29 additions & 40 deletions tests/functional/e2e/test_classify_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,24 @@

import pytest

from resolver_athena_client.client.athena_client import AthenaClient
from resolver_athena_client.client.athena_options import AthenaOptions
from resolver_athena_client.client.channel import (
CredentialHelper,
create_channel_with_credentials,
)
from resolver_athena_client.client.models import ImageData
from tests.functional.conftest import StreamingSender
from tests.functional.e2e.testcases.parser import (
AthenaTestCase,
load_test_cases,
load_test_cases_by_env,
)

TEST_CASES = load_test_cases("integrator_sample")
TEST_CASES = load_test_cases_by_env()

FP_ERROR_TOLERANCE = 1e-4


@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.functional
@pytest.mark.e2e
@pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda tc: tc.id)
async def test_classify_single(
athena_options: AthenaOptions,
credential_helper: CredentialHelper,
async def test_e2e_case(
streaming_sender: StreamingSender,
test_case: AthenaTestCase,
) -> None:
"""Functional test for ClassifySingle endpoint and API methods.
Expand All @@ -34,38 +28,33 @@ async def test_classify_single(

"""

# Create gRPC channel with credentials
channel = await create_channel_with_credentials(
athena_options.host, credential_helper
)
with Path.open(Path(test_case.filepath), "rb") as f:
image_bytes = f.read()

async with AthenaClient(channel, athena_options) as client:
image_data = ImageData(image_bytes)
image_data = ImageData(image_bytes)

# Classify with auto-generated correlation ID
result = await client.classify_single(image_data)
# Classify with auto-generated correlation ID
result = await streaming_sender.send(image_data)

if result.error.code:
msg = f"Image Result Error: {result.error.message}"
pytest.fail(msg)
if result.error.code:
msg = f"Image Result Error: {result.error.message}"
pytest.fail(msg)

actual_output = {c.label: c.weight for c in result.classifications}
assert set(test_case.expected_output.keys()).issubset(
set(actual_output.keys())
), (
"Expected output to contain labels: ",
f"{test_case.expected_output.keys() - actual_output.keys()}",
actual_output = {c.label: c.weight for c in result.classifications}
assert set(test_case.expected_output.keys()).issubset(
set(actual_output.keys())
), (
"Expected output to contain labels: ",
f"{test_case.expected_output.keys() - actual_output.keys()}",
)
actual_output = {k: actual_output[k] for k in test_case.expected_output}

for label in test_case.expected_output:
expected = test_case.expected_output[label]
actual = actual_output[label]
diff = abs(expected - actual)
assert diff < FP_ERROR_TOLERANCE, (
f"Weight for label '{label}' differs by more than "
f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, "
f"diff={diff}"
)
actual_output = {k: actual_output[k] for k in test_case.expected_output}

for label in test_case.expected_output:
expected = test_case.expected_output[label]
actual = actual_output[label]
diff = abs(expected - actual)
assert diff < FP_ERROR_TOLERANCE, (
f"Weight for label '{label}' differs by more than "
f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, "
f"diff={diff}"
)
10 changes: 10 additions & 0 deletions tests/functional/e2e/testcases/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
import os
from pathlib import Path

from dotenv import load_dotenv

# Path to the shared testcases directory in athena-protobufs
_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent
TESTCASES_DIR = _REPO_ROOT / "athena-protobufs" / "testcases"
Expand All @@ -23,6 +26,13 @@ def __init__(
self.classification_labels: list[str] = classification_labels


def load_test_cases_by_env() -> list[AthenaTestCase]:
_ = load_dotenv()
return load_test_cases(
os.getenv("ATHENA_E2E_TESTCASE_DIR", "integrator_sample")
)


def load_test_cases(dirname: str = "benign_model") -> list[AthenaTestCase]:
with Path.open(
Path(TESTCASES_DIR / dirname / "expected_outputs.json"),
Expand Down