Skip to content

Commit 6751af4

Browse files
committed
Fix CI to only be valid scores, and show agg_score and raw_score
1 parent 87fdc2c commit 6751af4

File tree

2 files changed

+247
-4
lines changed

2 files changed

+247
-4
lines changed

eval_protocol/pytest/evaluation_test_postprocess.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,17 @@ def postprocess(
3939
]
4040
agg_score = aggregate(scores, aggregation_method)
4141

42+
# Calculate raw score (total score / total rows, including invalid scores)
43+
all_scores = [r.evaluation_result.score for sublist in all_results for r in sublist if r.evaluation_result]
44+
raw_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
45+
4246
# Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats)
4347
ci_low: float | None = None
4448
ci_high: float | None = None
4549
standard_error: float | None = None
4650
if aggregation_method == "mean":
4751
try:
48-
result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist])
52+
result_ci = compute_fixed_set_mu_ci([item for sublist in valid_results for item in sublist])
4953
_, mu_ci_low, mu_ci_high, se = result_ci
5054
if mu_ci_low is not None and mu_ci_high is not None and se is not None:
5155
ci_low = float(mu_ci_low)
@@ -140,12 +144,17 @@ def postprocess(
140144
if should_print:
141145
if ci_low is not None and ci_high is not None and standard_error is not None:
142146
print(
143-
f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} se={summary_obj['standard_error']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}",
147+
f"EP Summary | suite={suite_name} model={model_used} runs={num_runs} rows={total_rows}\n"
148+
f" agg_score={summary_obj['agg_score']:.3f} (valid scores only)\n"
149+
f" raw_score={raw_score:.3f} (includes invalid scores)\n"
150+
f" se={summary_obj['standard_error']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}]",
144151
file=sys.__stderr__,
145152
)
146153
else:
147154
print(
148-
f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}",
155+
f"EP Summary | suite={suite_name} model={model_used} runs={num_runs} rows={total_rows}\n"
156+
f" agg_score={summary_obj['agg_score']:.3f} (valid scores only)\n"
157+
f" raw_score={raw_score:.3f} (includes invalid scores)",
149158
file=sys.__stderr__,
150159
)
151160
# As per project convention, avoid printing per-metric CI lines to reduce noise

tests/test_evaluation_postprocess.py

Lines changed: 235 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import pytest
44
from unittest.mock import Mock, patch
55

6-
from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata
6+
from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata, Message
77
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
8+
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
89

910

1011
class TestPostprocess:
@@ -205,3 +206,236 @@ def test_all_invalid_scores(self):
205206

