Skip to content

Commit 062e448

Browse files
committed
Address comments
1 parent 7356361 commit 062e448

13 files changed

+116
-375
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ jobs:
9292
--ignore=tests/pytest/test_frozen_lake.py \
9393
--ignore=tests/pytest/test_lunar_lander.py \
9494
--ignore=tests/pytest/test_tau_bench_airline.py \
95+
--ignore=tests/pytest/test_apps_coding.py \
9596
--ignore=tests/test_tau_bench_airline_smoke.py \
9697
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10
9798

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import csv
23
import io
34
import re
@@ -60,7 +61,7 @@ def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
6061
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
6162

6263

63-
async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]:
64+
def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[asyncio.Task[EvaluationRow]]:
6465
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor."""
6566
processed: List[EvaluationRow] = []
6667
for r in rows:
@@ -72,7 +73,7 @@ async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) ->
7273
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
7374
]
7475
processed.append(r)
75-
return await default_single_turn_rollout_processor(processed, config)
76+
return default_single_turn_rollout_processor(processed, config)
7677

7778

7879
@export_benchmark("gpqa")

eval_protocol/mcp/execution/manager.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ class ExecutionManager:
3535
Manage rollout for MCP environments.
3636
"""
3737

38-
async def execute_rollouts(
38+
def execute_rollouts(
3939
self,
4040
envs: "GeneralMCPVectorEnv",
4141
policy: Union["LLMBasePolicy", Callable],
4242
steps: int = 512,
4343
openai_format_log_file: Optional[str] = None,
4444
max_concurrent_rollouts: int = 8,
4545
evaluation_rows: Optional[List[EvaluationRow]] = None,
46-
) -> AsyncIterator[EvaluationRow]:
46+
) -> List[asyncio.Task[EvaluationRow]]:
4747
"""
4848
Execute general rollouts using tool calling interface with automatic record/playback.
4949
@@ -66,7 +66,7 @@ async def execute_rollouts(
6666
- Set and file exists: Playback mode (uses recorded data)
6767
6868
Returns:
69-
AsyncIterator of EvaluationRow objects with unified evaluation data format
69+
List of asyncio.Task objects for external handling
7070
"""
7171
start_time = time.time()
7272

@@ -151,18 +151,7 @@ async def _execute_with_semaphore(idx):
151151

152152
# Create all tasks
153153
tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)]
154-
155-
# Yield results as they complete (note that they're not necessarily in original order)
156-
try:
157-
for task in asyncio.as_completed(tasks):
158-
try:
159-
yield await task
160-
except Exception:
161-
logger.exception("Error processing rollout")
162-
finally:
163-
for t in tasks:
164-
t.cancel()
165-
await asyncio.gather(*tasks, return_exceptions=True)
154+
return tasks
166155

