-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_pytest_propagate_error.py
More file actions
74 lines (61 loc) · 2.74 KB
/
test_pytest_propagate_error.py
File metadata and controls
74 lines (61 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing_extensions import override
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.default_agent_rollout_processor import AgentRolloutProcessor
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
class TrackingLogger(DatasetLogger):
"""Custom logger that ensures that the final row is in an error state."""
def __init__(self, rollouts: dict[str, EvaluationRow]):
self.rollouts: dict[str, EvaluationRow] = rollouts
@override
def log(self, row: EvaluationRow):
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is None")
self.rollouts[row.execution_metadata.rollout_id] = row
@override
def read(self, row_id: str | None = None) -> list[EvaluationRow]:
return []
async def test_pytest_propagate_error():
"""
Properly propagate errors from rollout processing to eval_metadata.status.
To test this, we use a broken MCP configuration that should fail during the
rollout processing. Then the final eval_metadata.status should be an error.
This way the UI can properly render an error state for the rollout and a
developer can identify and investigate the error.
"""
from eval_protocol.pytest.evaluation_test import evaluation_test
input_messages = [
[
Message(
role="system",
content="You are a helpful assistant that can answer questions about Fireworks.",
),
]
]
completion_params_list = [
{"model": "dummy/local-model"},
]
rollouts: dict[str, EvaluationRow] = {}
logger = TrackingLogger(rollouts)
@evaluation_test(
input_messages=[input_messages],
completion_params=completion_params_list,
rollout_processor=AgentRolloutProcessor(),
mode="pointwise",
num_runs=5,
mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config_broken.json",
logger=logger,
)
def eval_fn(row: EvaluationRow) -> EvaluationRow:
return row
# Manually invoke all parameter combinations within a single test
for params in completion_params_list:
await eval_fn(input_messages=input_messages, completion_params=params) # pyright: ignore[reportCallIssue]
# assert that the status of eval_metadata.status is "error"
assert len(rollouts) == 5
for row in rollouts.values():
if row.eval_metadata is None:
raise ValueError("Row has no eval_metadata")
if row.eval_metadata.status is None:
raise ValueError("Eval metadata has no status")
assert row.eval_metadata.status.is_error()
assert any("unhandled errors in a TaskGroup" in row.rollout_status.message for row in rollouts.values())