Skip to content

Commit d406ce0

Browse files
cursoragentdhuang
andcommitted
Refactor: Use FireworksAPIClient for all API requests
This change centralizes API request logic into a new FireworksAPIClient class, simplifying and standardizing how the Fireworks API is interacted with across the project. It removes redundant request setup code and ensures consistent headers are sent. Co-authored-by: dhuang <dhuang@fireworks.ai>
1 parent 00339e8 commit d406ce0

6 files changed

Lines changed: 171 additions & 77 deletions

File tree

eval_protocol/auth.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
from pathlib import Path
55
from typing import Dict, Optional # Added Dict
66

7-
import requests
8-
9-
from .common_utils import get_user_agent
10-
117
logger = logging.getLogger(__name__)
128

139
# Default locations (used for tests and as fallback). Actual resolution is dynamic via _get_auth_ini_file().
@@ -244,9 +240,11 @@ def verify_api_key_and_get_account_id(
244240
if not resolved_key:
245241
return None
246242
resolved_base = api_base or get_fireworks_api_base()
247-
url = f"{resolved_base.rstrip('/')}/verifyApiKey"
248-
headers = {"Authorization": f"Bearer {resolved_key}", "User-Agent": get_user_agent()}
249-
resp = requests.get(url, headers=headers, timeout=10)
243+
244+
from .fireworks_api_client import FireworksAPIClient
245+
client = FireworksAPIClient(api_key=resolved_key, api_base=resolved_base)
246+
resp = client.get("verifyApiKey", timeout=10)
247+
250248
if resp.status_code != 200:
251249
logger.debug("verifyApiKey returned status %s", resp.status_code)
252250
return None

eval_protocol/evaluation.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
get_fireworks_api_key,
2121
verify_api_key_and_get_account_id,
2222
)
23-
from eval_protocol.common_utils import get_user_agent
23+
from eval_protocol.fireworks_api_client import FireworksAPIClient
2424
from eval_protocol.typed_interface import EvaluationMode
2525

2626
from eval_protocol.get_pep440_version import get_pep440_version
@@ -402,20 +402,15 @@ def preview(self, sample_file, max_samples=5):
402402
if "dev.api.fireworks.ai" in api_base and account_id == "fireworks":
403403
account_id = "pyroworks-dev"
404404

405-
url = f"{api_base}/v1/accounts/{account_id}/evaluators:previewEvaluator"
406-
headers = {
407-
"Authorization": f"Bearer {auth_token}",
408-
"Content-Type": "application/json",
409-
"User-Agent": get_user_agent(),
410-
}
411-
logger.info(f"Previewing evaluator using API endpoint: {url} with account: {account_id}")
412-
logger.debug(f"Preview API Request URL: {url}")
413-
logger.debug(f"Preview API Request Headers: {json.dumps(headers, indent=2)}")
405+
client = FireworksAPIClient(api_key=auth_token, api_base=api_base)
406+
path = f"v1/accounts/{account_id}/evaluators:previewEvaluator"
407+
408+
logger.info(f"Previewing evaluator using API endpoint: {api_base}/{path} with account: {account_id}")
414409
logger.debug(f"Preview API Request Payload: {json.dumps(payload, indent=2)}")
415410

416411
global used_preview_api
417412
try:
418-
response = requests.post(url, json=payload, headers=headers)
413+
response = client.post(path, json=payload)
419414
response.raise_for_status()
420415
result = response.json()
421416
used_preview_api = True
@@ -746,12 +741,8 @@ def create(self, evaluator_id, display_name=None, description=None, force=False)
746741
if "dev.api.fireworks.ai" in self.api_base and account_id == "fireworks":
747742
account_id = "pyroworks-dev"
748743

749-
base_url = f"{self.api_base}/v1/{parent}/evaluatorsV2"
750-
headers = {
751-
"Authorization": f"Bearer {auth_token}",
752-
"Content-Type": "application/json",
753-
"User-Agent": get_user_agent(),
754-
}
744+
client = FireworksAPIClient(api_key=auth_token, api_base=self.api_base)
745+
path = f"v1/{parent}/evaluatorsV2"
755746

756747
self._ensure_requirements_present(os.getcwd())
757748

@@ -813,7 +804,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False)
813804
upload_payload = {"name": evaluator_name, "filename_to_size": {tar_filename: tar_size}}
814805

