Skip to content

Commit d2992d9

Browse files
committed
LiveBench Fix
1 parent de246c6 commit d2992d9

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import json
23
import re
34
from typing import Any, Dict, List, Optional
@@ -7,6 +8,8 @@
78
SingleTurnRolloutProcessor,
89
)
910
from eval_protocol.pytest.evaluation_test import evaluation_test
11+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
12+
from eval_protocol.pytest.types import RolloutProcessorConfig
1013

1114
# -------------------------
1215
# Lightweight ports of LiveBench scoring utilities for data_analysis tasks
@@ -306,6 +309,41 @@ def _read_jsonl_table_from_text(text: str, header_cols: List[str]):
306309
return 1
307310

308311

312+
# -------------------------
313+
# Custom Rollout Processor to preserve ground truth
314+
# -------------------------
315+
316+
317+
class LiveBenchGroundTruthRolloutProcessor(RolloutProcessor):
318+
"""Rollout processor that preserves ground truth data from pre-loaded datasets."""
319+
320+
def __init__(self, task_rows: List[EvaluationRow]):
321+
super().__init__()
322+
self.single_turn_processor = SingleTurnRolloutProcessor()
323+
# Create a mapping from message content to ground truth
324+
self.ground_truth_map = {}
325+
for row in task_rows:
326+
if row.messages and len(row.messages) >= 2: # system + user messages
327+
user_msg = row.messages[1].content # user message is typically second
328+
if user_msg:
329+
self.ground_truth_map[str(user_msg)] = row.ground_truth
330+
331+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
332+
"""Set ground truth on rows based on message content, then delegate to SingleTurnRolloutProcessor."""
333+
processed: List[EvaluationRow] = []
334+
335+
for row in rows:
336+
# Find matching ground truth based on user message content
337+
if row.messages and len(row.messages) >= 2:
338+
user_msg = row.messages[1].content # user message
339+
if user_msg and str(user_msg) in self.ground_truth_map:
340+
row.ground_truth = self.ground_truth_map[str(user_msg)]
341+
processed.append(row)
342+
343+
# Delegate to SingleTurnRolloutProcessor
344+
return self.single_turn_processor(processed, config)
345+
346+
309347
# -------------------------
310348
# Dataset loading from Hugging Face at import time
311349
# -------------------------
@@ -415,7 +453,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
415453
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
416454
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
417455
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
418-
rollout_processor=SingleTurnRolloutProcessor(),
456+
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
419457
aggregation_method="mean",
420458
passed_threshold=None,
421459
num_runs=1,
@@ -458,7 +496,7 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
458496
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
459497
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
460498
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
461-
rollout_processor=SingleTurnRolloutProcessor(),
499+
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS),
462500
aggregation_method="mean",
463501
passed_threshold=None,
464502
num_runs=4,

0 commit comments

Comments
 (0)