Skip to content

Commit 10a4381

Browse files
committed
fix live bench and rollout processor
1 parent 971d6e4 commit 10a4381

File tree

3 files changed

+47
-38
lines changed

3 files changed

+47
-38
lines changed

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,29 @@ def _extract_abcd_letter(text: str) -> str | None:
5656

5757
_GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv()
5858

59+
def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
60+
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
61+
62+
63+
async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]:
64+
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor."""
65+
processed: List[EvaluationRow] = []
66+
for r in rows:
67+
gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
68+
if gt_tokens:
69+
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
70+
r.ground_truth = gt_val
71+
r.messages = [m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
72+
processed.append(r)
73+
return await default_single_turn_rollout_processor(processed, config)
74+
5975

6076
@export_benchmark("gpqa")
6177
@evaluation_test(
6278
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
6379
input_messages=_GPQA_INPUT_MESSAGES,
6480
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
65-
rollout_processor=default_single_turn_rollout_processor,
81+
rollout_processor=gpqa_strip_gt_rollout_processor,
6682
aggregation_method="mean",
6783
passed_threshold=None,
6884
num_runs=8,
@@ -73,9 +89,8 @@ def gpqa_pointwise(row: EvaluationRow) -> EvaluationRow:
7389
content = assistant_msgs[-1].content if assistant_msgs else ""
7490

7591
pred = _extract_abcd_letter(content or "")
76-
# Retrieve GT from the trailing system message we appended
77-
gt_tokens = [m.content for m in row.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
78-
gt = gt_tokens[-1].split(":", 1)[1].strip() if gt_tokens else None
92+
# GPQA diamond CSV constructs options so that the correct answer is always A
93+
gt = "A"
7994

8095
is_valid = pred is not None and gt in {"A", "B", "C", "D"}
8196
score = 1.0 if (is_valid and pred == gt) else 0.0

eval_protocol/benchmarks/suites/livebench_data_analysis.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def _read_jsonl_table_from_text(text: str, header_cols: List[str]):
315315
SYSTEM_PROMPT = "You are a helpful data analyst. Read the task and answer precisely."
316316

317317

318-
def _load_livebench_da_messages(task_name: str) -> List[List[Message]]:
318+
def _load_livebench_da_messages(task_name: str) -> List[EvaluationRow]:
319319
try:
320320
from datasets import load_dataset # type: ignore
321321
except Exception as e: # pragma: no cover
@@ -324,58 +324,57 @@ def _load_livebench_da_messages(task_name: str) -> List[List[Message]]:
324324
) from e
325325

326326
ds = load_dataset("livebench/data_analysis", split="test")
327-
rows: List[List[Message]] = []
327+
rows: List[EvaluationRow] = []
328328
for ex in ds:
329329
if str(ex.get("task", "")) != task_name:
330330
continue
331331
question_text = str(ex.get("turns", [""])[0])
332332
ground_truth = ex.get("ground_truth")
333+
release = ex.get("livebench_release_date", "")
333334
try:
334-
gt_json = json.dumps({
335-
"ground_truth": ground_truth,
336-
"release": ex.get("livebench_release_date", ""),
337-
}, ensure_ascii=False)
335+
gt_payload = json.dumps({"ground_truth": ground_truth, "release": release}, ensure_ascii=False)
338336
except TypeError:
339-
# Some rows may include non-serializable types; fall back to string cast
340-
gt_json = json.dumps({"ground_truth": str(ground_truth), "release": str(ex.get("livebench_release_date", ""))})
337+
gt_payload = json.dumps({"ground_truth": str(ground_truth), "release": str(release)})
341338
rows.append(
342-
[
343-
Message(role="system", content=SYSTEM_PROMPT),
344-
Message(role="user", content=question_text),
345-
Message(role="system", content=f"__GT__:{gt_json}"),
346-
]
339+
EvaluationRow(
340+
messages=[
341+
Message(role="system", content=SYSTEM_PROMPT),
342+
Message(role="user", content=question_text),
343+
],
344+
ground_truth=gt_payload,
345+
)
347346
)
348347
if not rows:
349348
raise RuntimeError(f"No rows found for LiveBench data_analysis task '{task_name}'")
350349
return rows
351350

352351

353352
def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
354-
gt_tokens = [
355-
m.content
356-
for m in row.messages
357-
if m.role == "system" and (m.content or "").startswith("__GT__:")
358-
]
359-
if not gt_tokens:
353+
# For LiveBench Data Analysis, we fetch the ground truth from the HF dataset
354+
# and store it in the top-level ground_truth field in the adapter below.
355+
# Here, just parse row.ground_truth if it contains a JSON payload, else string.
356+
if row.ground_truth is None:
360357
return {"ground_truth": None, "release": None}
361358
try:
362-
payload = json.loads(gt_tokens[-1].split(":", 1)[1])
363-
return payload if isinstance(payload, dict) else {"ground_truth": payload, "release": None}
359+
payload = json.loads(row.ground_truth)
360+
if isinstance(payload, dict):
361+
return payload
364362
except Exception:
365-
return {"ground_truth": gt_tokens[-1].split(":", 1)[1], "release": None}
363+
pass
364+
return {"ground_truth": row.ground_truth, "release": None}
366365

367366

368367
# -------------------------
369368
# CTA
370369
# -------------------------
371370

372-
_CTA_MESSAGES = _load_livebench_da_messages("cta")
371+
_CTA_ROWS = _load_livebench_da_messages("cta")
373372

374373

375374
@export_benchmark("live_bench/data_analysis/cta")
376375
@evaluation_test(
377376
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
378-
input_messages=_CTA_MESSAGES,
377+
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
379378
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
380379
rollout_processor=default_single_turn_rollout_processor,
381380
aggregation_method="mean",
@@ -412,13 +411,13 @@ def livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
412411
# Table Join
413412
# -------------------------
414413

415-
_TABLEJOIN_MESSAGES = _load_livebench_da_messages("tablejoin")
414+
_TABLEJOIN_ROWS = _load_livebench_da_messages("tablejoin")
416415

417416

418417
@export_benchmark("live_bench/data_analysis/tablejoin")
419418
@evaluation_test(
420419
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
421-
input_messages=_TABLEJOIN_MESSAGES,
420+
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
422421
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
423422
rollout_processor=default_single_turn_rollout_processor,
424423
aggregation_method="mean",
@@ -456,13 +455,13 @@ def livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
456455
# Table Reformat
457456
# -------------------------
458457

459-
_TABLEREFORMAT_MESSAGES = _load_livebench_da_messages("tablereformat")
458+
_TABLEREFORMAT_ROWS = _load_livebench_da_messages("tablereformat")
460459

461460

462461
@export_benchmark("live_bench/data_analysis/tablereformat")
463462
@evaluation_test(
464463
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
465-
input_messages=_TABLEREFORMAT_MESSAGES,
464+
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
466465
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
467466
rollout_processor=default_single_turn_rollout_processor,
468467
aggregation_method="mean",

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3131
if len(row.messages) == 0:
3232
raise ValueError("Messages is empty. Please provide a non-empty dataset")
3333

34-
# Filter out any sentinel ground-truth system messages (e.g., "__GT__:") before sending to the model
35-
messages_payload = [
36-
{"role": m.role, "content": m.content}
37-
for m in row.messages
38-
if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
39-
]
34+
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
4035

4136
request_params = {"model": config.model, "messages": messages_payload, **config.input_params}
4237
# Ensure caching is disabled only for this request (review feedback)

0 commit comments

Comments
 (0)