Skip to content

Commit 5c6de06

Browse files
committed
updated logger
1 parent bd384ed commit 5c6de06

9 files changed

Lines changed: 207 additions & 260 deletions

eval_protocol/mcp/execution/manager.py

Lines changed: 69 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import threading
1313
import time
1414
from dataclasses import asdict
15-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
15+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union
1616

1717
from openai.types import CompletionUsage
1818

@@ -42,7 +42,7 @@ async def execute_rollouts(
4242
openai_format_log_file: Optional[str] = None,
4343
max_concurrent_rollouts: int = 8,
4444
evaluation_rows: Optional[List[EvaluationRow]] = None,
45-
) -> List[EvaluationRow]:
45+
) -> AsyncIterator[EvaluationRow]:
4646
"""
4747
Execute general rollouts using tool calling interface with automatic record/playback.
4848
@@ -65,7 +65,7 @@ async def execute_rollouts(
6565
- Set and file exists: Playback mode (uses recorded data)
6666
6767
Returns:
68-
List of EvaluationRow objects with unified evaluation data format
68+
AsyncIterator of EvaluationRow objects with unified evaluation data format
6969
"""
7070
start_time = time.time()
7171

@@ -91,103 +91,84 @@ async def execute_rollouts(
9191

9292
logger.info(f"🧵 Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...")
9393

94-
results = {}
94+
if evaluation_rows is None:
95+
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in range(envs.n)]
96+
97+
shared_tool_schema = envs.tool_schemas
9598

9699
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
97100

98101
async def _execute_with_semaphore(idx):
99102
async with semaphore:
100-
result = await self._execute_rollout(
103+
trajectory = await self._execute_rollout(
101104
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
102105
)
103106

104-
return result
105-
106-
tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
107-
# exceptions will be try catched inside single _execute_rollout
108-
trajectories = await asyncio.gather(*tasks)
109-
110-
# Calculate durations
111-
total_duration = time.time() - start_time
112-
for trajectory in trajectories:
113-
trajectory.duration = total_duration
114-
115-
shared_tool_schema = envs.tool_schemas
116-
117-
# Enhanced reporting with control plane info
118-
successful = sum(1 for traj in trajectories if traj.total_reward > 0)
119-
terminated_by_control_plane = sum(
120-
1
121-
for traj in trajectories
122-
if traj.control_plane_summary.get("termination_reason") == "control_plane_signal"
123-
)
107+
# Convert trajectory to EvaluationRow immediately
108+
evaluation_row = evaluation_rows[idx]
109+
110+
# Handle multimodal content by extracting text from complex content structures
111+
messages = []
112+
for msg in trajectory.conversation_history:
113+
# Create a copy to avoid modifying the original
114+
msg_dict = dict(msg)
115+
116+
# Handle multimodal content (list of content blocks) by extracting text
117+
if isinstance(msg_dict.get("content"), list):
118+
text_content = None
119+
for content_block in msg_dict["content"]:
120+
if isinstance(content_block, dict) and content_block.get("type") == "text":
121+
text_content = content_block.get("text")
122+
break
123+
msg_dict["content"] = text_content or ""
124+
125+
messages.append(Message.model_validate(msg_dict))
126+
127+
evaluation_row.messages = messages
128+
evaluation_row.tools = shared_tool_schema
129+
evaluation_row.usage = CompletionUsage(**trajectory.usage)
130+
evaluation_row.input_metadata.completion_params = CompletionParams(
131+
model=policy.model_id,
132+
temperature=getattr(policy, "temperature", None),
133+
max_tokens=getattr(policy, "max_tokens", None),
134+
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
135+
)
124136

125-
logger.info(f"📊 Rollout complete: {successful}/{len(trajectories)} reached goal")
126-
logger.info(f"🎛️ Control plane terminations: {terminated_by_control_plane}/{len(trajectories)}")
127-
logger.info(f"⏱️ Total duration: {total_duration:.2f}s")
128-
logger.info(f"🧵 Used {max_concurrent_rollouts} concurrent threads")
137+
if trajectory.terminated:
138+
if trajectory.termination_reason in {
139+
TerminationReason.CONTROL_PLANE_SIGNAL,
140+
TerminationReason.USER_STOP,
141+
}:
142+
evaluation_row.rollout_status.status = "finished"
143+
elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}:
144+
evaluation_row.rollout_status.status = "stopped"
145+
evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get(
146+
"termination_reason", trajectory.termination_reason
147+
)
148+
else:
149+
evaluation_row.rollout_status.status = "error"
150+
evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get(
151+
"error_message", None
152+
)
153+
else:
154+
evaluation_row.rollout_status.status = "running"
129155

