|
1 | 1 | import importlib.util |
2 | | -import io |
3 | 2 | import json |
4 | 3 | import os |
5 | 4 | import sys |
|
9 | 8 | import hashlib |
10 | 9 | from pathlib import Path |
11 | 10 | from typing import Any, Callable, Dict, Iterable, Optional, Tuple |
12 | | -from urllib.parse import urlencode |
13 | 11 |
|
14 | | -import requests |
15 | | - |
16 | | -from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key |
17 | | -from .common_utils import get_user_agent |
| 12 | +from .fireworks_client import create_fireworks_client |
18 | 13 |
|
19 | 14 |
|
20 | 15 | def _map_api_host_to_app_host(api_base: str) -> str: |
@@ -142,43 +137,80 @@ def create_dataset_from_jsonl( |
142 | 137 | display_name: Optional[str], |
143 | 138 | jsonl_path: str, |
144 | 139 | ) -> Tuple[str, Dict[str, Any]]: |
145 | | - headers = { |
146 | | - "Authorization": f"Bearer {api_key}", |
147 | | - "Content-Type": "application/json", |
148 | | - "User-Agent": get_user_agent(), |
149 | | - } |
| 140 | + """Create a dataset and upload a JSONL file using the Fireworks SDK client. |
| 141 | +
|
| 142 | + This function uses the Fireworks SDK client which properly handles authentication |
| 143 | + including extra headers set via FIREWORKS_EXTRA_HEADERS environment variable. |
| 144 | +
|
| 145 | + Args: |
| 146 | + account_id: The Fireworks account ID. |
| 147 | + api_key: Fireworks API key. |
| 148 | + api_base: Fireworks API base URL. |
| 149 | + dataset_id: The ID for the new dataset. |
| 150 | + display_name: Display name for the dataset (optional). |
| 151 | + jsonl_path: Path to the JSONL file to upload. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + A tuple of (dataset_id, dataset_response_dict). |
| 155 | +
|
| 156 | + Raises: |
| 157 | + RuntimeError: If dataset creation or upload fails. |
| 158 | + """ |
150 | 159 | # Count examples quickly |
151 | 160 | example_count = 0 |
152 | 161 | with open(jsonl_path, "r", encoding="utf-8") as f: |
153 | 162 | for _ in f: |
154 | 163 | example_count += 1 |
155 | 164 |
|
156 | | - dataset_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets" |
157 | | - payload = { |
158 | | - "dataset": { |
159 | | - "displayName": display_name or dataset_id, |
160 | | - "evalProtocol": {}, |
161 | | - "format": "FORMAT_UNSPECIFIED", |
162 | | - "exampleCount": str(example_count), |
163 | | - }, |
164 | | - "datasetId": dataset_id, |
165 | | - } |
166 | | - resp = requests.post(dataset_url, json=payload, headers=headers, timeout=60) |
167 | | - if resp.status_code not in (200, 201): |
168 | | - raise RuntimeError(f"Dataset creation failed: {resp.status_code} {resp.text}") |
169 | | - ds = resp.json() |
170 | | - |
171 | | - upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload" |
172 | | - with open(jsonl_path, "rb") as f: |
173 | | - files = {"file": f} |
174 | | - up_headers = { |
175 | | - "Authorization": f"Bearer {api_key}", |
176 | | - "User-Agent": get_user_agent(), |
| 165 | + # Create Fireworks client with consistent configuration |
| 166 | + client = create_fireworks_client( |
| 167 | + api_key=api_key, |
| 168 | + account_id=account_id, |
| 169 | + base_url=api_base, |
| 170 | + ) |
| 171 | + |
| 172 | + try: |
| 173 | + # Create the dataset |
| 174 | + dataset = client.datasets.create( |
| 175 | + account_id=account_id, |
| 176 | + dataset_id=dataset_id, |
| 177 | + dataset={ |
| 178 | + "display_name": display_name or dataset_id, |
| 179 | + "eval_protocol": {}, |
| 180 | + "format": "FORMAT_UNSPECIFIED", |
| 181 | + "example_count": str(example_count), |
| 182 | + }, |
| 183 | + timeout=60.0, |
| 184 | + ) |
| 185 | + except Exception as e: |
| 186 | + raise RuntimeError(f"Dataset creation failed: {e}") from e |
| 187 | + |
| 188 | + try: |
| 189 | + # Upload the JSONL file |
| 190 | + with open(jsonl_path, "rb") as f: |
| 191 | + client.datasets.upload( |
| 192 | + dataset_id=dataset_id, |
| 193 | + account_id=account_id, |
| 194 | + file=f, |
| 195 | + timeout=600.0, |
| 196 | + ) |
| 197 | + except Exception as e: |
| 198 | + raise RuntimeError(f"Dataset upload failed: {e}") from e |
| 199 | + |
| 200 | + # Convert SDK response to dict for backwards compatibility |
| 201 | + ds_dict: Dict[str, Any] = {} |
| 202 | + if hasattr(dataset, "model_dump"): |
| 203 | + ds_dict = dataset.model_dump() |
| 204 | + elif hasattr(dataset, "dict"): |
| 205 | + ds_dict = dataset.dict() |
| 206 | + else: |
| 207 | + # Fallback: extract known fields |
| 208 | + ds_dict = { |
| 209 | + "name": getattr(dataset, "name", None), |
| 210 | + "state": getattr(dataset, "state", None), |
177 | 211 | } |
178 | | - up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600) |
179 | | - if up_resp.status_code not in (200, 201): |
180 | | - raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}") |
181 | | - return dataset_id, ds |
| 212 | + |
| 213 | + return dataset_id, ds_dict |
182 | 214 |
|
183 | 215 |
|
184 | 216 | def build_default_dataset_id(evaluator_id: str) -> str: |
|
0 commit comments