|
9 | 9 | from agent import setup_agent |
10 | 10 | from pydantic_ai.models.openai import OpenAIModel |
11 | 11 |
|
| 12 | +from tests.chinook.dataset import collect_dataset |
| 13 | + |
12 | 14 |
|
13 | 15 | @pytest.mark.asyncio |
14 | 16 | @evaluation_test( |
@@ -75,3 +77,73 @@ class Response(BaseModel): |
75 | 77 | reason=result.output.reason, |
76 | 78 | ) |
77 | 79 | return row |
| 80 | + |
| 81 | + |
| 82 | +@pytest.mark.asyncio |
| 83 | +@evaluation_test( |
| 84 | + input_rows=collect_dataset(), |
| 85 | + completion_params=[ |
| 86 | + { |
| 87 | + "model": { |
| 88 | + "orchestrator_agent_model": { |
| 89 | + "model": "accounts/fireworks/models/kimi-k2-instruct", |
| 90 | + "provider": "fireworks", |
| 91 | + } |
| 92 | + } |
| 93 | + }, |
| 94 | + ], |
| 95 | + rollout_processor=PydanticAgentRolloutProcessor(), |
| 96 | + rollout_processor_kwargs={"agent": setup_agent}, |
| 97 | + num_runs=3, |
| 98 | + mode="pointwise", |
| 99 | +) |
| 100 | +async def test_complex_queries(row: EvaluationRow) -> EvaluationRow: |
| 101 | + """ |
| 102 | + Complex queries for the Chinook database |
| 103 | + """ |
| 104 | + last_assistant_message = row.last_assistant_message() |
| 105 | + if last_assistant_message is None: |
| 106 | + row.evaluation_result = EvaluateResult( |
| 107 | + score=0.0, |
| 108 | + reason="No assistant message found", |
| 109 | + ) |
| 110 | + elif not last_assistant_message.content: |
| 111 | + row.evaluation_result = EvaluateResult( |
| 112 | + score=0.0, |
| 113 | + reason="No assistant message found", |
| 114 | + ) |
| 115 | + else: |
| 116 | + model = OpenAIModel( |
| 117 | + "accounts/fireworks/models/llama-v3p1-8b-instruct", |
| 118 | + provider="fireworks", |
| 119 | + ) |
| 120 | + |
| 121 | + class Response(BaseModel): |
| 122 | + """ |
| 123 | + A score between 0.0 and 1.0 indicating whether the response is correct. |
| 124 | + """ |
| 125 | + |
| 126 | + score: float |
| 127 | + |
| 128 | + """ |
| 129 | + A short explanation of why the response is correct or incorrect. |
| 130 | + """ |
| 131 | + reason: str |
| 132 | + |
| 133 | + comparison_agent = Agent( |
| 134 | + system_prompt=( |
| 135 | + "Your job is to compare the response to the expected answer." |
| 136 | + "If the response is correct, return 1.0. If the response is incorrect, return 0.0." |
| 137 | + ), |
| 138 | + output_type=Response, |
| 139 | + model=model, |
| 140 | + ) |
| 141 | + result = await comparison_agent.run( |
| 142 | + f"Expected answer: {row.ground_truth}\nResponse: {last_assistant_message.content}" |
| 143 | + ) |
| 144 | + row.evaluation_result = EvaluateResult( |
| 145 | + score=result.output.score, |
| 146 | + reason=result.output.reason, |
| 147 | + ) |
| 148 | + return row |
| 149 | + return row |
0 commit comments