Skip to content

Commit eae0959

Browse files
author
Dylan Huang
committed
test_pydantic_multi_agent runs
1 parent badc4d9 commit eae0959

File tree

5 files changed

+38
-22
lines changed

5 files changed

+38
-22
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
experiment_results/
2+
13
# Byte-compiled / optimized / DLL files
24
__pycache__/
35
*.py[cod]

eval_protocol/pytest/parameterize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def pytest_parametrize(
4444
argnames.append("evaluation_test_kwargs")
4545

4646
argvalues: list[ParameterSet | Sequence[object] | object] = []
47-
param_tuples: list[tuple[object, ...]] = []
4847
for combo in combinations:
4948
dataset, cp, messages, rows, etk = combo
5049
param_tuple: list[object] = []
@@ -63,7 +62,7 @@ def pytest_parametrize(
6362
raise ValueError(
6463
f"The length of argnames ({len(argnames)}) is not the same as the length of param_tuple ({len(param_tuple)})"
6564
)
66-
param_tuples.append(tuple(param_tuple))
65+
argvalues.append(tuple(param_tuple))
6766

6867
return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues)
6968

eval_protocol/pytest/validate_signature.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
11
from collections.abc import Sequence
22
from inspect import Signature
3+
from typing import get_origin, get_args
34

45
from eval_protocol.models import CompletionParams, EvaluationRow
56
from eval_protocol.pytest.types import EvaluationTestMode
67

78

9+
def _is_list_of_evaluation_row(annotation) -> bool: # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
10+
"""Check if annotation is list[EvaluationRow] or equivalent."""
11+
origin = get_origin(annotation) # pyright: ignore[reportUnknownArgumentType, reportAny]
12+
if origin is not list:
13+
return False
14+
15+
args = get_args(annotation)
16+
if len(args) != 1:
17+
return False
18+
19+
# Check if the single argument is EvaluationRow or equivalent
20+
arg = args[0] # pyright: ignore[reportAny]
21+
return arg is EvaluationRow or str(arg) == str(EvaluationRow) # pyright: ignore[reportAny]
22+
23+
824
def validate_signature(
925
signature: Signature, mode: EvaluationTestMode, completion_params: Sequence[CompletionParams | None] | None
1026
) -> None:
@@ -29,11 +45,13 @@ def validate_signature(
2945
raise ValueError("In groupwise mode, your eval function must have a parameter named 'rows'")
3046

3147
# validate that "Rows" is of type List[EvaluationRow]
32-
if signature.parameters["rows"].annotation is not list[EvaluationRow]: # pyright: ignore[reportAny]
33-
raise ValueError("In groupwise mode, the 'rows' parameter must be of type List[EvaluationRow")
48+
if not _is_list_of_evaluation_row(signature.parameters["rows"].annotation): # pyright: ignore[reportAny]
49+
raise ValueError(
50+
f"In groupwise mode, the 'rows' parameter must be of type List[EvaluationRow]. Got {str(signature.parameters['rows'].annotation)} instead" # pyright: ignore[reportAny]
51+
)
3452

3553
# validate that the function has a return type of List[EvaluationRow]
36-
if signature.return_annotation is not list[EvaluationRow]: # pyright: ignore[reportAny]
54+
if not _is_list_of_evaluation_row(signature.return_annotation): # pyright: ignore[reportAny]
3755
raise ValueError("In groupwise mode, your eval function must return a list of EvaluationRow instances")
3856
if completion_params is not None and len(completion_params) < 2:
3957
raise ValueError("In groupwise mode, you must provide at least 2 completion parameters")
@@ -43,9 +61,11 @@ def validate_signature(
4361
raise ValueError("In all mode, your eval function must have a parameter named 'rows'")
4462

4563
# validate that "Rows" is of type List[EvaluationRow]
46-
if signature.parameters["rows"].annotation is not list[EvaluationRow]: # pyright: ignore[reportAny]
47-
raise ValueError("In all mode, the 'rows' parameter must be of type List[EvaluationRow")
64+
if not _is_list_of_evaluation_row(signature.parameters["rows"].annotation): # pyright: ignore[reportAny]
65+
raise ValueError(
66+
f"In all mode, the 'rows' parameter must be of type list[EvaluationRow]. Got {str(signature.parameters['rows'].annotation)} instead" # pyright: ignore[reportAny]
67+
)
4868

4969
# validate that the function has a return type of List[EvaluationRow]
50-
if signature.return_annotation is not list[EvaluationRow]: # pyright: ignore[reportAny]
70+
if not _is_list_of_evaluation_row(signature.return_annotation): # pyright: ignore[reportAny]
5171
raise ValueError("In all mode, your eval function must return a list of EvaluationRow instances")

tests/pytest/test_get_metadata.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
from typing import Dict, List
32

43
from eval_protocol.pytest import evaluation_test
54
from eval_protocol.models import EvaluationRow, Message
@@ -19,16 +18,16 @@
1918
max_concurrent_rollouts=5,
2019
max_concurrent_evaluations=10,
2120
)
22-
def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
21+
def test_pytest_async(rows: list[EvaluationRow]) -> list[EvaluationRow]:
2322
"""Run math evaluation on sample dataset using pytest interface."""
2423
return rows
2524

2625

2726
def test_pytest_func_metainfo():
2827
assert hasattr(test_pytest_async, "_origin_func")
29-
origin_func = test_pytest_async._origin_func
30-
assert not asyncio.iscoroutinefunction(origin_func)
28+
origin_func = test_pytest_async._origin_func # pyright: ignore[reportAny, reportFunctionMemberAccess]
29+
assert not asyncio.iscoroutinefunction(origin_func) # pyright: ignore[reportAny]
3130
assert asyncio.iscoroutinefunction(test_pytest_async)
32-
assert test_pytest_async._metainfo["mode"] == "groupwise"
33-
assert test_pytest_async._metainfo["max_rollout_concurrency"] == 5
34-
assert test_pytest_async._metainfo["max_evaluation_concurrency"] == 10
31+
assert test_pytest_async._metainfo["mode"] == "groupwise" # pyright: ignore[reportAny, reportFunctionMemberAccess]
32+
assert test_pytest_async._metainfo["max_rollout_concurrency"] == 5 # pyright: ignore[reportAny, reportFunctionMemberAccess]
33+
assert test_pytest_async._metainfo["max_evaluation_concurrency"] == 10 # pyright: ignore[reportAny, reportFunctionMemberAccess]

tests/pytest/test_pytest_async.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import asyncio
2-
from typing import List
3-
41
import pytest
52

63
from eval_protocol.models import EvaluationRow, Message
74
from eval_protocol.pytest import evaluation_test
8-
from examples.math_example.main import evaluate as math_evaluate
95

106

117
@evaluation_test(
@@ -20,7 +16,7 @@
2016
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}],
2117
mode="all",
2218
)
23-
async def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
19+
async def test_pytest_async(rows: list[EvaluationRow]) -> list[EvaluationRow]:
2420
"""Run math evaluation on sample dataset using pytest interface."""
2521
return rows
2622

@@ -51,7 +47,7 @@ async def test_pytest_async_main():
5147
],
5248
)
5349
]
54-
result = await test_pytest_async(rows)
50+
result = await test_pytest_async(rows) # pyright: ignore[reportGeneralTypeIssues, reportUnknownVariableType, reportArgumentType, reportCallIssue]
5551
assert result == rows
5652

5753

@@ -65,5 +61,5 @@ async def test_pytest_async_pointwise_main():
6561
Message(role="user", content="What is the capital of France?"),
6662
],
6763
)
68-
result = await test_pytest_async_pointwise(row)
64+
result = await test_pytest_async_pointwise(row) # pyright: ignore[reportGeneralTypeIssues, reportArgumentType, reportUnknownVariableType, reportCallIssue]
6965
assert result == row

0 commit comments

Comments
 (0)