|
1 | | -from typing import List |
| 1 | +import os |
2 | 2 | from unittest.mock import Mock, patch |
3 | 3 |
|
4 | | -import eval_protocol.dataset_logger as dataset_logger |
5 | | -from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
6 | | -from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore |
7 | | -from eval_protocol.models import EvaluationRow |
8 | | -from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor |
9 | | -from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row |
10 | | - |
11 | 4 |
|
12 | 5 | async def test_ensure_logging(monkeypatch): |
13 | 6 | """ |
14 | 7 | Ensure that default SQLITE logger gets called by mocking the storage and checking that the storage is called. |
15 | 8 | """ |
16 | | - from eval_protocol.pytest.evaluation_test import evaluation_test |
17 | | - |
18 | 9 | # Mock the SqliteEvaluationRowStore to track calls |
19 | | - mock_store = Mock(spec=SqliteEvaluationRowStore) |
| 10 | + mock_store = Mock() |
20 | 11 | mock_store.upsert_row = Mock() |
21 | 12 | mock_store.read_rows = Mock(return_value=[]) |
22 | 13 | mock_store.db_path = "/tmp/test.db" |
23 | 14 |
|
24 | | - # Create a custom logger that uses our mocked store |
25 | | - class MockSqliteLogger(DatasetLogger): |
26 | | - def __init__(self, store: SqliteEvaluationRowStore): |
27 | | - self._store = store |
28 | | - |
29 | | - def log(self, row: EvaluationRow) -> None: |
30 | | - data = row.model_dump(exclude_none=True, mode="json") |
31 | | - self._store.upsert_row(data=data) |
32 | | - |
33 | | - def read(self, rollout_id=None) -> List[EvaluationRow]: |
34 | | - results = self._store.read_rows(rollout_id=rollout_id) |
35 | | - return [EvaluationRow(**data) for data in results] |
36 | | - |
37 | | - mock_logger = MockSqliteLogger(mock_store) |
38 | | - |
39 | | - @evaluation_test( |
40 | | - input_dataset=[ |
41 | | - "tests/pytest/data/markdown_dataset.jsonl", |
42 | | - ], |
43 | | - completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], |
44 | | - dataset_adapter=markdown_dataset_to_evaluation_row, |
45 | | - rollout_processor=default_no_op_rollout_processor, |
46 | | - mode="pointwise", |
47 | | - combine_datasets=False, |
48 | | - num_runs=2, |
49 | | - logger=mock_logger, # Use our mocked logger |
50 | | - ) |
51 | | - def eval_fn(row: EvaluationRow) -> EvaluationRow: |
52 | | - return row |
53 | | - |
54 | | - await eval_fn( |
55 | | - dataset_path=["tests/pytest/data/markdown_dataset.jsonl"], |
56 | | - completion_params={"temperature": 0.0, "model": "dummy/local-model"}, |
57 | | - ) |
58 | | - |
59 | | - # Verify that the store's upsert_row method was called |
60 | | - assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called" |
61 | | - |
62 | | - # Check that it was called multiple times (once for each row) |
63 | | - call_count = mock_store.upsert_row.call_count |
64 | | - assert call_count > 0, f"Expected upsert_row to be called at least once, but it was called {call_count} times" |
65 | | - |
66 | | - # Verify the calls were made with proper data structure |
67 | | - for call in mock_store.upsert_row.call_args_list: |
68 | | - args, kwargs = call |
69 | | - data = args[0] if args else kwargs.get("data") |
70 | | - assert data is not None, "upsert_row should be called with data parameter" |
71 | | - assert isinstance(data, dict), "data should be a dictionary" |
72 | | - assert "execution_metadata" in data, "data should contain execution_metadata" |
73 | | - assert "rollout_id" in data["execution_metadata"], "data should contain rollout_id in execution_metadata" |
| 15 | + # Mock the SqliteEvaluationRowStore constructor so that when SqliteDatasetLoggerAdapter |
| 16 | + # creates its store, it gets our mock instead |
| 17 | + with patch( |
| 18 | + "eval_protocol.dataset_logger.sqlite_dataset_logger_adapter.SqliteEvaluationRowStore", return_value=mock_store |
| 19 | + ): |
| 20 | + from eval_protocol.models import EvaluationRow |
| 21 | + from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor |
| 22 | + from eval_protocol.pytest.evaluation_test import evaluation_test |
| 23 | + from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row |
| 24 | + |
| 25 | + @evaluation_test( |
| 26 | + input_dataset=[ |
| 27 | + "tests/pytest/data/markdown_dataset.jsonl", |
| 28 | + ], |
| 29 | + completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], |
| 30 | + dataset_adapter=markdown_dataset_to_evaluation_row, |
| 31 | + rollout_processor=default_no_op_rollout_processor, |
| 32 | + mode="pointwise", |
| 33 | + combine_datasets=False, |
| 34 | + num_runs=2, |
| 35 | + # Don't pass logger parameter - let it use the default_logger (which we've replaced) |
| 36 | + ) |
| 37 | + def eval_fn(row: EvaluationRow) -> EvaluationRow: |
| 38 | + return row |
| 39 | + |
| 40 | + await eval_fn( |
| 41 | + dataset_path=["tests/pytest/data/markdown_dataset.jsonl"], |
| 42 | + completion_params={"temperature": 0.0, "model": "dummy/local-model"}, |
| 43 | + ) |
| 44 | + |
| 45 | + # Verify that the store's upsert_row method was called |
| 46 | + assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called" |
| 47 | + |
| 48 | + # Check that it was called multiple times (once for each row) |
| 49 | + call_count = mock_store.upsert_row.call_count |
| 50 | + assert call_count > 0, f"Expected upsert_row to be called at least once, but it was called {call_count} times" |
| 51 | + |
| 52 | + # Verify the calls were made with proper data structure |
| 53 | + for call in mock_store.upsert_row.call_args_list: |
| 54 | + args, kwargs = call |
| 55 | + data = args[0] if args else kwargs.get("data") |
| 56 | + assert data is not None, "upsert_row should be called with data parameter" |
| 57 | + assert isinstance(data, dict), "data should be a dictionary" |
| 58 | + assert "execution_metadata" in data, "data should contain execution_metadata" |
| 59 | + assert "rollout_id" in data["execution_metadata"], "data should contain rollout_id in execution_metadata" |
0 commit comments