Skip to content

Commit 4cd3aef

Browse files
anna-singleton-resolveranna-singleton-resolver
andauthored
Speed up E2E test cases (#117)
* feat: configurable e2e test cases * test: load dotenv before loading e2e test cases like other fixtures * doc: update docs to include information about e2e testing * perf: streaming connection for e2e test cases * test: athena_options is function scoped again * test: remove results buffer * test: always create futures in the correct event loop --------- Co-authored-by: anna-singleton-resolver <anna.singleton@resolver.com>
1 parent 4fd580f commit 4cd3aef

4 files changed

Lines changed: 146 additions & 45 deletions

File tree

README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ ATHENA_NON_EXISTENT_AFFILIATE=non-existent-affiliate-id (default:
162162
thisaffiliatedoesnotexist123) - this is used to test error handling.
163163
ATHENA_NON_PERMITTED_AFFILIATE=non-permitted-affiliate-id (default:
164164
thisaffiliatedoesnothaveathenaenabled) - this is used to test error handling.
165+
ATHENA_E2E_TESTCASE_DIR=test-case-directory (default: integrator_sample) - this is the test case directory to use for the e2e tests.
166+
See E2E Tests section below for more details.
165167
```
166168

167169
Then run the functional tests with:
@@ -170,8 +172,18 @@ Then run the functional tests with:
170172
pytest -m functional
171173
```
172174

173-
To exclude the e2e tests, which require usage of the live classifier and
174-
therefore are unsuitable for regular development runs, use:
175+
#### E2E Tests
176+
177+
The e2e tests assert that the API returns some expected _scores_ rather than
178+
exercising different API paths. As such, they are dependent on the classifier
179+
that you are calling through the API. Right now, there are 2 types of
180+
classifier, benign and live. By default, the tests will run the
181+
`integrator_sample` test set, which uses the live classifier. If you wish to
182+
use the benign classifier instead, you may set the `ATHENA_E2E_TESTCASE_DIR`
183+
environment variable to `benign_model`.
184+
185+
Alternatively, you may disable these tests altogether, by excluding tests that
186+
have the `e2e` marker, something like this:
175187

176188
```bash
177189
pytest -m 'functional and not e2e'

tests/functional/conftest.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1+
import asyncio
12
import os
23
import uuid
4+
from asyncio import Future, Queue, Task, create_task
5+
from collections.abc import AsyncIterator
6+
from copy import deepcopy
37

48
import cv2 as cv
59
import numpy as np
610
import pytest
711
import pytest_asyncio
812
from dotenv import load_dotenv
13+
from grpc.aio import Channel
914

15+
from resolver_athena_client.client.athena_client import AthenaClient
1016
from resolver_athena_client.client.athena_options import AthenaOptions
11-
from resolver_athena_client.client.channel import CredentialHelper
17+
from resolver_athena_client.client.channel import (
18+
CredentialHelper,
19+
create_channel_with_credentials,
20+
)
1221
from resolver_athena_client.client.consts import (
1322
EXPECTED_HEIGHT,
1423
EXPECTED_WIDTH,
1524
MAX_DEPLOYMENT_ID_LENGTH,
1625
)
26+
from resolver_athena_client.client.models.input_model import ImageData
27+
from resolver_athena_client.generated.athena.models_pb2 import (
28+
ClassificationOutput,
29+
)
1730

1831

1932
def _create_base_test_image_opencv(width: int, height: int) -> np.ndarray:
@@ -79,8 +92,7 @@ async def credential_helper() -> CredentialHelper:
7992
)
8093

8194

82-
@pytest.fixture
83-
def athena_options() -> AthenaOptions:
95+
def _load_options() -> AthenaOptions:
8496
_ = load_dotenv()
8597
host = os.getenv("ATHENA_HOST", "localhost")
8698

@@ -99,9 +111,15 @@ def athena_options() -> AthenaOptions:
99111
timeout=120.0, # Maximum duration, not forced timeout
100112
keepalive_interval=30.0, # Longer intervals for persistent streams
101113
affiliate=affiliate,
114+
compression_quality=2,
102115
)
103116

104117

118+
@pytest.fixture
119+
def athena_options() -> AthenaOptions:
120+
return _load_options()
121+
122+
105123
@pytest.fixture(scope="session", params=SUPPORTED_TEST_FORMATS)
106124
def valid_formatted_image(
107125
request: pytest.FixtureRequest,
@@ -144,3 +162,75 @@ def valid_formatted_image(
144162
_ = f.write(image_bytes)
145163

146164
return image_bytes
165+
166+
167+
class StreamingSender:
168+
"""Helper class to provide a single-send-like interface with speed
169+
170+
The class provides a 'send' method that can be passed an imagedata and will
171+
send it along a stream, and collect all results into an internal buffer.
172+
173+
The 'send' method will asynchronously wait for the result and return it,
174+
providing an interface that mimics a single request-response call, while
175+
under the hood it is using a streaming connection for speed.
176+
"""
177+
178+
def __init__(self, grpc_channel: Channel, options: AthenaOptions) -> None:
179+
self._request_queue: Queue[ImageData] = Queue()
180+
self._pending_results: dict[str, Future[ClassificationOutput]] = {}
181+
182+
# tests are run in series, so we gain nothing here from waiting for a
183+
# batch that will never fill, so just send it immediately for better
184+
# latency
185+
streaming_options = deepcopy(options)
186+
streaming_options.max_batch_size = 1
187+
188+
self._run_task: Task[None] = create_task(
189+
self._run(grpc_channel, streaming_options)
190+
)
191+
192+
async def _run(self, grpc_channel: Channel, options: AthenaOptions) -> None:
193+
async with AthenaClient(grpc_channel, options) as client:
194+
generator = self._send_from_queue()
195+
responses = client.classify_images(generator)
196+
async for response in responses:
197+
for output in response.outputs:
198+
if output.correlation_id in self._pending_results:
199+
future = self._pending_results.pop(
200+
output.correlation_id
201+
)
202+
future.set_result(output)
203+
204+
async def _send_from_queue(self) -> AsyncIterator[ImageData]:
205+
"""Async generator to yield requests from the queue."""
206+
while True:
207+
if image_data := await self._request_queue.get():
208+
yield image_data
209+
self._request_queue.task_done()
210+
211+
async def send(self, image_data: ImageData) -> ClassificationOutput:
212+
"""Send an image and wait for the corresponding result."""
213+
if self._run_task.done():
214+
self._run_task.result()
215+
216+
if image_data.correlation_id is None:
217+
image_data.correlation_id = str(uuid.uuid4())
218+
future = asyncio.get_event_loop().create_future()
219+
self._pending_results[image_data.correlation_id] = future
220+
221+
await self._request_queue.put(image_data)
222+
223+
return await future
224+
225+
226+
@pytest_asyncio.fixture(scope="session", loop_scope="session")
227+
async def streaming_sender(
228+
credential_helper: CredentialHelper,
229+
) -> StreamingSender:
230+
"""Fixture to provide a helper for sending over a streaming connection."""
231+
# Create gRPC channel with credentials
232+
opts = _load_options()
233+
channel = await create_channel_with_credentials(
234+
opts.host, credential_helper
235+
)
236+
return StreamingSender(channel, opts)

tests/functional/e2e/test_classify_single.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,24 @@
22

33
import pytest
44

5-
from resolver_athena_client.client.athena_client import AthenaClient
6-
from resolver_athena_client.client.athena_options import AthenaOptions
7-
from resolver_athena_client.client.channel import (
8-
CredentialHelper,
9-
create_channel_with_credentials,
10-
)
115
from resolver_athena_client.client.models import ImageData
6+
from tests.functional.conftest import StreamingSender
127
from tests.functional.e2e.testcases.parser import (
138
AthenaTestCase,
14-
load_test_cases,
9+
load_test_cases_by_env,
1510
)
1611

17-
TEST_CASES = load_test_cases("integrator_sample")
12+
TEST_CASES = load_test_cases_by_env()
1813

1914
FP_ERROR_TOLERANCE = 1e-4
2015

2116

22-
@pytest.mark.asyncio
17+
@pytest.mark.asyncio(loop_scope="session")
2318
@pytest.mark.functional
2419
@pytest.mark.e2e
2520
@pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda tc: tc.id)
26-
async def test_classify_single(
27-
athena_options: AthenaOptions,
28-
credential_helper: CredentialHelper,
21+
async def test_e2e_case(
22+
streaming_sender: StreamingSender,
2923
test_case: AthenaTestCase,
3024
) -> None:
3125
"""Functional test for ClassifySingle endpoint and API methods.
@@ -34,38 +28,33 @@ async def test_classify_single(
3428
3529
"""
3630

37-
# Create gRPC channel with credentials
38-
channel = await create_channel_with_credentials(
39-
athena_options.host, credential_helper
40-
)
4131
with Path.open(Path(test_case.filepath), "rb") as f:
4232
image_bytes = f.read()
4333

44-
async with AthenaClient(channel, athena_options) as client:
45-
image_data = ImageData(image_bytes)
34+
image_data = ImageData(image_bytes)
4635

47-
# Classify with auto-generated correlation ID
48-
result = await client.classify_single(image_data)
36+
# Classify with auto-generated correlation ID
37+
result = await streaming_sender.send(image_data)
4938

50-
if result.error.code:
51-
msg = f"Image Result Error: {result.error.message}"
52-
pytest.fail(msg)
39+
if result.error.code:
40+
msg = f"Image Result Error: {result.error.message}"
41+
pytest.fail(msg)
5342

54-
actual_output = {c.label: c.weight for c in result.classifications}
55-
assert set(test_case.expected_output.keys()).issubset(
56-
set(actual_output.keys())
57-
), (
58-
"Expected output to contain labels: ",
59-
f"{test_case.expected_output.keys() - actual_output.keys()}",
43+
actual_output = {c.label: c.weight for c in result.classifications}
44+
assert set(test_case.expected_output.keys()).issubset(
45+
set(actual_output.keys())
46+
), (
47+
"Expected output to contain labels: ",
48+
f"{test_case.expected_output.keys() - actual_output.keys()}",
49+
)
50+
actual_output = {k: actual_output[k] for k in test_case.expected_output}
51+
52+
for label in test_case.expected_output:
53+
expected = test_case.expected_output[label]
54+
actual = actual_output[label]
55+
diff = abs(expected - actual)
56+
assert diff < FP_ERROR_TOLERANCE, (
57+
f"Weight for label '{label}' differs by more than "
58+
f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, "
59+
f"diff={diff}"
6060
)
61-
actual_output = {k: actual_output[k] for k in test_case.expected_output}
62-
63-
for label in test_case.expected_output:
64-
expected = test_case.expected_output[label]
65-
actual = actual_output[label]
66-
diff = abs(expected - actual)
67-
assert diff < FP_ERROR_TOLERANCE, (
68-
f"Weight for label '{label}' differs by more than "
69-
f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, "
70-
f"diff={diff}"
71-
)

tests/functional/e2e/testcases/parser.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import json
2+
import os
23
from pathlib import Path
34

5+
from dotenv import load_dotenv
6+
47
# Path to the shared testcases directory in athena-protobufs
58
_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent
69
TESTCASES_DIR = _REPO_ROOT / "athena-protobufs" / "testcases"
@@ -23,6 +26,13 @@ def __init__(
2326
self.classification_labels: list[str] = classification_labels
2427

2528

29+
def load_test_cases_by_env() -> list[AthenaTestCase]:
30+
_ = load_dotenv()
31+
return load_test_cases(
32+
os.getenv("ATHENA_E2E_TESTCASE_DIR", "integrator_sample")
33+
)
34+
35+
2636
def load_test_cases(dirname: str = "benign_model") -> list[AthenaTestCase]:
2737
with Path.open(
2838
Path(TESTCASES_DIR / dirname / "expected_outputs.json"),

0 commit comments

Comments
 (0)