Skip to content

Commit edf99ac

Browse files
committed
fixing tests
1 parent a1d6a52 commit edf99ac

File tree

3 files changed

+36
-20
lines changed

3 files changed

+36
-20
lines changed
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from typing import List
1+
from typing import AsyncIterator, List
22

33
from eval_protocol.models import EvaluationRow
44
from eval_protocol.pytest.types import RolloutProcessorConfig
55

66

7-
def default_no_op_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]:
7+
async def default_no_op_rollout_processor(
8+
rows: List[EvaluationRow], config: RolloutProcessorConfig
9+
) -> AsyncIterator[EvaluationRow]:
810
"""
911
Simply passes input dataset through to the test function. This can be useful
1012
if you want to run the rollout yourself.
1113
"""
12-
return rows
14+
for row in rows:
15+
yield row

tests/pytest/test_pytest_ids.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def read(self):
1919
return list(self._rows.values())
2020

2121

22-
def test_evaluation_test_decorator(monkeypatch):
22+
async def test_evaluation_test_decorator(monkeypatch):
2323
from eval_protocol.pytest.evaluation_test import evaluation_test
2424

2525
logger = InMemoryLogger()
@@ -45,13 +45,13 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
4545

4646
# Manually invoke all parameter combinations within a single test
4747
for ds_path in dataset_paths:
48-
eval_fn(model="dummy/local-model", dataset_path=[ds_path])
48+
await eval_fn(model="dummy/local-model", dataset_path=[ds_path])
4949

5050
# Assertions on IDs generated by the decorator logic
5151
assert len(logger.read()) == 38
5252

5353

54-
def test_evaluation_test_decorator_ids_single(monkeypatch):
54+
async def test_evaluation_test_decorator_ids_single(monkeypatch):
5555
in_memory_logger = InMemoryLogger()
5656
unique_run_ids = set()
5757
unique_experiment_ids = set()
@@ -92,7 +92,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
9292
# Manually invoke all parameter combinations within a single test
9393
for ds_path in dataset_paths:
9494
for params in input_params_list:
95-
eval_fn(model="dummy/local-model", dataset_path=[ds_path], input_params=params)
95+
await eval_fn(model="dummy/local-model", dataset_path=[ds_path], input_params=params)
9696

9797
# Assertions on IDs generated by the decorator logic
9898
assert len(unique_invocation_ids) == 1

tests/test_rollout_control_plane_integration.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def mock_step_side_effect(env_index, tool_call):
239239
policy = MockPolicy(["right", "down", "right"])
240240

241241
# Execute rollout
242-
evaluation_rows = await self.execution_manager.execute_rollouts(mock_env, policy, steps=10)
242+
evaluation_rows = []
243+
async for row in self.execution_manager.execute_rollouts(mock_env, policy, steps=10):
244+
evaluation_rows.append(row)
243245

244246
# Validate results
245247
assert len(evaluation_rows) == 1, "Should have one evaluation row"
@@ -457,7 +459,9 @@ async def test_rollout_handles_control_plane_failure_gracefully(self):
457459

458460
# Execute rollout with control plane failure
459461
policy = MockPolicy(["right"])
460-
evaluation_rows = await self.execution_manager.execute_rollouts(mock_env, policy, steps=1)
462+
evaluation_rows = []
463+
async for row in self.execution_manager.execute_rollouts(mock_env, policy, steps=1):
464+
evaluation_rows.append(row)
461465

462466
# Should still work, but without control plane info
463467
assert len(evaluation_rows) == 1
@@ -500,15 +504,26 @@ async def test_rollout_creates_envs_from_url(self):
500504
mock_make.return_value = mock_env
501505

502506
manager_instance = MockManager.return_value
503-
manager_instance.execute_rollouts = AsyncMock(return_value=["ok"])
504507

505-
result = await ep.rollout(
508+
# Mock execute_rollouts to return an async generator and track calls
509+
call_args = []
510+
511+
async def mock_execute_rollouts(*args, **kwargs):
512+
call_args.append((args, kwargs))
513+
for item in ["ok"]:
514+
yield item
515+
516+
manager_instance.execute_rollouts = mock_execute_rollouts
517+
518+
result = []
519+
async for row in ep.rollout(
506520
"http://localhost:1234/mcp/",
507521
policy,
508522
dataset=dataset,
509523
model_id="test_model",
510524
steps=5,
511-
)
525+
):
526+
result.append(row)
512527

513528
mock_make.assert_called_once_with(
514529
"http://localhost:1234/mcp/",
@@ -517,14 +532,12 @@ async def test_rollout_creates_envs_from_url(self):
517532
model_id="test_model",
518533
)
519534

520-
manager_instance.execute_rollouts.assert_called_once_with(
521-
mock_make.return_value,
522-
policy,
523-
5,
524-
None,
525-
8,
526-
None,
527-
)
535+
# Verify execute_rollouts was called with correct arguments
536+
assert len(call_args) == 1, "execute_rollouts should be called once"
537+
args, kwargs = call_args[0]
538+
assert args[0] == mock_make.return_value, "First arg should be mock env"
539+
assert args[1] == policy, "Second arg should be policy"
540+
assert args[2] == 5, "Third arg should be steps"
528541

529542
assert result == ["ok"]
530543

0 commit comments

Comments
 (0)