Skip to content

Commit 6f69941

Browse files
author
Dylan Huang
committed
fix input_message data type
1 parent d330ee0 commit 6f69941

20 files changed

Lines changed: 119 additions & 97 deletions

eval_protocol/benchmarks/test_gpqa.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import csv
33
import io
44
import re
5-
from typing import List
65

76
import requests
87

@@ -20,12 +19,12 @@
2019
)
2120

2221

23-
def _load_gpqa_messages_from_csv() -> List[List[Message]]:
22+
def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
2423
url = "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
2524
resp = requests.get(url, timeout=60)
2625
resp.raise_for_status()
2726

28-
messages_list: List[List[Message]] = []
27+
messages_list: list[list[Message]] = []
2928
reader = csv.DictReader(io.StringIO(resp.text))
3029
for ex in reader:
3130
q = str(ex.get("Question", ""))
@@ -45,7 +44,7 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]:
4544
)
4645
if not messages_list:
4746
raise RuntimeError("Failed to load GPQA messages: no rows found from source")
48-
return messages_list
47+
return [messages_list]
4948

5049

5150
def _extract_abcd_letter(text: str) -> str | None:
@@ -58,7 +57,7 @@ def _extract_abcd_letter(text: str) -> str | None:
5857
_GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv()
5958

6059

61-
def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
60+
def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
6261
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
6362

6463

@@ -69,9 +68,9 @@ def __init__(self):
6968
super().__init__()
7069
self.single_turn_processor = SingleTurnRolloutProcessor()
7170

72-
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
71+
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
7372
"""Preprocess rows and delegate to SingleTurnRolloutProcessor."""
74-
processed: List[EvaluationRow] = []
73+
processed: list[EvaluationRow] = []
7574

7675
for r in rows:
7776
gt_tokens = [

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
409409

410410
@evaluation_test(
411411
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
412-
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
412+
input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]],
413413
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
414414
rollout_processor=SingleTurnRolloutProcessor(),
415415
aggregation_method="mean",
@@ -451,7 +451,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
451451

452452
@evaluation_test(
453453
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
454-
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
454+
input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]],
455455
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
456456
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
457457
aggregation_method="mean",

eval_protocol/pytest/evaluation_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
def evaluation_test(
6363
*,
6464
completion_params: Sequence[CompletionParams | None] | None = None,
65-
input_messages: Sequence[InputMessagesParam | None] | None = None,
65+
input_messages: Sequence[list[InputMessagesParam] | None] | None = None,
6666
input_dataset: Sequence[DatasetPathParam] | None = None,
6767
input_rows: Sequence[list[EvaluationRow]] | None = None,
6868
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny]
@@ -232,7 +232,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
232232
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
233233
# Support either a single row (List[Message]) or many rows (List[List[Message]])
234234
im = kwargs["input_messages"]
235-
data = [EvaluationRow(messages=im)]
235+
data = [EvaluationRow(messages=dataset_messages) for dataset_messages in im]
236236
elif "input_rows" in kwargs and kwargs["input_rows"] is not None:
237237
# Use pre-constructed EvaluationRow objects directly
238238
data = kwargs["input_rows"]

