Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 19 additions & 34 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@
import threading
import time
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import anyio
import httpx
from openai.types import CompletionUsage

from vendor.tau2.data_model.message import AssistantMessage, UserMessage
from vendor.tau2.user.user_simulator import UserSimulator

from ...models import EvaluationRow, InputMetadata, Message
from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory
from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus
from ...types import TerminationReason, Trajectory, NonSkippableException

if TYPE_CHECKING:
from ..session.manager import GeneralMCPVectorEnv
Expand Down Expand Up @@ -107,7 +106,7 @@ async def _execute_with_semaphore(idx):
)

# Convert trajectory to EvaluationRow immediately
evaluation_row = evaluation_rows[idx]
evaluation_row: EvaluationRow = evaluation_rows[idx]

# Handle multimodal content by extracting text from complex content structures
messages = []
Expand Down Expand Up @@ -137,16 +136,15 @@ async def _execute_with_semaphore(idx):
}

if trajectory.terminated:
if trajectory.termination_reason == TerminationReason.ERROR:
evaluation_row.rollout_status.status = "error"
evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get(
"error_message", None
)
else:
evaluation_row.rollout_status.status = "finished"
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED
# preserve the true error mesage if there are any
if trajectory.control_plane_summary.get("error_message"):
evaluation_row.rollout_status.extra_info = {
"error_message": trajectory.control_plane_summary.get("error_message")
}
else:
evaluation_row.rollout_status.status = "running"
evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING

return evaluation_row

