|
11 | 11 | from dataclasses import replace |
12 | 12 | from typing import Any, Callable, Dict, List, Literal, Optional, Union |
13 | 13 | from collections import defaultdict |
14 | | - |
| 14 | +import hashlib |
| 15 | +import ast |
15 | 16 | from mcp.types import Completion |
16 | 17 | import pytest |
17 | 18 |
|
@@ -244,6 +245,7 @@ def evaluation_test( # noqa: C901 |
244 | 245 | max_dataset_rows: Optional[int] = None, |
245 | 246 | mcp_config_path: Optional[str] = None, |
246 | 247 | max_concurrent_rollouts: int = 8, |
| 248 | + max_concurrent_evaluations: int = 64, |
247 | 249 | server_script_path: Optional[str] = None, |
248 | 250 | steps: int = 30, |
249 | 251 | mode: EvaluationTestMode = "pointwise", |
@@ -308,6 +310,7 @@ def evaluation_test( # noqa: C901 |
308 | 310 | max_dataset_rows: Limit dataset to the first N rows. |
309 | 311 | mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema |
310 | 312 | max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel. |
| 313 | + max_concurrent_evaluations: Maximum number of concurrent evaluations to run in parallel. |
311 | 314 | server_script_path: Path to the MCP server script to run (default: "examples/tau2_mcp/server.py"). |
312 | 315 | steps: Number of rollout steps to execute (default: 30). |
313 | 316 | mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result). |
@@ -581,30 +584,42 @@ def _log_eval_error( |
581 | 584 | # log the fresh_dataset |
582 | 585 | for row in fresh_dataset: |
583 | 586 | active_logger.log(row) |
584 | | - |
585 | | - if mode == "pointwise": |
586 | | - # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution |
587 | | - semaphore = asyncio.Semaphore(max_concurrent_rollouts) |
588 | | - tasks = [] |
589 | | - |
590 | | - async def _execute_with_semaphore(row): |
591 | | - async with semaphore: |
592 | | - # NOTE: we will still evaluate errored rows (give users control over this) |
593 | | - # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func |
| 587 | + |
| 588 | + # prepare parallel eval helper function |
| 589 | + semaphore = asyncio.Semaphore(max_concurrent_evaluations) |
| 590 | + async def _execute_eval_with_semaphore(**kwargs): |
| 591 | + async with semaphore: |
| 592 | + # NOTE: we will still evaluate errored rows (give users control over this) |
| 593 | + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func |
| 594 | + if "row" in kwargs: |
594 | 595 | result = await execute_with_params( |
595 | 596 | test_func, |
596 | | - processed_row=row, |
| 597 | + processed_row=kwargs["rows"], |
597 | 598 | evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, |
598 | 599 | ) |
599 | 600 | if result is None or not isinstance(result, EvaluationRow): |
600 | 601 | raise ValueError( |
601 | 602 | f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." |
602 | 603 | ) |
603 | 604 | return result |
| 605 | + if "rows" in kwargs: |
| 606 | + results = await execute_with_params( |
| 607 | + test_func, |
| 608 | + processed_dataset=kwargs["rows"], |
| 609 | + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, |
| 610 | + ) |
| 611 | + if results is None or not isinstance(results, list): |
| 612 | + raise ValueError( |
| 613 | + f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." |
| 614 | + ) |
| 615 | + return results |
604 | 616 |
|
| 617 | + if mode == "pointwise": |
| 618 | + # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution |
| 619 | + tasks = [] |
605 | 620 | # Use wrapper that handles retry logic internally |
606 | 621 | async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config): |
607 | | - tasks.append(asyncio.create_task(_execute_with_semaphore(row))) |
| 622 | + tasks.append(asyncio.create_task(_execute_eval_with_semaphore(row=row))) |
608 | 623 |
|
609 | 624 | results = await asyncio.gather(*tasks) |
610 | 625 |
|
@@ -645,14 +660,13 @@ async def _collect_result(config, lst): |
645 | 660 | for result in rollout_results: |
646 | 661 | for row in result: |
647 | 662 | row_groups[row.input_metadata.row_id].append(row) |
648 | | - results = [] |
| 663 | + tasks = [] |
649 | 664 | for row_id, rows in row_groups.items(): |
650 | | - result = await execute_with_params( |
651 | | - test_func, |
652 | | - processed_dataset=rows, |
653 | | - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, |
654 | | - ) |
655 | | - results.extend(result) |
| 665 | + tasks.append(asyncio.create_task(_execute_eval_with_semaphore(rows=rows))) |
| 666 | + results = [] |
| 667 | + for task in tasks: |
| 668 | + res = await task |
| 669 | + results.extend(res) |
656 | 670 | all_results[i] = results |
657 | 671 | else: |
658 | 672 | # Batch mode: collect all results first, then evaluate (no pipelining) |
@@ -788,6 +802,24 @@ async def dual_mode_wrapper(*args, **kwargs): |
788 | 802 |
|
789 | 803 | # If not a direct call, use the pytest wrapper |
790 | 804 | return await pytest_wrapper(*args, **kwargs) |
| 805 | + |
| 806 | + dual_mode_wrapper._origin_func = test_func |
| 807 | + dual_mode_wrapper._evaluator_id = test_func.__name__ |
| 808 | + # Generate (stable) evaluator ID from function source code hash |
| 809 | + try: |
| 810 | + func_source = inspect.getsource(test_func) |
| 811 | + parsed = ast.parse(func_source) |
| 812 | + normalized_source = ast.unparse(parsed) |
| 813 | + clean_source = ''.join(normalized_source.split()) + test_func.__name__ |
| 814 | + func_hash = hashlib.sha256(clean_source.encode('utf-8')).hexdigest()[:12] |
| 815 | + dual_mode_wrapper._version = f"{test_func.__name__}_{func_hash}" |
| 816 | + except (OSError, TypeError, SyntaxError): |
| 817 | + pass |
| 818 | + dual_mode_wrapper._metainfo = { |
| 819 | + "mode": mode, |
| 820 | + "max_rollout_concurrency": max_concurrent_rollouts, |
| 821 | + "max_evaluation_concurrency": max_concurrent_evaluations, |
| 822 | + } |
791 | 823 |
|
792 | 824 | # Copy all attributes from the pytest wrapper to our dual mode wrapper |
793 | 825 | import functools |
|
0 commit comments