Skip to content

Commit 3415d43

Browse files
committed
more change
1 parent a6adeb5 commit 3415d43

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ async def execute_rollouts(
7777
steps: int = 512,
7878
openai_format_log_file: Optional[str] = None,
7979
max_concurrent_rollouts: int = 8,
80+
evaluation_rows: Optional[List[EvaluationRow]] = None,
8081
) -> List[EvaluationRow]:
8182
"""
8283
Execute general rollouts using tool calling interface with automatic record/playback.
@@ -170,9 +171,11 @@ async def _execute_with_semaphore(idx):
170171
# Add note about control plane separation
171172
logger.info(f"🎛️ Trajectories include control plane separation")
172173

173-
# Convert trajectories to unified EvaluationRow format
174-
evaluation_rows = []
175-
for trajectory in trajectories:
174+
# Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility.
175+
if evaluation_rows is None:
176+
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories]
177+
178+
for idx, trajectory in enumerate(trajectories):
176179
# Handle multimodal content by extracting text from complex content structures
177180
messages = []
178181
for msg in trajectory.conversation_history:
@@ -190,26 +193,15 @@ async def _execute_with_semaphore(idx):
190193

191194
messages.append(Message.model_validate(msg_dict))
192195

193-
input_metadata = InputMetadata(
194-
row_id=trajectory.session.dataset_row.id if trajectory.session.dataset_row else None,
195-
dataset_info=asdict(trajectory.session.dataset_row) if trajectory.session.dataset_row else {},
196-
completion_params=CompletionParams(
197-
model=policy.model_id,
198-
temperature=getattr(policy, "temperature", None),
199-
max_tokens=getattr(policy, "max_tokens", None),
200-
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
201-
),
202-
session_data={
203-
"timestamp": time.time(),
204-
},
205-
)
206-
evaluation_row = EvaluationRow(
207-
messages=messages,
208-
tools=shared_tool_schema,
209-
input_metadata=input_metadata,
210-
usage=trajectory.usage,
196+
evaluation_rows[idx].messages = messages
197+
evaluation_rows[idx].tools = shared_tool_schema
198+
evaluation_rows[idx].usage = trajectory.usage
199+
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
200+
model=policy.model_id,
201+
temperature=getattr(policy, "temperature", None),
202+
max_tokens=getattr(policy, "max_tokens", None),
203+
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
211204
)
212-
evaluation_rows.append(evaluation_row)
213205

214206
return evaluation_rows
215207

0 commit comments

Comments
 (0)