eval_protocol/pytest/generate_parameter_combinations.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Either a single completion params object or None.
1919
"""
2020

21-
InputMessagesKwarg = InputMessagesParam | None
21+
InputMessagesKwarg = list[InputMessagesParam] | None
2222
InputRowsKwarg = Dataset | None
2323
EvaluationTestKwargs = EvaluationInputParam | None
2424

@@ -47,7 +47,7 @@ class ParameterizedTestKwargs(TypedDict):
4747
def generate_parameter_combinations(
4848
input_dataset: Sequence[DatasetPathParam] | None,
4949
completion_params: Sequence[CompletionParams | None],
50-
input_messages: Sequence[InputMessagesParam | None] | None,
50+
input_messages: Sequence[list[InputMessagesParam] | None] | None,
5151
input_rows: Sequence[list[EvaluationRow] | None] | None,
5252
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
5353
max_dataset_rows: int | None,
@@ -83,11 +83,15 @@ def generate_parameter_combinations(
8383
# Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over
8484
# each row. Instead, pass the entire sliced list through in a single test run
8585
# so summaries aggregate all rows together (AIME-style behavior).
86-
messages: Sequence[InputMessagesParam | None] = [None]
86+
messages: Sequence[list[InputMessagesParam] | None] = [None]
8787
if input_messages is not None:
8888
effective_max_rows = parse_ep_max_rows(max_dataset_rows)
8989
if effective_max_rows is not None:
90-
sliced_messages: Sequence[InputMessagesParam | None] = input_messages[:effective_max_rows]
90+
sliced_messages: Sequence[list[InputMessagesParam] | None] = [
91+
dataset_messages[:effective_max_rows]
92+
for dataset_messages in input_messages
93+
if dataset_messages is not None
94+
]
9195
else:
9296
sliced_messages = input_messages
9397
# Wrap as a single parameter payload

eval_protocol/pytest/parameterize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def pytest_parametrize(
1818
combinations: list[CombinationTuple],
1919
input_dataset: Sequence[DatasetPathParam] | None,
2020
completion_params: Sequence[CompletionParams | None] | None,
21-
input_messages: Sequence[InputMessagesParam | None] | None,
21+
input_messages: Sequence[list[InputMessagesParam] | None] | None,
2222
input_rows: Sequence[list[EvaluationRow]] | None,
2323
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
2424
) -> PytestParametrizeArgs:

eval_protocol/pytest/rollout_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
from abc import ABC, abstractmethod
3-
from typing import List
43

54
from eval_protocol.models import EvaluationRow
65
from eval_protocol.pytest.types import RolloutProcessorConfig
@@ -12,7 +11,7 @@ class RolloutProcessor(ABC):
1211
"""
1312

1413
@abstractmethod
15-
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
14+
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
1615
"""Process evaluation rows and return async tasks. Must be implemented by subclasses."""
1716
pass
1817

examples/healthbench/tests/test_evaluation.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import json
2-
from typing import Dict, List
3-
41
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
52
from eval_protocol.pytest.default_single_turn_rollout_process import (
63
SingleTurnRolloutProcessor,
@@ -34,13 +31,15 @@
3431
},
3532
]
3633

37-
_HB_INPUT_MESSAGES: List[List[Message]] = []
38-
_HB_RUBRICS_MAP: Dict[str, List[Dict]] = {}
34+
_HB_INPUT_MESSAGES: list[list[list[Message]]] = []
35+
_HB_RUBRICS_MAP: dict[str, list[dict]] = {}
3936
for s in _HB_SAMPLES:
4037
_HB_INPUT_MESSAGES.append(
4138
[
42-
Message(role="system", content=SYSTEM_PROMPT),
43-
Message(role="user", content=s["prompt_text"]),
39+
[
40+
Message(role="system", content=SYSTEM_PROMPT),
41+
Message(role="user", content=s["prompt_text"]),
42+
]
4443
]
4544
)
4645
_HB_RUBRICS_MAP[s["prompt_text"]] = s["rubrics"]

tests/chinook/test_pydantic_chinook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@pytest.mark.asyncio
2323
@evaluation_test(
24-
input_messages=[[Message(role="user", content="What is the total number of tracks in the database?")]],
24+
input_messages=[[[Message(role="user", content="What is the total number of tracks in the database?")]]],
2525
completion_params=[
2626
{
2727
"model": {

tests/pytest/test_get_metadata.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
@evaluation_test(
88
input_messages=[
99
[
10-
Message(role="user", content="What is the capital of France?"),
11-
],
12-
[
13-
Message(role="user", content="What is the capital of the moon?"),
14-
],
10+
[
11+
Message(role="user", content="What is the capital of France?"),
12+
],
13+
[
14+
Message(role="user", content="What is the capital of the moon?"),
15+
],
16+
]
1517
],
1618
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}] * 2,
1719
mode="groupwise",

tests/pytest/test_pydantic_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@pytest.mark.asyncio
1313
@evaluation_test(
14-
input_messages=[[Message(role="user", content="Hello, how are you?")]],
14+
input_messages=[[[Message(role="user", content="Hello, how are you?")]]],
1515
completion_params=[
1616
{"model": "accounts/fireworks/models/gpt-oss-120b", "provider": "fireworks"},
1717
],

0 commit comments

Comments
 (0)