Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union

import anyio
import httpx
from openai.types import CompletionUsage

from vendor.tau2.data_model.message import AssistantMessage, UserMessage
Expand Down Expand Up @@ -221,7 +222,7 @@ async def _execute_rollout(
current_observation = user_message.content if user_message.content else ""

user_prompt = envs.format_user_prompt(rollout_idx, current_observation)
conversation_history = [
trajectory.conversation_history = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
Expand All @@ -241,7 +242,7 @@ async def _execute_rollout(

if user_simulator and user_simulator_state:
# Get user simulator messages and find the last assistant message
user_simulator_messages = self._get_user_simulator_messages(conversation_history)
user_simulator_messages = self._get_user_simulator_messages(trajectory.conversation_history)

# Last message was agent, simulated user response
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
Expand All @@ -252,7 +253,7 @@ async def _execute_rollout(
user_content = user_message.content if user_message.content else ""

user_prompt = envs.format_user_prompt(rollout_idx, user_content)
conversation_history.append({"role": "user", "content": user_prompt})
trajectory.conversation_history.append({"role": "user", "content": user_prompt})

# Check if user simulator signaled termination
if UserSimulator.is_stop(user_message):
Expand All @@ -262,7 +263,7 @@ async def _execute_rollout(
# In each turn: keep looping until assistant is ready to provide final response
while not turn_completed and not trajectory.terminated:
tool_calls, usage_stats, finish_reason = await policy(
tool_schema, rollout_idx, conversation_history
tool_schema, rollout_idx, trajectory.conversation_history
)

# calc llm usage stats happened in this turn if there is aany
Expand Down Expand Up @@ -294,7 +295,7 @@ async def _execute_rollout(
rollout_idx,
tool_call,
tool_response,
conversation_history,
trajectory.conversation_history,
reward,
env_end,
info,
Expand Down Expand Up @@ -325,12 +326,14 @@ async def _execute_rollout(
"num_tool_calls": 1,
}
print(f"🔍 control_plane_step: {control_plane_step}")
conversation_history[-1]["control_plane_step"] = control_plane_step
trajectory.conversation_history[-1]["control_plane_step"] = control_plane_step
trajectory.control_plane_steps.append(control_plane_step)

# Log conversation state for playback if in recording mode
if recording_mode:
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
policy.log_conversation_state_for_playback(
rollout_idx, step - 1, trajectory.conversation_history
)

if env_end:
# if the env marks the end of the rollout, break the tool call loop
Expand Down Expand Up @@ -364,17 +367,21 @@ async def _execute_rollout(
"tool_calls": tool_calls_summary,
"num_tool_calls": len(tool_calls),
}
conversation_history[-1]["control_plane_step"] = control_plane_step
trajectory.conversation_history[-1]["control_plane_step"] = control_plane_step
trajectory.control_plane_steps.append(control_plane_step)

# Log conversation state for playback if in recording mode
if recording_mode:
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
policy.log_conversation_state_for_playback(
rollout_idx, step - 1, trajectory.conversation_history
)

# if the env marks end, update control plane summary and do one last policy call, then break the agent loop
# this is to ensure each turn ends with an assistant message, which will align with the actual agentic llm behavior
if env_end:
_, usage_stats, finish_reason = await policy(tool_schema, rollout_idx, conversation_history)
_, usage_stats, finish_reason = await policy(
tool_schema, rollout_idx, trajectory.conversation_history
)
if usage_stats:
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
Expand All @@ -392,10 +399,10 @@ async def _execute_rollout(

# Log final OpenAI conversation for terminated trajectories only
if openai_logger:
if conversation_history and len(conversation_history) > 0:
if trajectory.conversation_history and len(trajectory.conversation_history) > 0:
openai_logger(
{
"messages": conversation_history,
"messages": trajectory.conversation_history,
"metadata": {
"session_id": session.session_id,
"seed": session.seed,
Expand All @@ -421,8 +428,6 @@ async def _execute_rollout(
if not trajectory.termination_reason and step >= steps:
trajectory.termination_reason = TerminationReason.MAX_STEPS

trajectory.conversation_history = conversation_history

# Add termination_reason to the final control_plane_step
for msg in reversed(trajectory.conversation_history):
if msg.get("control_plane_step"):
Expand Down
Loading