From 8b48867658eb128605c370c39a099a9fe1be246d Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 18 Aug 2025 15:05:52 -0700 Subject: [PATCH 1/3] custom rollout exception and termination_reason=interrupted --- eval_protocol/mcp/execution/manager.py | 31 +++++++++++++------------- eval_protocol/models.py | 11 ++++----- eval_protocol/types/__init__.py | 3 ++- eval_protocol/types/errors.py | 13 +++++++++++ eval_protocol/types/types.py | 2 ++ 5 files changed, 38 insertions(+), 22 deletions(-) create mode 100644 eval_protocol/types/errors.py diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 5c21806e..bad40a69 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -12,17 +12,16 @@ import threading import time from dataclasses import asdict -from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import anyio -import httpx from openai.types import CompletionUsage from vendor.tau2.data_model.message import AssistantMessage, UserMessage from vendor.tau2.user.user_simulator import UserSimulator from ...models import EvaluationRow, InputMetadata, Message -from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory +from ...types import TerminationReason, Trajectory, RolloutTerminationException if TYPE_CHECKING: from ..session.manager import GeneralMCPVectorEnv @@ -107,7 +106,7 @@ async def _execute_with_semaphore(idx): ) # Convert trajectory to EvaluationRow immediately - evaluation_row = evaluation_rows[idx] + evaluation_row: EvaluationRow = evaluation_rows[idx] # Handle multimodal content by extracting text from complex content structures messages = [] @@ -137,14 +136,16 @@ async def _execute_with_semaphore(idx): } if trajectory.terminated: + evaluation_row.rollout_status.termination_reason = trajectory.termination_reason + # preserve the true error mesage if there are any + if trajectory.control_plane_summary.get("error_message"): + evaluation_row.rollout_status.extra_info = { + "error_message": trajectory.control_plane_summary.get("error_message") + } if trajectory.termination_reason == TerminationReason.ERROR: evaluation_row.rollout_status.status = "error" - evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get( - "error_message", None - ) else: evaluation_row.rollout_status.status = "finished" - evaluation_row.rollout_status.termination_reason = trajectory.termination_reason else: evaluation_row.rollout_status.status = "running" @@ -437,17 +438,16 @@ async def _execute_rollout( logger.info( f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" ) - - except asyncio.CancelledError: + except RolloutTerminationException as e: + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.INTERRUPTED + trajectory.control_plane_summary.update({"error_message": str(e)}) + logger.error(f"🚨 Rollout {rollout_idx} terminated due to user exception: {str(e)}", exc_info=True) + except asyncio.CancelledError: # need explicit catch since it doesn't inherit from Exception failure_reason = "asyncio context cancelled" logger.error( f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True ) - except (anyio.ClosedResourceError, anyio.BrokenResourceError): - failure_reason = "anyioconnection/resource error" - logger.error( - f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True - ) except Exception as e: error_msg = str(e) if str(e) else f"{type(e).__name__}: Unexpected error" logger.error(f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {error_msg}", exc_info=True) @@ -461,7 +461,6 @@ async def _execute_rollout( await envs.connection_manager.reset_session(session) except Exception as e: logger.warning(f"Failed to reset session {session.session_id}: {type(e).__name__}: {e}", exc_info=True) - try: await envs.connection_manager.close_session(session) except Exception as e: diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 9d528289..6b1750c8 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -11,6 +11,7 @@ from eval_protocol.get_pep440_version import get_pep440_version from eval_protocol.human_id import generate_id +from eval_protocol.types import TerminationReason class ChatCompletionContentPartTextParam(BaseModel): @@ -285,14 +286,14 @@ class RolloutStatus(BaseModel): """ running: Unfinished rollout which is still in progress. - finished: Rollout finished successfully. - error: Rollout failed. - stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop). + finished: Rollout finished. + error: Rollout failed due to unexpected error. The rollout record should be discard. """ status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.") - termination_reason: Optional[str] = Field( - "", description="reason of the rollout status, mapped to values in TerminationReason" + termination_reason: Optional[TerminationReason] = Field( + None, description="reason of the rollout status, mapped to values in TerminationReason" ) + extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.") class EvaluationRow(BaseModel): diff --git a/eval_protocol/types/__init__.py b/eval_protocol/types/__init__.py index 1b9fcd9a..1afd3547 100644 --- a/eval_protocol/types/__init__.py +++ b/eval_protocol/types/__init__.py @@ -1,3 +1,4 @@ from .types import DatasetRow, MCPSession, MCPToolCall, TerminationReason, Trajectory +from .errors import RolloutTerminationException -__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow"] +__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "RolloutTerminationException"] diff --git a/eval_protocol/types/errors.py b/eval_protocol/types/errors.py new file mode 100644 index 00000000..363f0a8f --- /dev/null +++ b/eval_protocol/types/errors.py @@ -0,0 +1,13 @@ +class RolloutTerminationException(Exception): + """ + Exception raised during rollout. This error means the rollout needs to be terminated and its not retriable, + and the trajectory need to be perserved for future evalution or analysis. + + For example, if the policy (llm) returns 400 User error, we need to end the rollout but keep the trajectory. + It differs from other exceptions such as network error, which are retriable and the trajectory should be discarded if + it fails eventually after retries. + + This will cause trajectory.termination_reason to be set to TerminationReason.INTERRUPTED. + """ + + pass diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index a94675c4..29dd99ac 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -13,6 +13,7 @@ class TerminationReason(str, Enum): CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition) USER_STOP: Trajectory ends because the simulated user signals to stop ERROR: Trajectory ends because of an error + INTERRUPTED: Trajectory is interrupted by some non-system related reason (e.g. policy returns unexpected response and we need to terminate the rollout). STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop") LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length") TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls") @@ -22,6 +23,7 @@ class TerminationReason(str, Enum): CONTROL_PLANE_SIGNAL = "control_plane_signal" USER_STOP = "user_stop" ERROR = "error" + INTERRUPTED = "interrupted" STOP = "stop" LENGTH = "length" TOOL_CALLS = "tool_calls" From 0f96c80d7efab8210360c8bf07e43854cae35452 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 18 Aug 2025 17:10:09 -0700 Subject: [PATCH 2/3] add nonskippableerror --- eval_protocol/mcp/execution/manager.py | 32 ++++++++------------------ eval_protocol/models.py | 9 +++++++- eval_protocol/pytest/utils.py | 20 ++++++++-------- eval_protocol/types/__init__.py | 4 ++-- eval_protocol/types/errors.py | 8 +++---- eval_protocol/types/types.py | 14 ++++++----- 6 files changed, 40 insertions(+), 47 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index bad40a69..6b1163e9 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -20,8 +20,8 @@ from vendor.tau2.data_model.message import AssistantMessage, UserMessage from vendor.tau2.user.user_simulator import UserSimulator -from ...models import EvaluationRow, InputMetadata, Message -from ...types import TerminationReason, Trajectory, RolloutTerminationException +from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus +from ...types import TerminationReason, Trajectory, NonSkippableException if TYPE_CHECKING: from ..session.manager import GeneralMCPVectorEnv @@ -137,17 +137,14 @@ async def _execute_with_semaphore(idx): if trajectory.terminated: evaluation_row.rollout_status.termination_reason = trajectory.termination_reason + evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED # preserve the true error mesage if there are any if trajectory.control_plane_summary.get("error_message"): evaluation_row.rollout_status.extra_info = { "error_message": trajectory.control_plane_summary.get("error_message") } - if trajectory.termination_reason == TerminationReason.ERROR: - evaluation_row.rollout_status.status = "error" - else: - evaluation_row.rollout_status.status = "finished" else: - evaluation_row.rollout_status.status = "running" + evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING return evaluation_row @@ -438,25 +435,14 @@ async def _execute_rollout( logger.info( f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" ) - except RolloutTerminationException as e: + except NonSkippableException as e: + # terminate the rollout right away, no retry and preserve the current trajectory history. + # for other types of exceptions, keep propagate them to upper layers and handle them with retry handler. trajectory.terminated = True - trajectory.termination_reason = TerminationReason.INTERRUPTED + trajectory.termination_reason = TerminationReason.NON_SKIPPABLE_ERROR trajectory.control_plane_summary.update({"error_message": str(e)}) - logger.error(f"🚨 Rollout {rollout_idx} terminated due to user exception: {str(e)}", exc_info=True) - except asyncio.CancelledError: # need explicit catch since it doesn't inherit from Exception - failure_reason = "asyncio context cancelled" - logger.error( - f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True - ) - except Exception as e: - error_msg = str(e) if str(e) else f"{type(e).__name__}: Unexpected error" - logger.error(f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {error_msg}", exc_info=True) - failure_reason = error_msg + logger.error(f"🚨 Rollout {rollout_idx} terminated due to non-skippable error: {str(e)}", exc_info=True) finally: - if failure_reason: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.ERROR - trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"}) try: await envs.connection_manager.reset_session(session) except Exception as e: diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 6b1750c8..1d65d3e0 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -1,5 +1,6 @@ import os from datetime import datetime +from enum import Enum from typing import Any, Dict, List, Literal, Optional, TypedDict, Union from openai.types import CompletionUsage @@ -289,7 +290,13 @@ class RolloutStatus(BaseModel): finished: Rollout finished. error: Rollout failed due to unexpected error. The rollout record should be discard. """ - status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.") + + class Status(str, Enum): + RUNNING = "running" + FINISHED = "finished" + ERROR = "error" + + status: Status = Field(Status.RUNNING, description="Status of the rollout.") termination_reason: Optional[TerminationReason] = Field( None, description="reason of the rollout status, mapped to values in TerminationReason" ) diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 617b9e85..14d2d4a7 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union from eval_protocol.dataset_logger.dataset_logger import DatasetLogger -from eval_protocol.models import EvalMetadata, EvaluationRow +from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( CompletionParams, @@ -248,7 +248,7 @@ async def rollout_processor_with_retry( """ try: - queue = asyncio.Queue() + queue: asyncio.Queue[EvaluationRow] = asyncio.Queue() retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset} failed_permanently = [] @@ -257,9 +257,9 @@ async def retry_handler(failed_row: EvaluationRow): current_attempts = retry_counts.get(rollout_id, 0) if current_attempts >= max_retry: - assert failed_row.rollout_status and failed_row.rollout_status.status == "error", ( - f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" - ) + assert ( + failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR + ), f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" failed_permanently.append(failed_row) await queue.put(failed_row) # put failed row on queue return @@ -273,10 +273,10 @@ async def retry_handler(failed_row: EvaluationRow): try: retry_result = await retry_tasks[0] - retry_result.rollout_status.status = "finished" + retry_result.rollout_status.status = RolloutStatus.Status.FINISHED await queue.put(retry_result) except Exception as e: - failed_row.rollout_status.status = "error" + failed_row.rollout_status.status = RolloutStatus.Status.ERROR failed_row.rollout_status.termination_reason = str(e) asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry @@ -299,11 +299,11 @@ async def initial_processor(): try: result = await task - result.rollout_status.status = "finished" + result.rollout_status.status = RolloutStatus.Status.FINISHED await queue.put(result) except Exception as e: failed_row = fresh_dataset[task_index] - failed_row.rollout_status.status = "error" + failed_row.rollout_status.status = RolloutStatus.Status.ERROR failed_row.rollout_status.termination_reason = str(e) asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task @@ -317,7 +317,7 @@ async def initial_processor(): finished_row = await queue.get() # only permanent failure rows are put on the queue, so we can check for them here - if finished_row.rollout_status and finished_row.rollout_status.status == "error": + if finished_row.rollout_status and finished_row.rollout_status.status == RolloutStatus.Status.ERROR: if max_retry > 0 and os.getenv("EP_FAIL_ON_MAX_RETRY", "true") != "false": raise RuntimeError( f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}" diff --git a/eval_protocol/types/__init__.py b/eval_protocol/types/__init__.py index 1afd3547..2fff1c4e 100644 --- a/eval_protocol/types/__init__.py +++ b/eval_protocol/types/__init__.py @@ -1,4 +1,4 @@ from .types import DatasetRow, MCPSession, MCPToolCall, TerminationReason, Trajectory -from .errors import RolloutTerminationException +from .errors import NonSkippableException -__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "RolloutTerminationException"] +__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "NonSkippableException"] diff --git a/eval_protocol/types/errors.py b/eval_protocol/types/errors.py index 363f0a8f..e0bcfa29 100644 --- a/eval_protocol/types/errors.py +++ b/eval_protocol/types/errors.py @@ -1,13 +1,11 @@ -class RolloutTerminationException(Exception): +class NonSkippableException(Exception): """ - Exception raised during rollout. This error means the rollout needs to be terminated and its not retriable, - and the trajectory need to be perserved for future evalution or analysis. + A type of custom exception raised during rollout or evaluation. This error means the rollout/evaluation result is not skippable and need to be + processed explicitly. For example, if the policy (llm) returns 400 User error, we need to end the rollout but keep the trajectory. It differs from other exceptions such as network error, which are retriable and the trajectory should be discarded if it fails eventually after retries. - - This will cause trajectory.termination_reason to be set to TerminationReason.INTERRUPTED. """ pass diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 29dd99ac..4e696aa4 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -12,8 +12,8 @@ class TerminationReason(str, Enum): MAX_STEPS: Trajectory ends because we hit the step limit CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition) USER_STOP: Trajectory ends because the simulated user signals to stop - ERROR: Trajectory ends because of an error - INTERRUPTED: Trajectory is interrupted by some non-system related reason (e.g. policy returns unexpected response and we need to terminate the rollout). + SKIPPABLE_ERROR: Trajectory ends because of an error, this trajectory can be discarded/skipped during postprocessing/evaluation. + NON_SKIPPABLE_ERROR: Trajectory is interrupted due to some non-skippable error (e.g. policy returns unexpected response and we need to terminate the rollout). STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop") LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length") TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls") @@ -22,8 +22,8 @@ class TerminationReason(str, Enum): MAX_STEPS = "max_steps" CONTROL_PLANE_SIGNAL = "control_plane_signal" USER_STOP = "user_stop" - ERROR = "error" - INTERRUPTED = "interrupted" + SKIPPABLE_ERROR = "skippable_error" + NON_SKIPPABLE_ERROR = "non_skippable_error" STOP = "stop" LENGTH = "length" TOOL_CALLS = "tool_calls" @@ -40,8 +40,10 @@ def from_str(cls, value: str) -> "TerminationReason": return cls.CONTROL_PLANE_SIGNAL elif value == "user_stop": return cls.USER_STOP - elif value == "error": - return cls.ERROR + elif value == "skippable_error": + return cls.SKIPPABLE_ERROR + elif value == "non_skippable_error": + return cls.NON_SKIPPABLE_ERROR elif value == "tool_calls": return cls.TOOL_CALLS else: From 5a2bfe22f2e19a2f8914d79204ecb7ec119a6556 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 18 Aug 2025 17:14:11 -0700 Subject: [PATCH 3/3] format --- eval_protocol/pytest/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 14d2d4a7..e0b8328a 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -257,9 +257,9 @@ async def retry_handler(failed_row: EvaluationRow): current_attempts = retry_counts.get(rollout_id, 0) if current_attempts >= max_retry: - assert ( - failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR - ), f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" + assert failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR, ( + f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" + ) failed_permanently.append(failed_row) await queue.put(failed_row) # put failed row on queue return