Skip to content

Commit 0f96c80

Browse files
committed
add nonskippableerror
1 parent 8b48867 commit 0f96c80

File tree

6 files changed

+40
-47
lines changed

6 files changed

+40
-47
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
2121
from vendor.tau2.user.user_simulator import UserSimulator
2222

23-
from ...models import EvaluationRow, InputMetadata, Message
24-
from ...types import TerminationReason, Trajectory, RolloutTerminationException
23+
from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus
24+
from ...types import TerminationReason, Trajectory, NonSkippableException
2525

2626
if TYPE_CHECKING:
2727
from ..session.manager import GeneralMCPVectorEnv
@@ -137,17 +137,14 @@ async def _execute_with_semaphore(idx):
137137

138138
if trajectory.terminated:
139139
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
140+
evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED
140141
# preserve the true error mesage if there are any
141142
if trajectory.control_plane_summary.get("error_message"):
142143
evaluation_row.rollout_status.extra_info = {
143144
"error_message": trajectory.control_plane_summary.get("error_message")
144145
}
145-
if trajectory.termination_reason == TerminationReason.ERROR:
146-
evaluation_row.rollout_status.status = "error"
147-
else:
148-
evaluation_row.rollout_status.status = "finished"
149146
else:
150-
evaluation_row.rollout_status.status = "running"
147+
evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING
151148

152149
return evaluation_row
153150

@@ -438,25 +435,14 @@ async def _execute_rollout(
438435
logger.info(
439436
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
440437
)
441-
except RolloutTerminationException as e:
438+
except NonSkippableException as e:
439+
# terminate the rollout right away, no retry and preserve the current trajectory history.
440+
# for other types of exceptions, keep propagate them to upper layers and handle them with retry handler.
442441
trajectory.terminated = True
443-
trajectory.termination_reason = TerminationReason.INTERRUPTED
442+
trajectory.termination_reason = TerminationReason.NON_SKIPPABLE_ERROR
444443
trajectory.control_plane_summary.update({"error_message": str(e)})
445-
logger.error(f"🚨 Rollout {rollout_idx} terminated due to user exception: {str(e)}", exc_info=True)
446-
except asyncio.CancelledError: # need explicit catch since it doesn't inherit from Exception
447-
failure_reason = "asyncio context cancelled"
448-
logger.error(
449-
f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True
450-
)
451-
except Exception as e:
452-
error_msg = str(e) if str(e) else f"{type(e).__name__}: Unexpected error"
453-
logger.error(f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {error_msg}", exc_info=True)
454-
failure_reason = error_msg
444+
logger.error(f"🚨 Rollout {rollout_idx} terminated due to non-skippable error: {str(e)}", exc_info=True)
455445
finally:
456-
if failure_reason:
457-
trajectory.terminated = True
458-
trajectory.termination_reason = TerminationReason.ERROR
459-
trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"})
460446
try:
461447
await envs.connection_manager.reset_session(session)
462448
except Exception as e:

eval_protocol/models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from datetime import datetime
3+
from enum import Enum
34
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
45

56
from openai.types import CompletionUsage
@@ -289,7 +290,13 @@ class RolloutStatus(BaseModel):
289290
finished: Rollout finished.
290291
error: Rollout failed due to unexpected error. The rollout record should be discard.
291292
"""
292-
status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.")
293+
294+
class Status(str, Enum):
295+
RUNNING = "running"
296+
FINISHED = "finished"
297+
ERROR = "error"
298+
299+
status: Status = Field(Status.RUNNING, description="Status of the rollout.")
293300
termination_reason: Optional[TerminationReason] = Field(
294301
None, description="reason of the rollout status, mapped to values in TerminationReason"
295302
)

eval_protocol/pytest/utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Callable, Dict, List, Literal, Optional, Union
77

88
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
9-
from eval_protocol.models import EvalMetadata, EvaluationRow
9+
from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus
1010
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1111
from eval_protocol.pytest.types import (
1212
CompletionParams,
@@ -248,7 +248,7 @@ async def rollout_processor_with_retry(
248248
"""
249249

250250
try:
251-
queue = asyncio.Queue()
251+
queue: asyncio.Queue[EvaluationRow] = asyncio.Queue()
252252
retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset}
253253
failed_permanently = []
254254

@@ -257,9 +257,9 @@ async def retry_handler(failed_row: EvaluationRow):
257257
current_attempts = retry_counts.get(rollout_id, 0)
258258

259259
if current_attempts >= max_retry:
260-
assert failed_row.rollout_status and failed_row.rollout_status.status == "error", (
261-
f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status"
262-
)
260+
assert (
261+
failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR
262+
), f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status"
263263
failed_permanently.append(failed_row)
264264
await queue.put(failed_row) # put failed row on queue
265265
return
@@ -273,10 +273,10 @@ async def retry_handler(failed_row: EvaluationRow):
273273

274274
try:
275275
retry_result = await retry_tasks[0]
276-
retry_result.rollout_status.status = "finished"
276+
retry_result.rollout_status.status = RolloutStatus.Status.FINISHED
277277
await queue.put(retry_result)
278278
except Exception as e:
279-
failed_row.rollout_status.status = "error"
279+
failed_row.rollout_status.status = RolloutStatus.Status.ERROR
280280
failed_row.rollout_status.termination_reason = str(e)
281281
asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry
282282

@@ -299,11 +299,11 @@ async def initial_processor():
299299

300300
try:
301301
result = await task
302-
result.rollout_status.status = "finished"
302+
result.rollout_status.status = RolloutStatus.Status.FINISHED
303303
await queue.put(result)
304304
except Exception as e:
305305
failed_row = fresh_dataset[task_index]
306-
failed_row.rollout_status.status = "error"
306+
failed_row.rollout_status.status = RolloutStatus.Status.ERROR
307307
failed_row.rollout_status.termination_reason = str(e)
308308
asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task
309309

@@ -317,7 +317,7 @@ async def initial_processor():
317317
finished_row = await queue.get()
318318

319319
# only permanent failure rows are put on the queue, so we can check for them here
320-
if finished_row.rollout_status and finished_row.rollout_status.status == "error":
320+
if finished_row.rollout_status and finished_row.rollout_status.status == RolloutStatus.Status.ERROR:
321321
if max_retry > 0 and os.getenv("EP_FAIL_ON_MAX_RETRY", "true") != "false":
322322
raise RuntimeError(
323323
f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}"

eval_protocol/types/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .types import DatasetRow, MCPSession, MCPToolCall, TerminationReason, Trajectory
2-
from .errors import RolloutTerminationException
2+
from .errors import NonSkippableException
33

4-
__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "RolloutTerminationException"]
4+
__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "NonSkippableException"]

