Skip to content

Commit 6cbe4b8

Browse files
author
Dylan Huang
committed
support FIREWORKS_API_KEY in env
1 parent 180d73f commit 6cbe4b8

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import asyncio
23
import logging
34
from typing import List
@@ -21,6 +22,7 @@
2122
ToolReturnPart,
2223
UserPromptPart,
2324
)
25+
from pydantic_ai.providers.fireworks import FireworksProvider
2426

2527
logger = logging.getLogger(__name__)
2628

@@ -47,9 +49,17 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
4749

4850
agent: Agent = config.kwargs["agent"]
4951

52+
if config.completion_params["provider"] == "fireworks":
53+
api_key = os.getenv("FIREWORKS_API_KEY")
54+
if not api_key:
55+
raise ValueError("FIREWORKS_API_KEY is not set")
56+
provider = FireworksProvider(api_key=api_key)
57+
else:
58+
provider = config.completion_params["provider"]
59+
5060
model = OpenAIModel(
5161
config.completion_params["model"],
52-
provider=config.completion_params["provider"],
62+
provider=provider,
5363
)
5464

5565
async def process_row(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)