diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1cf6aec..a0184b62 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -92,6 +92,7 @@ jobs: --ignore=tests/pytest/test_frozen_lake.py \ --ignore=tests/pytest/test_lunar_lander.py \ --ignore=tests/pytest/test_tau_bench_airline.py \ + --ignore=tests/pytest/test_apps_coding.py \ --ignore=tests/test_tau_bench_airline_smoke.py \ --cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10 diff --git a/eval_protocol/benchmarks/suites/aime25.py b/eval_protocol/benchmarks/suites/aime25.py index 3558eaa1..92d7bedc 100644 --- a/eval_protocol/benchmarks/suites/aime25.py +++ b/eval_protocol/benchmarks/suites/aime25.py @@ -3,7 +3,7 @@ from eval_protocol.benchmarks.registry import export_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test @@ -72,7 +72,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]: "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", } ], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=8, diff --git a/eval_protocol/benchmarks/suites/gpqa.py b/eval_protocol/benchmarks/suites/gpqa.py index 76967beb..ced8ac9f 100644 --- a/eval_protocol/benchmarks/suites/gpqa.py +++ b/eval_protocol/benchmarks/suites/gpqa.py @@ -1,3 +1,4 @@ +import asyncio import csv import io import re @@ -8,9 +9,11 @@ from eval_protocol.benchmarks.registry import export_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig SYSTEM_PROMPT = ( "You are a helpful assistant. Read the question and options carefully. " @@ -60,19 +63,31 @@ def _strip_gt_messages(msgs: List[Message]) -> List[Message]: return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))] -async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]: - """Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor.""" - processed: List[EvaluationRow] = [] - for r in rows: - gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")] - if gt_tokens: - gt_val = gt_tokens[-1].split(":", 1)[1].strip() - r.ground_truth = gt_val - r.messages = [ - m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:")) +class GPQAStripGTRolloutProcessor(RolloutProcessor): + """Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to SingleTurnRolloutProcessor.""" + + def __init__(self): + super().__init__() + self.single_turn_processor = SingleTurnRolloutProcessor() + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Preprocess rows and delegate to SingleTurnRolloutProcessor.""" + processed: List[EvaluationRow] = [] + + for r in rows: + gt_tokens = [ + m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:") ] - processed.append(r) - return await default_single_turn_rollout_processor(processed, config) + if gt_tokens: + gt_val = gt_tokens[-1].split(":", 1)[1].strip() + r.ground_truth = gt_val + r.messages = [ + m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:")) + ] + processed.append(r) + + # Delegate to SingleTurnRolloutProcessor + return self.single_turn_processor(processed, config) @export_benchmark("gpqa") @@ -81,7 +96,7 @@ async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> completion_params=[ {"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"} ], - rollout_processor=gpqa_strip_gt_rollout_processor, + rollout_processor=GPQAStripGTRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=8, diff --git a/eval_protocol/benchmarks/suites/livebench_data_analysis.py b/eval_protocol/benchmarks/suites/livebench_data_analysis.py index fc5abb4e..da384439 100644 --- a/eval_protocol/benchmarks/suites/livebench_data_analysis.py +++ b/eval_protocol/benchmarks/suites/livebench_data_analysis.py @@ -5,7 +5,7 @@ from eval_protocol.benchmarks.registry import export_benchmark, register_composite_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test @@ -375,7 +375,7 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]: completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_messages=[[m for m in r.messages] for r in _CTA_ROWS], rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=4, @@ -418,7 +418,7 @@ def livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS], rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=4, @@ -462,7 +462,7 @@ def livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS], rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=4, diff --git a/eval_protocol/benchmarks/suites/tau_bench_retail.py b/eval_protocol/benchmarks/suites/tau_bench_retail.py index 8e8aaea0..6c0a8a36 100644 --- a/eval_protocol/benchmarks/suites/tau_bench_retail.py +++ b/eval_protocol/benchmarks/suites/tau_bench_retail.py @@ -13,7 +13,7 @@ from eval_protocol.benchmarks.registry import export_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from vendor.tau2.data_model.message import ( AssistantMessage, SystemMessage, @@ -73,7 +73,7 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", } ], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), rollout_processor_kwargs={"domain": "retail"}, num_runs=8, mode="pointwise", diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 405e72b4..b0359d79 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -35,7 +35,7 @@ class ExecutionManager: Manage rollout for MCP environments. """ - async def execute_rollouts( + def execute_rollouts( self, envs: "GeneralMCPVectorEnv", policy: Union["LLMBasePolicy", Callable], @@ -43,7 +43,7 @@ async def execute_rollouts( openai_format_log_file: Optional[str] = None, max_concurrent_rollouts: int = 8, evaluation_rows: Optional[List[EvaluationRow]] = None, - ) -> AsyncIterator[EvaluationRow]: + ) -> List[asyncio.Task[EvaluationRow]]: """ Execute general rollouts using tool calling interface with automatic record/playback. @@ -66,7 +66,7 @@ async def execute_rollouts( - Set and file exists: Playback mode (uses recorded data) Returns: - AsyncIterator of EvaluationRow objects with unified evaluation data format + List of asyncio.Task objects for external handling """ start_time = time.time() @@ -138,7 +138,7 @@ async def _execute_with_semaphore(idx): if trajectory.terminated: if trajectory.termination_reason == TerminationReason.ERROR: evaluation_row.rollout_status.status = "error" - evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get( + evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get( "error_message", None ) else: @@ -151,18 +151,7 @@ async def _execute_with_semaphore(idx): # Create all tasks tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)] - - # Yield results as they complete (note that they're not necessarily in original order) - try: - for task in asyncio.as_completed(tasks): - try: - yield await task - except Exception: - logger.exception("Error processing rollout") - finally: - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + return tasks async def _execute_rollout( self, diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index 5d930a4e..f5d09ba0 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -236,7 +236,7 @@ def make( return mcp_envs -async def rollout( +def rollout( envs: GeneralMCPVectorEnv, policy: Union[FireworksPolicy, LLMBasePolicy, Callable], *, @@ -246,7 +246,7 @@ async def rollout( steps: int = 512, openai_format_log_file: Optional[str] = None, max_concurrent_rollouts: int = 8, -) -> AsyncIterator[EvaluationRow]: +) -> List[asyncio.Task[EvaluationRow]]: """ Execute general rollouts using tool calling interface with automatic record/playback. @@ -274,14 +274,14 @@ async def rollout( - Set and file exists: Playback mode (uses recorded data) Returns: - List of EvaluationRow objects + List of asyncio.Task objects for external handling Example: # Live mode - evaluation_rows = await ep.rollout(envs, policy) + tasks = ep.rollout(envs, policy) # Create environments automatically - trajectories = await ep.rollout( + tasks = ep.rollout( "http://localhost:8000/mcp/", policy, evaluation_rows=my_evaluation_rows, @@ -290,10 +290,10 @@ async def rollout( # Recording mode os.environ["EP_PLAYBACK_FILE"] = "record.jsonl" - evaluation_rows = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl") + tasks = ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl") # Playback mode (after recording file exists) - evaluation_rows = await ep.rollout(envs, policy) + tasks = ep.rollout(envs, policy) """ # Automatically create environments if a base URL is provided if isinstance(envs, str): @@ -301,15 +301,15 @@ async def rollout( raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL") auto_model_id = model_id or getattr(policy, "model_id", "unknown") - envs = await make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id) + envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id) # Use the new ExecutionManager for execution execution_manager = ExecutionManager() - async for evaluation_row in execution_manager.execute_rollouts( + tasks = execution_manager.execute_rollouts( envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows - ): - yield evaluation_row + ) + return tasks async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]: @@ -336,7 +336,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]: policy = FireworksPolicy("test-model") # Run short rollout - evaluation_rows = await rollout(envs, policy=policy, steps=10) + evaluation_rows = rollout(envs, policy=policy, steps=10) if evaluation_rows and len(evaluation_rows[0].messages) > 1: results["successful"] += 1 diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 2d2576d6..171fa3dc 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -1,18 +1,19 @@ -from .default_agent_rollout_processor import default_agent_rollout_processor +from .default_agent_rollout_processor import AgentRolloutProcessor from .default_dataset_adapter import default_dataset_adapter -from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor -from .default_no_op_rollout_process import default_no_op_rollout_processor -from .default_single_turn_rollout_process import default_single_turn_rollout_processor +from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor +from .default_no_op_rollout_processor import NoOpRolloutProcessor +from .default_single_turn_rollout_process import SingleTurnRolloutProcessor from .evaluation_test import evaluation_test -from .types import RolloutProcessor, RolloutProcessorConfig +from .rollout_processor import RolloutProcessor +from .types import RolloutProcessorConfig __all__ = [ - "default_agent_rollout_processor", - "default_mcp_gym_rollout_processor", - "default_no_op_rollout_processor", - "default_single_turn_rollout_processor", - "default_dataset_adapter", + "AgentRolloutProcessor", + "MCPGymRolloutProcessor", "RolloutProcessor", + "SingleTurnRolloutProcessor", + "NoOpRolloutProcessor", + "default_dataset_adapter", "RolloutProcessorConfig", "evaluation_test", ] diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index 50f12231..65428b4b 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -13,6 +13,7 @@ from eval_protocol.mcp.execution.policy import LiteLLMPolicy from eval_protocol.mcp.mcp_multi_client import MCPMultiClient from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig logger = logging.getLogger(__name__) @@ -115,46 +116,36 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex return tool_result.content -async def default_agent_rollout_processor( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: - """Process agent rollouts with bounded concurrency and yield as they complete.""" +class AgentRolloutProcessor(RolloutProcessor): + """Agent rollout processor for tool-calling agents.""" - max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 - semaphore = asyncio.Semaphore(max_concurrent) + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Create agent rollout tasks and return them for external handling.""" - async def process_row(row: EvaluationRow) -> EvaluationRow: - """Process a single row with agent rollout.""" - agent = Agent( - model=config.completion_params["model"], row=row, config_path=config.mcp_config_path, logger=config.logger - ) - try: - await agent.setup() - await agent.call_agent() - return agent.evaluation_row - finally: - if agent.mcp_client: - await agent.mcp_client.cleanup() - - async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: - async with semaphore: - try: - return await process_row(r) - except Exception as e: - logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}") - return r - - # Create all tasks - tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 + semaphore = asyncio.Semaphore(max_concurrent) - # Yield results as they complete (note that they're not necessarily in original order) - try: - for task in asyncio.as_completed(tasks): + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row with agent rollout.""" + agent = Agent( + model=config.completion_params["model"], + row=row, + config_path=config.mcp_config_path, + logger=config.logger, + ) try: - yield await task - except Exception: - logger.exception("Error processing row") - finally: - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + await agent.setup() + await agent.call_agent() + return agent.evaluation_row + finally: + if agent.mcp_client: + await agent.mcp_client.cleanup() + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + result = await process_row(r) + return result + + # Create and return tasks for external handling + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + return tasks diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 2b90239d..b7376e9c 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -6,10 +6,11 @@ import subprocess import time from pathlib import Path -from typing import AsyncIterator, List, Optional +from typing import List, Optional import eval_protocol as ep -from eval_protocol.models import EvaluationRow, Message +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig @@ -192,53 +193,73 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False # Don't suppress exceptions -async def default_mcp_gym_rollout_processor( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: +class MCPGymRolloutProcessor(RolloutProcessor): """ Rollout processor for tau bench environments. - This processor starts an MCP server, creates tau bench environments, and runs rollouts - using the eval_protocol framework, yielding results as they complete. + This processor starts an MCP server, creates tau bench environments, and returns rollout tasks + using the eval_protocol framework with proper cleanup handling. + """ - Args: - rows: List of EvaluationRow objects containing messages and dataset info in input_metadata - config: RolloutProcessorConfig with model and other parameters + def __init__(self): + self.server = None + self.policy = None - Returns: - AsyncIterator of EvaluationRow objects with completed conversations - """ - if config.server_script_path is None: - raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor") - server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) - - try: - server.start() - - policy = ep.LiteLLMPolicy( - model_id=config.completion_params.model, - temperature=config.completion_params.get("temperature", 0.0), - max_tokens=config.completion_params.get("max_tokens", 4096), - reasoning_effort=config.completion_params.get("reasoning_effort", None), - ) + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Process evaluation rows with MCP gym environments.""" + start_server = config.kwargs.get("start_server", True) if config.kwargs else True + + if start_server: + # Create fresh MCP server and environments for this run + if config.server_script_path is None: + raise ValueError("server_script_path is required for MCPGymRolloutProcessor") + + self.server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) + + try: + self.server.start() + + self.policy = ep.LiteLLMPolicy( + model_id=config.completion_params.get("model", None), + temperature=config.completion_params.get("temperature", 0.0), + max_tokens=config.completion_params.get("max_tokens", 4096), + reasoning_effort=config.completion_params.get("reasoning_effort", None), + ) + + except Exception as e: + if self.server: + self.server.stop() + self.server = None + self.policy = None + raise e + + else: + # Reuse existing MCP environments for retry + if not self.server or not self.policy: + raise RuntimeError( + "Cannot retry without existing server/environments. Call with start_server=True first." + ) # Create MCP environments directly from evaluation_rows envs = ep.make( "http://localhost:9700/mcp/", evaluation_rows=rows, - model_id=policy.model_id, + model_id=self.policy.model_id, ) - # Run rollout with environments and policy - async for evaluation_row in ep.rollout( + # Get rollout tasks from ep.rollout + tasks = ep.rollout( envs, - policy=policy, + policy=self.policy, evaluation_rows=rows, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts, - ): - yield evaluation_row - - finally: - # Always clean up the server - server.stop() + ) + return tasks + + def cleanup(self) -> None: + """Cleanup MCP server and environments.""" + if self.server: + self.server.stop() + self.server = None + self.policy = None diff --git a/eval_protocol/pytest/default_no_op_rollout_process.py b/eval_protocol/pytest/default_no_op_rollout_process.py deleted file mode 100644 index 47cb17be..00000000 --- a/eval_protocol/pytest/default_no_op_rollout_process.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import AsyncIterator, List - -from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.types import RolloutProcessorConfig - - -async def default_no_op_rollout_processor( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: - """ - Simply passes input dataset through to the test function. This can be useful - if you want to run the rollout yourself. - """ - for row in rows: - yield row diff --git a/eval_protocol/pytest/default_no_op_rollout_processor.py b/eval_protocol/pytest/default_no_op_rollout_processor.py new file mode 100644 index 00000000..973d6083 --- /dev/null +++ b/eval_protocol/pytest/default_no_op_rollout_processor.py @@ -0,0 +1,27 @@ +import asyncio +from typing import List + +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + + +class NoOpRolloutProcessor(RolloutProcessor): + """ + No-op rollout processor that passes input dataset through unchanged. + + Simply returns the input rows as completed tasks. This is useful for testing + or when you want to handle rollout processing manually. + """ + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Process rows by returning them unchanged (no-op implementation).""" + + async def return_row(row: EvaluationRow) -> EvaluationRow: + return row + + # Create tasks that immediately return the rows (no-op) + tasks = [asyncio.create_task(return_row(row)) for row in rows] + return tasks + + # Inherits cleanup() from RolloutProcessor - no override needed diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index ef2ad48b..bf43b7da 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -2,129 +2,117 @@ import logging import os import time -from typing import AsyncIterator, List +from typing import List -import litellm from litellm import acompletion from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall from eval_protocol.dataset_logger import default_logger from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig logger = logging.getLogger(__name__) -async def default_single_turn_rollout_processor( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: - """Generate a single response from any supported model provider using LiteLLM.""" - - # Quiet LiteLLM logs in test runs unless user overrode - try: - if os.environ.get("LITELLM_LOG") is None: - os.environ["LITELLM_LOG"] = "ERROR" - _llog = logging.getLogger("LiteLLM") - _llog.setLevel(logging.CRITICAL) - _llog.propagate = False - for _h in list(_llog.handlers): - _llog.removeHandler(_h) - except Exception: - pass - - # Do not modify global LiteLLM cache. Disable caching per-request instead. - - async def process_row(row: EvaluationRow) -> EvaluationRow: - """Process a single row asynchronously.""" - if len(row.messages) == 0: - raise ValueError("Messages is empty. Please provide a non-empty dataset") - - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] - - request_params = {"messages": messages_payload, **config.completion_params} - # Ensure caching is disabled only for this request (review feedback) - request_params["cache"] = {"no-cache": True} - # Single-level reasoning effort: expect `reasoning_effort` only - effort_val = None - - if "reasoning_effort" in config.completion_params: - effort_val = str(config.completion_params["reasoning_effort"]) # flat shape - elif ( - isinstance(config.completion_params.get("extra_body"), dict) - and "reasoning_effort" in config.completion_params["extra_body"] - ): - # Accept if user passed it directly inside extra_body - effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body - - if effort_val: - # Always under extra_body so LiteLLM forwards to provider-specific param set - request_params.setdefault("extra_body", {}) - request_params["extra_body"]["reasoning_effort"] = effort_val - # Ensure unsupported top-level keys are not present - if "reasoning_effort" in request_params: - request_params.pop("reasoning_effort", None) - - if row.tools is not None: - request_params["tools"] = row.tools - - # Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet - import importlib - - _litellm = importlib.import_module("litellm") - acompletion = getattr(_litellm, "acompletion") - response = await acompletion(**request_params) - - assistant_content = response.choices[0].message.content or "" - tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None - - converted_tool_calls = None - if tool_calls: - converted_tool_calls = [ - ChatCompletionMessageToolCall( - id=tool_call.id, - type=tool_call.type, - function={ - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, +class SingleTurnRolloutProcessor(RolloutProcessor): + """Single turn rollout processor for direct LLM calls.""" + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Generate single turn rollout tasks and return them for external handling.""" + + # Quiet LiteLLM logs in test runs unless user overrode + try: + if os.environ.get("LITELLM_LOG") is None: + os.environ["LITELLM_LOG"] = "ERROR" + _llog = logging.getLogger("LiteLLM") + _llog.setLevel(logging.CRITICAL) + _llog.propagate = False + for _h in list(_llog.handlers): + _llog.removeHandler(_h) + except Exception: + pass + + # Do not modify global LiteLLM cache. Disable caching per-request instead. + + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row asynchronously.""" + if len(row.messages) == 0: + raise ValueError("Messages is empty. Please provide a non-empty dataset") + + messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] + + request_params = {"messages": messages_payload, **config.completion_params} + # Ensure caching is disabled only for this request (review feedback) + request_params["cache"] = {"no-cache": True} + # Single-level reasoning effort: expect `reasoning_effort` only + effort_val = None + + if "reasoning_effort" in config.completion_params: + effort_val = str(config.completion_params["reasoning_effort"]) # flat shape + elif ( + isinstance(config.completion_params.get("extra_body"), dict) + and "reasoning_effort" in config.completion_params["extra_body"] + ): + # Accept if user passed it directly inside extra_body + effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body + + if effort_val: + # Always under extra_body so LiteLLM forwards to provider-specific param set + request_params.setdefault("extra_body", {}) + request_params["extra_body"]["reasoning_effort"] = effort_val + # Ensure unsupported top-level keys are not present + if "reasoning_effort" in request_params: + request_params.pop("reasoning_effort", None) + + if row.tools is not None: + request_params["tools"] = row.tools + + # Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet + import importlib + + _litellm = importlib.import_module("litellm") + acompletion = getattr(_litellm, "acompletion") + response = await acompletion(**request_params) + + assistant_content = response.choices[0].message.content or "" + tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None + + converted_tool_calls = None + if tool_calls: + converted_tool_calls = [ + ChatCompletionMessageToolCall( + id=tool_call.id, + type=tool_call.type, + function={ + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + ) + for tool_call in tool_calls + ] + + messages = list(row.messages) + [ + Message( + role="assistant", + content=assistant_content, + tool_calls=converted_tool_calls, ) - for tool_call in tool_calls ] - messages = list(row.messages) + [ - Message( - role="assistant", - content=assistant_content, - tool_calls=converted_tool_calls, - ) - ] - - row.messages = messages - default_logger.log(row) - return row - - # Process rows with bounded concurrency and yield as they complete - max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 - semaphore = asyncio.Semaphore(max_concurrent) - - async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: - async with semaphore: - try: - return await process_row(r) - except Exception as e: - return r - - # Create all tasks - tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] - - # Yield results as they complete (note that they're not necessarily in original order) - try: - for task in asyncio.as_completed(tasks): - try: - yield await task - except Exception: - logger.exception("Error processing row") - finally: - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + row.messages = messages + default_logger.log(row) + return row + + # Process rows with bounded concurrency + max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 + semaphore = asyncio.Semaphore(max_concurrent) + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + result = await process_row(r) + return result + + # Create and return tasks for external handling + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + return tasks diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index dd7ecb04..6127c7b9 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -8,6 +8,7 @@ import re import statistics import time +from dataclasses import replace from typing import Any, Callable, Dict, List, Literal, Optional, Union import pytest @@ -24,7 +25,8 @@ Message, ) from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter -from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( Dataset, DatasetPathParam, @@ -32,7 +34,6 @@ EvaluationTestMode, InputMessagesParam, ModelParam, - RolloutProcessor, RolloutProcessorConfig, RolloutProcessorInputParam, TestFunction, @@ -41,8 +42,14 @@ AggregationMethod, aggregate, create_dynamically_parameterized_wrapper, + deep_update_dict, execute_function, + extract_effort_tag, + generate_parameter_combinations, log_eval_status_and_rows, + parse_ep_max_rows, + rollout_processor_with_retry, + sanitize_filename, ) from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci @@ -55,7 +62,7 @@ def evaluation_test( # noqa: C901 input_messages: Optional[List[InputMessagesParam]] = None, input_dataset: Optional[List[DatasetPathParam]] = None, dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter, - rollout_processor: RolloutProcessor = default_no_op_rollout_processor, + rollout_processor: RolloutProcessor = NoOpRolloutProcessor(), evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None, rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None, aggregation_method: AggregationMethod = "mean", @@ -200,76 +207,15 @@ async def execute_with_params( return test_func(**kwargs) # Calculate all possible combinations of parameters - def _parse_ep_max_rows(default_value: int | None) -> int | None: - """Read EP_MAX_DATASET_ROWS env override as int or None.""" - raw = os.getenv("EP_MAX_DATASET_ROWS") - if raw is None: - return default_value - s = raw.strip().lower() - if s == "none": - return None - try: - return int(s) - except ValueError: - return default_value - - def _deep_update_dict(base: dict, override: dict) -> dict: - """Recursively update nested dictionaries in-place and return base.""" - for key, value in override.items(): - if isinstance(value, dict) and isinstance(base.get(key), dict): - _deep_update_dict(base[key], value) - else: - base[key] = value - return base - - def generate_combinations(): - combinations = [] - - # Handle optional parameters with defaults - # Optionally combine multiple dataset paths into one logical dataset, - # or parameterize to run one dataset per test invocation. - if input_dataset is not None: - if combine_datasets: - datasets: List[Optional[List[DatasetPathParam]]] = [input_dataset] # type: ignore - else: - # Fan out: one dataset path per parameterization - if isinstance(input_dataset, list): # type: ignore - datasets = [[p] for p in input_dataset] # type: ignore - else: - datasets = [[input_dataset]] # type: ignore - else: - datasets = [None] - cps: List[Optional[CompletionParams]] = completion_params if completion_params is not None else [None] # type: ignore - # Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over - # each row. Instead, pass the entire sliced list through in a single test run - # so summaries aggregate all rows together (AIME-style behavior). - if input_messages is not None and isinstance(input_messages, list): - effective_max_rows = _parse_ep_max_rows(max_dataset_rows) - if effective_max_rows is not None: - sliced_messages = input_messages[:effective_max_rows] # type: ignore - else: - sliced_messages = input_messages # type: ignore - # Wrap as a single parameter payload - messages = [sliced_messages] # type: ignore - else: - messages = [None] # type: ignore - kwargs: List[Optional[EvaluationInputParam]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] # type: ignore - - # Generate all combinations - for ds in datasets: - for cp in cps: - for im in messages: - for etk in kwargs: - # if no dataset and no messages, raise an error - if ds is None and im is None: - raise ValueError( - "No dataset or messages provided. Please provide at least one of input_dataset or input_messages." - ) - combinations.append((ds, cp, im, etk)) - - return combinations - combinations = generate_combinations() + combinations = generate_parameter_combinations( + input_dataset, + completion_params, + input_messages, + evaluation_test_kwargs, + max_dataset_rows, + combine_datasets, + ) if len(combinations) == 0: raise ValueError( "No combinations of parameters were found. Please provide at least a model and one of input_dataset or input_messages." @@ -331,7 +277,7 @@ def _log_eval_error( else: data_jsonl = load_jsonl(ds_arg) # Apply env override for max rows if present - effective_max_rows = _parse_ep_max_rows(max_dataset_rows) + effective_max_rows = parse_ep_max_rows(max_dataset_rows) if effective_max_rows is not None: data_jsonl = data_jsonl[:effective_max_rows] data = dataset_adapter(data_jsonl) @@ -367,7 +313,7 @@ def _log_eval_error( if _env_override: override_obj = _json.loads(_env_override) if isinstance(override_obj, dict): - completion_params = _deep_update_dict(dict(completion_params), override_obj) + completion_params = deep_update_dict(dict(completion_params), override_obj) except Exception: pass @@ -410,6 +356,8 @@ def _log_eval_error( kwargs=rollout_processor_kwargs or {}, ) + max_retry = int(os.getenv("EP_MAX_RETRY", "0")) + for i in range(num_runs): # Regenerate outputs each run by deep-copying the pristine dataset # so model responses are not reused across runs. @@ -428,8 +376,6 @@ def _log_eval_error( for row in fresh_dataset: active_logger.log(row) - rollout_result = rollout_processor(fresh_dataset, config) - if mode == "pointwise": # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution semaphore = asyncio.Semaphore(max_concurrent_rollouts) @@ -437,6 +383,8 @@ def _log_eval_error( async def _execute_with_semaphore(row): async with semaphore: + # NOTE: we will still evaluate errored rows (give users control over this) + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func result = await execute_with_params( test_func, processed_row=row, @@ -448,7 +396,10 @@ async def _execute_with_semaphore(row): ) return result - async for row in rollout_processor(fresh_dataset, config): + # Use wrapper that handles retry logic internally + async for row in rollout_processor_with_retry( + rollout_processor, fresh_dataset, config, max_retry + ): tasks.append(asyncio.create_task(_execute_with_semaphore(row))) all_results[i] = await asyncio.gather(*tasks) @@ -456,9 +407,12 @@ async def _execute_with_semaphore(row): else: # Batch mode: collect all results first, then evaluate (no pipelining) input_dataset = [] - async for row in rollout_result: + async for row in rollout_processor_with_retry( + rollout_processor, fresh_dataset, config, max_retry + ): input_dataset.append(row) - + # NOTE: we will still evaluate errored rows (give users control over this) + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func results = await execute_with_params( test_func, processed_dataset=input_dataset, @@ -517,11 +471,10 @@ async def _execute_with_semaphore(row): passed = success_passed and std_passed - # Update eval metadata status and passed field for all results + # Update eval metadata passed field for all results for result in all_results: for r in result: if r.eval_metadata is not None: - r.eval_metadata.status = "finished" r.eval_metadata.passed = passed active_logger.log(r) @@ -530,7 +483,7 @@ async def _execute_with_semaphore(row): should_print = os.getenv("EP_PRINT_SUMMARY") == "1" summary_path = os.getenv("EP_SUMMARY_JSON") suite_name = test_func.__name__ - model_used = config.completion_params.model + model_used = config.completion_params["model"] total_rows = len([item for sublist in all_results for item in sublist]) summary_obj = { "suite": suite_name, @@ -587,35 +540,9 @@ async def _execute_with_semaphore(row): ) # As per project convention, avoid printing per-metric CI lines to reduce noise if summary_path: - - def _sanitize_filename(text: str) -> str: - safe = re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip()) - return safe[:120] - - def _extract_effort_tag(params: dict) -> str | None: - try: - if not isinstance(params, dict): - return None - # Common locations - if "extra_body" in params and isinstance(params["extra_body"], dict): - eb = params["extra_body"] - if isinstance(eb.get("reasoning"), dict) and "effort" in eb["reasoning"]: - return str(eb["reasoning"]["effort"]).lower() - if "reasoning_effort" in eb: - return str(eb["reasoning_effort"]).lower() - if ( - "reasoning" in params - and isinstance(params["reasoning"], dict) - and "effort" in params["reasoning"] - ): - return str(params["reasoning"]["effort"]).lower() - except Exception: - return None - return None - - model_slug = _sanitize_filename(model_used) - effort_tag = _extract_effort_tag(completion_params) or "" - effort_suffix = f"__effort-{_sanitize_filename(effort_tag)}" if effort_tag else "" + model_slug = sanitize_filename(model_used) + effort_tag = extract_effort_tag(completion_params) or "" + effort_suffix = f"__effort-{sanitize_filename(effort_tag)}" if effort_tag else "" base_name = f"{suite_name}__{model_slug}{effort_suffix}__{mode}__runs{num_runs}.json" p = pathlib.Path(summary_path) @@ -633,7 +560,7 @@ def _extract_effort_tag(params: dict) -> str | None: parent.mkdir(parents=True, exist_ok=True) # If we detected an effort tag, fan out to separate files; otherwise write to the exact file if effort_tag: - out_file = parent / f"{p.stem}__{_sanitize_filename(effort_tag)}{p.suffix}" + out_file = parent / f"{p.stem}__{sanitize_filename(effort_tag)}{p.suffix}" else: out_file = p @@ -822,7 +749,7 @@ def run_evaluation_test_direct( input_dataset: Optional[List[DatasetPathParam]] = None, dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter, completion_params: Optional[CompletionParams] = None, - rollout_processor: RolloutProcessor = default_no_op_rollout_processor, + rollout_processor: RolloutProcessor = NoOpRolloutProcessor(), rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None, aggregation_method: AggregationMethod = "mean", passed_threshold: Optional[Union[EvaluationThreshold, float]] = None, @@ -844,26 +771,6 @@ def run_evaluation_test_direct( if passed_threshold is not None and not isinstance(passed_threshold, EvaluationThreshold): passed_threshold = EvaluationThreshold(success=passed_threshold) - def _parse_ep_max_rows(default_value: int | None) -> int | None: - raw = os.getenv("EP_MAX_DATASET_ROWS") - if raw is None: - return default_value - s = raw.strip().lower() - if s == "none": - return None - try: - return int(s) - except ValueError: - return default_value - - def _deep_update_dict(base: dict, override: dict) -> dict: - for key, value in override.items(): - if isinstance(value, dict) and isinstance(base.get(key), dict): - _deep_update_dict(base[key], value) - else: - base[key] = value - return base - # Build dataset/messages data: List[EvaluationRow] = [] if input_dataset is not None: @@ -871,12 +778,12 @@ def _deep_update_dict(base: dict, override: dict) -> dict: data_jsonl: List[Dict[str, Any]] = [] for p in input_dataset: data_jsonl.extend(load_jsonl(p)) - effective_max_rows = _parse_ep_max_rows(max_dataset_rows) + effective_max_rows = parse_ep_max_rows(max_dataset_rows) if effective_max_rows is not None: data_jsonl = data_jsonl[:effective_max_rows] data = dataset_adapter(data_jsonl) elif input_messages is not None: - effective_max_rows = _parse_ep_max_rows(max_dataset_rows) + effective_max_rows = parse_ep_max_rows(max_dataset_rows) msgs = input_messages if effective_max_rows is not None and isinstance(msgs, list): msgs = msgs[:effective_max_rows] # type: ignore @@ -896,7 +803,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict: if _env_override: override_obj = _json.loads(_env_override) if isinstance(override_obj, dict): - completion_params = _deep_update_dict(dict(completion_params), override_obj) + completion_params = deep_update_dict(dict(completion_params), override_obj) except Exception: pass @@ -990,7 +897,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict: total_rows = len(all_results) summary_obj = { "suite": suite_name, - "model": config.completion_params.model, + "model": config.completion_params["model"], "agg_score": float(agg_score) if agg_score is not None else None, "num_runs": num_runs, "rows": total_rows, @@ -1001,45 +908,20 @@ def _deep_update_dict(base: dict, override: dict) -> dict: if should_print: if ci_low is not None and ci_high is not None: print( - f"EP Summary | suite={suite_name} model={config.completion_params.model} agg={summary_obj['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}" + f"EP Summary | suite={suite_name} model={config.completion_params['model']} agg={summary_obj['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}" ) else: print( - f"EP Summary | suite={suite_name} model={config.completion_params.model} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}" + f"EP Summary | suite={suite_name} model={config.completion_params['model']} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}" ) if summary_path: import json as _json import pathlib as _pathlib - import re as _re import time as _time - def _sanitize_filename(text: str) -> str: - safe = _re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip()) - return safe[:120] - - def _extract_effort_tag(params: dict) -> str | None: - try: - if not isinstance(params, dict): - return None - if "extra_body" in params and isinstance(params["extra_body"], dict): - eb = params["extra_body"] - if isinstance(eb.get("reasoning"), dict) and "effort" in eb["reasoning"]: - return str(eb["reasoning"]["effort"]).lower() - if "reasoning_effort" in eb: - return str(eb["reasoning_effort"]).lower() - if ( - "reasoning" in params - and isinstance(params["reasoning"], dict) - and "effort" in params["reasoning"] - ): - return str(params["reasoning"]["effort"]).lower() - except Exception: - return None - return None - - model_slug = _sanitize_filename(config.completion_params.model) - effort_tag = _extract_effort_tag(completion_params) or "" - effort_suffix = f"__effort-{_sanitize_filename(effort_tag)}" if effort_tag else "" + model_slug = sanitize_filename(config.completion_params["model"]) + effort_tag = extract_effort_tag(completion_params) or "" + effort_suffix = f"__effort-{sanitize_filename(effort_tag)}" if effort_tag else "" base_name = f"{suite_name}__{model_slug}{effort_suffix}__{mode}__runs{num_runs}.json" p = _pathlib.Path(summary_path) @@ -1052,7 +934,7 @@ def _extract_effort_tag(params: dict) -> str | None: parent = p.parent parent.mkdir(parents=True, exist_ok=True) if effort_tag: - out_file = parent / f"{p.stem}__{_sanitize_filename(effort_tag)}{p.suffix}" + out_file = parent / f"{p.stem}__{sanitize_filename(effort_tag)}{p.suffix}" else: out_file = p with open(out_file, "w", encoding="utf-8") as f: diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 3a5ec0e2..4522caef 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -59,6 +59,23 @@ def pytest_addoption(parser) -> None: "Values: low|medium|high" ), ) + group.addoption( + "--ep-max-retry", + action="store", + type=int, + default=None, + help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."), + ) + group.addoption( + "--ep-fail-on-permanent-failure", + action="store", + default=None, + choices=["true", "false"], + help=( + "Whether to fail the entire rollout when permanent failures occur after max retries. " + "Default: true (fail on permanent failures). Set to 'false' to continue with remaining rollouts." + ), + ) def _normalize_max_rows(val: Optional[str]) -> Optional[str]: @@ -100,6 +117,14 @@ def pytest_configure(config) -> None: if summary_json_path: os.environ["EP_SUMMARY_JSON"] = summary_json_path + max_retry = config.getoption("--ep-max-retry") + if max_retry is not None: + os.environ["EP_MAX_RETRY"] = str(max_retry) + + fail_on_permanent_failure = config.getoption("--ep-fail-on-permanent-failure") + if fail_on_permanent_failure is not None: + os.environ["EP_FAIL_ON_PERMANENT_FAILURE"] = fail_on_permanent_failure + # Allow ad-hoc overrides of input params via CLI flags try: import json as _json diff --git a/eval_protocol/pytest/rollout_processor.py b/eval_protocol/pytest/rollout_processor.py new file mode 100644 index 00000000..824dd015 --- /dev/null +++ b/eval_protocol/pytest/rollout_processor.py @@ -0,0 +1,21 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import List + +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.types import RolloutProcessorConfig + + +class RolloutProcessor(ABC): + """ + Abstract base class for all rollout processor strategies. + """ + + @abstractmethod + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Process evaluation rows and return async tasks. Must be implemented by subclasses.""" + pass + + def cleanup(self) -> None: + """Cleanup resources. Override in subclasses if cleanup is needed.""" + pass diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 1a80254b..8a3be489 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -2,8 +2,9 @@ Parameter types """ +import asyncio from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from eval_protocol.dataset_logger import default_logger from eval_protocol.dataset_logger.dataset_logger import DatasetLogger @@ -49,6 +50,3 @@ class RolloutProcessorConfig: steps: int = 30 # max number of rollout steps logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor - - -RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], AsyncIterator[EvaluationRow]] diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 23a5722d..24b60028 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -1,9 +1,20 @@ import asyncio import inspect -from typing import Any, Callable, List, Literal, Optional +import os +import re +from dataclasses import replace +from typing import Any, Callable, Dict, List, Literal, Optional, Union from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.models import EvalMetadata, EvaluationRow +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.types import ( + CompletionParams, + DatasetPathParam, + EvaluationInputParam, + InputMessagesParam, + RolloutProcessorConfig, +) def execute_function(func: Callable, **kwargs) -> Any: @@ -124,3 +135,223 @@ def log_eval_status_and_rows( if r.eval_metadata is not None: r.eval_metadata.status = status logger.log(r) + + +def parse_ep_max_rows(default_value: Optional[int]) -> Optional[int]: + """Read EP_MAX_DATASET_ROWS env override as int or None.""" + raw = os.getenv("EP_MAX_DATASET_ROWS") + if raw is None: + return default_value + s = raw.strip().lower() + if s == "none": + return None + try: + return int(s) + except ValueError: + return default_value + + +def deep_update_dict(base: dict, override: dict) -> dict: + """Recursively update nested dictionaries in-place and return base.""" + for key, value in override.items(): + if isinstance(value, dict) and isinstance(base.get(key), dict): + deep_update_dict(base[key], value) + else: + base[key] = value + return base + + +def generate_parameter_combinations( + input_dataset: Optional[List[DatasetPathParam]], + completion_params: List[CompletionParams], + input_messages: Optional[List[InputMessagesParam]], + evaluation_test_kwargs: Optional[List[EvaluationInputParam]], + max_dataset_rows: Optional[int], + combine_datasets: bool, +) -> List[tuple]: + """ + Generate all combinations of parameters for pytest parameterization. + + Args: + input_dataset: Dataset paths to use + completion_params: Completion parameters to test + input_messages: Input messages to use + evaluation_test_kwargs: Additional kwargs for evaluation tests + max_dataset_rows: Maximum number of dataset rows to process + combine_datasets: Whether to combine multiple datasets into one test + + Returns: + List of parameter tuples for pytest.mark.parametrize + """ + combinations = [] + + # Handle optional parameters with defaults + # Optionally combine multiple dataset paths into one logical dataset, + # or parameterize to run one dataset per test invocation. + if input_dataset is not None: + if combine_datasets: + datasets: List[Optional[List[DatasetPathParam]]] = [input_dataset] # type: ignore + else: + # Fan out: one dataset path per parameterization + if isinstance(input_dataset, list): # type: ignore + datasets = [[p] for p in input_dataset] # type: ignore + else: + datasets = [[input_dataset]] # type: ignore + else: + datasets = [None] + + cps: List[Optional[CompletionParams]] = completion_params if completion_params is not None else [None] # type: ignore + + # Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over + # each row. Instead, pass the entire sliced list through in a single test run + # so summaries aggregate all rows together (AIME-style behavior). + if input_messages is not None and isinstance(input_messages, list): + effective_max_rows = parse_ep_max_rows(max_dataset_rows) + if effective_max_rows is not None: + sliced_messages = input_messages[:effective_max_rows] # type: ignore + else: + sliced_messages = input_messages # type: ignore + # Wrap as a single parameter payload + messages = [sliced_messages] # type: ignore + else: + messages = [None] # type: ignore + + kwargs: List[Optional[EvaluationInputParam]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] # type: ignore + + # Generate all combinations + for ds in datasets: + for cp in cps: + for im in messages: + for etk in kwargs: + # if no dataset and no messages, raise an error + if ds is None and im is None: + raise ValueError( + "No dataset or messages provided. Please provide at least one of input_dataset or input_messages." + ) + combinations.append((ds, cp, im, etk)) + + return combinations + + +async def rollout_processor_with_retry( + rollout_processor: RolloutProcessor, + fresh_dataset: List[EvaluationRow], + config: RolloutProcessorConfig, + max_retry: int, +): + """ + Wrapper around rollout_processor that handles retry logic internally. + Uses async queue pattern to yield results immediately as they become available. + Yields both successful and failed results, leaving it up to the user to handle them in test_func. + """ + + try: + queue = asyncio.Queue() + retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset} + failed_permanently = [] + + async def retry_handler(failed_row: EvaluationRow): + rollout_id = failed_row.execution_metadata.rollout_id + current_attempts = retry_counts.get(rollout_id, 0) + + if current_attempts >= max_retry: + assert ( + failed_row.rollout_status and failed_row.rollout_status.status == "error" + ), f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" + failed_permanently.append(failed_row) + await queue.put(failed_row) # put failed row on queue + return + + retry_counts[rollout_id] = current_attempts + 1 + + # add kwargs start_server=False to config so we don't start new MCP server + retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) + + retry_tasks = rollout_processor([failed_row], retry_config) + + try: + retry_result = await retry_tasks[0] + retry_result.rollout_status.status = "finished" + await queue.put(retry_result) + except Exception as e: + failed_row.rollout_status.status = "error" + failed_row.rollout_status.termination_reason = str(e) + asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry + + async def initial_processor(): + """Process initial batch and spawn retries for failures""" + base_tasks = rollout_processor(fresh_dataset, config) + pending = set(base_tasks) + + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + task_index = base_tasks.index(task) + + try: + result = await task + result.rollout_status.status = "finished" + await queue.put(result) + except Exception as e: + failed_row = fresh_dataset[task_index] + failed_row.rollout_status.status = "error" + failed_row.rollout_status.termination_reason = str(e) + asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task + + processor_task = asyncio.create_task(initial_processor()) + + # yield results as they become available + completed_count = 0 + total_expected = len(fresh_dataset) + + while completed_count < total_expected: + finished_row = await queue.get() + + # only permanent failure rows are put on the queue, so we can check for them here + if finished_row.rollout_status and finished_row.rollout_status.status == "error": + if os.getenv("EP_FAIL_ON_PERMANENT_FAILURE", "true") != "false": + raise RuntimeError( + f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}" + ) + + completed_count += 1 + yield finished_row + + await processor_task # explicitly wait for task completion and catch any exceptions + + finally: + rollout_processor.cleanup() + + +def sanitize_filename(text: str) -> str: + """Sanitize text for use in filenames by replacing special characters with dashes.""" + safe = re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip()) + return safe[:120] + + +def extract_effort_tag(params: dict) -> Optional[str]: + """ + Extract effort tag from completion parameters for use in file naming. + + Args: + params: Completion parameters dictionary + + Returns: + Effort tag string if found, None otherwise + """ + try: + if not isinstance(params, dict): + return None + # Common locations + if "extra_body" in params and isinstance(params["extra_body"], dict): + eb = params["extra_body"] + if isinstance(eb.get("reasoning"), dict) and "effort" in eb["reasoning"]: + return str(eb["reasoning"]["effort"]).lower() + if "reasoning_effort" in eb: + return str(eb["reasoning_effort"]).lower() + if "reasoning" in params and isinstance(params["reasoning"], dict) and "effort" in params["reasoning"]: + return str(params["reasoning"]["effort"]).lower() + except Exception: + return None + return None diff --git a/examples/gpqa/tests/test_gpqa.py b/examples/gpqa/tests/test_gpqa.py index dcbf7b53..d67e64a1 100644 --- a/examples/gpqa/tests/test_gpqa.py +++ b/examples/gpqa/tests/test_gpqa.py @@ -7,7 +7,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test @@ -66,7 +66,7 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]: completion_params=[ {"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"} ], # default to low effort; override via CLI plugin - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=8, diff --git a/examples/healthbench/tests/test_evaluation.py b/examples/healthbench/tests/test_evaluation.py index a40c5d96..e0c7917b 100644 --- a/examples/healthbench/tests/test_evaluation.py +++ b/examples/healthbench/tests/test_evaluation.py @@ -3,7 +3,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test @@ -51,7 +51,7 @@ completion_params=[ {"temperature": 0.2, "max_tokens": 512, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"} ], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=1, diff --git a/tests/pytest/test_apps_coding.py b/tests/pytest/test_apps_coding.py index 7cb976ac..9350a381 100644 --- a/tests/pytest/test_apps_coding.py +++ b/tests/pytest/test_apps_coding.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.apps_coding_reward import evaluate_apps_solution @@ -30,7 +30,7 @@ def apps_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio {"temperature": 0.0, "max_tokens": 4096, "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"} ], passed_threshold=0.33, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), num_runs=1, mode="pointwise", ) diff --git a/tests/pytest/test_basic_coding.py b/tests/pytest/test_basic_coding.py index 2b1c2a4a..4945d378 100644 --- a/tests/pytest/test_basic_coding.py +++ b/tests/pytest/test_basic_coding.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.code_execution import execute_python_code, extract_code_blocks @@ -32,7 +32,7 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat {"temperature": 0.0, "max_tokens": 4096, "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"} ], passed_threshold=0.8, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), num_runs=1, mode="pointwise", ) diff --git a/tests/pytest/test_frozen_lake.py b/tests/pytest/test_frozen_lake.py index bea42bed..24e32b56 100644 --- a/tests/pytest/test_frozen_lake.py +++ b/tests/pytest/test_frozen_lake.py @@ -9,7 +9,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message, MetricResult from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]: @@ -41,7 +41,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation completion_params=[ {"temperature": 0.0, "max_tokens": 4096, "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"} ], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), passed_threshold=0.66, num_runs=1, max_concurrent_rollouts=3, diff --git a/tests/pytest/test_hallucination.py b/tests/pytest/test_hallucination.py index b29fb53c..fe8f32f0 100644 --- a/tests/pytest/test_hallucination.py +++ b/tests/pytest/test_hallucination.py @@ -12,7 +12,7 @@ import litellm from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test # Configure the judge model for LiteLLM JUDGE_MODEL = "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct" @@ -35,7 +35,7 @@ def hallucination_dataset_adapter(data: List[Dict[str, Any]]) -> List[Evaluation completion_params=[ {"temperature": 0.0, "max_tokens": 512, "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"} ], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), passed_threshold=0.33, num_runs=1, mode="pointwise", diff --git a/tests/pytest/test_lunar_lander.py b/tests/pytest/test_lunar_lander.py index 3fddac62..00f966a5 100644 --- a/tests/pytest/test_lunar_lander.py +++ b/tests/pytest/test_lunar_lander.py @@ -9,7 +9,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor def lunar_lander_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]: @@ -39,7 +39,7 @@ def lunar_lander_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio input_dataset=["tests/pytest/data/lunar_lander_dataset.jsonl"], dataset_adapter=lunar_lander_to_evaluation_row, completion_params=[{"temperature": 0.0, "max_tokens": 4096, "model": "gpt-4.1"}], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), passed_threshold=0.0, num_runs=1, mode="pointwise", diff --git a/tests/pytest/test_markdown_highlighting.py b/tests/pytest/test_markdown_highlighting.py index 9c70721f..c393ee60 100644 --- a/tests/pytest/test_markdown_highlighting.py +++ b/tests/pytest/test_markdown_highlighting.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]: @@ -32,7 +32,7 @@ def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu {"temperature": 0.0, "max_tokens": 4096, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"} ], passed_threshold=0.5, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), num_runs=1, mode="pointwise", ) diff --git a/tests/pytest/test_pytest_default_agent_rollout_processor.py b/tests/pytest/test_pytest_default_agent_rollout_processor.py index 8320ec8a..bfabe35c 100644 --- a/tests/pytest/test_pytest_default_agent_rollout_processor.py +++ b/tests/pytest/test_pytest_default_agent_rollout_processor.py @@ -2,7 +2,7 @@ from typing import List from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test +from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test @evaluation_test( @@ -16,7 +16,7 @@ ) ] ], - rollout_processor=default_agent_rollout_processor, + rollout_processor=AgentRolloutProcessor(), completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], ) def test_pytest_default_agent_rollout_processor(rows: List[EvaluationRow]) -> List[EvaluationRow]: diff --git a/tests/pytest/test_pytest_ensure_logging.py b/tests/pytest/test_pytest_ensure_logging.py index c9884756..e57b3c8c 100644 --- a/tests/pytest/test_pytest_ensure_logging.py +++ b/tests/pytest/test_pytest_ensure_logging.py @@ -18,7 +18,7 @@ async def test_ensure_logging(monkeypatch): "eval_protocol.dataset_logger.sqlite_dataset_logger_adapter.SqliteEvaluationRowStore", return_value=mock_store ): from eval_protocol.models import EvaluationRow - from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor + from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor from eval_protocol.pytest.evaluation_test import evaluation_test from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row @@ -28,7 +28,7 @@ async def test_ensure_logging(monkeypatch): ], completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=default_no_op_rollout_processor, + rollout_processor=NoOpRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=2, diff --git a/tests/pytest/test_pytest_flaky_sometimes.py b/tests/pytest/test_pytest_flaky_sometimes.py index 65e1e63d..bde5e34c 100644 --- a/tests/pytest/test_pytest_flaky_sometimes.py +++ b/tests/pytest/test_pytest_flaky_sometimes.py @@ -5,7 +5,7 @@ import pytest from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_no_op_rollout_processor, evaluation_test +from eval_protocol.pytest import NoOpRolloutProcessor, evaluation_test # skip in CI since it will intentionally fail. This is useful for local generation of logs @@ -13,7 +13,7 @@ @evaluation_test( input_messages=[[Message(role="user", content="Return HEADS or TAILS at random.")]], completion_params=[{"model": "dummy/local-model"}], - rollout_processor=default_no_op_rollout_processor, + rollout_processor=NoOpRolloutProcessor(), mode="pointwise", num_runs=5, ) diff --git a/tests/pytest/test_pytest_function_calling.py b/tests/pytest/test_pytest_function_calling.py index 63488dbe..60f38b0d 100644 --- a/tests/pytest/test_pytest_function_calling.py +++ b/tests/pytest/test_pytest_function_calling.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluationRow -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.function_calling import exact_tool_match_reward @@ -23,7 +23,7 @@ def function_calling_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[Evalu completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], mode="pointwise", dataset_adapter=function_calling_to_evaluation_row, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), ) async def test_pytest_function_calling(row: EvaluationRow) -> EvaluationRow: """Run pointwise evaluation on sample dataset using pytest interface.""" diff --git a/tests/pytest/test_pytest_ids.py b/tests/pytest/test_pytest_ids.py index 045d2a19..b6bb4a35 100644 --- a/tests/pytest/test_pytest_ids.py +++ b/tests/pytest/test_pytest_ids.py @@ -3,7 +3,7 @@ import eval_protocol.dataset_logger as dataset_logger from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row @@ -30,7 +30,7 @@ async def test_evaluation_test_decorator(monkeypatch): ], completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=default_no_op_rollout_processor, + rollout_processor=NoOpRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=2, @@ -71,7 +71,7 @@ async def test_evaluation_test_decorator_ids_single(monkeypatch): {"temperature": 1.0, "model": "dummy/local-model"}, ], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=default_no_op_rollout_processor, + rollout_processor=NoOpRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=5, diff --git a/tests/pytest/test_pytest_input_messages.py b/tests/pytest/test_pytest_input_messages.py index dc460aa5..7b4f8d9e 100644 --- a/tests/pytest/test_pytest_input_messages.py +++ b/tests/pytest/test_pytest_input_messages.py @@ -1,7 +1,7 @@ from typing import List from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test @evaluation_test( @@ -11,7 +11,7 @@ ] ], completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), ) def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[EvaluationRow]: """Run math evaluation on sample dataset using pytest interface.""" diff --git a/tests/pytest/test_pytest_json_schema.py b/tests/pytest/test_pytest_json_schema.py index 158874f1..c5a20c5d 100644 --- a/tests/pytest/test_pytest_json_schema.py +++ b/tests/pytest/test_pytest_json_schema.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluationRow -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.json_schema import json_schema_reward @@ -26,7 +26,7 @@ def json_schema_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[Evaluation input_dataset=["tests/pytest/data/json_schema.jsonl"], completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], mode="pointwise", - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), dataset_adapter=json_schema_to_evaluation_row, ) async def test_pytest_function_calling(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/pytest/test_pytest_math_example.py b/tests/pytest/test_pytest_math_example.py index 23010797..55c525be 100644 --- a/tests/pytest/test_pytest_math_example.py +++ b/tests/pytest/test_pytest_math_example.py @@ -1,5 +1,5 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.math import math_reward from examples.math_example.main import check_think_answer_format from tests.pytest.helper.gsm8k_to_evaluation_row import gsm8k_to_evaluation_row @@ -11,7 +11,7 @@ completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], max_dataset_rows=5, passed_threshold=0.0, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), mode="pointwise", evaluation_test_kwargs=[ {"math_reward_kwargs": {"tolerance": 0.001, "absolute_tolerance": 1e-8, "require_units": False}} diff --git a/tests/pytest/test_pytest_math_format_length.py b/tests/pytest/test_pytest_math_format_length.py index 5bba5c0e..3da732a0 100644 --- a/tests/pytest/test_pytest_math_format_length.py +++ b/tests/pytest/test_pytest_math_format_length.py @@ -1,7 +1,7 @@ import math from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.length import count_tokens from eval_protocol.rewards.math import math_reward from examples.math_with_format_and_length.main import check_think_answer_format @@ -14,7 +14,7 @@ completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], max_dataset_rows=5, passed_threshold=0.0, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), mode="pointwise", evaluation_test_kwargs=[ { diff --git a/tests/pytest/test_pytest_mcp_config.py b/tests/pytest/test_pytest_mcp_config.py index dde15aa9..c578d07c 100644 --- a/tests/pytest/test_pytest_mcp_config.py +++ b/tests/pytest/test_pytest_mcp_config.py @@ -2,7 +2,7 @@ from typing import List from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test +from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test @evaluation_test( @@ -19,7 +19,7 @@ ) ] ], - rollout_processor=default_agent_rollout_processor, + rollout_processor=AgentRolloutProcessor(), completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-20b"}], mode="pointwise", mcp_config_path="tests/pytest/mcp_configurations/mock_discord_mcp_config.json", diff --git a/tests/pytest/test_pytest_mcp_url.py b/tests/pytest/test_pytest_mcp_url.py index 01c06c45..ce265da5 100644 --- a/tests/pytest/test_pytest_mcp_url.py +++ b/tests/pytest/test_pytest_mcp_url.py @@ -1,5 +1,5 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test +from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test @evaluation_test( @@ -18,7 +18,7 @@ ), ] ], - rollout_processor=default_agent_rollout_processor, + rollout_processor=AgentRolloutProcessor(), completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], mode="pointwise", mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config.json", diff --git a/tests/pytest/test_pytest_word_count_example.py b/tests/pytest/test_pytest_word_count_example.py index 339c5152..72c9bc2f 100644 --- a/tests/pytest/test_pytest_word_count_example.py +++ b/tests/pytest/test_pytest_word_count_example.py @@ -1,7 +1,7 @@ from haikus import haikus from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from tests.pytest.helper.word_count_to_evaluation_row import word_count_to_evaluation_row @@ -11,7 +11,7 @@ completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], max_dataset_rows=5, passed_threshold=0.3, # Reasonable threshold for word count evaluation - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), mode="pointwise", # Use pointwise mode for elegant row-by-row evaluation ) def test_word_count_evaluate(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/pytest/test_tau_bench_airline.py b/tests/pytest/test_tau_bench_airline.py index 0eeba626..f3a7c65f 100644 --- a/tests/pytest/test_tau_bench_airline.py +++ b/tests/pytest/test_tau_bench_airline.py @@ -12,7 +12,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from vendor.tau2.data_model.message import ( AssistantMessage, SystemMessage, @@ -72,7 +72,7 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", } ], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), passed_threshold={"success": 0.4, "standard_deviation": 0.1}, num_runs=8, mode="pointwise", diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py new file mode 100644 index 00000000..a483f0e1 --- /dev/null +++ b/tests/test_retry_mechanism.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +""" +Simple test to verify the retry mechanism works with evaluation_test. +""" + +import asyncio +import os +from collections import Counter +from typing import List +from unittest.mock import Mock + +import pytest + +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus +from eval_protocol.pytest.evaluation_test import evaluation_test +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + +os.environ["EP_MAX_RETRY"] = "2" # Allow up to 2 retries + + +class MockRolloutProcessorWithRetries(RolloutProcessor): + """Mock rollout processor that fails second task alphabetically on first attempt, succeeds on retry""" + + def __init__(self): + self.mock_tracker = Mock() + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + # Track this batch call + self.mock_tracker.batch_call(len(rows)) + + row_setup = { + 0: {"delay": 0.01, "should_fail": False}, + 1: {"delay": 0.01, "should_fail": True}, # Will be adjusted based on attempt number + 2: {"delay": 0.01, "should_fail": False}, + 3: {"delay": 0.01, "should_fail": False}, + 4: {"delay": 0.01, "should_fail": False}, + } + + async def process_single_row( + row: EvaluationRow, delay: float, base_should_fail: bool = False + ) -> EvaluationRow: + rollout_id = row.execution_metadata.rollout_id + + # Track individual row processing call + self.mock_tracker.process_row_call(rollout_id) + + # Determine attempt number by counting previous calls for this rollout_id + previous_calls = [ + call for call in self.mock_tracker.process_row_call.call_args_list if call[0][0] == rollout_id + ] + attempt_number = len(previous_calls) + + # Determine if this specific attempt should fail + # Row 1 fails on first attempt (attempt_number == 1), succeeds on retry (attempt_number == 2) + should_fail = base_should_fail and attempt_number == 1 + + print(f"šŸ”„ ATTEMPTING rollout_id={rollout_id}, attempt={attempt_number}, will_fail={should_fail}") + + await asyncio.sleep(delay) + print(f"šŸŽ‰ FINISHED {'error' if should_fail else 'finished'}: {row.execution_metadata.rollout_id}") + + if should_fail: + raise Exception("Simulated failure for testing") + + return row + + # Create and return tasks (let evaluation_test handle them) + tasks = [ + asyncio.create_task(process_single_row(row, row_setup[i]["delay"], row_setup[i]["should_fail"])) + for i, row in enumerate(rows) + ] + + return tasks + + +# Create a shared processor instance for testing +shared_processor = MockRolloutProcessorWithRetries() + + +@evaluation_test( + completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], + input_messages=[ + [Message(role="user", content="Task A")], + [Message(role="user", content="Task B")], + [Message(role="user", content="Task C")], + [Message(role="user", content="Task D")], + [Message(role="user", content="Task E")], + ], + rollout_processor=shared_processor, + num_runs=1, + mode="pointwise", +) +def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow: + """MOCK TEST: Tests that retry mechanism works - one task fails on first attempt, succeeds on retry.""" + print( + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + ) + + # Assign a score based on success/failure + score = 1.0 if row.rollout_status.status == "finished" else 0.0 + row.evaluation_result = EvaluateResult(score=score) + + return row + + +def test_retry_mechanism_mock_verification(): + """Test that verifies the retry mechanism worked by checking the mock calls""" + # Get our mock tracker + mock_tracker = shared_processor.mock_tracker + + print(f"\nšŸ”„ MOCK CALL ANALYSIS:") + print(f" Batch calls made: {mock_tracker.batch_call.call_count}") + print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") + + if mock_tracker.process_row_call.call_count == 0: + print("āš ļø No calls recorded yet. The evaluation test may not have run or completed.") + return + + # Get all rollout_ids that were processed + call_args = mock_tracker.process_row_call.call_args_list + rollout_ids = [call[0][0] for call in call_args] + + # Count calls per rollout_id + call_counts = Counter(rollout_ids) + + print(f" Call counts per rollout_id: {dict(call_counts)}") + print(f" Individual calls:") + for i, call_arg in enumerate(call_args, 1): + rollout_id = call_arg[0][0] + attempt_num = rollout_ids[:i].count(rollout_id) + print(f" {i}. rollout_id={rollout_id}, attempt={attempt_num}") + + # ASSERTIONS USING MOCK DATA + # Should have exactly 6 total row processing calls (5 initial + 1 retry) + assert ( + mock_tracker.process_row_call.call_count == 6 + ), f"Expected 6 total calls, got {mock_tracker.process_row_call.call_count}" + + # Should have exactly 2 batch calls (initial batch + retry batch) + assert mock_tracker.batch_call.call_count == 2, f"Expected 2 batch calls, got {mock_tracker.batch_call.call_count}" + + # First batch should have 5 rows, second batch should have 1 row (the retry) + batch_call_args = mock_tracker.batch_call.call_args_list + assert batch_call_args[0][0][0] == 5, f"Expected first batch to have 5 rows, got {batch_call_args[0][0][0]}" + assert batch_call_args[1][0][0] == 1, f"Expected second batch to have 1 row, got {batch_call_args[1][0][0]}" + + # Exactly one rollout_id should be called twice, others called once + call_count_values = list(call_counts.values()) + assert ( + call_count_values.count(2) == 1 + ), f"Expected exactly 1 rollout_id to be called twice, got counts: {dict(call_counts)}" + assert ( + call_count_values.count(1) == 4 + ), f"Expected exactly 4 rollout_ids to be called once, got counts: {dict(call_counts)}" + + print("āœ… All mock-based assertions passed! Retry mechanism is working correctly.") diff --git a/tests/test_rollout_control_plane_integration.py b/tests/test_rollout_control_plane_integration.py index 1b92d5aa..8d176780 100644 --- a/tests/test_rollout_control_plane_integration.py +++ b/tests/test_rollout_control_plane_integration.py @@ -239,8 +239,10 @@ def mock_step_side_effect(env_index, tool_call): policy = MockPolicy(["right", "down", "right"]) # Execute rollout + tasks = self.execution_manager.execute_rollouts(mock_env, policy, steps=10) evaluation_rows = [] - async for row in self.execution_manager.execute_rollouts(mock_env, policy, steps=10): + for task in tasks: + row = await task evaluation_rows.append(row) # Validate results @@ -459,8 +461,10 @@ async def test_rollout_handles_control_plane_failure_gracefully(self): # Execute rollout with control plane failure policy = MockPolicy(["right"]) + tasks = self.execution_manager.execute_rollouts(mock_env, policy, steps=1) evaluation_rows = [] - async for row in self.execution_manager.execute_rollouts(mock_env, policy, steps=1): + for task in tasks: + row = await task evaluation_rows.append(row) # Should still work, but without control plane info @@ -497,7 +501,7 @@ async def test_rollout_creates_envs_from_url(self): policy = MockPolicy(["right"]) with ( - patch("eval_protocol.mcp_env.make", new_callable=AsyncMock) as mock_make, + patch("eval_protocol.mcp_env.make") as mock_make, patch("eval_protocol.mcp_env.ExecutionManager") as MockManager, ): mock_env = MagicMock() @@ -505,24 +509,30 @@ async def test_rollout_creates_envs_from_url(self): manager_instance = MockManager.return_value - # Mock execute_rollouts to return an async generator and track calls + # Mock execute_rollouts to return tasks and track calls call_args = [] - async def mock_execute_rollouts(*args, **kwargs): + async def mock_task(): + return "ok" + + def mock_execute_rollouts(*args, **kwargs): call_args.append((args, kwargs)) - for item in ["ok"]: - yield item + import asyncio + + return [asyncio.create_task(mock_task())] manager_instance.execute_rollouts = mock_execute_rollouts result = [] - async for row in ep.rollout( + tasks = ep.rollout( "http://localhost:1234/mcp/", policy, dataset=dataset, model_id="test_model", steps=5, - ): + ) + for task in tasks: + row = await task result.append(row) mock_make.assert_called_once_with( diff --git a/tests/test_tau_bench_airline_smoke.py b/tests/test_tau_bench_airline_smoke.py index 200f7ca8..044447b7 100644 --- a/tests/test_tau_bench_airline_smoke.py +++ b/tests/test_tau_bench_airline_smoke.py @@ -13,7 +13,7 @@ from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from vendor.tau2.data_model.message import ( AssistantMessage, SystemMessage, @@ -72,7 +72,7 @@ def tau_bench_airline_smoke_to_evaluation_row(data: List[Dict[str, Any]]) -> Lis "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", } ], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), passed_threshold=0.36, num_runs=1, # Smoke test: single run for quick feedback mode="pointwise",