Expand Down Expand Up @@ -437,31 +435,18 @@ async def _execute_rollout(
logger.info(
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
)

except asyncio.CancelledError:
failure_reason = "asyncio context cancelled"
logger.error(
f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True
)
except (anyio.ClosedResourceError, anyio.BrokenResourceError):
failure_reason = "anyioconnection/resource error"
logger.error(
f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True
)
except Exception as e:
error_msg = str(e) if str(e) else f"{type(e).__name__}: Unexpected error"
logger.error(f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {error_msg}", exc_info=True)
failure_reason = error_msg
except NonSkippableException as e:
# terminate the rollout right away, no retry and preserve the current trajectory history.
# for other types of exceptions, keep propagate them to upper layers and handle them with retry handler.
trajectory.terminated = True
trajectory.termination_reason = TerminationReason.NON_SKIPPABLE_ERROR
trajectory.control_plane_summary.update({"error_message": str(e)})
logger.error(f"🚨 Rollout {rollout_idx} terminated due to non-skippable error: {str(e)}", exc_info=True)
finally:
if failure_reason:
trajectory.terminated = True
trajectory.termination_reason = TerminationReason.ERROR
trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"})
try:
await envs.connection_manager.reset_session(session)
except Exception as e:
logger.warning(f"Failed to reset session {session.session_id}: {type(e).__name__}: {e}", exc_info=True)

try:
await envs.connection_manager.close_session(session)
except Exception as e:
Expand Down
20 changes: 14 additions & 6 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union

from openai.types import CompletionUsage
Expand All @@ -11,6 +12,7 @@

from eval_protocol.get_pep440_version import get_pep440_version
from eval_protocol.human_id import generate_id
from eval_protocol.types import TerminationReason


class ChatCompletionContentPartTextParam(BaseModel):
Expand Down Expand Up @@ -285,14 +287,20 @@ class RolloutStatus(BaseModel):

"""
running: Unfinished rollout which is still in progress.
finished: Rollout finished successfully.
error: Rollout failed.
stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop).
finished: Rollout finished.
error: Rollout failed due to unexpected error. The rollout record should be discard.
"""
status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.")
termination_reason: Optional[str] = Field(
"", description="reason of the rollout status, mapped to values in TerminationReason"

class Status(str, Enum):
RUNNING = "running"
FINISHED = "finished"
ERROR = "error"

status: Status = Field(Status.RUNNING, description="Status of the rollout.")
termination_reason: Optional[TerminationReason] = Field(
None, description="reason of the rollout status, mapped to values in TerminationReason"
)
extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.")


class EvaluationRow(BaseModel):
Expand Down
16 changes: 8 additions & 8 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.models import EvalMetadata, EvaluationRow
from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import (
CompletionParams,
Expand Down Expand Up @@ -248,7 +248,7 @@ async def rollout_processor_with_retry(
"""

try:
queue = asyncio.Queue()
queue: asyncio.Queue[EvaluationRow] = asyncio.Queue()
retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset}
failed_permanently = []

Expand All @@ -257,7 +257,7 @@ async def retry_handler(failed_row: EvaluationRow):
current_attempts = retry_counts.get(rollout_id, 0)

if current_attempts >= max_retry:
assert failed_row.rollout_status and failed_row.rollout_status.status == "error", (
assert failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR, (
f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status"
)
failed_permanently.append(failed_row)
Expand All @@ -273,10 +273,10 @@ async def retry_handler(failed_row: EvaluationRow):

try:
retry_result = await retry_tasks[0]
retry_result.rollout_status.status = "finished"
retry_result.rollout_status.status = RolloutStatus.Status.FINISHED
await queue.put(retry_result)
except Exception as e:
failed_row.rollout_status.status = "error"
failed_row.rollout_status.status = RolloutStatus.Status.ERROR
failed_row.rollout_status.termination_reason = str(e)
asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry

Expand All @@ -299,11 +299,11 @@ async def initial_processor():

try:
result = await task
result.rollout_status.status = "finished"
result.rollout_status.status = RolloutStatus.Status.FINISHED
await queue.put(result)
except Exception as e:
failed_row = fresh_dataset[task_index]
failed_row.rollout_status.status = "error"
failed_row.rollout_status.status = RolloutStatus.Status.ERROR
failed_row.rollout_status.termination_reason = str(e)
asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task

Expand All @@ -317,7 +317,7 @@ async def initial_processor():
finished_row = await queue.get()

# only permanent failure rows are put on the queue, so we can check for them here
if finished_row.rollout_status and finished_row.rollout_status.status == "error":
if finished_row.rollout_status and finished_row.rollout_status.status == RolloutStatus.Status.ERROR:
if max_retry > 0 and os.getenv("EP_FAIL_ON_MAX_RETRY", "true") != "false":
raise RuntimeError(
f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}"
Expand Down
3 changes: 2 additions & 1 deletion eval_protocol/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .types import DatasetRow, MCPSession, MCPToolCall, TerminationReason, Trajectory
from .errors import NonSkippableException

__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow"]
__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "NonSkippableException"]
11 changes: 11 additions & 0 deletions eval_protocol/types/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class NonSkippableException(Exception):
"""
A type of custom exception raised during rollout or evaluation. This error means the rollout/evaluation result is not skippable and need to be
processed explicitly.

For example, if the policy (llm) returns 400 User error, we need to end the rollout but keep the trajectory.
It differs from other exceptions such as network error, which are retriable and the trajectory should be discarded if
it fails eventually after retries.
"""

pass
12 changes: 8 additions & 4 deletions eval_protocol/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class TerminationReason(str, Enum):
MAX_STEPS: Trajectory ends because we hit the step limit
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
USER_STOP: Trajectory ends because the simulated user signals to stop
ERROR: Trajectory ends because of an error
SKIPPABLE_ERROR: Trajectory ends because of an error, this trajectory can be discarded/skipped during postprocessing/evaluation.
NON_SKIPPABLE_ERROR: Trajectory is interrupted due to some non-skippable error (e.g. policy returns unexpected response and we need to terminate the rollout).
STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop")
LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length")
TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls")
Expand All @@ -21,7 +22,8 @@ class TerminationReason(str, Enum):
MAX_STEPS = "max_steps"
CONTROL_PLANE_SIGNAL = "control_plane_signal"
USER_STOP = "user_stop"
ERROR = "error"
SKIPPABLE_ERROR = "skippable_error"
NON_SKIPPABLE_ERROR = "non_skippable_error"
STOP = "stop"
LENGTH = "length"
TOOL_CALLS = "tool_calls"
Expand All @@ -38,8 +40,10 @@ def from_str(cls, value: str) -> "TerminationReason":
return cls.CONTROL_PLANE_SIGNAL
elif value == "user_stop":
return cls.USER_STOP
elif value == "error":
return cls.ERROR
elif value == "skippable_error":
return cls.SKIPPABLE_ERROR
elif value == "non_skippable_error":
return cls.NON_SKIPPABLE_ERROR
elif value == "tool_calls":
return cls.TOOL_CALLS
else:
Expand Down
Loading