Skip to content

Commit 17eb18f

Browse files
author
Dylan Huang
committed
use SDK for Dataset API calls
1 parent 3c2db59 commit 17eb18f

File tree

1 file changed

+68
-36
lines changed

1 file changed

+68
-36
lines changed

eval_protocol/fireworks_rft.py

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import importlib.util
2-
import io
32
import json
43
import os
54
import sys
@@ -9,12 +8,8 @@
98
import hashlib
109
from pathlib import Path
1110
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
12-
from urllib.parse import urlencode
1311

14-
import requests
15-
16-
from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key
17-
from .common_utils import get_user_agent
12+
from .fireworks_client import create_fireworks_client
1813

1914

2015
def _map_api_host_to_app_host(api_base: str) -> str:
@@ -142,43 +137,80 @@ def create_dataset_from_jsonl(
142137
display_name: Optional[str],
143138
jsonl_path: str,
144139
) -> Tuple[str, Dict[str, Any]]:
145-
headers = {
146-
"Authorization": f"Bearer {api_key}",
147-
"Content-Type": "application/json",
148-
"User-Agent": get_user_agent(),
149-
}
140+
"""Create a dataset and upload a JSONL file using the Fireworks SDK client.
141+
142+
This function uses the Fireworks SDK client which properly handles authentication
143+
including extra headers set via FIREWORKS_EXTRA_HEADERS environment variable.
144+
145+
Args:
146+
account_id: The Fireworks account ID.
147+
api_key: Fireworks API key.
148+
api_base: Fireworks API base URL.
149+
dataset_id: The ID for the new dataset.
150+
display_name: Display name for the dataset (optional).
151+
jsonl_path: Path to the JSONL file to upload.
152+
153+
Returns:
154+
A tuple of (dataset_id, dataset_response_dict).
155+
156+
Raises:
157+
RuntimeError: If dataset creation or upload fails.
158+
"""
150159
# Count examples quickly
151160
example_count = 0
152161
with open(jsonl_path, "r", encoding="utf-8") as f:
153162
for _ in f:
154163
example_count += 1
155164

156-
dataset_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets"
157-
payload = {
158-
"dataset": {
159-
"displayName": display_name or dataset_id,
160-
"evalProtocol": {},
161-
"format": "FORMAT_UNSPECIFIED",
162-
"exampleCount": str(example_count),
163-
},
164-
"datasetId": dataset_id,
165-
}
166-
resp = requests.post(dataset_url, json=payload, headers=headers, timeout=60)
167-
if resp.status_code not in (200, 201):
168-
raise RuntimeError(f"Dataset creation failed: {resp.status_code} {resp.text}")
169-
ds = resp.json()
170-
171-
upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload"
172-
with open(jsonl_path, "rb") as f:
173-
files = {"file": f}
174-
up_headers = {
175-
"Authorization": f"Bearer {api_key}",
176-
"User-Agent": get_user_agent(),
165+
# Create Fireworks client with consistent configuration
166+
client = create_fireworks_client(
167+
api_key=api_key,
168+
account_id=account_id,
169+
base_url=api_base,
170+
)
171+
172+
try:
173+
# Create the dataset
174+
dataset = client.datasets.create(
175+
account_id=account_id,
176+
dataset_id=dataset_id,
177+
dataset={
178+
"display_name": display_name or dataset_id,
179+
"eval_protocol": {},
180+
"format": "FORMAT_UNSPECIFIED",
181+
"example_count": str(example_count),
182+
},
183+
timeout=60.0,
184+
)
185+
except Exception as e:
186+
raise RuntimeError(f"Dataset creation failed: {e}") from e
187+
188+
try:
189+
# Upload the JSONL file
190+
with open(jsonl_path, "rb") as f:
191+
client.datasets.upload(
192+
dataset_id=dataset_id,
193+
account_id=account_id,
194+
file=f,
195+
timeout=600.0,
196+
)
197+
except Exception as e:
198+
raise RuntimeError(f"Dataset upload failed: {e}") from e
199+
200+
# Convert SDK response to dict for backwards compatibility
201+
ds_dict: Dict[str, Any] = {}
202+
if hasattr(dataset, "model_dump"):
203+
ds_dict = dataset.model_dump()
204+
elif hasattr(dataset, "dict"):
205+
ds_dict = dataset.dict()
206+
else:
207+
# Fallback: extract known fields
208+
ds_dict = {
209+
"name": getattr(dataset, "name", None),
210+
"state": getattr(dataset, "state", None),
177211
}
178-
up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600)
179-
if up_resp.status_code not in (200, 201):
180-
raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}")
181-
return dataset_id, ds
212+
213+
return dataset_id, ds_dict
182214

183215

184216
def build_default_dataset_id(evaluator_id: str) -> str:

0 commit comments

Comments
 (0)