Skip to content

Commit 3cf966a

Browse files
committed
WIP: vibe coded as an mvp
1 parent 4aa9e5c commit 3cf966a

5 files changed

Lines changed: 136 additions & 50 deletions

File tree

eval_protocol/pytest/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from .default_agent_rollout_processor import default_agent_rollout_processor
2+
from .default_dataset_adapter import default_dataset_adapter
3+
from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
24
from .default_no_op_rollout_process import default_no_op_rollout_processor
35
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
46
from .evaluation_test import evaluation_test
57
from .types import RolloutProcessor, RolloutProcessorConfig
6-
from .default_dataset_adapter import default_dataset_adapter
78

89
__all__ = [
910
"default_agent_rollout_processor",
1011
"default_no_op_rollout_processor",
1112
"default_single_turn_rollout_processor",
13+
"default_mcp_gym_rollout_processor",
1214
"default_dataset_adapter",
1315
"RolloutProcessor",
1416
"RolloutProcessorConfig",

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import asyncio
2-
from typing import List
2+
import logging
3+
import time
4+
from typing import AsyncIterator, List
35

4-
from litellm import acompletion
56
import litellm
7+
from litellm import acompletion
68
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
79

810
from eval_protocol.dataset_logger import default_logger
911
from eval_protocol.models import EvaluationRow, Message
1012
from eval_protocol.pytest.types import RolloutProcessorConfig
1113

14+
logger = logging.getLogger(__name__)
15+
1216

1317
async def default_single_turn_rollout_processor(
1418
rows: List[EvaluationRow], config: RolloutProcessorConfig
15-
) -> List[EvaluationRow]:
19+
) -> AsyncIterator[EvaluationRow]:
1620
"""Generate a single response from any supported model provider using LiteLLM."""
1721

1822
# Explicitly disable LiteLLM caching to avoid reused responses across runs
@@ -70,17 +74,45 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7074

7175
row.messages = messages
7276
default_logger.log(row)
77+
logger.info(f"FINISHED PROCESSING ROW: {row.input_metadata.row_id} at time {time.time()}")
7378
return row
7479

75-
# Process rows with bounded concurrency if configured
80+
# Process rows with bounded concurrency and yield as they complete
7681
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
7782
semaphore = asyncio.Semaphore(max_concurrent)
7883

7984
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
8085
async with semaphore:
8186
return await process_row(r)
8287

83-
tasks = [_sem_wrapper(row) for row in rows]
84-
dataset = list(await asyncio.gather(*tasks))
88+
# Create all tasks
89+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
8590

86-
return dataset
91+
# Yield results as they complete (not in original order)
92+
try:
93+
while tasks:
94+
# Wait for at least one task to complete
95+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
96+
97+
# Yield completed results
98+
for task in done:
99+
try:
100+
result = await task
101+
yield result
102+
except Exception as e:
103+
# Log error but continue processing other tasks
104+
print(f"Error processing row: {e}")
105+
# Could yield an error row or skip
106+
107+
# Update tasks list to only pending tasks
108+
tasks = list(pending)
109+
110+
finally:
111+
# Clean up any remaining tasks
112+
for task in tasks:
113+
if not task.done():
114+
task.cancel()
115+
try:
116+
await task
117+
except asyncio.CancelledError:
118+
pass

eval_protocol/pytest/evaluation_test.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import inspect
2-
import os
1+
import asyncio
32
import copy
3+
import inspect
44
import math
5+
import os
56
import statistics
67
from typing import Any, Callable, Dict, List, Optional
78

@@ -33,7 +34,7 @@
3334
from ..common_utils import load_jsonl
3435

3536

