@@ -260,8 +260,6 @@ async def _execute_rollout(
260260 {"role" : "user" , "content" : user_prompt },
261261 ]
262262
263- usage_stats_list : List [CompletionUsage ] = []
264-
265263 logger .info (f"🎯 Starting rollout { rollout_idx } in thread { threading .current_thread ().name } " )
266264
267265 # Run rollout loop for this specific environment
@@ -375,7 +373,9 @@ async def _execute_rollout(
375373
376374 # calc llm usage stats happened in this turn if there is aany
377375 if usage_stats :
378- usage_stats_list .append (usage_stats )
376+ trajectory .usage ["prompt_tokens" ] += usage_stats .prompt_tokens
377+ trajectory .usage ["completion_tokens" ] += usage_stats .completion_tokens
378+ trajectory .usage ["total_tokens" ] += usage_stats .total_tokens
379379
380380 # With user simulator, increment step after an entire conversation step
381381 if user_simulator is not None :
@@ -409,7 +409,9 @@ async def _execute_rollout(
409409 # tool indicates rollout should be terminated, call policy one last time to get the final response
410410 _ , usage_stats = await policy (tool_schema , rollout_idx , conversation_history )
411411 if usage_stats :
412- usage_stats_list .append (usage_stats )
412+ trajectory .usage ["prompt_tokens" ] += usage_stats .prompt_tokens
413+ trajectory .usage ["completion_tokens" ] += usage_stats .completion_tokens
414+ trajectory .usage ["total_tokens" ] += usage_stats .total_tokens
413415
414416 # Add final control plane summary
415417 trajectory .control_plane_summary .update (
@@ -460,11 +462,6 @@ async def _execute_rollout(
460462 msg ["control_plane_step" ]["termination_reason" ] = trajectory .termination_reason
461463 break
462464
463- for usage_stats in usage_stats_list :
464- trajectory .usage ["prompt_tokens" ] += usage_stats .prompt_tokens
465- trajectory .usage ["completion_tokens" ] += usage_stats .completion_tokens
466- trajectory .usage ["total_tokens" ] += usage_stats .total_tokens
467-
468465 logger .info (
469466 f"✅ Rollout { rollout_idx } completed: { trajectory .steps } steps, reward: { trajectory .total_reward :.2f} , termination: { trajectory .termination_reason } , in thread { threading .current_thread ().name } "
470467 )
0 commit comments