Skip to content

Commit e106227

Browse files
committed
Adding thresholds for aime and gpqa
1 parent c185d84 commit e106227

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

eval_protocol/benchmarks/test_aime25.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
7272
],
7373
rollout_processor=SingleTurnRolloutProcessor(),
7474
aggregation_method="mean",
75-
passed_threshold=None,
75+
passed_threshold=0.8,
7676
num_runs=8,
7777
max_dataset_rows=2,
7878
max_concurrent_rollouts=4,

eval_protocol/benchmarks/test_gpqa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
9696
],
9797
rollout_processor=GPQAStripGTRolloutProcessor(),
9898
aggregation_method="mean",
99-
passed_threshold=None,
99+
passed_threshold=0.6,
100100
num_runs=8,
101101
mode="pointwise",
102102
)

eval_protocol/pytest/evaluation_test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,16 @@ def postprocess(
8181
if aggregation_method == "mean":
8282
try:
8383
result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist])
84-
_, mu_ci_low, mu_ci_high, standard_error = result_ci
85-
if mu_ci_low is not None and mu_ci_high is not None:
84+
_, mu_ci_low, mu_ci_high, se = result_ci
85+
if mu_ci_low is not None and mu_ci_high is not None and se is not None:
8686
ci_low = float(mu_ci_low)
8787
ci_high = float(mu_ci_high)
88+
standard_error = float(se)
8889
# Keep agg_score as-is (mean over scores). For equal repeats per question these match.
8990
except Exception:
9091
ci_low = None
9192
ci_high = None
93+
standard_error = None
9294

9395
# Determine if the evaluation passed based on threshold
9496
passed = None
@@ -127,9 +129,10 @@ def postprocess(
127129
"num_runs": num_runs,
128130
"rows": total_rows,
129131
}
130-
if ci_low is not None and ci_high is not None:
132+
if ci_low is not None and ci_high is not None and standard_error is not None:
131133
summary_obj["agg_ci_low"] = ci_low
132134
summary_obj["agg_ci_high"] = ci_high
135+
summary_obj["standard_error"] = standard_error
133136

134137
# Aggregate per-metric mean and 95% CI when available
135138
metrics_summary: Dict[str, Dict[str, float]] = {}
@@ -164,9 +167,9 @@ def postprocess(
164167
if metrics_summary:
165168
summary_obj["metrics_agg"] = metrics_summary
166169
if should_print:
167-
if ci_low is not None and ci_high is not None:
170+
if ci_low is not None and ci_high is not None and standard_error is not None:
168171
print(
169-
f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}"
172+
f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} se={summary_obj['standard_error']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}"
170173
)
171174
else:
172175
print(

0 commit comments

Comments
 (0)