|
1 | | -import asyncio |
2 | 1 | import inspect |
3 | 2 | from typing import Any, Callable, Dict, List, Optional |
4 | 3 |
|
5 | 4 | import pytest |
6 | 5 |
|
| 6 | +from eval_protocol.models import EvaluationRow |
7 | 7 | from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor |
8 | 8 | from eval_protocol.pytest.types import ( |
9 | 9 | Dataset, |
|
16 | 16 | RolloutProcessorConfig, |
17 | 17 | TestFunction, |
18 | 18 | ) |
| 19 | +from eval_protocol.pytest.utils import aggregate, create_dynamically_parameterized_wrapper, execute_function |
19 | 20 |
|
20 | 21 | from ..common_utils import load_jsonl |
21 | | -from ..models import EvaluateResult, EvaluationRow |
22 | | - |
23 | | - |
24 | | -def _execute_function(func: Callable, **kwargs) -> Any: |
25 | | - """ |
26 | | - Execute a function with proper async handling. |
27 | | -
|
28 | | - This is a pure function that handles both async and non-async function execution |
29 | | - with proper event loop management for async functions. |
30 | | -
|
31 | | - Args: |
32 | | - func: The function to execute |
33 | | - **kwargs: Arguments to pass to the function |
34 | | -
|
35 | | - Returns: |
36 | | - The result of the function execution |
37 | | - """ |
38 | | - is_async = asyncio.iscoroutinefunction(func) |
39 | | - if is_async: |
40 | | - # Handle async functions with proper event loop management |
41 | | - try: |
42 | | - loop = asyncio.get_event_loop() |
43 | | - if not loop.is_closed(): |
44 | | - # Use existing loop |
45 | | - task = loop.create_task(func(**kwargs)) |
46 | | - results = loop.run_until_complete(task) |
47 | | - else: |
48 | | - # Loop is closed, create a new one |
49 | | - results = asyncio.run(func(**kwargs)) |
50 | | - except RuntimeError: |
51 | | - # No event loop or other issues, create a new one |
52 | | - results = asyncio.run(func(**kwargs)) |
53 | | - else: |
54 | | - results = func(**kwargs) |
55 | | - return results |
56 | | - |
57 | | - |
58 | | -def evaluate( |
59 | | - rows: List[EvaluationRow], reward_fn: Callable[..., EvaluateResult], **kwargs: Any |
60 | | -) -> List[EvaluationRow]: |
61 | | - """Apply a reward function to each row and attach the result.""" |
62 | | - evaluated: List[EvaluationRow] = [] |
63 | | - for row in rows: |
64 | | - result = reward_fn(messages=row.messages, ground_truth=row.ground_truth, **kwargs) |
65 | | - row.evaluation_result = result |
66 | | - evaluated.append(row) |
67 | | - return evaluated |
68 | | - |
69 | | - |
70 | | -def _aggregate(scores: List[float], method: str) -> float: |
71 | | - if not scores: |
72 | | - return 0.0 |
73 | | - if method == "mean": |
74 | | - return sum(scores) / len(scores) |
75 | | - if method == "max": |
76 | | - return max(scores) |
77 | | - if method == "min": |
78 | | - return min(scores) |
79 | | - raise ValueError(f"Unknown aggregation method: {method}") |
80 | | - |
81 | | - |
82 | | -def _create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names): |
83 | | - """ |
84 | | - Creates a wrapper function with dynamic parameters for pytest parameterization. |
85 | | -
|
86 | | - This function takes a test function and creates a wrapper that: |
87 | | - 1. Preserves the original function's metadata using functools.wraps |
88 | | - 2. Creates a new function signature with the specified parameter names that maps to pytest.mark.parametrize decorator |
89 | | - 3. Returns a callable that can be used with pytest.mark.parametrize |
90 | | -
|
91 | | - The function signature is dynamically created to match the parameter names expected by |
92 | | - pytest.mark.parametrize, ensuring that pytest can properly map the test parameters |
93 | | - to the function arguments. |
94 | | -
|
95 | | - Args: |
96 | | - test_func: The original test function to wrap |
97 | | - wrapper_body: The function body that contains the actual test logic |
98 | | - test_param_names: List of parameter names for the dynamic signature |
99 | | -
|
100 | | - Returns: |
101 | | - A wrapper function with the specified parameter signature that calls wrapper_body |
102 | | - """ |
103 | | - from functools import wraps |
104 | | - |
105 | | - @wraps(test_func) |
106 | | - def wrapper(**kwargs): |
107 | | - return wrapper_body(**kwargs) |
108 | | - |
109 | | - parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names] |
110 | | - wrapper.__signature__ = inspect.Signature(parameters) |
111 | | - |
112 | | - return wrapper |
113 | 22 |
|
114 | 23 |
|
115 | 24 | def evaluation_test( |
@@ -193,9 +102,6 @@ def test_func(model_name: str, input_messages: List[List[Message]]): |
193 | 102 | def decorator( |
194 | 103 | test_func: TestFunction, |
195 | 104 | ): |
196 | | - # Check if the function is async |
197 | | - is_async = inspect.iscoroutinefunction(test_func) |
198 | | - |
199 | 105 | sig = inspect.signature(test_func) |
200 | 106 |
|
201 | 107 | # For pointwise/rowwise mode, we expect a different signature |
@@ -240,7 +146,7 @@ def execute_with_params( |
240 | 146 | kwargs["model"] = model |
241 | 147 | if row is not None: |
242 | 148 | kwargs["row"] = row |
243 | | - return _execute_function(test_func, **kwargs) |
| 149 | + return execute_function(test_func, **kwargs) |
244 | 150 |
|
245 | 151 | # Calculate all possible combinations of parameters |
246 | 152 | def generate_combinations(): |
@@ -315,7 +221,7 @@ def wrapper_body(**kwargs): |
315 | 221 | initial_messages=kwargs.get("input_messages") if "input_messages" in kwargs else [], |
316 | 222 | ) |
317 | 223 | for row in data: |
318 | | - processed: List[EvaluationRow] = _execute_function(rollout_processor, row=row, config=config) |
| 224 | + processed: List[EvaluationRow] = execute_function(rollout_processor, row=row, config=config) |
319 | 225 | input_dataset.extend(processed) |
320 | 226 |
|
321 | 227 | all_results: List[EvaluationRow] = [] |
@@ -361,13 +267,13 @@ def wrapper_body(**kwargs): |
361 | 267 | all_results.extend(results) |
362 | 268 |
|
363 | 269 | scores = [r.evaluation_result.score for r in all_results if r.evaluation_result] |
364 | | - agg_score = _aggregate(scores, aggregation_method) |
| 270 | + agg_score = aggregate(scores, aggregation_method) |
365 | 271 | if threshold_of_success is not None: |
366 | 272 | assert ( |
367 | 273 | agg_score >= threshold_of_success |
368 | 274 | ), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}" |
369 | 275 |
|
370 | | - return _create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names) |
| 276 | + return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names) |
371 | 277 |
|
372 | 278 | wrapper = create_wrapper_with_signature() |
373 | 279 | wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(wrapper) |
|
0 commit comments