Skip to content

Commit 8b48867

Browse files
committed
custom rollout exception and termination_reason=interrupted
1 parent 7231c78 commit 8b48867

File tree

5 files changed

+38
-22
lines changed

5 files changed

+38
-22
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,16 @@
1212
import threading
1313
import time
1414
from dataclasses import asdict
15-
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union
15+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1616

1717
import anyio
18-
import httpx
1918
from openai.types import CompletionUsage
2019

2120
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
2221
from vendor.tau2.user.user_simulator import UserSimulator
2322

2423
from ...models import EvaluationRow, InputMetadata, Message
25-
from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory
24+
from ...types import TerminationReason, Trajectory, RolloutTerminationException
2625

2726
if TYPE_CHECKING:
2827
from ..session.manager import GeneralMCPVectorEnv
@@ -107,7 +106,7 @@ async def _execute_with_semaphore(idx):
107106
)
108107

109108
# Convert trajectory to EvaluationRow immediately
110-
evaluation_row = evaluation_rows[idx]
109+
evaluation_row: EvaluationRow = evaluation_rows[idx]
111110

112111
# Handle multimodal content by extracting text from complex content structures
113112
messages = []
@@ -137,14 +136,16 @@ async def _execute_with_semaphore(idx):
137136
}
138137

139138
if trajectory.terminated:
139+
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
140+
# preserve the true error mesage if there are any
141+
if trajectory.control_plane_summary.get("error_message"):
142+
evaluation_row.rollout_status.extra_info = {
143+
"error_message": trajectory.control_plane_summary.get("error_message")
144+
}
140145
if trajectory.termination_reason == TerminationReason.ERROR:
141146
evaluation_row.rollout_status.status = "error"
142-
evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get(
143-
"error_message", None
144-
)
145147
else:
146148
evaluation_row.rollout_status.status = "finished"
147-
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
148149
else:
149150
evaluation_row.rollout_status.status = "running"
150151

@@ -437,17 +438,16 @@ async def _execute_rollout(
437438
logger.info(
438439
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
439440
)
440-
441-
except asyncio.CancelledError:
441+
except RolloutTerminationException as e:
442+
trajectory.terminated = True
443+
trajectory.termination_reason = TerminationReason.INTERRUPTED
444+
trajectory.control_plane_summary.update({"error_message": str(e)})
445+
logger.error(f"🚨 Rollout {rollout_idx} terminated due to user exception: {str(e)}", exc_info=True)
446+
except asyncio.CancelledError: # need explicit catch since it doesn't inherit from Exception
442447
failure_reason = "asyncio context cancelled"
443448
logger.error(
444449
f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True
445450
)
446-
except (anyio.ClosedResourceError, anyio.BrokenResourceError):
447-
failure_reason = "anyioconnection/resource error"
448-
logger.error(
449-
f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True
450-
)
451451
except Exception as e:
452452
error_msg = str(e) if str(e) else f"{type(e).__name__}: Unexpected error"
453453
logger.error(f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {error_msg}", exc_info=True)
@@ -461,7 +461,6 @@ async def _execute_rollout(
461461
await envs.connection_manager.reset_session(session)
462462
except Exception as e:
463463
logger.warning(f"Failed to reset session {session.session_id}: {type(e).__name__}: {e}", exc_info=True)
464-
465464
try:
466465
await envs.connection_manager.close_session(session)
467466
except Exception as e:

eval_protocol/models.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from eval_protocol.get_pep440_version import get_pep440_version
1313
from eval_protocol.human_id import generate_id
14+
from eval_protocol.types import TerminationReason
1415

1516

1617
class ChatCompletionContentPartTextParam(BaseModel):
@@ -285,14 +286,14 @@ class RolloutStatus(BaseModel):
285286

286287
"""
287288
running: Unfinished rollout which is still in progress.
288-
finished: Rollout finished successfully.
289-
error: Rollout failed.
290-
stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop).
289+
finished: Rollout finished.
290+
error: Rollout failed due to unexpected error. The rollout record should be discard.
291291
"""
292292
status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.")
293-
termination_reason: Optional[str] = Field(
294-
"", description="reason of the rollout status, mapped to values in TerminationReason"
293+
termination_reason: Optional[TerminationReason] = Field(
294+
None, description="reason of the rollout status, mapped to values in TerminationReason"
295295
)
296+
extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.")
296297

297298

298299
class EvaluationRow(BaseModel):

eval_protocol/types/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .types import DatasetRow, MCPSession, MCPToolCall, TerminationReason, Trajectory
2+
from .errors import RolloutTerminationException
23

3-
__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow"]
4+
__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "RolloutTerminationException"]

eval_protocol/types/errors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class RolloutTerminationException(Exception):
2+
"""
3+
Exception raised during rollout. This error means the rollout needs to be terminated and its not retriable,
4+
and the trajectory need to be perserved for future evalution or analysis.
5+
6+
For example, if the policy (llm) returns 400 User error, we need to end the rollout but keep the trajectory.
7+
It differs from other exceptions such as network error, which are retriable and the trajectory should be discarded if
8+
it fails eventually after retries.
9+
10+
This will cause trajectory.termination_reason to be set to TerminationReason.INTERRUPTED.
11+
"""
12+
13+
pass

eval_protocol/types/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class TerminationReason(str, Enum):
1313
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
1414
USER_STOP: Trajectory ends because the simulated user signals to stop
1515
ERROR: Trajectory ends because of an error
16+
INTERRUPTED: Trajectory is interrupted by some non-system related reason (e.g. policy returns unexpected response and we need to terminate the rollout).
1617
STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop")
1718
LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length")
1819
TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls")
@@ -22,6 +23,7 @@ class TerminationReason(str, Enum):
2223
CONTROL_PLANE_SIGNAL = "control_plane_signal"
2324
USER_STOP = "user_stop"
2425
ERROR = "error"
26+
INTERRUPTED = "interrupted"
2527
STOP = "stop"
2628
LENGTH = "length"
2729
TOOL_CALLS = "tool_calls"

0 commit comments

Comments
 (0)