130-
# Print log file locations if created
131-
if openai_format_log_file:
132-
logger.info(f"💬 OpenAI format log: {openai_format_log_file}")
133-
if recording_mode:
134-
logger.info(f"📝 Recorded trajectory: {playback_file}")
135-
# Add note about control plane separation
136-
logger.info(f"🎛️ Trajectories include control plane separation")
156+
return evaluation_row
137157

138-
# Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility.
139-
if evaluation_rows is None:
140-
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories]
141-
142-
for idx, trajectory in enumerate(trajectories):
143-
# Handle multimodal content by extracting text from complex content structures
144-
messages = []
145-
for msg in trajectory.conversation_history:
146-
# Create a copy to avoid modifying the original
147-
msg_dict = dict(msg)
148-
149-
# Handle multimodal content (list of content blocks) by extracting text
150-
if isinstance(msg_dict.get("content"), list):
151-
text_content = None
152-
for content_block in msg_dict["content"]:
153-
if isinstance(content_block, dict) and content_block.get("type") == "text":
154-
text_content = content_block.get("text")
155-
break
156-
msg_dict["content"] = text_content or ""
157-
158-
messages.append(Message.model_validate(msg_dict))
159-
160-
evaluation_rows[idx].messages = messages
161-
# evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
162-
# evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
163-
evaluation_rows[idx].tools = shared_tool_schema
164-
evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage)
165-
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
166-
model=policy.model_id,
167-
temperature=getattr(policy, "temperature", None),
168-
max_tokens=getattr(policy, "max_tokens", None),
169-
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
170-
)
171-
if trajectory.terminated:
172-
if trajectory.termination_reason in {
173-
TerminationReason.CONTROL_PLANE_SIGNAL,
174-
TerminationReason.USER_STOP,
175-
}:
176-
evaluation_rows[idx].rollout_status.status = "finished"
177-
elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}:
178-
evaluation_rows[idx].rollout_status.status = "stopped"
179-
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
180-
"termination_reason", trajectory.termination_reason
181-
)
182-
else:
183-
evaluation_rows[idx].rollout_status.status = "error"
184-
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
185-
"error_message", None
186-
)
187-
else:
188-
evaluation_rows[idx].rollout_status.status = "running"
158+
# Create all tasks
159+
tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)]
189160

190-
return evaluation_rows
161+
# Yield results as they complete (note that they're not necessarily in original order)
162+
try:
163+
for task in asyncio.as_completed(tasks):
164+
try:
165+
yield await task
166+
except Exception:
167+
logger.exception("Error processing rollout")
168+
finally:
169+
for t in tasks:
170+
t.cancel()
171+
await asyncio.gather(*tasks, return_exceptions=True)
191172

