1212import threading
1313import time
1414from dataclasses import asdict
15- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
15+ from typing import TYPE_CHECKING , Any , AsyncIterator , Callable , Dict , List , Optional , Union
1616
1717from openai .types import CompletionUsage
1818
@@ -42,7 +42,7 @@ async def execute_rollouts(
4242 openai_format_log_file : Optional [str ] = None ,
4343 max_concurrent_rollouts : int = 8 ,
4444 evaluation_rows : Optional [List [EvaluationRow ]] = None ,
45- ) -> List [EvaluationRow ]:
45+ ) -> AsyncIterator [EvaluationRow ]:
4646 """
4747 Execute general rollouts using tool calling interface with automatic record/playback.
4848
@@ -65,7 +65,7 @@ async def execute_rollouts(
6565 - Set and file exists: Playback mode (uses recorded data)
6666
6767 Returns:
68- List of EvaluationRow objects with unified evaluation data format
68+ AsyncIterator of EvaluationRow objects with unified evaluation data format
6969 """
7070 start_time = time .time ()
7171
@@ -91,103 +91,84 @@ async def execute_rollouts(
9191
9292 logger .info (f"🧵 Starting { envs .n } rollouts with max { max_concurrent_rollouts } concurrent threads..." )
9393
94- results = {}
94+ if evaluation_rows is None :
95+ evaluation_rows = [EvaluationRow (messages = [], input_metadata = InputMetadata ()) for _ in range (envs .n )]
96+
97+ shared_tool_schema = envs .tool_schemas
9598
9699 semaphore = asyncio .Semaphore (max_concurrent_rollouts )
97100
98101 async def _execute_with_semaphore (idx ):
99102 async with semaphore :
100- result = await self ._execute_rollout (
103+ trajectory = await self ._execute_rollout (
101104 envs , policy , idx , steps , openai_logger , recording_mode , playback_mode , start_time
102105 )
103106
104- return result
105-
106- tasks = [_execute_with_semaphore (i ) for i in range (envs .n )]
107- # exceptions will be try catched inside single _execute_rollout
108- trajectories = await asyncio .gather (* tasks )
109-
110- # Calculate durations
111- total_duration = time .time () - start_time
112- for trajectory in trajectories :
113- trajectory .duration = total_duration
114-
115- shared_tool_schema = envs .tool_schemas
116-
117- # Enhanced reporting with control plane info
118- successful = sum (1 for traj in trajectories if traj .total_reward > 0 )
119- terminated_by_control_plane = sum (
120- 1
121- for traj in trajectories
122- if traj .control_plane_summary .get ("termination_reason" ) == "control_plane_signal"
123- )
107+ # Convert trajectory to EvaluationRow immediately
108+ evaluation_row = evaluation_rows [idx ]
109+
110+ # Handle multimodal content by extracting text from complex content structures
111+ messages = []
112+ for msg in trajectory .conversation_history :
113+ # Create a copy to avoid modifying the original
114+ msg_dict = dict (msg )
115+
116+ # Handle multimodal content (list of content blocks) by extracting text
117+ if isinstance (msg_dict .get ("content" ), list ):
118+ text_content = None
119+ for content_block in msg_dict ["content" ]:
120+ if isinstance (content_block , dict ) and content_block .get ("type" ) == "text" :
121+ text_content = content_block .get ("text" )
122+ break
123+ msg_dict ["content" ] = text_content or ""
124+
125+ messages .append (Message .model_validate (msg_dict ))
126+
127+ evaluation_row .messages = messages
128+ evaluation_row .tools = shared_tool_schema
129+ evaluation_row .usage = CompletionUsage (** trajectory .usage )
130+ evaluation_row .input_metadata .completion_params = CompletionParams (
131+ model = policy .model_id ,
132+ temperature = getattr (policy , "temperature" , None ),
133+ max_tokens = getattr (policy , "max_tokens" , None ),
134+ max_tool_calls = getattr (policy , "max_tools_per_turn" , None ),
135+ )
124136
125- logger .info (f"📊 Rollout complete: { successful } /{ len (trajectories )} reached goal" )
126- logger .info (f"🎛️ Control plane terminations: { terminated_by_control_plane } /{ len (trajectories )} " )
127- logger .info (f"⏱️ Total duration: { total_duration :.2f} s" )
128- logger .info (f"🧵 Used { max_concurrent_rollouts } concurrent threads" )
137+ if trajectory .terminated :
138+ if trajectory .termination_reason in {
139+ TerminationReason .CONTROL_PLANE_SIGNAL ,
140+ TerminationReason .USER_STOP ,
141+ }:
142+ evaluation_row .rollout_status .status = "finished"
143+ elif trajectory .termination_reason in {TerminationReason .MAX_STEPS , TerminationReason .INTERRUPTED }:
144+ evaluation_row .rollout_status .status = "stopped"
145+ evaluation_row .rollout_status .error_message = trajectory .control_plane_summary .get (
146+ "termination_reason" , trajectory .termination_reason
147+ )
148+ else :
149+ evaluation_row .rollout_status .status = "error"
150+ evaluation_row .rollout_status .error_message = trajectory .control_plane_summary .get (
151+ "error_message" , None
152+ )
153+ else :
154+ evaluation_row .rollout_status .status = "running"
129155
130- # Print log file locations if created
131- if openai_format_log_file :
132- logger .info (f"💬 OpenAI format log: { openai_format_log_file } " )
133- if recording_mode :
134- logger .info (f"📝 Recorded trajectory: { playback_file } " )
135- # Add note about control plane separation
136- logger .info (f"🎛️ Trajectories include control plane separation" )
156+ return evaluation_row
137157
138- # Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility.
139- if evaluation_rows is None :
140- evaluation_rows = [EvaluationRow (messages = [], input_metadata = InputMetadata ()) for _ in trajectories ]
141-
142- for idx , trajectory in enumerate (trajectories ):
143- # Handle multimodal content by extracting text from complex content structures
144- messages = []
145- for msg in trajectory .conversation_history :
146- # Create a copy to avoid modifying the original
147- msg_dict = dict (msg )
148-
149- # Handle multimodal content (list of content blocks) by extracting text
150- if isinstance (msg_dict .get ("content" ), list ):
151- text_content = None
152- for content_block in msg_dict ["content" ]:
153- if isinstance (content_block , dict ) and content_block .get ("type" ) == "text" :
154- text_content = content_block .get ("text" )
155- break
156- msg_dict ["content" ] = text_content or ""
157-
158- messages .append (Message .model_validate (msg_dict ))
159-
160- evaluation_rows [idx ].messages = messages
161- # evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
162- # evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
163- evaluation_rows [idx ].tools = shared_tool_schema
164- evaluation_rows [idx ].usage = CompletionUsage (** trajectory .usage )
165- evaluation_rows [idx ].input_metadata .completion_params = CompletionParams (
166- model = policy .model_id ,
167- temperature = getattr (policy , "temperature" , None ),
168- max_tokens = getattr (policy , "max_tokens" , None ),
169- max_tool_calls = getattr (policy , "max_tools_per_turn" , None ),
170- )
171- if trajectory .terminated :
172- if trajectory .termination_reason in {
173- TerminationReason .CONTROL_PLANE_SIGNAL ,
174- TerminationReason .USER_STOP ,
175- }:
176- evaluation_rows [idx ].rollout_status .status = "finished"
177- elif trajectory .termination_reason in {TerminationReason .MAX_STEPS , TerminationReason .INTERRUPTED }:
178- evaluation_rows [idx ].rollout_status .status = "stopped"
179- evaluation_rows [idx ].rollout_status .error_message = trajectory .control_plane_summary .get (
180- "termination_reason" , trajectory .termination_reason
181- )
182- else :
183- evaluation_rows [idx ].rollout_status .status = "error"
184- evaluation_rows [idx ].rollout_status .error_message = trajectory .control_plane_summary .get (
185- "error_message" , None
186- )
187- else :
188- evaluation_rows [idx ].rollout_status .status = "running"
158+ # Create all tasks
159+ tasks = [asyncio .create_task (_execute_with_semaphore (i )) for i in range (envs .n )]
189160
190- return evaluation_rows
161+ # Yield results as they complete (note that they're not necessarily in original order)
162+ try :
163+ for task in asyncio .as_completed (tasks ):
164+ try :
165+ yield await task
166+ except Exception :
167+ logger .exception ("Error processing rollout" )
168+ finally :
169+ for t in tasks :
170+ t .cancel ()
171+ await asyncio .gather (* tasks , return_exceptions = True )
191172
192173 async def _execute_rollout (
193174 self ,
0 commit comments