1212import threading
1313import time
1414from dataclasses import asdict
15- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
15+ from typing import TYPE_CHECKING , Any , AsyncIterator , Callable , Dict , List , Optional , Union
1616
1717import anyio
1818from openai .types import CompletionUsage
@@ -43,7 +43,7 @@ async def execute_rollouts(
4343 openai_format_log_file : Optional [str ] = None ,
4444 max_concurrent_rollouts : int = 8 ,
4545 evaluation_rows : Optional [List [EvaluationRow ]] = None ,
46- ) -> List [EvaluationRow ]:
46+ ) -> AsyncIterator [EvaluationRow ]:
4747 """
4848 Execute general rollouts using tool calling interface with automatic record/playback.
4949
@@ -66,7 +66,7 @@ async def execute_rollouts(
6666 - Set and file exists: Playback mode (uses recorded data)
6767
6868 Returns:
69- List of EvaluationRow objects with unified evaluation data format
69+ AsyncIterator of EvaluationRow objects with unified evaluation data format
7070 """
7171 start_time = time .time ()
7272
@@ -92,96 +92,77 @@ async def execute_rollouts(
9292
9393 logger .info (f"🧵 Starting { envs .n } rollouts with max { max_concurrent_rollouts } concurrent threads..." )
9494
95- results = {}
95+ if evaluation_rows is None :
96+ evaluation_rows = [EvaluationRow (messages = [], input_metadata = InputMetadata ()) for _ in range (envs .n )]
97+
98+ shared_tool_schema = envs .tool_schemas
9699
97100 semaphore = asyncio .Semaphore (max_concurrent_rollouts )
98101
99102 async def _execute_with_semaphore (idx ):
100103 async with semaphore :
101- result = await self ._execute_rollout (
104+ trajectory = await self ._execute_rollout (
102105 envs , policy , idx , steps , openai_logger , recording_mode , playback_mode , start_time
103106 )
104107
105- return result
106-
107- tasks = [_execute_with_semaphore (i ) for i in range (envs .n )]
108- # exceptions will be try catched inside single _execute_rollout
109- trajectories = await asyncio .gather (* tasks )
110-
111- # Calculate durations
112- total_duration = time .time () - start_time
113- for trajectory in trajectories :
114- trajectory .duration = total_duration
115-
116- shared_tool_schema = envs .tool_schemas
117-
118- # Enhanced reporting with control plane info
119- successful = sum (1 for traj in trajectories if traj .total_reward > 0 )
120- terminated_by_control_plane = sum (
121- 1
122- for traj in trajectories
123- if traj .control_plane_summary .get ("termination_reason" ) == "control_plane_signal"
124- )
108+ # Convert trajectory to EvaluationRow immediately
109+ evaluation_row = evaluation_rows [idx ]
110+
111+ # Handle multimodal content by extracting text from complex content structures
112+ messages = []
113+ for msg in trajectory .conversation_history :
114+ # Create a copy to avoid modifying the original
115+ msg_dict = dict (msg )
116+
117+ # Handle multimodal content (list of content blocks) by extracting text
118+ if isinstance (msg_dict .get ("content" ), list ):
119+ text_content = None
120+ for content_block in msg_dict ["content" ]:
121+ if isinstance (content_block , dict ) and content_block .get ("type" ) == "text" :
122+ text_content = content_block .get ("text" )
123+ break
124+ msg_dict ["content" ] = text_content or ""
125+
126+ messages .append (Message .model_validate (msg_dict ))
127+
128+ evaluation_row .messages = messages
129+ evaluation_row .tools = shared_tool_schema
130+ evaluation_row .usage = CompletionUsage (** trajectory .usage )
131+ evaluation_row .input_metadata .completion_params = CompletionParams (
132+ model = policy .model_id ,
133+ temperature = getattr (policy , "temperature" , None ),
134+ max_tokens = getattr (policy , "max_tokens" , None ),
135+ max_tool_calls = getattr (policy , "max_tools_per_turn" , None ),
136+ )
125137
126- logger .info (f"📊 Rollout complete: { successful } /{ len (trajectories )} reached goal" )
127- logger .info (f"🎛️ Control plane terminations: { terminated_by_control_plane } /{ len (trajectories )} " )
128- logger .info (f"⏱️ Total duration: { total_duration :.2f} s" )
129- logger .info (f"🧵 Used { max_concurrent_rollouts } concurrent threads" )
138+ if trajectory .terminated :
139+ if trajectory .termination_reason == TerminationReason .ERROR :
140+ evaluation_row .rollout_status .status = "error"
141+ evaluation_row .rollout_status .error_message = trajectory .control_plane_summary .get (
142+ "error_message" , None
143+ )
144+ else :
145+ evaluation_row .rollout_status .status = "finished"
146+ evaluation_row .rollout_status .termination_reason = trajectory .termination_reason
147+ else :
148+ evaluation_row .rollout_status .status = "running"
130149
131- # Print log file locations if created
132- if openai_format_log_file :
133- logger .info (f"💬 OpenAI format log: { openai_format_log_file } " )
134- if recording_mode :
135- logger .info (f"📝 Recorded trajectory: { playback_file } " )
136- # Add note about control plane separation
137- logger .info (f"🎛️ Trajectories include control plane separation" )
150+ return evaluation_row
138151
139- # Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility.
140- if evaluation_rows is None :
141- evaluation_rows = [EvaluationRow (messages = [], input_metadata = InputMetadata ()) for _ in trajectories ]
142-
143- for idx , trajectory in enumerate (trajectories ):
144- # Handle multimodal content by extracting text from complex content structures
145- messages = []
146- for msg in trajectory .conversation_history :
147- # Create a copy to avoid modifying the original
148- msg_dict = dict (msg )
149-
150- # Handle multimodal content (list of content blocks) by extracting text
151- if isinstance (msg_dict .get ("content" ), list ):
152- text_content = None
153- for content_block in msg_dict ["content" ]:
154- if isinstance (content_block , dict ) and content_block .get ("type" ) == "text" :
155- text_content = content_block .get ("text" )
156- break
157- msg_dict ["content" ] = text_content or ""
158-
159- messages .append (Message .model_validate (msg_dict ))
160-
161- evaluation_rows [idx ].messages = messages
162- # evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
163- # evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
164- evaluation_rows [idx ].tools = shared_tool_schema
165- evaluation_rows [idx ].usage = CompletionUsage (** trajectory .usage )
166- evaluation_rows [idx ].input_metadata .completion_params = CompletionParams (
167- model = policy .model_id ,
168- temperature = getattr (policy , "temperature" , None ),
169- max_tokens = getattr (policy , "max_tokens" , None ),
170- max_tool_calls = getattr (policy , "max_tools_per_turn" , None ),
171- )
172- if trajectory .terminated :
173- if trajectory .termination_reason == TerminationReason .ERROR :
174- evaluation_rows [idx ].rollout_status .status = "error"
175- evaluation_rows [idx ].rollout_status .termination_reason = trajectory .control_plane_summary .get (
176- "error_message" , None
177- )
178- else :
179- evaluation_rows [idx ].rollout_status .status = "finished"
180- evaluation_rows [idx ].rollout_status .termination_reason = trajectory .termination_reason
181- else :
182- evaluation_rows [idx ].rollout_status .status = "running"
152+ # Create all tasks
153+ tasks = [asyncio .create_task (_execute_with_semaphore (i )) for i in range (envs .n )]
183154
184- return evaluation_rows
155+ # Yield results as they complete (note that they're not necessarily in original order)
156+ try :
157+ for task in asyncio .as_completed (tasks ):
158+ try :
159+ yield await task
160+ except Exception :
161+ logger .exception ("Error processing rollout" )
162+ finally :
163+ for t in tasks :
164+ t .cancel ()
165+ await asyncio .gather (* tasks , return_exceptions = True )
185166
186167 async def _execute_rollout (
187168 self ,
0 commit comments