@@ -98,10 +98,12 @@ async def execute_rollouts(
9898
9999 async def _execute_with_semaphore (idx ):
100100 async with semaphore :
101- return await self ._execute_rollout (
101+ result = await self ._execute_rollout (
102102 envs , policy , idx , steps , openai_logger , recording_mode , playback_mode , start_time
103103 )
104104
105+ return result
106+
105107 tasks = [_execute_with_semaphore (i ) for i in range (envs .n )]
106108 # exceptions will be try catched inside single _execute_rollout
107109 trajectories = await asyncio .gather (* tasks )
@@ -114,7 +116,7 @@ async def _execute_with_semaphore(idx):
114116 shared_tool_schema = envs .tool_schemas
115117
116118 # Clean up
117- await envs .close ()
119+ # await envs.close()
118120
119121 # Enhanced reporting with control plane info
120122 successful = sum (1 for traj in trajectories if traj .total_reward > 0 )
@@ -227,6 +229,7 @@ async def _execute_rollout(
227229 "total_tokens" : 0 ,
228230 },
229231 )
232+ failure_reason = None
230233 try :
231234 current_observation , tool_schema = await envs .reset (session )
232235 system_prompt = dataset_row .system_prompt
@@ -467,11 +470,25 @@ async def _execute_rollout(
467470 logger .info (
468471 f"✅ Rollout { rollout_idx } completed: { trajectory .steps } steps, reward: { trajectory .total_reward :.2f} , termination: { trajectory .termination_reason } , in thread { threading .current_thread ().name } "
469472 )
473+
474+ except asyncio .CancelledError :
475+ logger .error (f"🚨 AsyncIO Cancel Error in roll out { rollout_idx } " , exc_info = True )
476+ failure_reason = "asyncio context cancelled"
470477 except Exception as e :
471478 logger .error (f"🚨 Error in rollout { rollout_idx } : { e } " , exc_info = True )
479+ failure_reason = str (e )
480+ finally :
472481 trajectory .terminated = True
473482 trajectory .termination_reason = TerminationReason .ERROR
474- trajectory .control_plane_summary .update ({"error_message" : str (e )})
483+ trajectory .control_plane_summary .update ({"error_message" : f"{ failure_reason } " })
484+ try :
485+ await envs .connection_manager .reset_session (session )
486+ except :
487+ logger .error (f"Error resetting session { session .session_id } " )
488+ try :
489+ await envs .connection_manager .close_session (session )
490+ except :
491+ logger .error (f"Error closing session { session .session_id } " )
475492 return trajectory
476493
477494 async def _get_control_plane_status (self , session ) -> Optional [Dict [str , Any ]]:
0 commit comments