36-
def evaluation_test(
37+
def evaluation_test( # noqa: C901
3738
*,
3839
model: List[ModelParam],
3940
input_messages: Optional[List[InputMessagesParam]] = None,
@@ -221,7 +222,7 @@ def generate_combinations():
221222
# Create wrapper function with exact signature that pytest expects
222223
def create_wrapper_with_signature() -> Callable:
223224
# Create the function body that will be used
224-
def wrapper_body(**kwargs):
225+
async def wrapper_body(**kwargs):
225226
model_name = kwargs["model"]
226227
eval_metadata = None
227228
all_results: List[EvaluationRow] = []
@@ -300,10 +301,14 @@ def wrapper_body(**kwargs):
300301
# Regenerate outputs each run by deep-copying the pristine dataset
301302
# so model responses are not reused across runs.
302303
fresh_rows = [copy.deepcopy(r) for r in data]
303-
input_dataset = execute_function(rollout_processor, rows=fresh_rows, config=config)
304+
305+
# All rollout processors now return AsyncIterator for pipelining
306+
rollout_result = rollout_processor(fresh_rows, config)
307+
304308
if mode == "pointwise":
305-
# Pointwise mode: apply the evaluator function to each row
306-
for row in input_dataset:
309+
# Pointwise mode: true pipelining with concurrent evaluations
310+
async def process_evaluation(row):
311+
"""Process a single evaluation and return the result."""
307312
result = execute_with_params(
308313
test_func,
309314
row=row,
@@ -313,8 +318,25 @@ def wrapper_body(**kwargs):
313318
raise ValueError(
314319
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
315320
)
316-
all_results.append(result)
321+
return result
322+
323+
# Start evaluations as rollouts complete - true pipelining
324+
eval_tasks = []
325+
async for row in rollout_result:
326+
# Start evaluation immediately when rollout completes
327+
eval_task = asyncio.create_task(process_evaluation(row))
328+
eval_tasks.append(eval_task)
329+
330+
# Collect all evaluation results
331+
if eval_tasks:
332+
eval_results = await asyncio.gather(*eval_tasks)
333+
all_results.extend(eval_results)
317334
else:
335+
# Batch mode: collect all results first, then evaluate
336+
input_dataset = []
337+
async for row in rollout_result:
338+
input_dataset.append(row)
339+
318340
# Batch mode: call the test function with the full dataset
319341
results = execute_with_params(
320342
test_func,
@@ -353,8 +375,12 @@ def wrapper_body(**kwargs):
353375
sample_std = statistics.stdev(scores)
354376
se = sample_std / math.sqrt(n)
355377
margin = 1.96 * se
356-
ci_low = float(max(0.0, (agg_score or 0.0) - margin)) if agg_score is not None else None
357-
ci_high = float(min(1.0, (agg_score or 0.0) + margin)) if agg_score is not None else None
378+
ci_low = (
379+
float(max(0.0, (agg_score or 0.0) - margin)) if agg_score is not None else None
380+
)
381+
ci_high = (
382+
float(min(1.0, (agg_score or 0.0) + margin)) if agg_score is not None else None
383+
)
358384
except Exception:
359385
ci_low = None
360386
ci_high = None
@@ -392,6 +418,7 @@ def wrapper_body(**kwargs):
392418
# Aggregate per-metric mean and 95% CI when available
393419
metrics_summary: Dict[str, Dict[str, float]] = {}
394420
from collections import defaultdict
421+
395422
metric_scores: Dict[str, list] = defaultdict(list)
396423
for r in all_results:
397424
if r.evaluation_result and r.evaluation_result.metrics:
@@ -435,12 +462,16 @@ def wrapper_body(**kwargs):
435462
parts = []
436463
for m_name, entry in metrics_summary.items():
437464
if "ci_low" in entry and "ci_high" in entry:
438-
parts.append(f"{m_name}={entry['mean']:.3f} ci95=[{entry['ci_low']:.3f},{entry['ci_high']:.3f}]")
465+
parts.append(
466+
f"{m_name}={entry['mean']:.3f} ci95=[{entry['ci_low']:.3f},{entry['ci_high']:.3f}]"
467+
)
439468
else:
440469
parts.append(f"{m_name}={entry['mean']:.3f}")
441470
print(f"EP Metrics | " + ", ".join(parts))
442471
if summary_path:
443-
import json, pathlib, time
472+
import json
473+
import pathlib
474+
import time
444475

445476
p = pathlib.Path(summary_path)
446477
p.parent.mkdir(parents=True, exist_ok=True)
@@ -483,6 +514,7 @@ def wrapper_body(**kwargs):
483514
# Create the pytest wrapper
484515
pytest_wrapper = create_wrapper_with_signature()
485516
pytest_wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(pytest_wrapper)
517+
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
486518

487519
def create_dual_mode_wrapper() -> Callable:
488520
"""
@@ -500,17 +532,21 @@ def create_dual_mode_wrapper() -> Callable:
500532
"""
501533
import asyncio
502534

503-
# Check if the test function is async
504-
is_async = asyncio.iscoroutinefunction(test_func)
535+
# Check if the pytest wrapper is async (it should be now)
536+
is_pytest_wrapper_async = asyncio.iscoroutinefunction(pytest_wrapper)
537+
is_test_func_async = asyncio.iscoroutinefunction(test_func)
505538

506-
if is_async:
539+
if is_pytest_wrapper_async:
507540

508541
async def dual_mode_wrapper(*args, **kwargs):
509542
# Check if this is a direct call with the expected signature
510543
if mode == "pointwise":
511544
# For pointwise mode, check if called with a single row argument
512545
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs:
513-
return await test_func(row=args[0])
546+
if is_test_func_async:
547+
return await test_func(row=args[0])
548+
else:
549+
return test_func(row=args[0])
514550
else:
515551
# For batch mode, check if called with rows argument
516552
if (
@@ -519,18 +555,24 @@ async def dual_mode_wrapper(*args, **kwargs):
519555
and all(isinstance(r, EvaluationRow) for r in args[0])
520556
and not kwargs
521557
):
522-
return await test_func(rows=args[0])
558+
if is_test_func_async:
559+
return await test_func(rows=args[0])
560+
else:
561+
return test_func(rows=args[0])
523562
# Also check if called with keyword argument 'rows'
524563
if (
525564
len(args) == 0
526565
and "rows" in kwargs
527566
and isinstance(kwargs["rows"], list)
528567
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"])
529568
):
530-
return await test_func(**kwargs)
569+
if is_test_func_async:
570+
return await test_func(**kwargs)
571+
else:
572+
return test_func(**kwargs)
531573

532574
# If not a direct call, use the pytest wrapper
533-
return pytest_wrapper(*args, **kwargs)
575+
return await pytest_wrapper(*args, **kwargs)
534576

535577
else:
536578

eval_protocol/pytest/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,18 @@ def create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param
8484
"""
8585
from functools import wraps
8686

87-
@wraps(test_func)
88-
def wrapper(**kwargs):
89-
return wrapper_body(**kwargs)
87+
# Check if wrapper_body is async and create appropriate wrapper
88+
if asyncio.iscoroutinefunction(wrapper_body):
89+
90+
@wraps(test_func)
91+
async def wrapper(**kwargs):
92+
return await wrapper_body(**kwargs)
93+
94+
else:
95+
96+
@wraps(test_func)
97+
def wrapper(**kwargs):
98+
return wrapper_body(**kwargs)
9099

91100
parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names]
92101
wrapper.__signature__ = inspect.Signature(parameters)

tests/pytest/test_basic_coding.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
and comparing the output against expected results in a pointwise manner.
66
"""
77

8+
import logging
9+
import time
810
from typing import Any, Dict, List
911

1012
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
1113
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
12-
from eval_protocol.rewards.code_execution import extract_code_blocks, execute_python_code
14+
from eval_protocol.rewards.code_execution import execute_python_code, extract_code_blocks
15+
16+
logger = logging.getLogger(__name__)
1317

1418

1519
def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
@@ -18,8 +22,8 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat
1822
"""
1923
return [
2024
EvaluationRow(
21-
messages=[Message(role="user", content=f"{row['prompt']} Input: {row['input']}")],
22-
ground_truth=row["expected_output"]
25+
messages=[Message(role="user", content=f"{row['prompt']} Input: {row['input']}")],
26+
ground_truth=row["expected_output"],
2327
)
2428
for row in data
2529
]
@@ -38,55 +42,52 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat
3842
def test_coding_code_evaluation(row: EvaluationRow) -> EvaluationRow:
3943
"""
4044
Evaluation function that tests code correctness by executing it locally.
41-
45+
4246
This function:
4347
1. Extracts Python code from the assistant's response
4448
2. Executes the code locally with timeout=10
4549
3. Compares the output to ground_truth
4650
4. Returns a score of 1.0 if output matches, 0.0 otherwise
47-
51+
4852
Args:
4953
row: EvaluationRow containing the conversation messages and expected_output in ground_truth
50-
54+
5155
Returns:
5256
EvaluationRow with the evaluation result
5357
"""
58+
logger.info(f"STARTING TO EVALUATE ROW: {row.input_metadata.row_id} at time {time.time()}")
5459
# Check if we have an assistant response
5560
if len(row.messages) < 2 or row.messages[-1].role != "assistant":
5661
row.evaluation_result = EvaluateResult(score=0.0, reason="No assistant response found")
5762
return row
58-
63+
5964
assistant_content = row.messages[-1].content or ""
6065
expected_output = (row.ground_truth or "").strip()
61-
66+
6267
# Extract Python code blocks
6368
code_blocks = extract_code_blocks(assistant_content, language="python")
6469
if not code_blocks:
6570
row.evaluation_result = EvaluateResult(score=0.0, reason="No Python code block found")
6671
return row
67-
72+
6873
code = code_blocks[0]["code"]
69-
74+
7075
# Execute the code locally
7176
execution_result = execute_python_code(code, timeout=10)
72-
77+
7378
if not execution_result.get("success", False):
7479
error_msg = execution_result.get("error", "Code execution failed")
7580
row.evaluation_result = EvaluateResult(score=0.0, reason=f"Execution error: {error_msg}")
7681
return row
77-
82+
7883
# Compare output with expected
7984
actual_output = (execution_result.get("output", "") or "").strip()
80-
85+
8186
if actual_output == expected_output:
82-
row.evaluation_result = EvaluateResult(
83-
score=1.0,
84-
reason=f"✅ Output matches: '{actual_output}'"
85-
)
87+
row.evaluation_result = EvaluateResult(score=1.0, reason=f"✅ Output matches: '{actual_output}'")
8688
else:
8789
row.evaluation_result = EvaluateResult(
88-
score=0.0,
89-
reason=f"❌ Expected: '{expected_output}', Got: '{actual_output}'"
90+
score=0.0, reason=f"❌ Expected: '{expected_output}', Got: '{actual_output}'"
9091
)
91-
92+
9293
return row

0 commit comments

Comments
 (0)