|
| 1 | +from typing import List |
| 2 | +from unittest.mock import Mock, patch |
| 3 | + |
| 4 | +import eval_protocol.dataset_logger as dataset_logger |
| 5 | +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
| 6 | +from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore |
| 7 | +from eval_protocol.models import EvaluationRow |
| 8 | +from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor |
| 9 | +from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row |
| 10 | + |
| 11 | + |
| 12 | +async def test_ensure_logging(monkeypatch): |
| 13 | + """ |
| 14 | + Ensure that default SQLITE logger gets called by mocking the storage and checking that the storage is called. |
| 15 | + """ |
| 16 | + from eval_protocol.pytest.evaluation_test import evaluation_test |
| 17 | + |
| 18 | + # Mock the SqliteEvaluationRowStore to track calls |
| 19 | + mock_store = Mock(spec=SqliteEvaluationRowStore) |
| 20 | + mock_store.upsert_row = Mock() |
| 21 | + mock_store.read_rows = Mock(return_value=[]) |
| 22 | + mock_store.db_path = "/tmp/test.db" |
| 23 | + |
| 24 | + # Create a custom logger that uses our mocked store |
| 25 | + class MockSqliteLogger(DatasetLogger): |
| 26 | + def __init__(self, store: SqliteEvaluationRowStore): |
| 27 | + self._store = store |
| 28 | + |
| 29 | + def log(self, row: EvaluationRow) -> None: |
| 30 | + data = row.model_dump(exclude_none=True, mode="json") |
| 31 | + self._store.upsert_row(data=data) |
| 32 | + |
| 33 | + def read(self, rollout_id=None) -> List[EvaluationRow]: |
| 34 | + results = self._store.read_rows(rollout_id=rollout_id) |
| 35 | + return [EvaluationRow(**data) for data in results] |
| 36 | + |
| 37 | + mock_logger = MockSqliteLogger(mock_store) |
| 38 | + |
| 39 | + @evaluation_test( |
| 40 | + input_dataset=[ |
| 41 | + "tests/pytest/data/markdown_dataset.jsonl", |
| 42 | + ], |
| 43 | + completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], |
| 44 | + dataset_adapter=markdown_dataset_to_evaluation_row, |
| 45 | + rollout_processor=default_no_op_rollout_processor, |
| 46 | + mode="pointwise", |
| 47 | + combine_datasets=False, |
| 48 | + num_runs=2, |
| 49 | + logger=mock_logger, # Use our mocked logger |
| 50 | + ) |
| 51 | + def eval_fn(row: EvaluationRow) -> EvaluationRow: |
| 52 | + return row |
| 53 | + |
| 54 | + await eval_fn( |
| 55 | + dataset_path=["tests/pytest/data/markdown_dataset.jsonl"], |
| 56 | + completion_params={"temperature": 0.0, "model": "dummy/local-model"}, |
| 57 | + ) |
| 58 | + |
| 59 | + # Verify that the store's upsert_row method was called |
| 60 | + assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called" |
| 61 | + |
| 62 | + # Check that it was called multiple times (once for each row) |
| 63 | + call_count = mock_store.upsert_row.call_count |
| 64 | + assert call_count > 0, f"Expected upsert_row to be called at least once, but it was called {call_count} times" |
| 65 | + |
| 66 | + # Verify the calls were made with proper data structure |
| 67 | + for call in mock_store.upsert_row.call_args_list: |
| 68 | + args, kwargs = call |
| 69 | + data = args[0] if args else kwargs.get("data") |
| 70 | + assert data is not None, "upsert_row should be called with data parameter" |
| 71 | + assert isinstance(data, dict), "data should be a dictionary" |
| 72 | + assert "execution_metadata" in data, "data should contain execution_metadata" |
| 73 | + assert "rollout_id" in data["execution_metadata"], "data should contain rollout_id in execution_metadata" |
0 commit comments