-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_single_turn_rollout_process.py
More file actions
42 lines (31 loc) · 1.51 KB
/
default_single_turn_rollout_process.py
File metadata and controls
42 lines (31 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import asyncio
from typing import List
from openai import AsyncOpenAI
from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig
async def default_single_turn_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[EvaluationRow]:
"""Generate a single response from a Fireworks model concurrently."""
api_key = get_fireworks_api_key()
api_base = get_fireworks_api_base()
client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
response = await client.chat.completions.create(
model=config.model, messages=messages_payload, **config.input_params
)
assistant_content = response.choices[0].message.content or ""
messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]
return EvaluationRow(
messages=messages,
**row.model_dump(exclude={"messages"}),
)
# Process all rows concurrently
tasks = [process_row(row) for row in rows]
dataset = await asyncio.gather(*tasks)
return dataset