Skip to content

Commit b6f6b7f

Browse files
committed
keep intermediate llm usage stats even for failure trajectories
1 parent 986452f commit b6f6b7f

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,6 @@ async def _execute_rollout(
260260
{"role": "user", "content": user_prompt},
261261
]
262262

263-
usage_stats_list: List[CompletionUsage] = []
264-
265263
logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}")
266264

267265
# Run rollout loop for this specific environment
@@ -375,7 +373,9 @@ async def _execute_rollout(
375373

376374
# calc llm usage stats happened in this turn if there is aany
377375
if usage_stats:
378-
usage_stats_list.append(usage_stats)
376+
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
377+
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
378+
trajectory.usage["total_tokens"] += usage_stats.total_tokens
379379

380380
# With user simulator, increment step after an entire conversation step
381381
if user_simulator is not None:
@@ -409,7 +409,9 @@ async def _execute_rollout(
409409
# tool indicates rollout should be terminated, call policy one last time to get the final response
410410
_, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
411411
if usage_stats:
412-
usage_stats_list.append(usage_stats)
412+
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
413+
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
414+
trajectory.usage["total_tokens"] += usage_stats.total_tokens
413415

414416
# Add final control plane summary
415417
trajectory.control_plane_summary.update(
@@ -460,11 +462,6 @@ async def _execute_rollout(
460462
msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason
461463
break
462464

463-
for usage_stats in usage_stats_list:
464-
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
465-
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
466-
trajectory.usage["total_tokens"] += usage_stats.total_tokens
467-
468465
logger.info(
469466
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
470467
)

0 commit comments

Comments
 (0)