Skip to content

Commit dd2c6b6

Browse files
author
Dylan Huang
committed
fix completion params not propagating in parametrized test for input_rows
1 parent ea9ec83 commit dd2c6b6

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
239239
im = kwargs["input_messages"]
240240
data = [EvaluationRow(messages=dataset_messages) for dataset_messages in im]
241241
elif "input_rows" in kwargs and kwargs["input_rows"] is not None:
242-
# Use pre-constructed EvaluationRow objects directly
243-
data = kwargs["input_rows"]
242+
# Deep copy pre-constructed EvaluationRow objects
243+
data = [row.model_copy(deep=True) for row in kwargs["input_rows"]]
244244
else:
245245
raise ValueError("No input dataset, input messages, or input rows provided")
246246

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from eval_protocol.models import EvaluationRow, Message
2+
from eval_protocol.pytest import evaluation_test
3+
4+
5+
@evaluation_test(
6+
completion_params=[{"model": "gpt-4"}, {"model": "gpt-4o"}],
7+
input_rows=[[EvaluationRow(messages=[Message(role="user", content="Hello, how are you?")])]],
8+
evaluation_test_kwargs=[{"seen_models": set()}],
9+
)
10+
def test_pytest_input_rows_parametrized_completion_params(row: EvaluationRow, **kwargs) -> EvaluationRow:
11+
"""Tests that parametrized completion params are working correctly for input_rows"""
12+
seen_models = kwargs["seen_models"]
13+
model = row.input_metadata.completion_params["model"]
14+
if len(seen_models) == 1:
15+
# assert that the other model was seen
16+
if model == "gpt-4":
17+
assert "gpt-4o" in seen_models
18+
else:
19+
assert "gpt-4" in seen_models
20+
seen_models.add(model)
21+
return row

0 commit comments

Comments
 (0)