|
| 1 | +from typing_extensions import override |
1 | 2 | import pytest |
2 | 3 | from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
3 | 4 | from eval_protocol.models import EvaluateResult, EvaluationRow, Message |
@@ -48,12 +49,16 @@ class TrackingLogger(DatasetLogger): |
48 | 49 | """Custom logger that ensures that the final row is in an error state.""" |
49 | 50 |
|
50 | 51 | def __init__(self, rollouts: dict[str, EvaluationRow]): |
51 | | - self.rollouts = rollouts |
| 52 | + self.rollouts: dict[str, EvaluationRow] = rollouts |
52 | 53 |
|
| 54 | + @override |
53 | 55 | def log(self, row: EvaluationRow): |
| 56 | + if row.execution_metadata.rollout_id is None: |
| 57 | + raise ValueError("Rollout ID is None") |
54 | 58 | self.rollouts[row.execution_metadata.rollout_id] = row |
55 | 59 |
|
56 | | - def read(self): |
| 60 | + @override |
| 61 | + def read(self, row_id: str | None = None) -> list[EvaluationRow]: |
57 | 62 | return [] |
58 | 63 |
|
59 | 64 | input_messages = [ |
@@ -82,11 +87,13 @@ def read(self): |
82 | 87 | def eval_fn(row: EvaluationRow) -> EvaluationRow: |
83 | 88 | return row |
84 | 89 |
|
85 | | - await eval_fn(input_messages=input_messages, completion_params=completion_params_list[0]) |
| 90 | + await eval_fn(input_messages=input_messages[0], completion_params=completion_params_list[0]) # pyright: ignore[reportCallIssue] |
86 | 91 |
|
87 | 92 | # ensure that the row has tools that were set during AgentRolloutProcessor |
88 | 93 | assert len(rollouts) == 1 |
89 | 94 | row = list(rollouts.values())[0] |
90 | | - assert sorted([tool["function"].name for tool in row.tools]) == sorted( |
| 95 | + if row.tools is None: |
| 96 | + raise ValueError("Row has no tools") |
| 97 | + assert sorted([tool["function"].name for tool in row.tools]) == sorted( # pyright: ignore[reportAny] |
91 | 98 | ["list_servers", "get_channels", "read_messages"] |
92 | 99 | ) |
0 commit comments