Skip to content

Commit 86294ed

Browse files
committed
record rollout status
1 parent d274bce commit 86294ed

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ async def _execute_with_semaphore(idx):
103103
)
104104

105105
tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
106-
# exceptions should be try catched inside single _execute_rollout
107-
# exceptions should be try catched inside single _execute_rollout
106+
# exceptions will be try catched inside single _execute_rollout
108107
trajectories = await asyncio.gather(*tasks)
109108

110109
# Calculate durations
@@ -171,6 +170,21 @@ async def _execute_with_semaphore(idx):
171170
max_tokens=getattr(policy, "max_tokens", None),
172171
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
173172
)
173+
if trajectory.terminated:
174+
if trajectory.termination_reason in {
175+
TerminationReason.CONTROL_PLANE_SIGNAL,
176+
TerminationReason.USER_STOP,
177+
}:
178+
evaluation_rows[idx].rollout_status.status = "finished"
179+
elif trajectory.termination_reason == TerminationReason.MAX_STEPS:
180+
evaluation_rows[idx].rollout_status.status = "stopped"
181+
else:
182+
evaluation_rows[idx].rollout_status.status = "error"
183+
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
184+
"error_message", None
185+
)
186+
else:
187+
evaluation_rows[idx].rollout_status.status = "running"
174188

175189
return evaluation_rows
176190

@@ -458,8 +472,7 @@ async def _execute_rollout(
458472
logger.error(f"🚨 Error in rollout {rollout_idx}: {e}", exc_info=True)
459473
trajectory.terminated = True
460474
trajectory.termination_reason = TerminationReason.ERROR
461-
trajectory.input_metadata.session_data["error"] = True
462-
trajectory.input_metadata.session_data["error_message"] = str(e)
475+
trajectory.control_plane_summary.update({"error_message": str(e)})
463476
return trajectory
464477

465478
async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]:

eval_protocol/models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,21 @@ class EvalMetadata(BaseModel):
220220
passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold")
221221

222222

223+
class RolloutStatus(BaseModel):
224+
"""Status of the rollout."""
225+
226+
"""
227+
running: Unfinished rollout which is still in progress.
228+
finished: Rollout finished successfully.
229+
error: Rollout failed.
230+
stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop).
231+
"""
232+
status: Literal["running", "finished", "error", "stopped"] = Field(
233+
"finished", description="Status of the rollout."
234+
)
235+
error_message: Optional[str] = Field(None, description="Error message if the rollout failed.")
236+
237+
223238
class EvaluationRow(BaseModel):
224239
"""
225240
Unified data structure for a single evaluation unit that contains messages,
@@ -244,6 +259,11 @@ class EvaluationRow(BaseModel):
244259
description="Metadata related to the input (dataset info, model config, session data, etc.).",
245260
)
246261

262+
rollout_status: RolloutStatus = Field(
263+
default_factory=RolloutStatus,
264+
description="The status of the rollout.",
265+
)
266+
247267
# Ground truth reference (moved from EvaluateResult to top level)
248268
ground_truth: Optional[str] = Field(
249269
default=None, description="Optional ground truth reference for this evaluation."

0 commit comments

Comments
 (0)