Skip to content

Commit f04e52a

Browse files
author
Dylan Huang
committed
more complex example
1 parent fcff843 commit f04e52a

File tree

7 files changed

+182
-2
lines changed

7 files changed

+182
-2
lines changed

tests/chinook/agent.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,33 @@ def setup_agent(orchestrator_agent_model: Model):
2626
)
2727

2828
@agent.tool(retries=5)
29-
def execute_sql(ctx: RunContext, query: str) -> tuple[any, ...]:
29+
def execute_sql(ctx: RunContext, query: str) -> dict:
3030
try:
3131
cursor.execute(query)
32-
return cursor.fetchall()
32+
# Get column headers from cursor description
33+
columns = [desc[0] for desc in cursor.description] if cursor.description else []
34+
# Get data rows
35+
rows = cursor.fetchall()
36+
37+
if not columns or not rows:
38+
return "No results found."
39+
40+
# Create markdown table
41+
table_lines = []
42+
43+
# Header row
44+
table_lines.append("| " + " | ".join(columns) + " |")
45+
46+
# Separator row
47+
table_lines.append("| " + " | ".join(["---"] * len(columns)) + " |")
48+
49+
# Data rows
50+
for row in rows:
51+
# Convert all values to strings and escape pipes
52+
formatted_row = [str(cell).replace("|", "\\|") if cell is not None else "" for cell in row]
53+
table_lines.append("| " + " | ".join(formatted_row) + " |")
54+
55+
return "\n".join(table_lines)
3356
except Exception as e:
3457
connection.rollback()
3558
raise ModelRetry("Please try again with a different query. Here is the error: " + str(e))

tests/chinook/dataset.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import List
2+
import os
3+
import glob
4+
5+
from eval_protocol.models import EvaluationRow, Message
6+
7+
8+
def collect_dataset() -> List[EvaluationRow]:
9+
"""
10+
Iterate through the dataset folder and create EvaluationRow objects.
11+
12+
For each folder named "task_<n>", reads "task.txt" and "ground_truth.md"
13+
and creates an EvaluationRow where:
14+
- messages contains a user message with the task content
15+
- ground_truth contains the contents of ground_truth.md
16+
"""
17+
dataset_rows = []
18+
dataset_path = os.path.join(os.path.dirname(__file__), "dataset")
19+
20+
# Find all task folders (task_<n>)
21+
task_folders = glob.glob(os.path.join(dataset_path, "task_*"))
22+
23+
for task_folder in sorted(task_folders):
24+
task_name = os.path.basename(task_folder)
25+
26+
# Read task.txt
27+
task_file = os.path.join(task_folder, "task.txt")
28+
if not os.path.exists(task_file):
29+
raise FileNotFoundError(f"Task file not found: {task_file}")
30+
31+
with open(task_file, "r", encoding="utf-8") as f:
32+
task_content = f.read().strip()
33+
34+
# Read ground_truth.md
35+
ground_truth_file = os.path.join(task_folder, "ground_truth.md")
36+
if not os.path.exists(ground_truth_file):
37+
raise FileNotFoundError(f"Ground truth file not found: {ground_truth_file}")
38+
39+
with open(ground_truth_file, "r", encoding="utf-8") as f:
40+
ground_truth_content = f.read().strip()
41+
42+
# Create user message with the task
43+
user_message = Message(role="user", content=task_content)
44+
45+
# Create EvaluationRow
46+
evaluation_row = EvaluationRow(messages=[user_message], ground_truth=ground_truth_content)
47+
48+
dataset_rows.append(evaluation_row)
49+
50+
return dataset_rows
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
| customer_name | favorite_genre | total_invoices | total_spent | spending_rank |
2+
| ------------------ | -------------- | -------------- | ----------- | ------------- |
3+
| Helena Holý | Rock | 7 | 49.62 | 1 |
4+
| Richard Cunningham | Rock | 7 | 47.62 | 2 |
5+
| Luis Rojas | Rock | 7 | 46.62 | 3 |
6+
| Ladislav Kovács | Rock | 7 | 45.62 | 4 |
7+
| Hugh O'Reilly | Rock | 7 | 45.62 | 4 |
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Find the top 5 customers by total spending, including their favorite genre. Show customer name, favorite genre, total invoices, total spent, and spending rank.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
| genre_name | usa_revenue | canada_revenue | germany_revenue | france_revenue | brazil_revenue | total_revenue | total_unique_customers | percentage_of_total |
2+
| ------------------ | ----------- | -------------- | --------------- | -------------- | -------------- | ------------- | ---------------------- | ------------------- |
3+
| Rock | 1526.14 | 989.03 | 577.17 | 525.73 | 691.02 | 4309.09 | 35 | 35.47 |
4+
| Latin | 754.67 | 510.84 | 163.35 | 265.32 | 501.93 | 2196.11 | 33 | 18.08 |
5+
| Metal | 554.45 | 335.61 | 189.09 | 192.06 | 74.25 | 1345.46 | 34 | 11.07 |
6+
| Alternative & Punk | 415.92 | 300.96 | 105.12 | 207.90 | 71.28 | 1101.18 | 29 | 9.06 |
7+
| Jazz | 202.95 | 126.72 | 27.72 | 139.59 | 0 | 496.98 | 19 | 4.09 |
8+
| Blues | 126.72 | 15.84 | 97.02 | 11.88 | 30.69 | 282.15 | 13 | 2.32 |
9+
| TV Shows | 191.70 | 3.98 | 44.73 | 16.86 | 0 | 257.27 | 10 | 2.12 |
10+
| Reggae | 73.26 | 58.41 | 0 | 13.86 | 83.16 | 228.69 | 8 | 1.88 |
11+
| Soundtrack | 45.54 | 0 | 59.40 | 54.45 | 55.44 | 214.83 | 7 | 1.77 |
12+
| Drama | 143.16 | 13.89 | 14.91 | 38.69 | 0 | 210.65 | 7 | 1.73 |
13+
| Classical | 87.20 | 24.75 | 0 | 57.48 | 39.60 | 209.03 | 9 | 1.72 |
14+
| R&B/Soul | 75.28 | 69.30 | 0 | 0 | 29.70 | 174.28 | 10 | 1.43 |
15+
| Alternative | 79.30 | 0 | 0.99 | 67.44 | 0 | 147.73 | 3 | 1.22 |
16+
| Hip Hop/Rap | 6.93 | 57.45 | 0 | 33.72 | 27.72 | 125.82 | 7 | 1.04 |
17+
| Pop | 33.66 | 0 | 13.86 | 27.72 | 36.63 | 111.87 | 7 | 0.92 |
18+
| World | 0 | 83.16 | 0 | 0 | 27.72 | 110.88 | 5 | 0.91 |
19+
| Heavy Metal | 50.49 | 0 | 41.58 | 0 | 0 | 92.07 | 3 | 0.76 |
20+
| Comedy | 90.44 | 0 | 0 | 0 | 0 | 90.44 | 3 | 0.74 |
21+
| Sci Fi & Fantasy | 71.62 | 0 | 0 | 7.96 | 7.96 | 87.54 | 4 | 0.72 |
22+
| Bossa Nova | 43.56 | 28.71 | 0 | 13.86 | 0 | 86.13 | 7 | 0.71 |
23+
| Rock And Roll | 41.58 | 27.72 | 0 | 13.86 | 0 | 83.16 | 4 | 0.68 |
24+
| Electronica/Dance | 0 | 43.59 | 0 | 33.72 | 0 | 77.31 | 3 | 0.64 |
25+
| Easy Listening | 41.58 | 0 | 27.72 | 0 | 0 | 69.30 | 2 | 0.57 |
26+
| Science Fiction | 10.91 | 0 | 29.82 | 0 | 0 | 40.73 | 2 | 0.34 |
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Create a genre popularity matrix by country. Show genre name, revenue by country (USA, Canada, Germany, France, Brazil), total revenue, unique customers, and revenue percentage of total sales.

