Skip to content

Commit 999ed18

Browse files
committed
format
1 parent 44b6f27 commit 999ed18

3 files changed

Lines changed: 27 additions & 11 deletions

File tree

eval_protocol/pytest/evaluation_test.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -729,13 +729,15 @@ async def _collect_result(config, lst, max_retry):
729729
except Exception:
730730
_log_eval_error("error", data if "data" in locals() else None, passed=False)
731731
raise
732+
732733
if asyncio.iscoroutinefunction(test_func):
733734
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
734735
else:
736+
735737
def sync_wrapper_body(**kwargs):
736738
return asyncio.run(wrapper_body(**kwargs))
737-
return create_dynamically_parameterized_wrapper(test_func, sync_wrapper_body, test_param_names)
738739

740+
return create_dynamically_parameterized_wrapper(test_func, sync_wrapper_body, test_param_names)
739741

740742
# Create the pytest wrapper
741743
pytest_wrapper = create_wrapper_with_signature()
@@ -763,6 +765,7 @@ def create_dual_mode_wrapper() -> Callable:
763765
is_async = asyncio.iscoroutinefunction(test_func)
764766

765767
if is_async:
768+
766769
async def dual_mode_wrapper(*args, **kwargs):
767770
# Check if this is a direct call with the expected signature
768771
if mode == "pointwise":
@@ -789,20 +792,30 @@ async def dual_mode_wrapper(*args, **kwargs):
789792

790793
# If not a direct call, use the pytest wrapper
791794
return await pytest_wrapper(*args, **kwargs)
792-
795+
793796
_dual_model_wrapper_fn = dual_mode_wrapper
794797
else:
798+
795799
def dual_mode_wrapper(*args, **kwargs):
796800
if mode == "pointwise":
797801
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs:
798802
return test_func(row=args[0])
799803
else:
800-
if len(args) == 1 and isinstance(args[0], list) and all(isinstance(r, EvaluationRow) for r in args[0]) and not kwargs:
804+
if (
805+
len(args) == 1
806+
and isinstance(args[0], list)
807+
and all(isinstance(r, EvaluationRow) for r in args[0])
808+
and not kwargs
809+
):
801810
return test_func(rows=args[0])
802-
if "rows" in kwargs and isinstance(kwargs["rows"], list) and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]):
811+
if (
812+
"rows" in kwargs
813+
and isinstance(kwargs["rows"], list)
814+
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"])
815+
):
803816
return test_func(**kwargs)
804817
return pytest_wrapper(*args, **kwargs)
805-
818+
806819
_dual_model_wrapper_fn = dual_mode_wrapper
807820

808821
# Copy all attributes from the pytest wrapper to our dual mode wrapper

eval_protocol/pytest/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,12 @@ def create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param
9999
from functools import wraps
100100

101101
if asyncio.iscoroutinefunction(wrapper_body):
102+
102103
@wraps(test_func)
103104
async def wrapper(**kwargs):
104105
return await wrapper_body(**kwargs)
105106
else:
107+
106108
@wraps(test_func)
107109
def wrapper(**kwargs):
108110
return wrapper_body(**kwargs)

tests/pytest/test_direct_run.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
],
1616
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
1717
rollout_processor=SingleTurnRolloutProcessor(),
18-
mode="listwise",
18+
mode="all",
1919
)
2020
def test_direct_run(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2121
"""Run math evaluation on sample dataset using pytest interface."""
@@ -53,7 +53,7 @@ def test_direct_run_main():
5353
],
5454
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
5555
rollout_processor=SingleTurnRolloutProcessor(),
56-
mode="listwise",
56+
mode="all",
5757
)
5858
async def test_direct_run_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
5959
"""Run math evaluation on sample dataset using pytest interface."""
@@ -62,21 +62,22 @@ async def test_direct_run_async(rows: List[EvaluationRow]) -> List[EvaluationRow
6262
return rows
6363

6464

65-
6665
@pytest.mark.asyncio
6766
async def test_direct_run_async_main():
6867
rows = [
6968
EvaluationRow(
7069
messages=[
71-
Message(role="user", content="What is the capital of France?"),
70+
Message(role="user", content="1"),
7271
],
7372
),
7473
EvaluationRow(
7574
messages=[
76-
Message(role="user", content="What is the capital of the moon?"),
75+
Message(role="user", content="2"),
7776
],
7877
),
7978
]
8079
res = await test_direct_run_async(rows)
80+
assert res[0].messages[0].content == "1"
81+
assert res[1].messages[0].content == "2"
8182
assert res[0].evaluation_result.score == 0
82-
assert res[1].evaluation_result.score == 1
83+
assert res[1].evaluation_result.score == 1

0 commit comments

Comments
 (0)