|
| 1 | +import asyncio |
1 | 2 | import json |
2 | 3 | import re |
3 | 4 | from typing import Any, Dict, List, Optional |
|
7 | 8 | SingleTurnRolloutProcessor, |
8 | 9 | ) |
9 | 10 | from eval_protocol.pytest.evaluation_test import evaluation_test |
| 11 | +from eval_protocol.pytest.rollout_processor import RolloutProcessor |
| 12 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
10 | 13 |
|
11 | 14 | # ------------------------- |
12 | 15 | # Lightweight ports of LiveBench scoring utilities for data_analysis tasks |
@@ -306,6 +309,41 @@ def _read_jsonl_table_from_text(text: str, header_cols: List[str]): |
306 | 309 | return 1 |
307 | 310 |
|
308 | 311 |
|
| 312 | +# ------------------------- |
| 313 | +# Custom Rollout Processor to preserve ground truth |
| 314 | +# ------------------------- |
| 315 | + |
| 316 | + |
| 317 | +class LiveBenchGroundTruthRolloutProcessor(RolloutProcessor): |
| 318 | + """Rollout processor that preserves ground truth data from pre-loaded datasets.""" |
| 319 | + |
| 320 | + def __init__(self, task_rows: List[EvaluationRow]): |
| 321 | + super().__init__() |
| 322 | + self.single_turn_processor = SingleTurnRolloutProcessor() |
| 323 | + # Create a mapping from message content to ground truth |
| 324 | + self.ground_truth_map = {} |
| 325 | + for row in task_rows: |
| 326 | + if row.messages and len(row.messages) >= 2: # system + user messages |
| 327 | + user_msg = row.messages[1].content # user message is typically second |
| 328 | + if user_msg: |
| 329 | + self.ground_truth_map[str(user_msg)] = row.ground_truth |
| 330 | + |
| 331 | + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: |
| 332 | + """Set ground truth on rows based on message content, then delegate to SingleTurnRolloutProcessor.""" |
| 333 | + processed: List[EvaluationRow] = [] |
| 334 | + |
| 335 | + for row in rows: |
| 336 | + # Find matching ground truth based on user message content |
| 337 | + if row.messages and len(row.messages) >= 2: |
| 338 | + user_msg = row.messages[1].content # user message |
| 339 | + if user_msg and str(user_msg) in self.ground_truth_map: |
| 340 | + row.ground_truth = self.ground_truth_map[str(user_msg)] |
| 341 | + processed.append(row) |
| 342 | + |
| 343 | + # Delegate to SingleTurnRolloutProcessor |
| 344 | + return self.single_turn_processor(processed, config) |
| 345 | + |
| 346 | + |
309 | 347 | # ------------------------- |
310 | 348 | # Dataset loading from Hugging Face at import time |
311 | 349 | # ------------------------- |
@@ -415,7 +453,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: |
415 | 453 | completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], |
416 | 454 | input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS], |
417 | 455 | rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], |
418 | | - rollout_processor=SingleTurnRolloutProcessor(), |
| 456 | + rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS), |
419 | 457 | aggregation_method="mean", |
420 | 458 | passed_threshold=None, |
421 | 459 | num_runs=1, |
@@ -458,7 +496,7 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: |
458 | 496 | completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], |
459 | 497 | input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS], |
460 | 498 | rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], |
461 | | - rollout_processor=SingleTurnRolloutProcessor(), |
| 499 | + rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS), |
462 | 500 | aggregation_method="mean", |
463 | 501 | passed_threshold=None, |
464 | 502 | num_runs=4, |
|
0 commit comments