Skip to content

Commit 94ae1b3

Browse files
author
Dylan Huang
committed
test_import_logs works
1 parent ab6d761 commit 94ae1b3

5 files changed

Lines changed: 119 additions & 5 deletions

File tree

eval_protocol/pytest/evaluation_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def decorator(
205205
# Create parameter tuples for pytest.mark.parametrize
206206
pytest_parametrize_args = pytest_parametrize(
207207
combinations,
208+
test_func,
208209
input_dataset,
209210
completion_params,
210211
completion_params_provided,
@@ -268,7 +269,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
268269
index = abs(index) % (max_index + 1)
269270
row.input_metadata.row_id = generate_id(seed=0, index=index)
270271

271-
completion_params = kwargs["completion_params"]
272+
completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None
272273
# Create eval metadata with test function info and current commit hash
273274
eval_metadata = EvalMetadata(
274275
name=test_func.__name__,

eval_protocol/pytest/generate_parameter_combinations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
]
3232

3333

34-
class ParameterizedTestKwargs(TypedDict):
34+
class ParameterizedTestKwargs(TypedDict, total=False):
3535
"""
3636
These are the type of parameters that can be passed to the generated pytest
3737
function. Every experiment is a unique combination of these parameters.

eval_protocol/pytest/parameterize.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import inspect
23
from typing import TypedDict, Protocol
34
from collections.abc import Callable, Sequence, Iterable, Awaitable
@@ -9,6 +10,111 @@
910
from eval_protocol.pytest.types import DatasetPathParam, EvaluationInputParam, InputMessagesParam, TestFunction
1011

1112

13+
def _has_pytest_parametrize_with_completion_params(test_func: TestFunction) -> bool:
14+
"""
15+
Check if a test function has a pytest.mark.parametrize decorator with argnames="completion_params".
16+
17+
This function uses inspect.getsource and ast to parse the function's source code and look for
18+
pytest.mark.parametrize decorators that include "completion_params" in their argnames.
19+
20+
Args:
21+
test_func: The test function to analyze
22+
23+
Returns:
24+
True if the function has a pytest.mark.parametrize decorator with "completion_params" in argnames,
25+
False otherwise
26+
27+
Raises:
28+
OSError: If the source code cannot be retrieved (e.g., function is defined in interactive mode)
29+
SyntaxError: If the source code cannot be parsed as valid Python
30+
"""
31+
try:
32+
source = inspect.getsource(test_func)
33+
except OSError:
34+
# Function source cannot be retrieved (e.g., defined in interactive mode)
35+
return False
36+
37+
try:
38+
tree = ast.parse(source)
39+
except SyntaxError:
40+
# Source code cannot be parsed
41+
return False
42+
43+
# Walk through the AST to find pytest.mark.parametrize decorators
44+
for node in ast.walk(tree):
45+
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
46+
# Check decorators on this function
47+
for decorator in node.decorator_list:
48+
if _is_pytest_parametrize_with_completion_params(decorator):
49+
return True
50+
51+
return False
52+
53+
54+
def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
55+
"""
56+
Check if a decorator is pytest.mark.parametrize with "completion_params" in argnames.
57+
58+
Args:
59+
decorator: AST node representing a decorator
60+
61+
Returns:
62+
True if this is a pytest.mark.parametrize decorator with "completion_params" in argnames
63+
"""
64+
# Look for pytest.mark.parametrize pattern
65+
if isinstance(decorator, ast.Call):
66+
# Check if it's pytest.mark.parametrize
67+
if isinstance(decorator.func, ast.Attribute):
68+
if (
69+
isinstance(decorator.func.value, ast.Attribute)
70+
and isinstance(decorator.func.value.value, ast.Name)
71+
and decorator.func.value.value.id == "pytest"
72+
and decorator.func.value.attr == "mark"
73+
and decorator.func.attr == "parametrize"
74+
):
75+
# Check positional arguments first (argnames is typically the first positional arg)
76+
if len(decorator.args) > 0:
77+
argnames_arg = decorator.args[0]
78+
if _check_argnames_for_completion_params(argnames_arg):
79+
return True
80+
81+
# Check keyword arguments for argnames
82+
for keyword in decorator.keywords:
83+
if keyword.arg == "argnames":
84+
if _check_argnames_for_completion_params(keyword.value):
85+
return True
86+
87+
return False
88+
89+
90+
def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool:
91+
"""
92+
Check if an argnames AST node contains "completion_params".
93+
94+
Args:
95+
argnames_node: AST node representing the argnames value
96+
97+
Returns:
98+
True if argnames contains "completion_params"
99+
"""
100+
if isinstance(argnames_node, ast.Constant):
101+
# Single string case: argnames="completion_params"
102+
if argnames_node.value == "completion_params":
103+
return True
104+
elif isinstance(argnames_node, ast.List):
105+
# List case: argnames=["completion_params", ...]
106+
for elt in argnames_node.elts:
107+
if isinstance(elt, ast.Constant) and elt.value == "completion_params":
108+
return True
109+
elif isinstance(argnames_node, ast.Tuple):
110+
# Tuple case: argnames=("completion_params", ...)
111+
for elt in argnames_node.elts:
112+
if isinstance(elt, ast.Constant) and elt.value == "completion_params":
113+
return True
114+
115+
return False
116+
117+
12118
class PytestMarkParametrizeKwargs(TypedDict):
13119
argnames: Sequence[str]
14120
argvalues: Iterable[ParameterSet | Sequence[object] | object]
@@ -96,6 +202,7 @@ def generate_id_from_dict(d: dict[str, object], max_length: int = 200) -> str |
96202

97203
def pytest_parametrize(
98204
combinations: list[CombinationTuple],
205+
test_func: TestFunction | None,
99206
input_dataset: Sequence[DatasetPathParam] | None,
100207
completion_params: Sequence[CompletionParams | None] | None,
101208
completion_params_provided: bool,
@@ -112,16 +219,22 @@ def pytest_parametrize(
112219
API.
113220
"""
114221

222+
if test_func is not None:
223+
has_pytest_parametrize = _has_pytest_parametrize_with_completion_params(test_func)
224+
else:
225+
has_pytest_parametrize = False
226+
115227
# Create parameter tuples for pytest.mark.parametrize
116228
argnames: list[str] = []
117229
sig_parameters: list[str] = []
118230
if input_dataset is not None:
119231
argnames.append("dataset_path")
120232
sig_parameters.append("dataset_path")
121233
if completion_params is not None:
122-
if completion_params_provided:
234+
if completion_params_provided and not has_pytest_parametrize:
123235
argnames.append("completion_params")
124-
sig_parameters.append("completion_params")
236+
if has_pytest_parametrize or completion_params_provided:
237+
sig_parameters.append("completion_params")
125238
if input_messages is not None:
126239
argnames.append("input_messages")
127240
sig_parameters.append("input_messages")

eval_protocol/quickstart/llm_judge_openai_responses.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct-0905",
5353
},
5454
],
55-
ids=DefaultParameterIdGenerator.generate_id_from_dict,
5655
)
5756
@evaluation_test(
5857
input_rows=[input_rows],

tests/pytest/test_parameterized_ids.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def test_pytest_parametrize_with_custom_id_generator():
149149
# Test with default generator
150150
result = pytest_parametrize(
151151
combinations=combinations,
152+
test_func=None,
152153
input_dataset=None,
153154
completion_params=[{"model": "gpt-4"}, {"model": "claude-3"}, {"temperature": 0.5}],
154155
completion_params_provided=True,

0 commit comments

Comments
 (0)