815806
logger.info(f"Requesting upload endpoint for {tar_filename}")
816-
upload_response = requests.post(upload_endpoint_url, json=upload_payload, headers=headers)
807+
upload_response = client.post(upload_endpoint_url, json=upload_payload)
817808
upload_response.raise_for_status()
818809

819810
# Check for signed URLs
@@ -895,7 +886,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False)
895886
# Step 3: Validate upload
896887
validate_url = f"{self.api_base}/v1/{evaluator_name}:validateUpload"
897888
validate_payload = {"name": evaluator_name}
898-
validate_response = requests.post(validate_url, json=validate_payload, headers=headers)
889+
validate_response = client.post(validate_url, json=validate_payload)
899890
validate_response.raise_for_status()
900891

901892
validate_data = validate_response.json()
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""Centralized client for making requests to Fireworks API with consistent headers."""
2+
3+
import os
4+
from typing import Any, Dict, Optional
5+
6+
import requests
7+
8+
from .common_utils import get_user_agent
9+
10+
11+
class FireworksAPIClient:
12+
"""Client for making authenticated requests to Fireworks API with proper headers.
13+
14+
This client automatically includes:
15+
- Authorization header (Bearer token)
16+
- User-Agent header for tracking eval-protocol CLI usage
17+
"""
18+
19+
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
20+
"""Initialize the Fireworks API client.
21+
22+
Args:
23+
api_key: Fireworks API key. If None, will be read from environment.
24+
api_base: API base URL. If None, defaults to https://api.fireworks.ai
25+
"""
26+
self.api_key = api_key
27+
self.api_base = api_base or os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai")
28+
self._session = requests.Session()
29+
30+
def _get_headers(self, content_type: Optional[str] = "application/json",
31+
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
32+
"""Build headers for API requests.
33+
34+
Args:
35+
content_type: Content-Type header value. If None, Content-Type won't be set.
36+
additional_headers: Additional headers to merge in.
37+
38+
Returns:
39+
Dictionary of headers including authorization and user-agent.
40+
"""
41+
headers = {
42+
"User-Agent": get_user_agent(),
43+
}
44+
45+
if self.api_key:
46+
headers["Authorization"] = f"Bearer {self.api_key}"
47+
48+
if content_type:
49+
headers["Content-Type"] = content_type
50+
51+
if additional_headers:
52+
headers.update(additional_headers)
53+
54+
return headers
55+
56+
def get(self, path: str, params: Optional[Dict[str, Any]] = None,
57+
timeout: int = 30, **kwargs) -> requests.Response:
58+
"""Make a GET request to the Fireworks API.
59+
60+
Args:
61+
path: API path (relative to api_base)
62+
params: Query parameters
63+
timeout: Request timeout in seconds
64+
**kwargs: Additional arguments passed to requests.get
65+
66+
Returns:
67+
Response object
68+
"""
69+
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
70+
headers = self._get_headers(content_type=None)
71+
if "headers" in kwargs:
72+
headers.update(kwargs.pop("headers"))
73+
return self._session.get(url, params=params, headers=headers, timeout=timeout, **kwargs)
74+
75+
def post(self, path: str, json: Optional[Dict[str, Any]] = None,
76+
data: Optional[Any] = None, files: Optional[Dict[str, Any]] = None,
77+
timeout: int = 60, **kwargs) -> requests.Response:
78+
"""Make a POST request to the Fireworks API.
79+
80+
Args:
81+
path: API path (relative to api_base)
82+
json: JSON payload
83+
data: Form data payload
84+
files: Files to upload
85+
timeout: Request timeout in seconds
86+
**kwargs: Additional arguments passed to requests.post
87+
88+
Returns:
89+
Response object
90+
"""
91+
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
92+
93+
# For file uploads, don't set Content-Type (let requests handle multipart/form-data)
94+
content_type = None if files else "application/json"
95+
headers = self._get_headers(content_type=content_type)
96+
97+
if "headers" in kwargs:
98+
headers.update(kwargs.pop("headers"))
99+
100+
return self._session.post(url, json=json, data=data, files=files,
101+
headers=headers, timeout=timeout, **kwargs)
102+
103+
def put(self, path: str, json: Optional[Dict[str, Any]] = None,
104+
timeout: int = 60, **kwargs) -> requests.Response:
105+
"""Make a PUT request to the Fireworks API."""
106+
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
107+
headers = self._get_headers()
108+
if "headers" in kwargs:
109+
headers.update(kwargs.pop("headers"))
110+
return self._session.put(url, json=json, headers=headers, timeout=timeout, **kwargs)
111+
112+
def patch(self, path: str, json: Optional[Dict[str, Any]] = None,
113+
timeout: int = 60, **kwargs) -> requests.Response:
114+
"""Make a PATCH request to the Fireworks API."""
115+
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
116+
headers = self._get_headers()
117+
if "headers" in kwargs:
118+
headers.update(kwargs.pop("headers"))
119+
return self._session.patch(url, json=json, headers=headers, timeout=timeout, **kwargs)
120+
121+
def delete(self, path: str, timeout: int = 30, **kwargs) -> requests.Response:
122+
"""Make a DELETE request to the Fireworks API."""
123+
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
124+
headers = self._get_headers(content_type=None)
125+
if "headers" in kwargs:
126+
headers.update(kwargs.pop("headers"))
127+
return self._session.delete(url, headers=headers, timeout=timeout, **kwargs)

eval_protocol/fireworks_rft.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import requests
1212

1313
from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key
14-
from .common_utils import get_user_agent
14+
from .fireworks_api_client import FireworksAPIClient
1515

1616

1717
def _map_api_host_to_app_host(api_base: str) -> str:
@@ -158,17 +158,14 @@ def create_dataset_from_jsonl(
158158
display_name: Optional[str],
159159
jsonl_path: str,
160160
) -> Tuple[str, Dict[str, Any]]:
161-
headers = {
162-
"Authorization": f"Bearer {api_key}",
163-
"Content-Type": "application/json",
164-
"User-Agent": get_user_agent(),
165-
}
161+
client = FireworksAPIClient(api_key=api_key, api_base=api_base)
162+
166163
# Count examples quickly
167164
example_count = 0
168165
with open(jsonl_path, "r", encoding="utf-8") as f:
169166
for _ in f:
170167
example_count += 1
171-
dataset_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets"
168+
172169
payload = {
173170
"dataset": {
174171
"displayName": display_name or dataset_id,
@@ -178,16 +175,15 @@ def create_dataset_from_jsonl(
178175
},
179176
"datasetId": dataset_id,
180177
}
181-
resp = requests.post(dataset_url, json=payload, headers=headers, timeout=60)
178+
resp = client.post(f"v1/accounts/{account_id}/datasets", json=payload, timeout=60)
182179
if resp.status_code not in (200, 201):
183180
raise RuntimeError(f"Dataset creation failed: {resp.status_code} {resp.text}")
184181
ds = resp.json()
185182

186-
upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload"
187183
with open(jsonl_path, "rb") as f:
188184
files = {"file": f}
189-
up_headers = {"Authorization": f"Bearer {api_key}", "User-Agent": get_user_agent()}
190-
up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600)
185+
up_resp = client.post(f"v1/accounts/{account_id}/datasets/{dataset_id}:upload",
186+
files=files, timeout=600)
191187
if up_resp.status_code not in (200, 201):
192188
raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}")
193189
return dataset_id, ds
@@ -199,14 +195,10 @@ def create_reinforcement_fine_tuning_job(
199195
api_base: str,
200196
body: Dict[str, Any],
201197
) -> Dict[str, Any]:
202-
url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs"
203-
headers = {
204-
"Authorization": f"Bearer {api_key}",
205-
"Content-Type": "application/json",
206-
"Accept": "application/json",
207-
"User-Agent": get_user_agent(),
208-
}
209-
resp = requests.post(url, json=body, headers=headers, timeout=60)
198+
client = FireworksAPIClient(api_key=api_key, api_base=api_base)
199+
resp = client.post(f"v1/accounts/{account_id}/reinforcementFineTuningJobs",
200+
json=body, timeout=60,
201+
headers={"Accept": "application/json"})
210202
if resp.status_code not in (200, 201):
211203
raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}")
212204
return resp.json()

eval_protocol/platform_api.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_fireworks_api_base,
1212
get_fireworks_api_key,
1313
)
14-
from eval_protocol.common_utils import get_user_agent
14+
from eval_protocol.fireworks_api_client import FireworksAPIClient
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -93,11 +93,7 @@ def create_or_update_fireworks_secret(
9393
logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.")
9494
return False
9595

96-
headers = {
97-
"Authorization": f"Bearer {resolved_api_key}",
98-
"Content-Type": "application/json",
99-
"User-Agent": get_user_agent(),
100-
}
96+
client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base)
10197

10298
# The secret_id for GET/PATCH/DELETE operations is the key_name.
10399
# The 'name' field in the gatewaySecret model for POST/PATCH is a bit ambiguous.
@@ -109,10 +105,9 @@ def create_or_update_fireworks_secret(
109105

110106
# Check if secret exists using GET (path uses normalized resource id)
111107
resource_id = _normalize_secret_resource_id(key_name)
112-
get_url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}"
113108
secret_exists = False
114109
try:
115-
response = requests.get(get_url, headers=headers, timeout=10)
110+
response = client.get(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=10)
116111
if response.status_code == 200:
117112
secret_exists = True
118113
logger.info(f"Secret '{key_name}' already exists. Will attempt to update.")
@@ -133,7 +128,6 @@ def create_or_update_fireworks_secret(
133128

134129
if secret_exists:
135130
# Update existing secret (PATCH)
136-
patch_url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}"
137131
# Body for PATCH requires 'keyName' and 'value'.
138132
# Transform key_name for payload: uppercase and underscores
139133
payload_key_name = key_name.upper().replace("-", "_")
@@ -148,7 +142,8 @@ def create_or_update_fireworks_secret(
148142
payload = {"keyName": payload_key_name, "value": secret_value}
149143
try:
150144
logger.debug(f"PATCH payload for '{key_name}': {payload}")
151-
response = requests.patch(patch_url, headers=headers, json=payload, timeout=30)
145+
response = client.patch(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}",
146+
json=payload, timeout=30)
152147
response.raise_for_status()
153148
logger.info(f"Successfully updated secret '{key_name}' on Fireworks platform.")
154149
return True
@@ -160,7 +155,6 @@ def create_or_update_fireworks_secret(
160155
return False
161156
else:
162157
# Create new secret (POST)
163-
post_url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets"
164158
# Body for POST is gatewaySecret. 'name' field in payload is the resource path.
165159
# Let's assume for POST, the 'name' in payload can be omitted or is the key_name.
166160
# The API should ideally use 'keyName' from URL or a specific 'secretId' in payload for creation if 'name' is server-assigned.
@@ -185,7 +179,8 @@ def create_or_update_fireworks_secret(
185179
}
186180
try:
187181
logger.debug(f"POST payload for '{key_name}': {payload}")
188-
response = requests.post(post_url, headers=headers, json=payload, timeout=30)
182+
response = client.post(f"v1/accounts/{resolved_account_id}/secrets",
183+
json=payload, timeout=30)
189184
response.raise_for_status()
190185
logger.info(
191186
f"Successfully created secret '{key_name}' on Fireworks platform. Full name: {response.json().get('name')}"
@@ -219,12 +214,11 @@ def get_fireworks_secret(
219214
logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.")
220215
return None
221216

222-
headers = {"Authorization": f"Bearer {resolved_api_key}", "User-Agent": get_user_agent()}
217+
client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base)
223218
resource_id = _normalize_secret_resource_id(key_name)
224-
url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}"
225219

226220
try:
227-
response = requests.get(url, headers=headers, timeout=10)
221+
response = client.get(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=10)
228222
if response.status_code == 200:
229223
logger.info(f"Successfully retrieved secret '{key_name}'.")
230224
return response.json()
@@ -256,12 +250,11 @@ def delete_fireworks_secret(
256250
logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.")
257251
return False
258252

259-
headers = {"Authorization": f"Bearer {resolved_api_key}", "User-Agent": get_user_agent()}
253+
client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base)
260254
resource_id = _normalize_secret_resource_id(key_name)
261-
url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}"
262255

263256
try:
264-
response = requests.delete(url, headers=headers, timeout=30)
257+
response = client.delete(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=30)
265258
if response.status_code == 200 or response.status_code == 204: # 204 No Content is also success for DELETE
266259
logger.info(f"Successfully deleted secret '{key_name}'.")
267260
return True

0 commit comments

Comments
 (0)