-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_pytest_function_calling.py
More file actions
33 lines (29 loc) · 1.23 KB
/
test_pytest_function_calling.py
File metadata and controls
33 lines (29 loc) · 1.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import json
from typing import Any, Dict, List
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
from eval_protocol.rewards.function_calling import exact_tool_match_reward
def function_calling_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
"""
Convert a function calling row to an evaluation row.
"""
dataset: List[EvaluationRow] = []
for row in rows:
dataset.append(
EvaluationRow(messages=row["messages"][:1], tools=row["tools"], ground_truth=row["ground_truth"])
)
return dataset
@evaluation_test(
input_dataset=["tests/pytest/data/function_calling.jsonl"],
model=["accounts/fireworks/models/kimi-k2-instruct"],
mode="pointwise",
dataset_adapter=function_calling_to_evaluation_row,
rollout_processor=default_single_turn_rollout_processor,
)
async def test_pytest_function_calling(row: EvaluationRow) -> EvaluationRow:
"""Run pointwise evaluation on sample dataset using pytest interface."""
ground_truth = json.loads(row.ground_truth)
result = exact_tool_match_reward(row.messages, ground_truth)
row.evaluation_result = result
print(result)
return row