Skip to content

Commit 54d9b11

Browse files
author
Dylan Huang
committed
move decorator into its own file
1 parent 1722e0f commit 54d9b11

2 files changed

Lines changed: 102 additions & 100 deletions

File tree

Lines changed: 6 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import asyncio
21
import inspect
32
from typing import Any, Callable, Dict, List, Optional
43

54
import pytest
65

6+
from eval_protocol.models import EvaluationRow
77
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
88
from eval_protocol.pytest.types import (
99
Dataset,
@@ -16,100 +16,9 @@
1616
RolloutProcessorConfig,
1717
TestFunction,
1818
)
19+
from eval_protocol.pytest.utils import aggregate, create_dynamically_parameterized_wrapper, execute_function
1920

2021
from ..common_utils import load_jsonl
21-
from ..models import EvaluateResult, EvaluationRow
22-
23-
24-
def _execute_function(func: Callable, **kwargs) -> Any:
25-
"""
26-
Execute a function with proper async handling.
27-
28-
This is a pure function that handles both async and non-async function execution
29-
with proper event loop management for async functions.
30-
31-
Args:
32-
func: The function to execute
33-
**kwargs: Arguments to pass to the function
34-
35-
Returns:
36-
The result of the function execution
37-
"""
38-
is_async = asyncio.iscoroutinefunction(func)
39-
if is_async:
40-
# Handle async functions with proper event loop management
41-
try:
42-
loop = asyncio.get_event_loop()
43-
if not loop.is_closed():
44-
# Use existing loop
45-
task = loop.create_task(func(**kwargs))
46-
results = loop.run_until_complete(task)
47-
else:
48-
# Loop is closed, create a new one
49-
results = asyncio.run(func(**kwargs))
50-
except RuntimeError:
51-
# No event loop or other issues, create a new one
52-
results = asyncio.run(func(**kwargs))
53-
else:
54-
results = func(**kwargs)
55-
return results
56-
57-
58-
def evaluate(
59-
rows: List[EvaluationRow], reward_fn: Callable[..., EvaluateResult], **kwargs: Any
60-
) -> List[EvaluationRow]:
61-
"""Apply a reward function to each row and attach the result."""
62-
evaluated: List[EvaluationRow] = []
63-
for row in rows:
64-
result = reward_fn(messages=row.messages, ground_truth=row.ground_truth, **kwargs)
65-
row.evaluation_result = result
66-
evaluated.append(row)
67-
return evaluated
68-
69-
70-
def _aggregate(scores: List[float], method: str) -> float:
71-
if not scores:
72-
return 0.0
73-
if method == "mean":
74-
return sum(scores) / len(scores)
75-
if method == "max":
76-
return max(scores)
77-
if method == "min":
78-
return min(scores)
79-
raise ValueError(f"Unknown aggregation method: {method}")
80-
81-
82-
def _create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names):
83-
"""
84-
Creates a wrapper function with dynamic parameters for pytest parameterization.
85-
86-
This function takes a test function and creates a wrapper that:
87-
1. Preserves the original function's metadata using functools.wraps
88-
2. Creates a new function signature with the specified parameter names that maps to pytest.mark.parametrize decorator
89-
3. Returns a callable that can be used with pytest.mark.parametrize
90-
91-
The function signature is dynamically created to match the parameter names expected by
92-
pytest.mark.parametrize, ensuring that pytest can properly map the test parameters
93-
to the function arguments.
94-
95-
Args:
96-
test_func: The original test function to wrap
97-
wrapper_body: The function body that contains the actual test logic
98-
test_param_names: List of parameter names for the dynamic signature
99-
100-
Returns:
101-
A wrapper function with the specified parameter signature that calls wrapper_body
102-
"""
103-
from functools import wraps
104-
105-
@wraps(test_func)
106-
def wrapper(**kwargs):
107-
return wrapper_body(**kwargs)
108-
109-
parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names]
110-
wrapper.__signature__ = inspect.Signature(parameters)
111-
112-
return wrapper
11322

11423

