1111import os
1212import threading
1313import time
14- from concurrent .futures import ThreadPoolExecutor , as_completed
15- from dataclasses import asdict , dataclass
14+ from dataclasses import asdict
1615from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
1716
1817from openai .types import CompletionUsage
@@ -98,10 +97,12 @@ async def execute_rollouts(
9897
9998 async def _execute_with_semaphore (idx ):
10099 async with semaphore :
101- return await self ._execute_rollout (
100+ result = await self ._execute_rollout (
102101 envs , policy , idx , steps , openai_logger , recording_mode , playback_mode , start_time
103102 )
104103
104+ return result
105+
105106 tasks = [_execute_with_semaphore (i ) for i in range (envs .n )]
106107 # exceptions will be try catched inside single _execute_rollout
107108 trajectories = await asyncio .gather (* tasks )
@@ -113,9 +114,6 @@ async def _execute_with_semaphore(idx):
113114
114115 shared_tool_schema = envs .tool_schemas
115116
116- # Clean up
117- await envs .close ()
118-
119117 # Enhanced reporting with control plane info
120118 successful = sum (1 for traj in trajectories if traj .total_reward > 0 )
121119 terminated_by_control_plane = sum (
@@ -176,8 +174,11 @@ async def _execute_with_semaphore(idx):
176174 TerminationReason .USER_STOP ,
177175 }:
178176 evaluation_rows [idx ].rollout_status .status = "finished"
179- elif trajectory .termination_reason == TerminationReason .MAX_STEPS :
177+ elif trajectory .termination_reason in { TerminationReason .MAX_STEPS , TerminationReason . INTERRUPTED } :
180178 evaluation_rows [idx ].rollout_status .status = "stopped"
179+ evaluation_rows [idx ].rollout_status .error_message = trajectory .control_plane_summary .get (
180+ "termination_reason" , trajectory .termination_reason
181+ )
181182 else :
182183 evaluation_rows [idx ].rollout_status .status = "error"
183184 evaluation_rows [idx ].rollout_status .error_message = trajectory .control_plane_summary .get (
@@ -227,6 +228,7 @@ async def _execute_rollout(
227228 "total_tokens" : 0 ,
228229 },
229230 )
231+ failure_reason = None
230232 try :
231233 current_observation , tool_schema = await envs .reset (session )
232234 system_prompt = dataset_row .system_prompt
@@ -248,7 +250,7 @@ async def _execute_rollout(
248250
249251 # Get initial messages in tau2-bench format for user simulator
250252 user_simulator_state = user_simulator .get_init_state ()
251- user_message , user_simulator_state = user_simulator .generate_next_message (
253+ user_message , user_simulator_state = await user_simulator .generate_next_message (
252254 AssistantMessage (role = "assistant" , content = "Hi! How can I help you today?" ),
253255 user_simulator_state ,
254256 )
@@ -280,7 +282,7 @@ async def _execute_rollout(
280282 # Last message was agent, simulated user response
281283 if user_simulator_messages and isinstance (user_simulator_messages [- 1 ], AssistantMessage ):
282284 # Generate user response using the simulator
283- user_message , user_simulator_state = user_simulator .generate_next_message (
285+ user_message , user_simulator_state = await user_simulator .generate_next_message (
284286 user_simulator_messages [- 1 ], user_simulator_state
285287 )
286288 user_content = user_message .content if user_message .content else ""
@@ -312,8 +314,7 @@ async def _execute_rollout(
312314 # If there's no user simulator, no tool call means policy failed and we should terminate the rollout
313315 elif tool_calls [0 ].tool_name in ["_playback_terminate" , "_no_tool_call" ]:
314316 trajectory .terminated = True
315- trajectory .termination_reason = TerminationReason .ERROR
316- trajectory .control_plane_summary .update ({"error_message" : "No expected tool call" })
317+ trajectory .termination_reason = TerminationReason .INTERRUPTED
317318 break
318319
319320 # Execute each tool call sequentially
@@ -467,11 +468,26 @@ async def _execute_rollout(
467468 logger .info (
468469 f"✅ Rollout { rollout_idx } completed: { trajectory .steps } steps, reward: { trajectory .total_reward :.2f} , termination: { trajectory .termination_reason } , in thread { threading .current_thread ().name } "
469470 )
471+
472+ except asyncio .CancelledError :
473+ logger .error (f"🚨 AsyncIO Cancel Error in roll out { rollout_idx } " , exc_info = True )
474+ failure_reason = "asyncio context cancelled"
470475 except Exception as e :
471476 logger .error (f"🚨 Error in rollout { rollout_idx } : { e } " , exc_info = True )
472- trajectory .terminated = True
473- trajectory .termination_reason = TerminationReason .ERROR
474- trajectory .control_plane_summary .update ({"error_message" : str (e )})
477+ failure_reason = str (e )
478+ finally :
479+ if failure_reason :
480+ trajectory .terminated = True
481+ trajectory .termination_reason = TerminationReason .ERROR
482+ trajectory .control_plane_summary .update ({"error_message" : f"{ failure_reason } " })
483+ try :
484+ await envs .connection_manager .reset_session (session )
485+ except :
486+ logger .error (f"Error resetting session { session .session_id } " )
487+ try :
488+ await envs .connection_manager .close_session (session )
489+ except :
490+ logger .error (f"Error closing session { session .session_id } " )
475491 return trajectory
476492
477493 async def _get_control_plane_status (self , session ) -> Optional [Dict [str , Any ]]:
0 commit comments