eval_protocol/types/errors.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
class RolloutTerminationException(Exception):
1+
class NonSkippableException(Exception):
22
"""
3-
Exception raised during rollout. This error means the rollout needs to be terminated and its not retriable,
4-
and the trajectory need to be perserved for future evalution or analysis.
3+
A type of custom exception raised during rollout or evaluation. This error means the rollout/evaluation result is not skippable and need to be
4+
processed explicitly.
55
66
For example, if the policy (llm) returns 400 User error, we need to end the rollout but keep the trajectory.
77
It differs from other exceptions such as network error, which are retriable and the trajectory should be discarded if
88
it fails eventually after retries.
9-
10-
This will cause trajectory.termination_reason to be set to TerminationReason.INTERRUPTED.
119
"""
1210

1311
pass

eval_protocol/types/types.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ class TerminationReason(str, Enum):
1212
MAX_STEPS: Trajectory ends because we hit the step limit
1313
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
1414
USER_STOP: Trajectory ends because the simulated user signals to stop
15-
ERROR: Trajectory ends because of an error
16-
INTERRUPTED: Trajectory is interrupted by some non-system related reason (e.g. policy returns unexpected response and we need to terminate the rollout).
15+
SKIPPABLE_ERROR: Trajectory ends because of an error, this trajectory can be discarded/skipped during postprocessing/evaluation.
16+
NON_SKIPPABLE_ERROR: Trajectory is interrupted due to some non-skippable error (e.g. policy returns unexpected response and we need to terminate the rollout).
1717
STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop")
1818
LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length")
1919
TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls")
@@ -22,8 +22,8 @@ class TerminationReason(str, Enum):
2222
MAX_STEPS = "max_steps"
2323
CONTROL_PLANE_SIGNAL = "control_plane_signal"
2424
USER_STOP = "user_stop"
25-
ERROR = "error"
26-
INTERRUPTED = "interrupted"
25+
SKIPPABLE_ERROR = "skippable_error"
26+
NON_SKIPPABLE_ERROR = "non_skippable_error"
2727
STOP = "stop"
2828
LENGTH = "length"
2929
TOOL_CALLS = "tool_calls"
@@ -40,8 +40,10 @@ def from_str(cls, value: str) -> "TerminationReason":
4040
return cls.CONTROL_PLANE_SIGNAL
4141
elif value == "user_stop":
4242
return cls.USER_STOP
43-
elif value == "error":
44-
return cls.ERROR
43+
elif value == "skippable_error":
44+
return cls.SKIPPABLE_ERROR
45+
elif value == "non_skippable_error":
46+
return cls.NON_SKIPPABLE_ERROR
4547
elif value == "tool_calls":
4648
return cls.TOOL_CALLS
4749
else:

0 commit comments

Comments
 (0)