Skip to content

Commit 7edd65e

Browse files
xzrderekbenjibc
andauthored
Pipelining (#46)
* Add AIME2025, GPQA, HealthBench evaluation_test suites; unify row-limiting via pytest flag; clean up examples * evaluation with aggregated scores * WIP: vibe coded as an mvp * merge * remove * updated logger * formatting * formatting * fixing tests --------- Co-authored-by: benjibc <youfychenbc5000@gmail.com>
1 parent 7317549 commit 7edd65e

13 files changed

Lines changed: 271 additions & 234 deletions

eval_protocol/mcp/execution/manager.py

Lines changed: 62 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import threading
1313
import time
1414
from dataclasses import asdict
15-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
15+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union
1616

1717
import anyio
1818
from openai.types import CompletionUsage
@@ -43,7 +43,7 @@ async def execute_rollouts(
4343
openai_format_log_file: Optional[str] = None,
4444
max_concurrent_rollouts: int = 8,
4545
evaluation_rows: Optional[List[EvaluationRow]] = None,
46-
) -> List[EvaluationRow]:
46+
) -> AsyncIterator[EvaluationRow]:
4747
"""
4848
Execute general rollouts using tool calling interface with automatic record/playback.
4949
@@ -66,7 +66,7 @@ async def execute_rollouts(
6666
- Set and file exists: Playback mode (uses recorded data)
6767
6868
Returns:
69-
List of EvaluationRow objects with unified evaluation data format
69+
AsyncIterator of EvaluationRow objects with unified evaluation data format
7070
"""
7171
start_time = time.time()
7272

@@ -92,96 +92,77 @@ async def execute_rollouts(
9292

9393
logger.info(f"🧵 Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...")
9494

95-
results = {}
95+
if evaluation_rows is None:
96+
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in range(envs.n)]
97+
98+
shared_tool_schema = envs.tool_schemas
9699

97100
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
98101

99102
async def _execute_with_semaphore(idx):
100103
async with semaphore:
101-
result = await self._execute_rollout(
104+
trajectory = await self._execute_rollout(
102105
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
103106
)
104107

105-
return result
106-
107-
tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
108-
# exceptions will be try catched inside single _execute_rollout
109-
trajectories = await asyncio.gather(*tasks)
110-
111-
# Calculate durations
112-
total_duration = time.time() - start_time
113-
for trajectory in trajectories:
114-
trajectory.duration = total_duration
115-
116-
shared_tool_schema = envs.tool_schemas
117-
118-
# Enhanced reporting with control plane info
119-
successful = sum(1 for traj in trajectories if traj.total_reward > 0)
120-
terminated_by_control_plane = sum(
121-
1
122-
for traj in trajectories
123-
if traj.control_plane_summary.get("termination_reason") == "control_plane_signal"
124-
)
108+
# Convert trajectory to EvaluationRow immediately
109+
evaluation_row = evaluation_rows[idx]
110+
111+
# Handle multimodal content by extracting text from complex content structures
112+
messages = []
113+
for msg in trajectory.conversation_history:
114+
# Create a copy to avoid modifying the original
115+
msg_dict = dict(msg)
116+
117+
# Handle multimodal content (list of content blocks) by extracting text
118+
if isinstance(msg_dict.get("content"), list):
119+
text_content = None
120+
for content_block in msg_dict["content"]:
121+
if isinstance(content_block, dict) and content_block.get("type") == "text":
122+
text_content = content_block.get("text")
123+
break
124+
msg_dict["content"] = text_content or ""
125+
126+
messages.append(Message.model_validate(msg_dict))
127+
128+
evaluation_row.messages = messages
129+
evaluation_row.tools = shared_tool_schema
130+
evaluation_row.usage = CompletionUsage(**trajectory.usage)
131+
evaluation_row.input_metadata.completion_params = CompletionParams(
132+
model=policy.model_id,
133+
temperature=getattr(policy, "temperature", None),
134+
max_tokens=getattr(policy, "max_tokens", None),
135+
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
136+
)
125137

126-
logger.info(f"📊 Rollout complete: {successful}/{len(trajectories)} reached goal")
127-
logger.info(f"🎛️ Control plane terminations: {terminated_by_control_plane}/{len(trajectories)}")
128-
logger.info(f"⏱️ Total duration: {total_duration:.2f}s")
129-
logger.info(f"🧵 Used {max_concurrent_rollouts} concurrent threads")
138+
if trajectory.terminated:
139+
if trajectory.termination_reason == TerminationReason.ERROR:
140+
evaluation_row.rollout_status.status = "error"
141+
evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get(
142+
"error_message", None
143+
)
144+
else:
145+
evaluation_row.rollout_status.status = "finished"
146+
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
147+
else:
148+
evaluation_row.rollout_status.status = "running"
130149

131-
# Print log file locations if created
132-
if openai_format_log_file:
133-
logger.info(f"💬 OpenAI format log: {openai_format_log_file}")
134-
if recording_mode:
135-
logger.info(f"📝 Recorded trajectory: {playback_file}")
136-
# Add note about control plane separation
137-
logger.info(f"🎛️ Trajectories include control plane separation")
150+
return evaluation_row
138151

139-
# Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility.
140-
if evaluation_rows is None:
141-
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories]
142-
143-
for idx, trajectory in enumerate(trajectories):
144-
# Handle multimodal content by extracting text from complex content structures
145-
messages = []
146-
for msg in trajectory.conversation_history:
147-
# Create a copy to avoid modifying the original
148-
msg_dict = dict(msg)
149-
150-
# Handle multimodal content (list of content blocks) by extracting text
151-
if isinstance(msg_dict.get("content"), list):
152-
text_content = None
153-
for content_block in msg_dict["content"]:
154-
if isinstance(content_block, dict) and content_block.get("type") == "text":
155-
text_content = content_block.get("text")
156-
break
157-
msg_dict["content"] = text_content or ""
158-
159-
messages.append(Message.model_validate(msg_dict))
160-
161-
evaluation_rows[idx].messages = messages
162-
# evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
163-
# evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
164-
evaluation_rows[idx].tools = shared_tool_schema
165-
evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage)
166-
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
167-
model=policy.model_id,
168-
temperature=getattr(policy, "temperature", None),
169-
max_tokens=getattr(policy, "max_tokens", None),
170-
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
171-
)
172-
if trajectory.terminated:
173-
if trajectory.termination_reason == TerminationReason.ERROR:
174-
evaluation_rows[idx].rollout_status.status = "error"
175-
evaluation_rows[idx].rollout_status.termination_reason = trajectory.control_plane_summary.get(
176-
"error_message", None
177-
)
178-
else:
179-
evaluation_rows[idx].rollout_status.status = "finished"
180-
evaluation_rows[idx].rollout_status.termination_reason = trajectory.termination_reason
181-
else:
182-
evaluation_rows[idx].rollout_status.status = "running"
152+
# Create all tasks
153+
tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)]
183154