206207
# Should still call logger.log for all rows
207208
assert mock_logger.log.call_count == 2
209+
210+
211+
class TestComputeFixedSetMuCi:
212+
"""Tests for compute_fixed_set_mu_ci function."""
213+
214+
@patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads
215+
def test_compute_fixed_set_mu_ci_with_flattened_results(self):
216+
"""Test that postprocess correctly calls compute_fixed_set_mu_ci with flattened all_results structure."""
217+
218+
q1_run1 = EvaluationRow(
219+
messages=[Message(role="user", content="What is 2+2?")],
220+
evaluation_result=EvaluateResult(score=0.5, is_score_valid=True, reason="correct"),
221+
input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}),
222+
execution_metadata=ExecutionMetadata(),
223+
eval_metadata=EvalMetadata(
224+
name="test",
225+
description="test",
226+
version="1.0",
227+
status=None,
228+
num_runs=3,
229+
aggregation_method="mean",
230+
passed_threshold=None,
231+
passed=None,
232+
),
233+
)
234+
q1_run2 = EvaluationRow(
235+
messages=[Message(role="user", content="What is 2+2?")],
236+
evaluation_result=EvaluateResult(score=0.4, is_score_valid=True, reason="incorrect"),
237+
input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}),
238+
execution_metadata=ExecutionMetadata(),
239+
eval_metadata=EvalMetadata(
240+
name="test",
241+
description="test",
242+
version="1.0",
243+
status=None,
244+
num_runs=3,
245+
aggregation_method="mean",
246+
passed_threshold=None,
247+
passed=None,
248+
),
249+
)
250+
q1_run3 = EvaluationRow(
251+
messages=[Message(role="user", content="What is 2+2?")],
252+
evaluation_result=EvaluateResult(score=0.45, is_score_valid=True, reason="incorrect"),
253+
input_metadata=InputMetadata(row_id="q1", completion_params={"model": "test"}),
254+
execution_metadata=ExecutionMetadata(),
255+
eval_metadata=EvalMetadata(
256+
name="test",
257+
description="test",
258+
version="1.0",
259+
status=None,
260+
num_runs=3,
261+
aggregation_method="mean",
262+
passed_threshold=None,
263+
passed=None,
264+
),
265+
)
266+
q2_run1 = EvaluationRow(
267+
messages=[Message(role="user", content="What is 3+3?")],
268+
evaluation_result=EvaluateResult(score=0.8, is_score_valid=True, reason="incorrect"),
269+
input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}),
270+
execution_metadata=ExecutionMetadata(),
271+
eval_metadata=EvalMetadata(
272+
name="test",
273+
description="test",
274+
version="1.0",
275+
status=None,
276+
num_runs=3,
277+
aggregation_method="mean",
278+
passed_threshold=None,
279+
passed=None,
280+
),
281+
)
282+
q2_run2 = EvaluationRow(
283+
messages=[Message(role="user", content="What is 3+3?")],
284+
evaluation_result=EvaluateResult(score=0.9, is_score_valid=True, reason="correct"),
285+
input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}),
286+
execution_metadata=ExecutionMetadata(),
287+
eval_metadata=EvalMetadata(
288+
name="test",
289+
description="test",
290+
version="1.0",
291+
status=None,
292+
num_runs=3,
293+
aggregation_method="mean",
294+
passed_threshold=None,
295+
passed=None,
296+
),
297+
)
298+
q2_run3 = EvaluationRow(
299+
messages=[Message(role="user", content="What is 3+3?")],
300+
evaluation_result=EvaluateResult(score=0.95, is_score_valid=True, reason="correct"),
301+
input_metadata=InputMetadata(row_id="q2", completion_params={"model": "test"}),
302+
execution_metadata=ExecutionMetadata(),
303+
eval_metadata=EvalMetadata(
304+
name="test",
305+
description="test",
306+
version="1.0",
307+
status=None,
308+
num_runs=3,
309+
aggregation_method="mean",
310+
passed_threshold=None,
311+
passed=None,
312+
),
313+
)
314+
q3_run1 = EvaluationRow(
315+
messages=[Message(role="user", content="What is 4+4?")],
316+
evaluation_result=EvaluateResult(score=0.1, is_score_valid=True, reason="incorrect"),
317+
input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}),
318+
execution_metadata=ExecutionMetadata(),
319+
eval_metadata=EvalMetadata(
320+
name="test",
321+
description="test",
322+
version="1.0",
323+
status=None,
324+
num_runs=3,
325+
aggregation_method="mean",
326+
passed_threshold=None,
327+
passed=None,
328+
),
329+
)
330+
q3_run2 = EvaluationRow(
331+
messages=[Message(role="user", content="What is 4+4?")],
332+
evaluation_result=EvaluateResult(score=0.2, is_score_valid=True, reason="correct"),
333+
input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}),
334+
execution_metadata=ExecutionMetadata(),
335+
eval_metadata=EvalMetadata(
336+
name="test",
337+
description="test",
338+
version="1.0",
339+
status=None,
340+
num_runs=3,
341+
aggregation_method="mean",
342+
passed_threshold=None,
343+
passed=None,
344+
),
345+
)
346+
q3_run3_valid = EvaluationRow(
347+
messages=[Message(role="user", content="What is 4+4?")],
348+
evaluation_result=EvaluateResult(score=0.3, is_score_valid=True, reason="correct"),
349+
input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}),
350+
execution_metadata=ExecutionMetadata(),
351+
eval_metadata=EvalMetadata(
352+
name="test",
353+
description="test",
354+
version="1.0",
355+
status=None,
356+
num_runs=3,
357+
aggregation_method="mean",
358+
passed_threshold=None,
359+
passed=None,
360+
),
361+
)
362+
q3_run3_invalid = EvaluationRow(
363+
messages=[Message(role="user", content="What is 4+4?")],
364+
evaluation_result=EvaluateResult(score=0.3, is_score_valid=False, reason="correct"),
365+
input_metadata=InputMetadata(row_id="q3", completion_params={"model": "test"}),
366+
execution_metadata=ExecutionMetadata(),
367+
eval_metadata=EvalMetadata(
368+
name="test",
369+
description="test",
370+
version="1.0",
371+
status=None,
372+
num_runs=3,
373+
aggregation_method="mean",
374+
passed_threshold=None,
375+
passed=None,
376+
),
377+
)
378+
379+
rows = [[q1_run1, q2_run1, q3_run1], [q1_run2, q2_run2, q1_run3], [q2_run3, q3_run2, q3_run3_valid]]
380+
rows_with_invalid_score = [
381+
[q1_run1, q2_run1, q3_run1],
382+
[q1_run2, q2_run2, q1_run3],
383+
[q2_run3, q3_run2, q3_run3_invalid],
384+
]
385+
386+
# Store results for assertions
387+
first_result = None
388+
second_result = None
389+
390+
# Test first case (all valid scores)
391+
with patch("eval_protocol.pytest.evaluation_test_postprocess.compute_fixed_set_mu_ci") as mock_ci:
392+
mock_ci.side_effect = lambda input_rows, **kwargs: compute_fixed_set_mu_ci(input_rows, **kwargs)
393+
394+
postprocess(
395+
all_results=rows,
396+
aggregation_method="mean",
397+
threshold=None,
398+
active_logger=Mock(),
399+
mode="pointwise",
400+
completion_params={"model": "test-model"},
401+
test_func_name="test_ci_flattened",
402+
num_runs=3,
403+
experiment_duration_seconds=10.0,
404+
)
405+
406+
first_result = mock_ci.return_value
407+
408+
# Test second case (with invalid score)
409+
with patch("eval_protocol.pytest.evaluation_test_postprocess.compute_fixed_set_mu_ci") as mock_ci:
410+
mock_ci.side_effect = lambda input_rows, **kwargs: compute_fixed_set_mu_ci(input_rows, **kwargs)
411+
412+
postprocess(
413+
all_results=rows_with_invalid_score,
414+
aggregation_method="mean",
415+
threshold=None,
416+
active_logger=Mock(),
417+
mode="pointwise",
418+
completion_params={"model": "test-model"},
419+
test_func_name="test_ci_flattened_invalid",
420+
num_runs=3,
421+
experiment_duration_seconds=10.0,
422+
)
423+
424+
second_result = mock_ci.return_value
425+
426+
# Assert exact values
427+
# First case: (0.5111111111111111, 0.18101430525778583, 0.8412079169644363, 0.168416737680268)
428+
if first_result and len(first_result) == 4:
429+
mu_hat1, ci_low1, ci_high1, se1 = first_result
430+
assert abs(mu_hat1 - 0.5111111111111111) < 1e-10
431+
assert abs(ci_low1 - 0.18101430525778583) < 1e-10
432+
assert abs(ci_high1 - 0.8412079169644363) < 1e-10
433+
assert abs(se1 - 0.168416737680268) < 1e-10
434+
435+
# Second case: (0.49444444444444446, 0.13494616580367125, 0.8539427230852177, 0.18341748910243533)
436+
if second_result and len(second_result) == 4:
437+
mu_hat2, ci_low2, ci_high2, se2 = second_result
438+
assert abs(mu_hat2 - 0.49444444444444446) < 1e-10
439+
assert abs(ci_low2 - 0.13494616580367125) < 1e-10
440+
assert abs(ci_high2 - 0.8539427230852177) < 1e-10
441+
assert abs(se2 - 0.18341748910243533) < 1e-10

0 commit comments

Comments
 (0)