1+ import ast
12import inspect
23from typing import TypedDict , Protocol
34from collections .abc import Callable , Sequence , Iterable , Awaitable
910from eval_protocol .pytest .types import DatasetPathParam , EvaluationInputParam , InputMessagesParam , TestFunction
1011
1112
13+ def _has_pytest_parametrize_with_completion_params (test_func : TestFunction ) -> bool :
14+ """
15+ Check if a test function has a pytest.mark.parametrize decorator with argnames="completion_params".
16+
17+ This function uses inspect.getsource and ast to parse the function's source code and look for
18+ pytest.mark.parametrize decorators that include "completion_params" in their argnames.
19+
20+ Args:
21+ test_func: The test function to analyze
22+
23+ Returns:
24+ True if the function has a pytest.mark.parametrize decorator with "completion_params" in argnames,
25+ False otherwise
26+
27+ Raises:
28+ OSError: If the source code cannot be retrieved (e.g., function is defined in interactive mode)
29+ SyntaxError: If the source code cannot be parsed as valid Python
30+ """
31+ try :
32+ source = inspect .getsource (test_func )
33+ except OSError :
34+ # Function source cannot be retrieved (e.g., defined in interactive mode)
35+ return False
36+
37+ try :
38+ tree = ast .parse (source )
39+ except SyntaxError :
40+ # Source code cannot be parsed
41+ return False
42+
43+ # Walk through the AST to find pytest.mark.parametrize decorators
44+ for node in ast .walk (tree ):
45+ if isinstance (node , ast .FunctionDef ) or isinstance (node , ast .AsyncFunctionDef ):
46+ # Check decorators on this function
47+ for decorator in node .decorator_list :
48+ if _is_pytest_parametrize_with_completion_params (decorator ):
49+ return True
50+
51+ return False
52+
53+
54+ def _is_pytest_parametrize_with_completion_params (decorator : ast .expr ) -> bool :
55+ """
56+ Check if a decorator is pytest.mark.parametrize with "completion_params" in argnames.
57+
58+ Args:
59+ decorator: AST node representing a decorator
60+
61+ Returns:
62+ True if this is a pytest.mark.parametrize decorator with "completion_params" in argnames
63+ """
64+ # Look for pytest.mark.parametrize pattern
65+ if isinstance (decorator , ast .Call ):
66+ # Check if it's pytest.mark.parametrize
67+ if isinstance (decorator .func , ast .Attribute ):
68+ if (
69+ isinstance (decorator .func .value , ast .Attribute )
70+ and isinstance (decorator .func .value .value , ast .Name )
71+ and decorator .func .value .value .id == "pytest"
72+ and decorator .func .value .attr == "mark"
73+ and decorator .func .attr == "parametrize"
74+ ):
75+ # Check positional arguments first (argnames is typically the first positional arg)
76+ if len (decorator .args ) > 0 :
77+ argnames_arg = decorator .args [0 ]
78+ if _check_argnames_for_completion_params (argnames_arg ):
79+ return True
80+
81+ # Check keyword arguments for argnames
82+ for keyword in decorator .keywords :
83+ if keyword .arg == "argnames" :
84+ if _check_argnames_for_completion_params (keyword .value ):
85+ return True
86+
87+ return False
88+
89+
90+ def _check_argnames_for_completion_params (argnames_node : ast .expr ) -> bool :
91+ """
92+ Check if an argnames AST node contains "completion_params".
93+
94+ Args:
95+ argnames_node: AST node representing the argnames value
96+
97+ Returns:
98+ True if argnames contains "completion_params"
99+ """
100+ if isinstance (argnames_node , ast .Constant ):
101+ # Single string case: argnames="completion_params"
102+ if argnames_node .value == "completion_params" :
103+ return True
104+ elif isinstance (argnames_node , ast .List ):
105+ # List case: argnames=["completion_params", ...]
106+ for elt in argnames_node .elts :
107+ if isinstance (elt , ast .Constant ) and elt .value == "completion_params" :
108+ return True
109+ elif isinstance (argnames_node , ast .Tuple ):
110+ # Tuple case: argnames=("completion_params", ...)
111+ for elt in argnames_node .elts :
112+ if isinstance (elt , ast .Constant ) and elt .value == "completion_params" :
113+ return True
114+
115+ return False
116+
117+
12118class PytestMarkParametrizeKwargs (TypedDict ):
13119 argnames : Sequence [str ]
14120 argvalues : Iterable [ParameterSet | Sequence [object ] | object ]
@@ -96,6 +202,7 @@ def generate_id_from_dict(d: dict[str, object], max_length: int = 200) -> str |
96202
97203def pytest_parametrize (
98204 combinations : list [CombinationTuple ],
205+ test_func : TestFunction | None ,
99206 input_dataset : Sequence [DatasetPathParam ] | None ,
100207 completion_params : Sequence [CompletionParams | None ] | None ,
101208 completion_params_provided : bool ,
@@ -112,16 +219,22 @@ def pytest_parametrize(
112219 API.
113220 """
114221
222+ if test_func is not None :
223+ has_pytest_parametrize = _has_pytest_parametrize_with_completion_params (test_func )
224+ else :
225+ has_pytest_parametrize = False
226+
115227 # Create parameter tuples for pytest.mark.parametrize
116228 argnames : list [str ] = []
117229 sig_parameters : list [str ] = []
118230 if input_dataset is not None :
119231 argnames .append ("dataset_path" )
120232 sig_parameters .append ("dataset_path" )
121233 if completion_params is not None :
122- if completion_params_provided :
234+ if completion_params_provided and not has_pytest_parametrize :
123235 argnames .append ("completion_params" )
124- sig_parameters .append ("completion_params" )
236+ if has_pytest_parametrize or completion_params_provided :
237+ sig_parameters .append ("completion_params" )
125238 if input_messages is not None :
126239 argnames .append ("input_messages" )
127240 sig_parameters .append ("input_messages" )
0 commit comments