diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 5c21806e..6b1163e9 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -12,17 +12,16 @@ import threading import time from dataclasses import asdict -from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import anyio -import httpx from openai.types import CompletionUsage from vendor.tau2.data_model.message import AssistantMessage, UserMessage from vendor.tau2.user.user_simulator import UserSimulator -from ...models import EvaluationRow, InputMetadata, Message -from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory +from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus +from ...types import TerminationReason, Trajectory, NonSkippableException if TYPE_CHECKING: from ..session.manager import GeneralMCPVectorEnv @@ -107,7 +106,7 @@ async def _execute_with_semaphore(idx): ) # Convert trajectory to EvaluationRow immediately - evaluation_row = evaluation_rows[idx] + evaluation_row: EvaluationRow = evaluation_rows[idx] # Handle multimodal content by extracting text from complex content structures messages = [] @@ -137,16 +136,15 @@ async def _execute_with_semaphore(idx): } if trajectory.terminated: - if trajectory.termination_reason == TerminationReason.ERROR: - evaluation_row.rollout_status.status = "error" - evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get( - "error_message", None - ) - else: - evaluation_row.rollout_status.status = "finished" - evaluation_row.rollout_status.termination_reason = trajectory.termination_reason + evaluation_row.rollout_status.termination_reason = trajectory.termination_reason + evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED + # preserve the true error mesage if there are any + if trajectory.control_plane_summary.get("error_message"): + evaluation_row.rollout_status.extra_info = { + "error_message": trajectory.control_plane_summary.get("error_message") + } else: - evaluation_row.rollout_status.status = "running" + evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING return evaluation_row @@ -437,31 +435,18 @@ async def _execute_rollout( logger.info( f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" ) - - except asyncio.CancelledError: - failure_reason = "asyncio context cancelled" - logger.error( - f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True - ) - except (anyio.ClosedResourceError, anyio.BrokenResourceError): - failure_reason = "anyioconnection/resource error" - logger.error( - f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True - ) - except Exception as e: - error_msg = str(e) if str(e) else f"{type(e).__name__}: Unexpected error" - logger.error(f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {error_msg}", exc_info=True) - failure_reason = error_msg + except NonSkippableException as e: + # terminate the rollout right away, no retry and preserve the current trajectory history. + # for other types of exceptions, keep propagate them to upper layers and handle them with retry handler. + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.NON_SKIPPABLE_ERROR + trajectory.control_plane_summary.update({"error_message": str(e)}) + logger.error(f"🚨 Rollout {rollout_idx} terminated due to non-skippable error: {str(e)}", exc_info=True) finally: - if failure_reason: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.ERROR - trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"}) try: await envs.connection_manager.reset_session(session) except Exception as e: logger.warning(f"Failed to reset session {session.session_id}: {type(e).__name__}: {e}", exc_info=True) - try: await envs.connection_manager.close_session(session) except Exception as e: diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 9d528289..1d65d3e0 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -1,5 +1,6 @@ import os from datetime import datetime +from enum import Enum from typing import Any, Dict, List, Literal, Optional, TypedDict, Union from openai.types import CompletionUsage @@ -11,6 +12,7 @@ from eval_protocol.get_pep440_version import get_pep440_version from eval_protocol.human_id import generate_id +from eval_protocol.types import TerminationReason class ChatCompletionContentPartTextParam(BaseModel): @@ -285,14 +287,20 @@ class RolloutStatus(BaseModel): """ running: Unfinished rollout which is still in progress. - finished: Rollout finished successfully. - error: Rollout failed. - stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop). + finished: Rollout finished. + error: Rollout failed due to unexpected error. The rollout record should be discard. """ - status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.") - termination_reason: Optional[str] = Field( - "", description="reason of the rollout status, mapped to values in TerminationReason" + + class Status(str, Enum): + RUNNING = "running" + FINISHED = "finished" + ERROR = "error" + + status: Status = Field(Status.RUNNING, description="Status of the rollout.") + termination_reason: Optional[TerminationReason] = Field( + None, description="reason of the rollout status, mapped to values in TerminationReason" ) + extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.") class EvaluationRow(BaseModel): diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 617b9e85..e0b8328a 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union from eval_protocol.dataset_logger.dataset_logger import DatasetLogger -from eval_protocol.models import EvalMetadata, EvaluationRow +from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( CompletionParams, @@ -248,7 +248,7 @@ async def rollout_processor_with_retry( """ try: - queue = asyncio.Queue() + queue: asyncio.Queue[EvaluationRow] = asyncio.Queue() retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset} failed_permanently = [] @@ -257,7 +257,7 @@ async def retry_handler(failed_row: EvaluationRow): current_attempts = retry_counts.get(rollout_id, 0) if current_attempts >= max_retry: - assert failed_row.rollout_status and failed_row.rollout_status.status == "error", ( + assert failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR, ( f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" ) failed_permanently.append(failed_row) @@ -273,10 +273,10 @@ async def retry_handler(failed_row: EvaluationRow): try: retry_result = await retry_tasks[0] - retry_result.rollout_status.status = "finished" + retry_result.rollout_status.status = RolloutStatus.Status.FINISHED await queue.put(retry_result) except Exception as e: - failed_row.rollout_status.status = "error" + failed_row.rollout_status.status = RolloutStatus.Status.ERROR failed_row.rollout_status.termination_reason = str(e) asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry @@ -299,11 +299,11 @@ async def initial_processor(): try: result = await task - result.rollout_status.status = "finished" + result.rollout_status.status = RolloutStatus.Status.FINISHED await queue.put(result) except Exception as e: failed_row = fresh_dataset[task_index] - failed_row.rollout_status.status = "error" + failed_row.rollout_status.status = RolloutStatus.Status.ERROR failed_row.rollout_status.termination_reason = str(e) asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task @@ -317,7 +317,7 @@ async def initial_processor(): finished_row = await queue.get() # only permanent failure rows are put on the queue, so we can check for them here - if finished_row.rollout_status and finished_row.rollout_status.status == "error": + if finished_row.rollout_status and finished_row.rollout_status.status == RolloutStatus.Status.ERROR: if max_retry > 0 and os.getenv("EP_FAIL_ON_MAX_RETRY", "true") != "false": raise RuntimeError( f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}" diff --git a/eval_protocol/types/__init__.py b/eval_protocol/types/__init__.py index 1b9fcd9a..2fff1c4e 100644 --- a/eval_protocol/types/__init__.py +++ b/eval_protocol/types/__init__.py @@ -1,3 +1,4 @@ from .types import DatasetRow, MCPSession, MCPToolCall, TerminationReason, Trajectory +from .errors import NonSkippableException -__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow"] +__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "NonSkippableException"] diff --git a/eval_protocol/types/errors.py b/eval_protocol/types/errors.py new file mode 100644 index 00000000..e0bcfa29 --- /dev/null +++ b/eval_protocol/types/errors.py @@ -0,0 +1,11 @@ +class NonSkippableException(Exception): + """ + A type of custom exception raised during rollout or evaluation. This error means the rollout/evaluation result is not skippable and need to be + processed explicitly. + + For example, if the policy (llm) returns 400 User error, we need to end the rollout but keep the trajectory. + It differs from other exceptions such as network error, which are retriable and the trajectory should be discarded if + it fails eventually after retries. + """ + + pass diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index a94675c4..4e696aa4 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -12,7 +12,8 @@ class TerminationReason(str, Enum): MAX_STEPS: Trajectory ends because we hit the step limit CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition) USER_STOP: Trajectory ends because the simulated user signals to stop - ERROR: Trajectory ends because of an error + SKIPPABLE_ERROR: Trajectory ends because of an error, this trajectory can be discarded/skipped during postprocessing/evaluation. + NON_SKIPPABLE_ERROR: Trajectory is interrupted due to some non-skippable error (e.g. policy returns unexpected response and we need to terminate the rollout). STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop") LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length") TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls") @@ -21,7 +22,8 @@ class TerminationReason(str, Enum): MAX_STEPS = "max_steps" CONTROL_PLANE_SIGNAL = "control_plane_signal" USER_STOP = "user_stop" - ERROR = "error" + SKIPPABLE_ERROR = "skippable_error" + NON_SKIPPABLE_ERROR = "non_skippable_error" STOP = "stop" LENGTH = "length" TOOL_CALLS = "tool_calls" @@ -38,8 +40,10 @@ def from_str(cls, value: str) -> "TerminationReason": return cls.CONTROL_PLANE_SIGNAL elif value == "user_stop": return cls.USER_STOP - elif value == "error": - return cls.ERROR + elif value == "skippable_error": + return cls.SKIPPABLE_ERROR + elif value == "non_skippable_error": + return cls.NON_SKIPPABLE_ERROR elif value == "tool_calls": return cls.TOOL_CALLS else: