Skip to content

Commit d0cd7de

Browse files
committed
support groupwise scoring
1 parent 3f7b4c3 commit d0cd7de

File tree

4 files changed

+348
-192
lines changed

4 files changed

+348
-192
lines changed

eval_protocol/benchmarks/test_aime25.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,69 @@ def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
113113
metrics=metrics,
114114
)
115115
return row
116+
117+
118+
# @evaluation_test(
119+
# input_dataset=[
120+
# "https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl",
121+
# # "https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl",
122+
# ],
123+
# dataset_adapter=aime2025_dataset_adapter,
124+
# completion_params=[
125+
# {
126+
# "max_tokens": 131000,
127+
# "extra_body": {"reasoning_effort": "low"},
128+
# "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
129+
# },
130+
# {
131+
# "max_tokens": 131000,
132+
# "extra_body": {"reasoning_effort": "low"},
133+
# "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-20b",
134+
# }
135+
# ],
136+
# rollout_processor=SingleTurnRolloutProcessor(),
137+
# aggregation_method="mean",
138+
# passed_threshold=None,
139+
# num_runs=1,
140+
# max_dataset_rows=2,
141+
# max_concurrent_rollouts=4,
142+
# mode="groupwise",
143+
# )
144+
# def test_aime25_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]:
145+
# output = []
146+
# for row in rows:
147+
# assistant_msgs = [m for m in row.messages if m.role == "assistant"]
148+
# content = assistant_msgs[-1].content if assistant_msgs else ""
149+
150+
# extracted_text = _extract_boxed_text(content or "")
151+
# extracted_int = _normalize_to_int_or_none(extracted_text)
152+
# gt_int = _normalize_to_int_or_none(row.ground_truth or "")
153+
154+
# is_valid = extracted_int is not None and gt_int is not None
155+
# score = 1.0 if (is_valid and extracted_int == gt_int) else 0.0
156+
157+
# metrics = {
158+
# "exact_match": MetricResult(
159+
# score=score,
160+
# is_score_valid=is_valid,
161+
# reason=(
162+
# "Parsed both integers and they matched"
163+
# if score == 1.0
164+
# else ("Parsed integers did not match" if is_valid else "Failed to parse integer")
165+
# ),
166+
# data={
167+
# "extracted_text": extracted_text,
168+
# "extracted_int": extracted_int,
169+
# "ground_truth_int": gt_int,
170+
# },
171+
# )
172+
# }
173+
174+
# row.evaluation_result = EvaluateResult(
175+
# score=score,
176+
# reason=("Answer correct" if score == 1.0 else "Answer incorrect"),
177+
# is_score_valid=is_valid,
178+
# metrics=metrics,
179+
# )
180+
# output.append(row)
181+
# return output

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
135135
async def process_row(row: EvaluationRow) -> EvaluationRow:
136136
"""Process a single row with agent rollout."""
137137
agent = Agent(
138-
model=config.completion_params["model"],
138+
model=row.input_metadata.completion_params["model"],
139139
row=row,
140140
config_path=config.mcp_config_path,
141141
logger=config.logger,

0 commit comments

Comments
 (0)