tests/chinook/test_pydantic_chinook.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from agent import setup_agent
1010
from pydantic_ai.models.openai import OpenAIModel
1111

12+
from tests.chinook.dataset import collect_dataset
13+
1214

1315
@pytest.mark.asyncio
1416
@evaluation_test(
@@ -75,3 +77,73 @@ class Response(BaseModel):
7577
reason=result.output.reason,
7678
)
7779
return row
80+
81+
82+
@pytest.mark.asyncio
83+
@evaluation_test(
84+
input_rows=collect_dataset(),
85+
completion_params=[
86+
{
87+
"model": {
88+
"orchestrator_agent_model": {
89+
"model": "accounts/fireworks/models/kimi-k2-instruct",
90+
"provider": "fireworks",
91+
}
92+
}
93+
},
94+
],
95+
rollout_processor=PydanticAgentRolloutProcessor(),
96+
rollout_processor_kwargs={"agent": setup_agent},
97+
num_runs=3,
98+
mode="pointwise",
99+
)
100+
async def test_complex_queries(row: EvaluationRow) -> EvaluationRow:
101+
"""
102+
Complex queries for the Chinook database
103+
"""
104+
last_assistant_message = row.last_assistant_message()
105+
if last_assistant_message is None:
106+
row.evaluation_result = EvaluateResult(
107+
score=0.0,
108+
reason="No assistant message found",
109+
)
110+
elif not last_assistant_message.content:
111+
row.evaluation_result = EvaluateResult(
112+
score=0.0,
113+
reason="No assistant message found",
114+
)
115+
else:
116+
model = OpenAIModel(
117+
"accounts/fireworks/models/llama-v3p1-8b-instruct",
118+
provider="fireworks",
119+
)
120+
121+
class Response(BaseModel):
122+
"""
123+
A score between 0.0 and 1.0 indicating whether the response is correct.
124+
"""
125+
126+
score: float
127+
128+
"""
129+
A short explanation of why the response is correct or incorrect.
130+
"""
131+
reason: str
132+
133+
comparison_agent = Agent(
134+
system_prompt=(
135+
"Your job is to compare the response to the expected answer."
136+
"If the response is correct, return 1.0. If the response is incorrect, return 0.0."
137+
),
138+
output_type=Response,
139+
model=model,
140+
)
141+
result = await comparison_agent.run(
142+
f"Expected answer: {row.ground_truth}\nResponse: {last_assistant_message.content}"
143+
)
144+
row.evaluation_result = EvaluateResult(
145+
score=result.output.score,
146+
reason=result.output.reason,
147+
)
148+
return row
149+
return row

0 commit comments

Comments
 (0)