Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
EvaluationInputParam,
EvaluationTestMode,
InputMessagesParam,
InputRowsParam,
ModelParam,
RolloutProcessorConfig,
RolloutProcessorInputParam,
Expand Down Expand Up @@ -238,6 +239,7 @@ def evaluation_test( # noqa: C901
completion_params: List[CompletionParams],
input_messages: Optional[List[InputMessagesParam]] = None,
input_dataset: Optional[List[DatasetPathParam]] = None,
input_rows: Optional[List[InputRowsParam]] = None,
dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter,
rollout_processor: RolloutProcessor = NoOpRolloutProcessor(),
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
Expand Down Expand Up @@ -299,6 +301,9 @@ def evaluation_test( # noqa: C901
input_dataset: Paths to JSONL datasets. This is useful if you have a
dataset already. Provide a dataset_adapter to convert the input dataset
to a list of EvaluationRows if you have a custom dataset format.
input_rows: Pre-constructed EvaluationRow objects to use directly. This is useful
when you want to provide EvaluationRow objects with custom metadata, input_messages,
or other fields already populated. Will be passed as "input_dataset" to the test function.
dataset_adapter: Function to convert the input dataset to a list of
EvaluationRows. This is useful if you have a custom dataset format.
completion_params: Generation parameters for the rollout.
Expand Down Expand Up @@ -413,33 +418,42 @@ async def execute_with_params(
# Calculate all possible combinations of parameters
if mode == "groupwise":
combinations = generate_parameter_combinations(
input_dataset, None, input_messages, evaluation_test_kwargs, max_dataset_rows, combine_datasets
input_dataset,
None,
input_messages,
input_rows,
evaluation_test_kwargs,
max_dataset_rows,
combine_datasets,
)
else:
combinations = generate_parameter_combinations(
input_dataset,
completion_params,
input_messages,
input_rows,
evaluation_test_kwargs,
max_dataset_rows,
combine_datasets,
)
if len(combinations) == 0:
raise ValueError(
"No combinations of parameters were found. Please provide at least a model and one of input_dataset or input_messages."
"No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows."
)

# Create parameter tuples for pytest.mark.parametrize
param_tuples = []
for combo in combinations:
dataset, cp, messages, etk = combo
dataset, cp, messages, rows, etk = combo
param_tuple = []
if input_dataset is not None:
param_tuple.append(dataset)
if completion_params is not None:
param_tuple.append(cp)
if input_messages is not None:
param_tuple.append(messages)
if input_rows is not None:
param_tuple.append(rows)
if evaluation_test_kwargs is not None:
param_tuple.append(etk)
param_tuples.append(tuple(param_tuple))
Expand All @@ -452,6 +466,8 @@ async def execute_with_params(
test_param_names.append("completion_params")
if input_messages is not None:
test_param_names.append("input_messages")
if input_rows is not None:
test_param_names.append("input_rows")
if evaluation_test_kwargs is not None:
test_param_names.append("evaluation_test_kwargs")

Expand Down Expand Up @@ -500,8 +516,11 @@ def _log_eval_error(
else:
# Multiple rows: list of List[Message]
data = [EvaluationRow(messages=m) for m in im]
elif "input_rows" in kwargs and kwargs["input_rows"] is not None:
# Use pre-constructed EvaluationRow objects directly
data = kwargs["input_rows"]
else:
raise ValueError("No input dataset or input messages provided")
raise ValueError("No input dataset, input messages, or input rows provided")

for row in data:
# generate a stable row_id for each row
Expand Down
1 change: 1 addition & 0 deletions eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
DatasetPathParam = str
InputMessagesParam = List[Message]
InputRowsParam = List[EvaluationRow]
EvaluationInputParam = Dict[str, Any]
RolloutProcessorInputParam = Dict[str, Any]

Expand Down
30 changes: 23 additions & 7 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DatasetPathParam,
EvaluationInputParam,
InputMessagesParam,
InputRowsParam,
RolloutProcessorConfig,
)
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config
Expand Down Expand Up @@ -166,6 +167,7 @@ def generate_parameter_combinations(
input_dataset: Optional[List[DatasetPathParam]],
completion_params: List[CompletionParams],
input_messages: Optional[List[InputMessagesParam]],
input_rows: Optional[List[InputRowsParam]],
evaluation_test_kwargs: Optional[List[EvaluationInputParam]],
max_dataset_rows: Optional[int],
combine_datasets: bool,
Expand All @@ -177,6 +179,7 @@ def generate_parameter_combinations(
input_dataset: Dataset paths to use
completion_params: Completion parameters to test
input_messages: Input messages to use
input_rows: Pre-constructed EvaluationRow objects to use
evaluation_test_kwargs: Additional kwargs for evaluation tests
max_dataset_rows: Maximum number of dataset rows to process
combine_datasets: Whether to combine multiple datasets into one test
Expand Down Expand Up @@ -217,6 +220,18 @@ def generate_parameter_combinations(
else:
messages = [None] # type: ignore

# Handle input_rows - similar to input_messages, apply max_dataset_rows if specified
if input_rows is not None and isinstance(input_rows, list):
effective_max_rows = parse_ep_max_rows(max_dataset_rows)
if effective_max_rows is not None:
sliced_rows = input_rows[:effective_max_rows] # type: ignore
else:
sliced_rows = input_rows # type: ignore
# Wrap as a single parameter payload
rows = [sliced_rows] # type: ignore
else:
rows = [None] # type: ignore

kwargs: List[Optional[EvaluationInputParam]] = (
evaluation_test_kwargs if evaluation_test_kwargs is not None else [None]
) # type: ignore
Expand All @@ -225,13 +240,14 @@ def generate_parameter_combinations(
for ds in datasets:
for cp in cps:
for im in messages:
for etk in kwargs:
# if no dataset and no messages, raise an error
if ds is None and im is None:
raise ValueError(
"No dataset or messages provided. Please provide at least one of input_dataset or input_messages."
)
combinations.append((ds, cp, im, etk))
for ir in rows:
for etk in kwargs:
# if no dataset, no messages, and no rows, raise an error
if ds is None and im is None and ir is None:
raise ValueError(
"No dataset, messages, or rows provided. Please provide at least one of input_dataset, input_messages, or input_rows."
)
combinations.append((ds, cp, im, ir, etk))

return combinations

Expand Down
15 changes: 15 additions & 0 deletions tests/pytest/test_pytest_input_rows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor


@evaluation_test(
input_rows=[EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])],
completion_params=[{"model": "no-op"}],
rollout_processor=NoOpRolloutProcessor(),
mode="pointwise",
)
def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow:
"""Run math evaluation on sample dataset using pytest interface."""
assert row.messages[0].content == "What is the capital of France?"
return row
Loading