@@ -239,7 +239,9 @@ def mock_step_side_effect(env_index, tool_call):
239239 policy = MockPolicy (["right" , "down" , "right" ])
240240
241241 # Execute rollout
242- evaluation_rows = await self .execution_manager .execute_rollouts (mock_env , policy , steps = 10 )
242+ evaluation_rows = []
243+ async for row in self .execution_manager .execute_rollouts (mock_env , policy , steps = 10 ):
244+ evaluation_rows .append (row )
243245
244246 # Validate results
245247 assert len (evaluation_rows ) == 1 , "Should have one evaluation row"
@@ -457,7 +459,9 @@ async def test_rollout_handles_control_plane_failure_gracefully(self):
457459
458460 # Execute rollout with control plane failure
459461 policy = MockPolicy (["right" ])
460- evaluation_rows = await self .execution_manager .execute_rollouts (mock_env , policy , steps = 1 )
462+ evaluation_rows = []
463+ async for row in self .execution_manager .execute_rollouts (mock_env , policy , steps = 1 ):
464+ evaluation_rows .append (row )
461465
462466 # Should still work, but without control plane info
463467 assert len (evaluation_rows ) == 1
@@ -500,15 +504,26 @@ async def test_rollout_creates_envs_from_url(self):
500504 mock_make .return_value = mock_env
501505
502506 manager_instance = MockManager .return_value
503- manager_instance .execute_rollouts = AsyncMock (return_value = ["ok" ])
504507
505- result = await ep .rollout (
508+ # Mock execute_rollouts to return an async generator and track calls
509+ call_args = []
510+
511+ async def mock_execute_rollouts (* args , ** kwargs ):
512+ call_args .append ((args , kwargs ))
513+ for item in ["ok" ]:
514+ yield item
515+
516+ manager_instance .execute_rollouts = mock_execute_rollouts
517+
518+ result = []
519+ async for row in ep .rollout (
506520 "http://localhost:1234/mcp/" ,
507521 policy ,
508522 dataset = dataset ,
509523 model_id = "test_model" ,
510524 steps = 5 ,
511- )
525+ ):
526+ result .append (row )
512527
513528 mock_make .assert_called_once_with (
514529 "http://localhost:1234/mcp/" ,
@@ -517,14 +532,12 @@ async def test_rollout_creates_envs_from_url(self):
517532 model_id = "test_model" ,
518533 )
519534
520- manager_instance .execute_rollouts .assert_called_once_with (
521- mock_make .return_value ,
522- policy ,
523- 5 ,
524- None ,
525- 8 ,
526- None ,
527- )
535+ # Verify execute_rollouts was called with correct arguments
536+ assert len (call_args ) == 1 , "execute_rollouts should be called once"
537+ args , kwargs = call_args [0 ]
538+ assert args [0 ] == mock_make .return_value , "First arg should be mock env"
539+ assert args [1 ] == policy , "Second arg should be policy"
540+ assert args [2 ] == 5 , "Third arg should be steps"
528541
529542 assert result == ["ok" ]
530543
0 commit comments