From 4b4c5c2bbcc7fea48f6b6f97c9e9201515d3f2d7 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 21 Aug 2025 12:37:49 -0700 Subject: [PATCH] accept in memory rows as input --- eval_protocol/pytest/evaluation_test.py | 27 ++++++++++++++++++---- eval_protocol/pytest/types.py | 1 + eval_protocol/pytest/utils.py | 30 +++++++++++++++++++------ tests/pytest/test_pytest_input_rows.py | 15 +++++++++++++ 4 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 tests/pytest/test_pytest_input_rows.py diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 77292a3f..ce5f8817 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -36,6 +36,7 @@ EvaluationInputParam, EvaluationTestMode, InputMessagesParam, + InputRowsParam, ModelParam, RolloutProcessorConfig, RolloutProcessorInputParam, @@ -238,6 +239,7 @@ def evaluation_test( # noqa: C901 completion_params: List[CompletionParams], input_messages: Optional[List[InputMessagesParam]] = None, input_dataset: Optional[List[DatasetPathParam]] = None, + input_rows: Optional[List[InputRowsParam]] = None, dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter, rollout_processor: RolloutProcessor = NoOpRolloutProcessor(), evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None, @@ -299,6 +301,9 @@ def evaluation_test( # noqa: C901 input_dataset: Paths to JSONL datasets. This is useful if you have a dataset already. Provide a dataset_adapter to convert the input dataset to a list of EvaluationRows if you have a custom dataset format. + input_rows: Pre-constructed EvaluationRow objects to use directly. This is useful + when you want to provide EvaluationRow objects with custom metadata, input_messages, + or other fields already populated. Will be passed as "input_dataset" to the test function. dataset_adapter: Function to convert the input dataset to a list of EvaluationRows. This is useful if you have a custom dataset format. completion_params: Generation parameters for the rollout. @@ -413,26 +418,33 @@ async def execute_with_params( # Calculate all possible combinations of parameters if mode == "groupwise": combinations = generate_parameter_combinations( - input_dataset, None, input_messages, evaluation_test_kwargs, max_dataset_rows, combine_datasets + input_dataset, + None, + input_messages, + input_rows, + evaluation_test_kwargs, + max_dataset_rows, + combine_datasets, ) else: combinations = generate_parameter_combinations( input_dataset, completion_params, input_messages, + input_rows, evaluation_test_kwargs, max_dataset_rows, combine_datasets, ) if len(combinations) == 0: raise ValueError( - "No combinations of parameters were found. Please provide at least a model and one of input_dataset or input_messages." + "No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows." ) # Create parameter tuples for pytest.mark.parametrize param_tuples = [] for combo in combinations: - dataset, cp, messages, etk = combo + dataset, cp, messages, rows, etk = combo param_tuple = [] if input_dataset is not None: param_tuple.append(dataset) @@ -440,6 +452,8 @@ async def execute_with_params( param_tuple.append(cp) if input_messages is not None: param_tuple.append(messages) + if input_rows is not None: + param_tuple.append(rows) if evaluation_test_kwargs is not None: param_tuple.append(etk) param_tuples.append(tuple(param_tuple)) @@ -452,6 +466,8 @@ async def execute_with_params( test_param_names.append("completion_params") if input_messages is not None: test_param_names.append("input_messages") + if input_rows is not None: + test_param_names.append("input_rows") if evaluation_test_kwargs is not None: test_param_names.append("evaluation_test_kwargs") @@ -500,8 +516,11 @@ def _log_eval_error( else: # Multiple rows: list of List[Message] data = [EvaluationRow(messages=m) for m in im] + elif "input_rows" in kwargs and kwargs["input_rows"] is not None: + # Use pre-constructed EvaluationRow objects directly + data = kwargs["input_rows"] else: - raise ValueError("No input dataset or input messages provided") + raise ValueError("No input dataset, input messages, or input rows provided") for row in data: # generate a stable row_id for each row diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index b563b770..cd6c7687 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -15,6 +15,7 @@ ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct DatasetPathParam = str InputMessagesParam = List[Message] +InputRowsParam = List[EvaluationRow] EvaluationInputParam = Dict[str, Any] RolloutProcessorInputParam = Dict[str, Any] diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 6b709782..57af60b9 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -13,6 +13,7 @@ DatasetPathParam, EvaluationInputParam, InputMessagesParam, + InputRowsParam, RolloutProcessorConfig, ) from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config @@ -166,6 +167,7 @@ def generate_parameter_combinations( input_dataset: Optional[List[DatasetPathParam]], completion_params: List[CompletionParams], input_messages: Optional[List[InputMessagesParam]], + input_rows: Optional[List[InputRowsParam]], evaluation_test_kwargs: Optional[List[EvaluationInputParam]], max_dataset_rows: Optional[int], combine_datasets: bool, @@ -177,6 +179,7 @@ def generate_parameter_combinations( input_dataset: Dataset paths to use completion_params: Completion parameters to test input_messages: Input messages to use + input_rows: Pre-constructed EvaluationRow objects to use evaluation_test_kwargs: Additional kwargs for evaluation tests max_dataset_rows: Maximum number of dataset rows to process combine_datasets: Whether to combine multiple datasets into one test @@ -217,6 +220,18 @@ def generate_parameter_combinations( else: messages = [None] # type: ignore + # Handle input_rows - similar to input_messages, apply max_dataset_rows if specified + if input_rows is not None and isinstance(input_rows, list): + effective_max_rows = parse_ep_max_rows(max_dataset_rows) + if effective_max_rows is not None: + sliced_rows = input_rows[:effective_max_rows] # type: ignore + else: + sliced_rows = input_rows # type: ignore + # Wrap as a single parameter payload + rows = [sliced_rows] # type: ignore + else: + rows = [None] # type: ignore + kwargs: List[Optional[EvaluationInputParam]] = ( evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] ) # type: ignore @@ -225,13 +240,14 @@ def generate_parameter_combinations( for ds in datasets: for cp in cps: for im in messages: - for etk in kwargs: - # if no dataset and no messages, raise an error - if ds is None and im is None: - raise ValueError( - "No dataset or messages provided. Please provide at least one of input_dataset or input_messages." - ) - combinations.append((ds, cp, im, etk)) + for ir in rows: + for etk in kwargs: + # if no dataset, no messages, and no rows, raise an error + if ds is None and im is None and ir is None: + raise ValueError( + "No dataset, messages, or rows provided. Please provide at least one of input_dataset, input_messages, or input_rows." + ) + combinations.append((ds, cp, im, ir, etk)) return combinations diff --git a/tests/pytest/test_pytest_input_rows.py b/tests/pytest/test_pytest_input_rows.py new file mode 100644 index 00000000..e382ee9e --- /dev/null +++ b/tests/pytest/test_pytest_input_rows.py @@ -0,0 +1,15 @@ +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor + + +@evaluation_test( + input_rows=[EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])], + completion_params=[{"model": "no-op"}], + rollout_processor=NoOpRolloutProcessor(), + mode="pointwise", +) +def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow: + """Run math evaluation on sample dataset using pytest interface.""" + assert row.messages[0].content == "What is the capital of France?" + return row