Skip to content

Commit 370d4da

Browse files
committed
preserve metadata and evaluator id etc to the wrapped eval func
1 parent d957771 commit 370d4da

File tree

2 files changed

+98
-20
lines changed

2 files changed

+98
-20
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from dataclasses import replace
1212
from typing import Any, Callable, Dict, List, Literal, Optional, Union
1313
from collections import defaultdict
14-
14+
import hashlib
15+
import ast
1516
from mcp.types import Completion
1617
import pytest
1718

@@ -244,6 +245,7 @@ def evaluation_test( # noqa: C901
244245
max_dataset_rows: Optional[int] = None,
245246
mcp_config_path: Optional[str] = None,
246247
max_concurrent_rollouts: int = 8,
248+
max_concurrent_evaluations: int = 64,
247249
server_script_path: Optional[str] = None,
248250
steps: int = 30,
249251
mode: EvaluationTestMode = "pointwise",
@@ -308,6 +310,7 @@ def evaluation_test( # noqa: C901
308310
max_dataset_rows: Limit dataset to the first N rows.
309311
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
310312
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
313+
max_concurrent_evaluations: Maximum number of concurrent evaluations to run in parallel.
311314
server_script_path: Path to the MCP server script to run (default: "examples/tau2_mcp/server.py").
312315
steps: Number of rollout steps to execute (default: 30).
313316
mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result).
@@ -581,30 +584,42 @@ def _log_eval_error(
581584
# log the fresh_dataset
582585
for row in fresh_dataset:
583586
active_logger.log(row)
584-
585-
if mode == "pointwise":
586-
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
587-
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
588-
tasks = []
589-
590-
async def _execute_with_semaphore(row):
591-
async with semaphore:
592-
# NOTE: we will still evaluate errored rows (give users control over this)
593-
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
587+
588+
# prepare parallel eval helper function
589+
semaphore = asyncio.Semaphore(max_concurrent_evaluations)
590+
async def _execute_eval_with_semaphore(**kwargs):
591+
async with semaphore:
592+
# NOTE: we will still evaluate errored rows (give users control over this)
593+
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
594+
if "row" in kwargs:
594595
result = await execute_with_params(
595596
test_func,
596-
processed_row=row,
597+
processed_row=kwargs["rows"],
597598
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
598599
)
599600
if result is None or not isinstance(result, EvaluationRow):
600601
raise ValueError(
601602
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
602603
)
603604
return result
605+
if "rows" in kwargs:
606+
results = await execute_with_params(
607+
test_func,
608+
processed_dataset=kwargs["rows"],
609+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
610+
)
611+
if results is None or not isinstance(results, list):
612+
raise ValueError(
613+
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
614+
)
615+
return results
604616

617+
if mode == "pointwise":
618+
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
619+
tasks = []
605620
# Use wrapper that handles retry logic internally
606621
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
607-
tasks.append(asyncio.create_task(_execute_with_semaphore(row)))
622+
tasks.append(asyncio.create_task(_execute_eval_with_semaphore(row=row)))
608623

609624
results = await asyncio.gather(*tasks)
610625

@@ -645,14 +660,13 @@ async def _collect_result(config, lst):
645660
for result in rollout_results:
646661
for row in result:
647662
row_groups[row.input_metadata.row_id].append(row)
648-
results = []
663+
tasks = []
649664
for row_id, rows in row_groups.items():
650-
result = await execute_with_params(
651-
test_func,
652-
processed_dataset=rows,
653-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
654-
)
655-
results.extend(result)
665+
tasks.append(asyncio.create_task(_execute_eval_with_semaphore(rows=rows)))
666+
results = []
667+
for task in tasks:
668+
res = await task
669+
results.extend(res)
656670
all_results[i] = results
657671
else:
658672
# Batch mode: collect all results first, then evaluate (no pipelining)
@@ -788,6 +802,24 @@ async def dual_mode_wrapper(*args, **kwargs):
788802

789803
# If not a direct call, use the pytest wrapper
790804
return await pytest_wrapper(*args, **kwargs)
805+
806+
dual_mode_wrapper._origin_func = test_func
807+
dual_mode_wrapper._evaluator_id = test_func.__name__
808+
# Generate (stable) evaluator ID from function source code hash
809+
try:
810+
func_source = inspect.getsource(test_func)
811+
parsed = ast.parse(func_source)
812+
normalized_source = ast.unparse(parsed)
813+
clean_source = ''.join(normalized_source.split()) + test_func.__name__
814+
func_hash = hashlib.sha256(clean_source.encode('utf-8')).hexdigest()[:12]
815+
dual_mode_wrapper._version = f"{test_func.__name__}_{func_hash}"
816+
except (OSError, TypeError, SyntaxError):
817+
pass
818+
dual_mode_wrapper._metainfo = {
819+
"mode": mode,
820+
"max_rollout_concurrency": max_concurrent_rollouts,
821+
"max_evaluation_concurrency": max_concurrent_evaluations,
822+
}
791823

792824
# Copy all attributes from the pytest wrapper to our dual mode wrapper
793825
import functools

tests/pytest/test_get_metadata.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import asyncio
2+
from typing import Dict, List
3+
4+
from eval_protocol.pytest import evaluation_test
5+
from eval_protocol.models import EvaluationRow, Message
6+
7+
@evaluation_test(
8+
input_messages=[
9+
[
10+
Message(role="user", content="What is the capital of France?"),
11+
],
12+
[
13+
Message(role="user", content="What is the capital of the moon?"),
14+
],
15+
],
16+
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}] * 2,
17+
mode="groupwise",
18+
max_concurrent_rollouts=5,
19+
max_concurrent_evaluations=10,
20+
)
21+
def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
22+
"""Run math evaluation on sample dataset using pytest interface."""
23+
return rows
24+
25+
26+
27+
def test_pytest_func_metainfo():
28+
assert hasattr(test_pytest_async, "_origin_func")
29+
origin_func = test_pytest_async._origin_func
30+
assert not asyncio.iscoroutinefunction(origin_func)
31+
assert asyncio.iscoroutinefunction(test_pytest_async)
32+
assert test_pytest_async._metainfo["mode"] == "groupwise"
33+
assert test_pytest_async._metainfo["max_rollout_concurrency"] == 5
34+
assert test_pytest_async._metainfo["max_evaluation_concurrency"] == 10
35+
36+
# Test evaluator ID generation
37+
assert hasattr(test_pytest_async, "_evaluator_id")
38+
evaluator_id = test_pytest_async._evaluator_id
39+
assert evaluator_id.startswith("eval_")
40+
assert len(evaluator_id) == 17 # "eval_" + 12 character hash
41+
print(f"Generated evaluator ID: {evaluator_id}")
42+
43+
44+
45+
46+

0 commit comments

Comments
 (0)