192173
async def _execute_rollout(
193174
self,

eval_protocol/mcp_env.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,20 @@
4141
"""
4242

4343
import asyncio
44+
import hashlib
45+
import json
4446

4547
# For legacy compatibility - import the facade functions
4648
import logging
4749
import random
48-
from typing import Any, Callable, Dict, List, Optional, Union
50+
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
4951

5052
# Import all functionality from the new modular components
5153
from .mcp.execution.manager import ExecutionManager
5254
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LiteLLMPolicy, LLMBasePolicy, OpenAIPolicy
5355
from .mcp.session.manager import GeneralMCPVectorEnv
5456
from .models import EvaluationRow
5557
from .types import DatasetRow, MCPSession, MCPToolCall
56-
import asyncio
57-
import hashlib
58-
import json
5958

6059
logger = logging.getLogger(__name__)
6160

@@ -247,7 +246,7 @@ async def rollout(
247246
steps: int = 512,
248247
openai_format_log_file: Optional[str] = None,
249248
max_concurrent_rollouts: int = 8,
250-
) -> List[EvaluationRow]:
249+
) -> AsyncIterator[EvaluationRow]:
251250
"""
252251
Execute general rollouts using tool calling interface with automatic record/playback.
253252
@@ -307,9 +306,10 @@ async def rollout(
307306
# Use the new ExecutionManager for execution
308307
execution_manager = ExecutionManager()
309308

310-
return await execution_manager.execute_rollouts(
309+
async for evaluation_row in execution_manager.execute_rollouts(
311310
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
312-
)
311+
):
312+
yield evaluation_row
313313

314314

315315
async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from .default_agent_rollout_processor import default_agent_rollout_processor
22
from .default_dataset_adapter import default_dataset_adapter
3+
from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
34
from .default_no_op_rollout_process import default_no_op_rollout_processor
45
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
56
from .evaluation_test import evaluation_test
67
from .types import RolloutProcessor, RolloutProcessorConfig
78

89
__all__ = [
910
"default_agent_rollout_processor",
11+
"default_mcp_gym_rollout_processor",
1012
"default_no_op_rollout_processor",
1113
"default_single_turn_rollout_processor",
1214
"default_dataset_adapter",

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
import json
3+
import logging
34
import os
4-
from typing import Any, List, Optional, Union
5+
from typing import Any, AsyncIterator, List, Optional, Union
56

67
from mcp.types import CallToolResult, TextContent
78
from openai import NOT_GIVEN, NotGiven
@@ -14,6 +15,8 @@
1415
from eval_protocol.models import EvaluationRow, Message
1516
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class Agent:
1922
"""
@@ -114,13 +117,42 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex
114117

115118
async def default_agent_rollout_processor(
116119
rows: List[EvaluationRow], config: RolloutProcessorConfig
117-
) -> List[EvaluationRow]:
118-
dataset: Dataset = []
119-
for row in rows:
120+
) -> AsyncIterator[EvaluationRow]:
121+
"""Process agent rollouts with bounded concurrency and yield as they complete."""
122+
123+
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
124+
semaphore = asyncio.Semaphore(max_concurrent)
125+
126+
async def process_row(row: EvaluationRow) -> EvaluationRow:
127+
"""Process a single row with agent rollout."""
120128
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
121-
await agent.setup()
122-
await agent.call_agent()
123-
dataset.append(agent.evaluation_row)
124-
if agent.mcp_client:
125-
await agent.mcp_client.cleanup()
126-
return dataset
129+
try:
130+
await agent.setup()
131+
await agent.call_agent()
132+
return agent.evaluation_row
133+
finally:
134+
if agent.mcp_client:
135+
await agent.mcp_client.cleanup()
136+
137+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
138+
async with semaphore:
139+
try:
140+
return await process_row(r)
141+
except Exception as e:
142+
logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}")
143+
return r
144+
145+
# Create all tasks
146+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
147+
148+
# Yield results as they complete (note that they're not necessarily in original order)
149+
try:
150+
for task in asyncio.as_completed(tasks):
151+
try:
152+
yield await task
153+
except Exception:
154+
logger.exception("Error processing row")
155+
finally:
156+
for t in tasks:
157+
t.cancel()
158+
await asyncio.gather(*tasks, return_exceptions=True)

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import subprocess
77
import time
88
from pathlib import Path
9-
from typing import List, Optional
9+
from typing import AsyncIterator, List, Optional
1010

1111
import eval_protocol as ep
1212
from eval_protocol.models import EvaluationRow, Message
@@ -194,22 +194,19 @@ def __exit__(self, exc_type, exc_val, exc_tb):
194194

195195
async def default_mcp_gym_rollout_processor(
196196
rows: List[EvaluationRow], config: RolloutProcessorConfig
197-
) -> List[EvaluationRow]:
197+
) -> AsyncIterator[EvaluationRow]:
198198
"""
199199
Rollout processor for tau bench environments.
200200
201-
202201
This processor starts an MCP server, creates tau bench environments, and runs rollouts
203-
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.
204-
202+
using the eval_protocol framework, yielding results as they complete.
205203
206204
Args:
207205
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
208206
config: RolloutProcessorConfig with model and other parameters
209207
210-
211208
Returns:
212-
List of EvaluationRow objects with completed conversations
209+
AsyncIterator of EvaluationRow objects with completed conversations
213210
"""
214211
if config.server_script_path is None:
215212
raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor")
@@ -233,15 +230,14 @@ async def default_mcp_gym_rollout_processor(
233230
)
234231

235232
# Run rollout with environments and policy
236-
evaluation_rows = await ep.rollout(
233+
async for evaluation_row in ep.rollout(
237234
envs,
238235
policy=policy,
239236
evaluation_rows=rows,
240237
steps=config.steps,
241238
max_concurrent_rollouts=config.max_concurrent_rollouts,
242-
)
243-
244-
return evaluation_rows
239+
):
240+
yield evaluation_row
245241

246242
finally:
247243
# Always clean up the server

0 commit comments

Comments
 (0)