Skip to content

Commit 2431dbe

Browse files
committed
format
1 parent d0cd7de commit 2431dbe

File tree

2 files changed

+52
-38
lines changed

2 files changed

+52
-38
lines changed

eval_protocol/benchmarks/test_aime25.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,4 @@ def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
178178
# metrics=metrics,
179179
# )
180180
# output.append(row)
181-
# return output
181+
# return output

eval_protocol/pytest/evaluation_test.py

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,18 @@
5858
from ..common_utils import load_jsonl
5959

6060

61-
def postprocess(all_results: List[List[EvaluationRow]],
62-
aggregation_method: AggregationMethod,
63-
threshold: Optional[EvaluationThreshold],
64-
active_logger: DatasetLogger,
65-
mode: EvaluationTestMode,
66-
completion_params: CompletionParams,
67-
test_func_name: str,
68-
num_runs: int):
61+
def postprocess(
62+
all_results: List[List[EvaluationRow]],
63+
aggregation_method: AggregationMethod,
64+
threshold: Optional[EvaluationThreshold],
65+
active_logger: DatasetLogger,
66+
mode: EvaluationTestMode,
67+
completion_params: CompletionParams,
68+
test_func_name: str,
69+
num_runs: int,
70+
):
6971
scores = [
70-
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result)
71-
for result in all_results
72+
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result) for result in all_results
7273
]
7374
agg_score = aggregate(scores, aggregation_method)
7475

@@ -220,9 +221,7 @@ def postprocess(all_results: List[List[EvaluationRow]],
220221

221222
# Check threshold after logging
222223
if threshold is not None and not passed:
223-
assert agg_score >= threshold.success, (
224-
f"Aggregated score {agg_score:.3f} below threshold {threshold.success}"
225-
)
224+
assert agg_score >= threshold.success, f"Aggregated score {agg_score:.3f} below threshold {threshold.success}"
226225
if threshold.standard_error is not None and standard_error is not None:
227226
assert standard_error <= threshold.standard_error, (
228227
f"Standard error {standard_error:.3f} above threshold {threshold.standard_error}"
@@ -350,21 +349,15 @@ def decorator(
350349
# additional check for groupwise evaluation
351350
elif mode == "groupwise":
352351
if "rows" not in sig.parameters:
353-
raise ValueError(
354-
"In listwise mode, your eval function must have a parameter named 'rows'"
355-
)
352+
raise ValueError("In listwise mode, your eval function must have a parameter named 'rows'")
356353

357354
# validate that "Rows" is of type List[EvaluationRow]
358355
if sig.parameters["rows"].annotation is not List[EvaluationRow]:
359-
raise ValueError(
360-
"In listwise mode, the 'rows' parameter must be of type List[EvaluationRow"
361-
)
356+
raise ValueError("In listwise mode, the 'rows' parameter must be of type List[EvaluationRow")
362357

363358
# validate that the function has a return type of List[EvaluationRow]
364359
if sig.return_annotation is not List[EvaluationRow]:
365-
raise ValueError(
366-
"In listwise mode, your eval function must return a list of EvaluationRow instances"
367-
)
360+
raise ValueError("In listwise mode, your eval function must return a list of EvaluationRow instances")
368361
if len(completion_params) < 2:
369362
raise ValueError("In groupwise mode, you must provide at least 2 completion parameters")
370363
else:
@@ -378,9 +371,7 @@ def decorator(
378371

379372
# validate that the function has a return type of List[EvaluationRow]
380373
if sig.return_annotation is not List[EvaluationRow]:
381-
raise ValueError(
382-
"In listwise mode, your eval function must return a list of EvaluationRow instances"
383-
)
374+
raise ValueError("In listwise mode, your eval function must return a list of EvaluationRow instances")
384375

385376
async def execute_with_params(
386377
test_func: TestFunction,
@@ -411,7 +402,9 @@ async def execute_with_params(
411402

412403
# Calculate all possible combinations of parameters
413404
if mode == "groupwise":
414-
combinations = generate_parameter_combinations(input_dataset, None, input_dataset, evaluation_test_kwargs, max_dataset_rows, combine_datasets)
405+
combinations = generate_parameter_combinations(
406+
input_dataset, None, input_dataset, evaluation_test_kwargs, max_dataset_rows, combine_datasets
407+
)
415408
else:
416409
combinations = generate_parameter_combinations(
417410
input_dataset,
@@ -619,7 +612,7 @@ async def _execute_with_semaphore(row):
619612
all_results[i] = results
620613
elif mode == "groupwise":
621614
# rollout all the completion_params for the same row at once, and then send the output to the test_func
622-
row_groups = defaultdict(list) # key: row_id, value: list of rollout_result
615+
row_groups = defaultdict(list) # key: row_id, value: list of rollout_result
623616
tasks: List[asyncio.Task[List[EvaluationRow]]] = []
624617
# completion_groups = []
625618
for idx, cp in enumerate(original_completion_params_list):
@@ -636,7 +629,9 @@ async def _execute_with_semaphore(row):
636629

637630
async def _collect_result(config, lst, max_retry):
638631
result = []
639-
async for row in rollout_processor_with_retry(rollout_processor, lst, config, max_retry):
632+
async for row in rollout_processor_with_retry(
633+
rollout_processor, lst, config, max_retry
634+
):
640635
result.append(row)
641636
return result
642637

@@ -654,7 +649,9 @@ async def _collect_result(config, lst, max_retry):
654649
results = []
655650
for row_id, rows in row_groups.items():
656651
result = await execute_with_params(
657-
test_func, processed_dataset=rows, evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}
652+
test_func,
653+
processed_dataset=rows,
654+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
658655
)
659656
results.extend(result)
660657
all_results[i] = results
@@ -670,10 +667,7 @@ async def _collect_result(config, lst, max_retry):
670667
results = await execute_with_params(
671668
test_func,
672669
processed_dataset=input_dataset,
673-
evaluation_test_kwargs=kwargs.get(
674-
"evaluation_test_kwargs"
675-
)
676-
or {},
670+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
677671
)
678672
if results is None:
679673
raise ValueError(
@@ -698,17 +692,37 @@ async def _collect_result(config, lst, max_retry):
698692
r.eval_metadata.status = "finished"
699693
active_logger.log(r)
700694

701-
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
695+
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
702696
# rollout_id is used to differentiate the result from different completion_params
703697
if mode == "groupwise":
704-
results_by_group = [[[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list))]
698+
results_by_group = [
699+
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list))
700+
]
705701
for i, result in enumerate(all_results):
706702
for r in result:
707703
results_by_group[int(r.execution_metadata.rollout_id)][i].append(r)
708704
for i, result in enumerate(results_by_group):
709-
postprocess(result, aggregation_method, threshold, active_logger, mode, original_completion_params_list[i], test_func.__name__, num_runs)
705+
postprocess(
706+
result,
707+
aggregation_method,
708+
threshold,
709+
active_logger,
710+
mode,
711+
original_completion_params_list[i],
712+
test_func.__name__,
713+
num_runs,
714+
)
710715
else:
711-
postprocess(all_results, aggregation_method, threshold, active_logger, mode, completion_params, test_func.__name__, num_runs)
716+
postprocess(
717+
all_results,
718+
aggregation_method,
719+
threshold,
720+
active_logger,
721+
mode,
722+
completion_params,
723+
test_func.__name__,
724+
num_runs,
725+
)
712726

713727
except AssertionError:
714728
_log_eval_error("finished", data if "data" in locals() else None, passed=False)

0 commit comments

Comments
 (0)