11from collections .abc import Sequence
22from inspect import Signature
3+ from typing import get_origin , get_args
34
45from eval_protocol .models import CompletionParams , EvaluationRow
56from eval_protocol .pytest .types import EvaluationTestMode
67
78
9+ def _is_list_of_evaluation_row (annotation ) -> bool : # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
10+ """Check if annotation is list[EvaluationRow] or equivalent."""
11+ origin = get_origin (annotation ) # pyright: ignore[reportUnknownArgumentType, reportAny]
12+ if origin is not list :
13+ return False
14+
15+ args = get_args (annotation )
16+ if len (args ) != 1 :
17+ return False
18+
19+ # Check if the single argument is EvaluationRow or equivalent
20+ arg = args [0 ] # pyright: ignore[reportAny]
21+ return arg is EvaluationRow or str (arg ) == str (EvaluationRow ) # pyright: ignore[reportAny]
22+
23+
824def validate_signature (
925 signature : Signature , mode : EvaluationTestMode , completion_params : Sequence [CompletionParams | None ] | None
1026) -> None :
@@ -29,11 +45,13 @@ def validate_signature(
2945 raise ValueError ("In groupwise mode, your eval function must have a parameter named 'rows'" )
3046
3147 # validate that "Rows" is of type List[EvaluationRow]
32- if signature .parameters ["rows" ].annotation is not list [EvaluationRow ]: # pyright: ignore[reportAny]
33- raise ValueError ("In groupwise mode, the 'rows' parameter must be of type List[EvaluationRow" )
48+ if not _is_list_of_evaluation_row (signature .parameters ["rows" ].annotation ): # pyright: ignore[reportAny]
49+ raise ValueError (
50+ f"In groupwise mode, the 'rows' parameter must be of type List[EvaluationRow]. Got { str (signature .parameters ['rows' ].annotation )} instead" # pyright: ignore[reportAny]
51+ )
3452
3553 # validate that the function has a return type of List[EvaluationRow]
36- if signature .return_annotation is not list [ EvaluationRow ] : # pyright: ignore[reportAny]
54+ if not _is_list_of_evaluation_row ( signature .return_annotation ) : # pyright: ignore[reportAny]
3755 raise ValueError ("In groupwise mode, your eval function must return a list of EvaluationRow instances" )
3856 if completion_params is not None and len (completion_params ) < 2 :
3957 raise ValueError ("In groupwise mode, you must provide at least 2 completion parameters" )
@@ -43,9 +61,11 @@ def validate_signature(
4361 raise ValueError ("In all mode, your eval function must have a parameter named 'rows'" )
4462
4563 # validate that "Rows" is of type List[EvaluationRow]
46- if signature .parameters ["rows" ].annotation is not list [EvaluationRow ]: # pyright: ignore[reportAny]
47- raise ValueError ("In all mode, the 'rows' parameter must be of type List[EvaluationRow" )
64+ if not _is_list_of_evaluation_row (signature .parameters ["rows" ].annotation ): # pyright: ignore[reportAny]
65+ raise ValueError (
66+ f"In all mode, the 'rows' parameter must be of type list[EvaluationRow]. Got { str (signature .parameters ['rows' ].annotation )} instead" # pyright: ignore[reportAny]
67+ )
4868
4969 # validate that the function has a return type of List[EvaluationRow]
50- if signature .return_annotation is not list [ EvaluationRow ] : # pyright: ignore[reportAny]
70+ if not _is_list_of_evaluation_row ( signature .return_annotation ) : # pyright: ignore[reportAny]
5171 raise ValueError ("In all mode, your eval function must return a list of EvaluationRow instances" )
0 commit comments