184-
return evaluation_rows
155+
# Yield results as they complete (note that they're not necessarily in original order)
156+
try:
157+
for task in asyncio.as_completed(tasks):
158+
try:
159+
yield await task
160+
except Exception:
161+
logger.exception("Error processing rollout")
162+
finally:
163+
for t in tasks:
164+
t.cancel()
165+
await asyncio.gather(*tasks, return_exceptions=True)
185166

186167
async def _execute_rollout(
187168
self,

eval_protocol/mcp_env.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,20 @@
4141
"""
4242

4343
import asyncio
44+
import hashlib
45+
import json
4446

4547
# For legacy compatibility - import the facade functions
4648
import logging
4749
import random
48-
from typing import Any, Callable, Dict, List, Optional, Union
50+
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
4951

5052
# Import all functionality from the new modular components
5153
from .mcp.execution.manager import ExecutionManager
5254
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LiteLLMPolicy, LLMBasePolicy, OpenAIPolicy
5355
from .mcp.session.manager import GeneralMCPVectorEnv
5456
from .models import EvaluationRow
5557
from .types import DatasetRow, MCPSession, MCPToolCall
56-
import asyncio
57-
import hashlib
58-
import json
5958

6059
logger = logging.getLogger(__name__)
6160

@@ -247,7 +246,7 @@ async def rollout(
247246
steps: int = 512,
248247
openai_format_log_file: Optional[str] = None,
249248
max_concurrent_rollouts: int = 8,
250-
) -> List[EvaluationRow]:
249+
) -> AsyncIterator[EvaluationRow]:
251250
"""
252251
Execute general rollouts using tool calling interface with automatic record/playback.
253252
@@ -307,9 +306,10 @@ async def rollout(
307306
# Use the new ExecutionManager for execution
308307
execution_manager = ExecutionManager()
309308

310-
return await execution_manager.execute_rollouts(
309+
async for evaluation_row in execution_manager.execute_rollouts(
311310
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
312-
)
311+
):
312+
yield evaluation_row
313313

314314

315315
async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:

eval_protocol/pytest/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from .default_agent_rollout_processor import default_agent_rollout_processor
2+
from .default_dataset_adapter import default_dataset_adapter
3+
from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
24
from .default_no_op_rollout_process import default_no_op_rollout_processor
35
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
46
from .evaluation_test import evaluation_test
57
from .types import RolloutProcessor, RolloutProcessorConfig
6-
from .default_dataset_adapter import default_dataset_adapter
78

89
__all__ = [
910
"default_agent_rollout_processor",
11+
"default_mcp_gym_rollout_processor",
1012
"default_no_op_rollout_processor",
1113
"default_single_turn_rollout_processor",
1214
"default_dataset_adapter",

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
import json
3+
import logging
34
import os
4-
from typing import Any, List, Optional, Union
5+
from typing import Any, AsyncIterator, List, Optional, Union
56

67
from mcp.types import CallToolResult, TextContent
78
from openai import NOT_GIVEN, NotGiven
@@ -14,6 +15,8 @@
1415
from eval_protocol.models import EvaluationRow, Message
1516
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class Agent:
1922
"""
@@ -114,13 +117,42 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex
114117

115118
async def default_agent_rollout_processor(
116119
rows: List[EvaluationRow], config: RolloutProcessorConfig
117-
) -> List[EvaluationRow]:
118-
dataset: Dataset = []
119-
for row in rows:
120+
) -> AsyncIterator[EvaluationRow]:
121+
"""Process agent rollouts with bounded concurrency and yield as they complete."""
122+
123+
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
124+
semaphore = asyncio.Semaphore(max_concurrent)
125+
126+
async def process_row(row: EvaluationRow) -> EvaluationRow:
127+
"""Process a single row with agent rollout."""
120128
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
121-
await agent.setup()
122-
await agent.call_agent()
123-
dataset.append(agent.evaluation_row)
124-
if agent.mcp_client:
125-
await agent.mcp_client.cleanup()
126-
return dataset
129+
try:
130+
await agent.setup()
131+
await agent.call_agent()
132+
return agent.evaluation_row
133+
finally:
134+
if agent.mcp_client:
135+
await agent.mcp_client.cleanup()
136+
137+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
138+
async with semaphore:
139+
try:
140+
return await process_row(r)
141+
except Exception as e:
142+
logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}")
143+
return r
144+
145+
# Create all tasks
146+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
147+
148+
# Yield results as they complete (note that they're not necessarily in original order)
149+
try:
150+
for task in asyncio.as_completed(tasks):
151+
try:
152+
yield await task
153+
except Exception:
154+
logger.exception("Error processing row")
155+
finally:
156+
for t in tasks:
157+
t.cancel()
158+
await asyncio.gather(*tasks, return_exceptions=True)

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import subprocess
77
import time
88
from pathlib import Path
9-
from typing import List, Optional
9+
from typing import AsyncIterator, List, Optional
1010

1111
import eval_protocol as ep
1212
from eval_protocol.models import EvaluationRow, Message
@@ -194,22 +194,19 @@ def __exit__(self, exc_type, exc_val, exc_tb):
194194

195195
async def default_mcp_gym_rollout_processor(
196196
rows: List[EvaluationRow], config: RolloutProcessorConfig
197-
) -> List[EvaluationRow]:
197+
) -> AsyncIterator[EvaluationRow]:
198198
"""
199199
Rollout processor for tau bench environments.
200200
201-
202201
This processor starts an MCP server, creates tau bench environments, and runs rollouts
203-
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.
204-
202+
using the eval_protocol framework, yielding results as they complete.
205203
206204
Args:
207205
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
208206
config: RolloutProcessorConfig with model and other parameters
209207
210-
211208
Returns:
212-
List of EvaluationRow objects with completed conversations
209+
AsyncIterator of EvaluationRow objects with completed conversations
213210
"""
214211
if config.server_script_path is None:
215212
raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor")
@@ -233,15 +230,14 @@ async def default_mcp_gym_rollout_processor(
233230
)
234231

235232
# Run rollout with environments and policy
236-
evaluation_rows = await ep.rollout(
233+
async for evaluation_row in ep.rollout(
237234
envs,
238235
policy=policy,
239236
evaluation_rows=rows,
240237
steps=config.steps,
241238
max_concurrent_rollouts=config.max_concurrent_rollouts,
242-
)
243-
244-
return evaluation_rows
239+
):
240+
yield evaluation_row
245241

246242
finally:
247243
# Always clean up the server
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from typing import List
1+
from typing import AsyncIterator, List
22

33
from eval_protocol.models import EvaluationRow
44
from eval_protocol.pytest.types import RolloutProcessorConfig
55

66

7-
def default_no_op_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]:
7+
async def default_no_op_rollout_processor(
8+
rows: List[EvaluationRow], config: RolloutProcessorConfig
9+
) -> AsyncIterator[EvaluationRow]:
810
"""
911
Simply passes input dataset through to the test function. This can be useful
1012
if you want to run the rollout yourself.
1113
"""
12-
return rows
14+
for row in rows:
15+
yield row

0 commit comments

Comments
 (0)