Skip to content

Commit 4f8c8b8

Browse files
author
Dylan Huang
committed
fix tests
1 parent 90933d3 commit 4f8c8b8

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

eval_protocol/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ class InputMetadata(BaseModel):
198198
model_config = ConfigDict(extra="allow")
199199

200200
row_id: Optional[str] = Field(default_factory=generate_id, description="Unique string to ID the row")
201-
completion_params: CompletionParams = Field(..., description="Completion endpoint parameters used")
201+
completion_params: CompletionParams = Field(
202+
default_factory=dict, description="Completion endpoint parameters used"
203+
)
202204
dataset_info: Optional[Dict[str, Any]] = Field(
203205
None, description="Dataset row details: seed, system_prompt, environment_context, etc"
204206
)

tests/test_logs_server.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
class TestWebSocketManager:
3030
"""Test WebSocketManager class."""
3131

32+
input_metadata = InputMetadata(row_id="test-123", completion_params={"model": "gpt-4o"})
33+
3234
def test_initialization(self):
3335
"""Test WebSocketManager initialization."""
3436
manager = WebSocketManager()
@@ -64,7 +66,7 @@ async def test_connect_sends_initial_logs(self):
6466
mock_logs = [
6567
EvaluationRow(
6668
messages=[Message(role="user", content="test")],
67-
input_metadata=InputMetadata(row_id="test-123"),
69+
input_metadata=self.input_metadata,
6870
)
6971
]
7072

@@ -82,7 +84,7 @@ def test_broadcast_row_upserted(self):
8284
manager = WebSocketManager()
8385
test_row = EvaluationRow(
8486
messages=[Message(role="user", content="test")],
85-
input_metadata=InputMetadata(row_id="test-123"),
87+
input_metadata=self.input_metadata,
8688
)
8789

8890
# Test that broadcast doesn't fail when no connections
@@ -96,6 +98,7 @@ def test_broadcast_row_upserted(self):
9698
assert "row" in data
9799
assert data["row"]["messages"][0]["content"] == "test"
98100
assert data["row"]["input_metadata"]["row_id"] == "test-123"
101+
assert data["row"]["input_metadata"]["completion_params"]["model"] == "gpt-4o"
99102

100103
@pytest.mark.asyncio
101104
async def test_broadcast_loop(self):
@@ -221,7 +224,7 @@ def test_should_update_status_stopped_process(self, mock_process):
221224

222225
test_row = EvaluationRow(
223226
messages=[Message(role="user", content="test")],
224-
input_metadata=InputMetadata(row_id="test-123"),
227+
input_metadata=self.input_metadata,
225228
eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="running"),
226229
pid=12345,
227230
)
@@ -240,7 +243,7 @@ def test_should_update_status_no_such_process(self, mock_process):
240243

241244
test_row = EvaluationRow(
242245
messages=[Message(role="user", content="test")],
243-
input_metadata=InputMetadata(row_id="test-123"),
246+
input_metadata=self.input_metadata,
244247
eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="running"),
245248
pid=999,
246249
)
@@ -255,7 +258,7 @@ def test_should_update_status_not_running(self):
255258

256259
test_row = EvaluationRow(
257260
messages=[Message(role="user", content="test")],
258-
input_metadata=InputMetadata(row_id="test-123"),
261+
input_metadata=self.input_metadata,
259262
eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="finished"),
260263
pid=12345,
261264
)
@@ -270,7 +273,7 @@ def test_should_update_status_no_pid(self):
270273

271274
test_row = EvaluationRow(
272275
messages=[Message(role="user", content="test")],
273-
input_metadata=InputMetadata(row_id="test-123"),
276+
input_metadata=self.input_metadata,
274277
eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="running"),
275278
pid=None,
276279
)
@@ -326,7 +329,7 @@ async def test_handle_event(self, temp_build_dir):
326329
# Test handling a log event
327330
test_row = {
328331
"messages": [{"role": "user", "content": "test"}],
329-
"input_metadata": {"row_id": "test-123"},
332+
"input_metadata": self.input_metadata.model_dump(),
330333
}
331334

332335
server._handle_event(LOG_EVENT_TYPE, test_row)
@@ -543,7 +546,7 @@ async def test_websocket_connection_lifecycle(self):
543546
# Test broadcasting without starting the loop
544547
test_row = EvaluationRow(
545548
messages=[Message(role="user", content="test")],
546-
input_metadata=InputMetadata(row_id="test-123"),
549+
input_metadata=self.input_metadata,
547550
)
548551
manager.broadcast_row_upserted(test_row)
549552

@@ -573,7 +576,7 @@ async def test_multiple_websocket_connections(self):
573576
# Test broadcasting to all without starting the loop
574577
test_row = EvaluationRow(
575578
messages=[Message(role="user", content="test")],
576-
input_metadata=InputMetadata(row_id="test-123"),
579+
input_metadata=self.input_metadata,
577580
)
578581
manager.broadcast_row_upserted(test_row)
579582

0 commit comments

Comments
 (0)