Skip to content

Commit 3c27801

Browse files
author
Dylan Huang
authored
fix ensure_logging test (#78)
1 parent 1f2dadc commit 3c27801

2 files changed

Lines changed: 71 additions & 70 deletions

File tree

eval_protocol/dataset_logger/__init__.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,31 @@
33
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
44
from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter
55

6+
67
# Allow disabling sqlite logger to avoid environment-specific constraints in simple CLI runs.
7-
if os.getenv("DISABLE_EP_SQLITE_LOG", "0").strip() != "1":
8-
default_logger = SqliteDatasetLoggerAdapter()
9-
else:
8+
def _get_default_logger():
9+
if os.getenv("DISABLE_EP_SQLITE_LOG", "0").strip() != "1":
10+
return SqliteDatasetLoggerAdapter()
11+
else:
12+
13+
class _NoOpLogger(DatasetLogger):
14+
def log(self, row):
15+
return None
16+
17+
def read(self, rollout_id=None):
18+
return []
19+
20+
return _NoOpLogger()
21+
22+
23+
# Lazy property that creates the logger only when accessed
24+
class _LazyLogger(DatasetLogger):
25+
26+
def log(self, row):
27+
return _get_default_logger().log(row)
1028

11-
class _NoOpLogger(DatasetLogger):
12-
def log(self, row):
13-
return None
29+
def read(self, rollout_id=None):
30+
return _get_default_logger().read(rollout_id)
1431

15-
def read(self, rollout_id=None):
16-
return []
1732

18-
default_logger = _NoOpLogger()
33+
default_logger: DatasetLogger = _LazyLogger()
Lines changed: 47 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,59 @@
1-
from typing import List
1+
import os
22
from unittest.mock import Mock, patch
33

4-
import eval_protocol.dataset_logger as dataset_logger
5-
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
6-
from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore
7-
from eval_protocol.models import EvaluationRow
8-
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
9-
from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row
10-
114

125
async def test_ensure_logging(monkeypatch):
136
"""
147
Ensure that default SQLITE logger gets called by mocking the storage and checking that the storage is called.
158
"""
16-
from eval_protocol.pytest.evaluation_test import evaluation_test
17-
189
# Mock the SqliteEvaluationRowStore to track calls
19-
mock_store = Mock(spec=SqliteEvaluationRowStore)
10+
mock_store = Mock()
2011
mock_store.upsert_row = Mock()
2112
mock_store.read_rows = Mock(return_value=[])
2213
mock_store.db_path = "/tmp/test.db"
2314

24-
# Create a custom logger that uses our mocked store
25-
class MockSqliteLogger(DatasetLogger):
26-
def __init__(self, store: SqliteEvaluationRowStore):
27-
self._store = store
28-
29-
def log(self, row: EvaluationRow) -> None:
30-
data = row.model_dump(exclude_none=True, mode="json")
31-
self._store.upsert_row(data=data)
32-
33-
def read(self, rollout_id=None) -> List[EvaluationRow]:
34-
results = self._store.read_rows(rollout_id=rollout_id)
35-
return [EvaluationRow(**data) for data in results]
36-
37-
mock_logger = MockSqliteLogger(mock_store)
38-
39-
@evaluation_test(
40-
input_dataset=[
41-
"tests/pytest/data/markdown_dataset.jsonl",
42-
],
43-
completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}],
44-
dataset_adapter=markdown_dataset_to_evaluation_row,
45-
rollout_processor=default_no_op_rollout_processor,
46-
mode="pointwise",
47-
combine_datasets=False,
48-
num_runs=2,
49-
logger=mock_logger, # Use our mocked logger
50-
)
51-
def eval_fn(row: EvaluationRow) -> EvaluationRow:
52-
return row
53-
54-
await eval_fn(
55-
dataset_path=["tests/pytest/data/markdown_dataset.jsonl"],
56-
completion_params={"temperature": 0.0, "model": "dummy/local-model"},
57-
)
58-
59-
# Verify that the store's upsert_row method was called
60-
assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called"
61-
62-
# Check that it was called multiple times (once for each row)
63-
call_count = mock_store.upsert_row.call_count
64-
assert call_count > 0, f"Expected upsert_row to be called at least once, but it was called {call_count} times"
65-
66-
# Verify the calls were made with proper data structure
67-
for call in mock_store.upsert_row.call_args_list:
68-
args, kwargs = call
69-
data = args[0] if args else kwargs.get("data")
70-
assert data is not None, "upsert_row should be called with data parameter"
71-
assert isinstance(data, dict), "data should be a dictionary"
72-
assert "execution_metadata" in data, "data should contain execution_metadata"
73-
assert "rollout_id" in data["execution_metadata"], "data should contain rollout_id in execution_metadata"
15+
# Mock the SqliteEvaluationRowStore constructor so that when SqliteDatasetLoggerAdapter
16+
# creates its store, it gets our mock instead
17+
with patch(
18+
"eval_protocol.dataset_logger.sqlite_dataset_logger_adapter.SqliteEvaluationRowStore", return_value=mock_store
19+
):
20+
from eval_protocol.models import EvaluationRow
21+
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
22+
from eval_protocol.pytest.evaluation_test import evaluation_test
23+
from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row
24+
25+
@evaluation_test(
26+
input_dataset=[
27+
"tests/pytest/data/markdown_dataset.jsonl",
28+
],
29+
completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}],
30+
dataset_adapter=markdown_dataset_to_evaluation_row,
31+
rollout_processor=default_no_op_rollout_processor,
32+
mode="pointwise",
33+
combine_datasets=False,
34+
num_runs=2,
35+
# Don't pass logger parameter - let it use the default_logger (which we've replaced)
36+
)
37+
def eval_fn(row: EvaluationRow) -> EvaluationRow:
38+
return row
39+
40+
await eval_fn(
41+
dataset_path=["tests/pytest/data/markdown_dataset.jsonl"],
42+
completion_params={"temperature": 0.0, "model": "dummy/local-model"},
43+
)
44+
45+
# Verify that the store's upsert_row method was called
46+
assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called"
47+
48+
# Check that it was called multiple times (once for each row)
49+
call_count = mock_store.upsert_row.call_count
50+
assert call_count > 0, f"Expected upsert_row to be called at least once, but it was called {call_count} times"
51+
52+
# Verify the calls were made with proper data structure
53+
for call in mock_store.upsert_row.call_args_list:
54+
args, kwargs = call
55+
data = args[0] if args else kwargs.get("data")
56+
assert data is not None, "upsert_row should be called with data parameter"
57+
assert isinstance(data, dict), "data should be a dictionary"
58+
assert "execution_metadata" in data, "data should contain execution_metadata"
59+
assert "rollout_id" in data["execution_metadata"], "data should contain rollout_id in execution_metadata"

0 commit comments

Comments
 (0)