167156
async def _execute_rollout(
168157
self,

eval_protocol/mcp_env.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def make(
236236
return mcp_envs
237237

238238

239-
async def rollout(
239+
def rollout(
240240
envs: GeneralMCPVectorEnv,
241241
policy: Union[FireworksPolicy, LLMBasePolicy, Callable],
242242
*,
@@ -246,7 +246,7 @@ async def rollout(
246246
steps: int = 512,
247247
openai_format_log_file: Optional[str] = None,
248248
max_concurrent_rollouts: int = 8,
249-
) -> AsyncIterator[EvaluationRow]:
249+
) -> List[asyncio.Task[EvaluationRow]]:
250250
"""
251251
Execute general rollouts using tool calling interface with automatic record/playback.
252252
@@ -274,14 +274,14 @@ async def rollout(
274274
- Set and file exists: Playback mode (uses recorded data)
275275
276276
Returns:
277-
List of EvaluationRow objects
277+
List of asyncio.Task objects for external handling
278278
279279
Example:
280280
# Live mode
281-
evaluation_rows = await ep.rollout(envs, policy)
281+
tasks = await ep.rollout(envs, policy)
282282
283283
# Create environments automatically
284-
trajectories = await ep.rollout(
284+
tasks = await ep.rollout(
285285
"http://localhost:8000/mcp/",
286286
policy,
287287
evaluation_rows=my_evaluation_rows,
@@ -290,26 +290,26 @@ async def rollout(
290290
291291
# Recording mode
292292
os.environ["EP_PLAYBACK_FILE"] = "record.jsonl"
293-
evaluation_rows = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl")
293+
tasks = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl")
294294
295295
# Playback mode (after recording file exists)
296-
evaluation_rows = await ep.rollout(envs, policy)
296+
tasks = await ep.rollout(envs, policy)
297297
"""
298298
# Automatically create environments if a base URL is provided
299299
if isinstance(envs, str):
300300
if evaluation_rows is None and dataset is None:
301301
raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL")
302302

303303
auto_model_id = model_id or getattr(policy, "model_id", "unknown")
304-
envs = await make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
304+
envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
305305

306306
# Use the new ExecutionManager for execution
307307
execution_manager = ExecutionManager()
308308

309-
async for evaluation_row in execution_manager.execute_rollouts(
309+
tasks = execution_manager.execute_rollouts(
310310
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
311-
):
312-
yield evaluation_row
311+
)
312+
return tasks
313313

314314

315315
async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
@@ -336,7 +336,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
336336
policy = FireworksPolicy("test-model")
337337

338338
# Run short rollout
339-
evaluation_rows = await rollout(envs, policy=policy, steps=10)
339+
evaluation_rows = rollout(envs, policy=policy, steps=10)
340340

341341
if evaluation_rows and len(evaluation_rows[0].messages) > 1:
342342
results["successful"] += 1

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex
115115
return tool_result.content
116116

117117

118-
async def default_agent_rollout_processor(
118+
def default_agent_rollout_processor(
119119
rows: List[EvaluationRow], config: RolloutProcessorConfig
120-
) -> AsyncIterator[EvaluationRow]:
121-
"""Process agent rollouts with bounded concurrency and yield as they complete."""
120+
) -> List[asyncio.Task[EvaluationRow]]:
121+
"""Create agent rollout tasks and return them for external handling."""
122122

123123
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
124124
semaphore = asyncio.Semaphore(max_concurrent)
@@ -138,24 +138,9 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
138138

139139
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
140140
async with semaphore:
141-
try:
142-
return await process_row(r)
143-
except Exception as e:
144-
r.rollout_status.status = "error"
145-
r.rollout_status.termination_reason = str(e)
146-
return r
147-
148-
# Create all tasks
149-
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
141+
result = await process_row(r)
142+
return result
150143

151-
# Yield results as they complete (note that they're not necessarily in original order)
152-
try:
153-
for task in asyncio.as_completed(tasks):
154-
try:
155-
yield await task
156-
except Exception:
157-
logger.exception("Error processing row")
158-
finally:
159-
for t in tasks:
160-
t.cancel()
161-
await asyncio.gather(*tasks, return_exceptions=True)
144+
# Create and return tasks for external handling
145+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
146+
return tasks

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
194194
return False # Don't suppress exceptions
195195

196196

197-
async def default_mcp_gym_rollout_processor(
197+
def default_mcp_gym_rollout_processor(
198198
rows: List[EvaluationRow], config: RolloutProcessorConfig
199-
) -> AsyncIterator[EvaluationRow]:
199+
) -> List[asyncio.Task[EvaluationRow]]:
200200
"""
201201
Rollout processor for tau bench environments.
202202
203-
This processor starts an MCP server, creates tau bench environments, and runs rollouts
204-
using the eval_protocol framework, yielding results as they complete.
203+
This processor starts an MCP server, creates tau bench environments, and returns rollout tasks
204+
using the eval_protocol framework.
205205
206206
Args:
207207
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
@@ -210,7 +210,7 @@ async def default_mcp_gym_rollout_processor(
210210
- start_server (bool): If True, create fresh server and environments. If False, reuse existing ones. Default: True.
211211
212212
Returns:
213-
AsyncIterator of EvaluationRow objects with completed conversations
213+
List of asyncio.Task objects for external handling
214214
"""
215215
start_server = config.kwargs.get("start_server", True) if config.kwargs else True
216216
if start_server:
@@ -260,15 +260,15 @@ async def default_mcp_gym_rollout_processor(
260260
envs = CURRENT_RUN_STATE["envs"]
261261
policy = CURRENT_RUN_STATE["policy"]
262262

263-
# Run rollout with environments and policy (automatically resets environments)
264-
async for evaluation_row in ep.rollout(
263+
# Get rollout tasks from ep.rollout
264+
tasks = ep.rollout(
265265
envs,
266266
policy=policy,
267267
evaluation_rows=rows,
268268
steps=config.steps,
269269
max_concurrent_rollouts=config.max_concurrent_rollouts,
270-
):
271-
yield evaluation_row
270+
)
271+
return tasks
272272

273273

274274
# Add cleanup method directly to the function object
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1-
from typing import AsyncIterator, List
1+
import asyncio
2+
from typing import List
23

34
from eval_protocol.models import EvaluationRow
45
from eval_protocol.pytest.types import RolloutProcessorConfig
56

67

7-
async def default_no_op_rollout_processor(
8+
def default_no_op_rollout_processor(
89
rows: List[EvaluationRow], config: RolloutProcessorConfig
9-
) -> AsyncIterator[EvaluationRow]:
10+
) -> List[asyncio.Task[EvaluationRow]]:
1011
"""
1112
Simply passes input dataset through to the test function. This can be useful
1213
if you want to run the rollout yourself.
1314
"""
14-
for row in rows:
15-
yield row
15+
16+
async def return_row(row: EvaluationRow) -> EvaluationRow:
17+
return row
18+
19+
# Create tasks that immediately return the rows (no-op)
20+
tasks = [asyncio.create_task(return_row(row)) for row in rows]
21+
return tasks

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
logger = logging.getLogger(__name__)
1616

1717

18-
async def default_single_turn_rollout_processor(
18+
def default_single_turn_rollout_processor(
1919
rows: List[EvaluationRow], config: RolloutProcessorConfig
20-
) -> AsyncIterator[EvaluationRow]:
21-
"""Generate a single response from any supported model provider using LiteLLM."""
20+
) -> List[asyncio.Task[EvaluationRow]]:
21+
"""Generate single turn rollout tasks and return them for external handling."""
2222

2323
# Quiet LiteLLM logs in test runs unless user overrode
2424
try:
@@ -103,30 +103,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
103103
default_logger.log(row)
104104
return row
105105

106-
# Process rows with bounded concurrency and yield as they complete
106+
# Process rows with bounded concurrency
107107
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
108108
semaphore = asyncio.Semaphore(max_concurrent)
109109

110110
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
111111
async with semaphore:
112-
try:
113-
return await process_row(r)
114-
except Exception as e:
115-
r.rollout_status.status = "error"
116-
r.rollout_status.termination_reason = str(e)
117-
return r
118-
119-
# Create all tasks
120-
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
112+
result = await process_row(r)
113+
return result
121114

122-
# Yield results as they complete (note that they're not necessarily in original order)
123-
try:
124-
for task in asyncio.as_completed(tasks):
125-
try:
126-
yield await task
127-
except Exception:
128-
logger.exception("Error processing row")
129-
finally:
130-
for t in tasks:
131-
t.cancel()
132-
await asyncio.gather(*tasks, return_exceptions=True)
115+
# Create and return tasks for external handling
116+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
117+
return tasks

eval_protocol/pytest/evaluation_test.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -304,21 +304,37 @@ async def retry_handler(failed_row: EvaluationRow):
304304
# add kwargs start_server=False to config so we don't start new MCP server
305305
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
306306

307-
retry_call = rollout_processor([failed_row], retry_config)
307+
retry_tasks = rollout_processor([failed_row], retry_config)
308308

309-
retry_result = await anext(retry_call)
310-
if retry_result.rollout_status and retry_result.rollout_status.status == "finished":
309+
try:
310+
retry_result = await retry_tasks[0]
311+
retry_result.rollout_status.status = "finished"
311312
await queue.put(retry_result)
312-
else:
313-
asyncio.create_task(retry_handler(retry_result)) # retry failed, spawn another retry
313+
except Exception as e:
314+
failed_row.rollout_status.status = "error"
315+
failed_row.rollout_status.termination_reason = str(e)
316+
asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry
314317

315318
async def initial_processor():
316319
"""Process initial batch and spawn retries for failures"""
317-
async for initial_row in rollout_processor(fresh_dataset, config):
318-
if initial_row.rollout_status and initial_row.rollout_status.status == "finished":
319-
await queue.put(initial_row) # rollout succeeded, put on queue
320-
else:
321-
asyncio.create_task(retry_handler(initial_row)) # rollout errored, spawn retry task
320+
base_tasks = rollout_processor(fresh_dataset, config)
321+
pending = set(base_tasks)
322+
323+
while pending:
324+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
325+
326+
for task in done:
327+
task_index = base_tasks.index(task)
328+
329+
try:
330+
result = await task
331+
result.rollout_status.status = "finished"
332+
await queue.put(result)
333+
except Exception as e:
334+
failed_row = fresh_dataset[task_index]
335+
failed_row.rollout_status.status = "error"
336+
failed_row.rollout_status.termination_reason = str(e)
337+
asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task
322338

323339
processor_task = asyncio.create_task(initial_processor())
324340

@@ -606,7 +622,7 @@ async def _execute_with_semaphore(row):
606622
for result in all_results:
607623
for r in result:
608624
if r.eval_metadata is not None:
609-
r.eval_metadata.status = "finished"
625+
r.eval_metadata.status = "finished" # TODO: might not be needed
610626
r.eval_metadata.passed = passed
611627
active_logger.log(r)
612628

eval_protocol/pytest/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
Parameter types
33
"""
44

5+
import asyncio
56
from dataclasses import dataclass, field
6-
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional
7+
from typing import Any, Callable, Dict, List, Literal, Optional
78

89
from eval_protocol.dataset_logger import default_logger
910
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
@@ -51,4 +52,4 @@ class RolloutProcessorConfig:
5152
kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor
5253

5354

54-
RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], AsyncIterator[EvaluationRow]]
55+
RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[asyncio.Task[EvaluationRow]]]

0 commit comments

Comments
 (0)