11524
def evaluation_test(
@@ -193,9 +102,6 @@ def test_func(model_name: str, input_messages: List[List[Message]]):
193102
def decorator(
194103
test_func: TestFunction,
195104
):
196-
# Check if the function is async
197-
is_async = inspect.iscoroutinefunction(test_func)
198-
199105
sig = inspect.signature(test_func)
200106

201107
# For pointwise/rowwise mode, we expect a different signature
@@ -240,7 +146,7 @@ def execute_with_params(
240146
kwargs["model"] = model
241147
if row is not None:
242148
kwargs["row"] = row
243-
return _execute_function(test_func, **kwargs)
149+
return execute_function(test_func, **kwargs)
244150

245151
# Calculate all possible combinations of parameters
246152
def generate_combinations():
@@ -315,7 +221,7 @@ def wrapper_body(**kwargs):
315221
initial_messages=kwargs.get("input_messages") if "input_messages" in kwargs else [],
316222
)
317223
for row in data:
318-
processed: List[EvaluationRow] = _execute_function(rollout_processor, row=row, config=config)
224+
processed: List[EvaluationRow] = execute_function(rollout_processor, row=row, config=config)
319225
input_dataset.extend(processed)
320226

321227
all_results: List[EvaluationRow] = []
@@ -361,13 +267,13 @@ def wrapper_body(**kwargs):
361267
all_results.extend(results)
362268

363269
scores = [r.evaluation_result.score for r in all_results if r.evaluation_result]
364-
agg_score = _aggregate(scores, aggregation_method)
270+
agg_score = aggregate(scores, aggregation_method)
365271
if threshold_of_success is not None:
366272
assert (
367273
agg_score >= threshold_of_success
368274
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
369275

370-
return _create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
276+
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
371277

372278
wrapper = create_wrapper_with_signature()
373279
wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(wrapper)

eval_protocol/pytest/utils.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import asyncio
2+
import inspect
3+
from typing import Any, Callable, List
4+
5+
from ..models import EvaluateResult, EvaluationRow
6+
7+
8+
def execute_function(func: Callable, **kwargs) -> Any:
9+
"""
10+
Execute a function with proper async handling.
11+
12+
This is a pure function that handles both async and non-async function execution
13+
with proper event loop management for async functions.
14+
15+
Args:
16+
func: The function to execute
17+
**kwargs: Arguments to pass to the function
18+
19+
Returns:
20+
The result of the function execution
21+
"""
22+
is_async = asyncio.iscoroutinefunction(func)
23+
if is_async:
24+
# Handle async functions with proper event loop management
25+
try:
26+
loop = asyncio.get_event_loop()
27+
if not loop.is_closed():
28+
# Use existing loop
29+
task = loop.create_task(func(**kwargs))
30+
results = loop.run_until_complete(task)
31+
else:
32+
# Loop is closed, create a new one
33+
results = asyncio.run(func(**kwargs))
34+
except RuntimeError:
35+
# No event loop or other issues, create a new one
36+
results = asyncio.run(func(**kwargs))
37+
else:
38+
results = func(**kwargs)
39+
return results
40+
41+
42+
def evaluate(
43+
rows: List[EvaluationRow], reward_fn: Callable[..., EvaluateResult], **kwargs: Any
44+
) -> List[EvaluationRow]:
45+
"""Apply a reward function to each row and attach the result."""
46+
evaluated: List[EvaluationRow] = []
47+
for row in rows:
48+
result = reward_fn(messages=row.messages, ground_truth=row.ground_truth, **kwargs)
49+
row.evaluation_result = result
50+
evaluated.append(row)
51+
return evaluated
52+
53+
54+
def aggregate(scores: List[float], method: str) -> float:
55+
if not scores:
56+
return 0.0
57+
if method == "mean":
58+
return sum(scores) / len(scores)
59+
if method == "max":
60+
return max(scores)
61+
if method == "min":
62+
return min(scores)
63+
raise ValueError(f"Unknown aggregation method: {method}")
64+
65+
66+
def create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names):
67+
"""
68+
Creates a wrapper function with dynamic parameters for pytest parameterization.
69+
70+
This function takes a test function and creates a wrapper that:
71+
1. Preserves the original function's metadata using functools.wraps
72+
2. Creates a new function signature with the specified parameter names that maps to pytest.mark.parametrize decorator
73+
3. Returns a callable that can be used with pytest.mark.parametrize
74+
75+
The function signature is dynamically created to match the parameter names expected by
76+
pytest.mark.parametrize, ensuring that pytest can properly map the test parameters
77+
to the function arguments.
78+
79+
Args:
80+
test_func: The original test function to wrap
81+
wrapper_body: The function body that contains the actual test logic
82+
test_param_names: List of parameter names for the dynamic signature
83+
84+
Returns:
85+
A wrapper function with the specified parameter signature that calls wrapper_body
86+
"""
87+
from functools import wraps
88+
89+
@wraps(test_func)
90+
def wrapper(**kwargs):
91+
return wrapper_body(**kwargs)
92+
93+
parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names]
94+
wrapper.__signature__ = inspect.Signature(parameters)
95+
96+
return wrapper

0 commit comments

Comments
 (0)